CTC loss
01 Mar 2020 - trungthanhnguyenTrong các bài toán machine learning, Seq2Seq là bài toán phổ biến cả trong NLP và CV, có input và output đều dạng chuỗi (sequence). Những bài toán Seq2Seq như Machine Translate, Auto Tagging, Speech to Text, Text to Speech, Handwriting recognition khá quen thuộc và hầu hết giải pháp tối ưu nhất cho những bài toán này đều dựa vào Deep learning. Trong đó có hai bài toán đặc biệt khó: Speech to Text và Handwriting recognition. Để hiểu được lý do tại sao, hãy quan sát ví dụ dưới đây:
Định nghĩa:
Trước đây, trước khi có CTC người ta dùng các giải pháp để cố gắng tách chuỗi input thành từng đoạn tương ứng với $y_i$, tức cố gắng chia input sao cho độ dài chuỗi X bằng độ dài chuỗi Y. Tuy nhiên cách này không thực sự hiệu quả. CTC ra đời nhằm giải quyết những bài toán dạng này.
Thay vì tìm cách phân chia input thành các đoạn tương ứng với từng $y_i$ (bước align) bằng các thuật toán tự định nghĩa, CTC sẽ tự học và làm việc đó. Hãy quan sát hình dưới đây.
Input là một audio, giả sử audio này được chia thành 10 phần bằng nhau time_step = 10, label của audio là “hello”.
Thay vì tìm cách predict trực tiếp ra “hello”, ta sẽ predict ra các biến thể của nó là “hhheel_llo”, “h_e_ll_l_o”, … mỗi biến thể này ta gọi là một alignment. Độ dài của một alignment bằng với độ dài input = T ($=timesteps = 10$). Như vậy model deep learning nhận input đầu vào và output ra các alignment ( “hhheel_llo” hoặc h_e_ll_l_o, … ) tương ứng với nhãn “hello”. Từ các alignment, bằng việc loại bỏ các kí tự lặp ta thu được kết quả mong muốn. xác suất để predict ra “hello” bằng tổng xác suất tất cả alignment.
Tạo các alignment cho “hello” bằng cách lặp lại các kí tự với sao cho len(y) = len(x) = T = 10. Có một vấn đề đặt ra là: giả sử model trả về output “heeeeellloo”, loại bỏ các kí tự lặp ta chỉ thu được “helo” thay vì “hello”.
Để giải quyết vấn đề này, ta sử dụng kí tự đặc biệt blank token - $ϵ$ (đừng nhầm với kí tự khoảng trắng, khoảng cách). $ϵ$ có thể được thêm vào bất kì chỗ nào trong chuỗi “hello”, và giữa hai kí tự giống nhau liền kề bắt buộc phải có $ϵ$ ở giữa. Như vậy, “hello” sẽ có các alignment: “heeeel–llo”, ‘-hell-l–oo’ … luôn có ít nhất một kí tự “-“ giữa hai kí tự “l” ( “-“ là biểu diễn cho $ϵ$)
Đặt $A_{XY}$ là tập tất cả các alignment A của Y. xác suất model có thể predict ra Y với input X là tổng xác suất của tất cả A trong $A_{XY}$. Xác suất của A lại bằng tổng xác suất từng kí tự $a_t$ trong A. CTC loss chính là P(Y|X) trong hình:
Note: Thuật toán này đọc hơi trìu tượng, hơi khó diễn tả, bạn có thể skip phần này. CTC_loss được hỗ trợ trong keras, tensorflow cho bạn dùng sẵn.
Về lý thuyết, ta cần thống kê ra tất cả alignment có thể có (kể cả đúng và sai). Giả sử với tập 9 kí tự, time_step = 10, số lượng tổ hợp lên tới $10^9$. Số lượng tổ hợp tăng cấp số nhân theo số lượng kí tự và time_step. Như vậy, việc liệt kê tất cả các alignment và tính tổng xác suất các alignment hợp lệ là việc không khả thi. Để cải thiện tốc độ, người ta đã sử dụng thuật toán dynamic programing, thuật toán như sau:
Do blank token $ϵ$ có thể xuất hiện tại bất kì vị trí nào trong chuỗi Y, để dễ dàng mô tả thuật toán, ta đặt target output là Z = $[ϵ, y_1, ϵ, y_2 … y_i, ϵ]$.
Như vậy để tính CTC_loss, ta lần lượt tính $a_{1,1}$, $a_{1,2}$, $a_{2,1}$… để cuối cùng tính ra được $a_{s,t}$ cuối cùng với s=S, t=T. Trong quá trình tính $a_{s,t}$ tại 1 thời điểm bất kì, sẽ xảy ra 2 trường hợp:
Case: Z[s-1] = $ϵ$ và nằm giữa hai kí tự giống nhau. Trong trường hợp này, kí tự $ϵ$ là kí tự không thể thiếu, bắt buộc phải có để phân biệt hai kí tự giống nhau. Vậy sẽ có hai trường hợp hợp lệ xảy ra taị thời điểm t-1 là: M[t-1] = Z[s-1] = $ϵ$ và M[t-1] = Z[s] (do Z[s-1] = $ϵ$ nên Z[s] != $ϵ$)
\[a_{s,t} = (a_{s-1,t-1} + a_{s,t-1} ) * p_{s,t}\]Case: Z[s-1] = $ϵ$ và nằm giữa hai kí tự khác nhau. Đặt hai kí tự này là b và c. Trong trường hợp này, kí tự $ϵ$ không còn quan trọng (tức không giữ vai trò phân biệt hai kí tự lặp) nên kết quả tính ra có thể không cần $ϵ$. Vậy tại 1 thời điểm t với M[t] = b, sẽ có 3 khả năng xảy ra tại thời điểm t-1:
Vậy, $a_{s,t}$ được tính theo công thức:
\[a_{s,t} = (a_{s-2,t-1} + a_{s-1,t-1} + a_{s,t-1}) * p_{s,t}\]Như vậy mình đã giới thiệu qua về tư tưởng của CTC. Dù hơi khó hiểu, mình vẫn hi vọng bài viết sẽ giúp bạn có cái nhìn rõ hơn chút về CTC. Cảm ơn bạn đã đọc bài viết.