©PaperWeekly 原創 · 作者|鄧雲天
學校|哈佛大學計算機系博士生
研究方向|自然語言處理
論文標題: Cascaded Decoding with Markov Transformers
論文連結: https://arxiv.org/abs/2006.01112
論文代碼: https://github.com/harvardnlp/cascaded-generation
引言 目前文本生成最常用的算法基於 fully autoregressive 模型,比如 RNN 和 transformer。在 fully autoregressive 模型中,生成下一個詞的概率取決於之前所有的詞。給定一個 fully autoregressive 模型,文本生成通常使用 beam search 從左到右搜索概率最大的句子。但由於 beam search 是一個順序的過程,我們無法在 GPU 上進行並行加速。近年來,為了加速文本生成,Gu et al 2017 提出了 non-autoregerssive 模型。在 non-autoregressive 模型中,不同位置的詞的生成是相互獨立的,因此可以使用 GPU 同時生成所有詞。但是這個獨立假設太強,經常導致一些明顯的問題,比如重複生成相同的詞。我們指出 non-autoregressive 模型是並行生成的充分但不必要條件。如果我們考慮 m 階 Markov 模型的概率分布(每個詞的概率取決於過去 m 個詞稱為 m 階 Markov 模型),那麼從這個分布中採樣也是可以並行計算的(Rush et al 2020),而 non-autoregressive 模型只是 0 階 Markov 模型的特殊情況。在這個工作中,我們利用這個有限階數 Markov 模型的性質提出 cascaded decoding(Weiss et al 2010)。Cascaded decoding 的核心是從 0 階 Markov 模型開始,逐漸引入高階 Markov 模型,從而逐步縮小搜索空間。
為了支持這個搜索算法,我們需要一組不同階數的 Markov 模型。為此我們提出 transformer 的一個變種 Markov transformer,由此通過一個 Markov transformer 實現一組不同階數的 Markov 模型。
值得一提的是,我們方法的速度與 non-autorgressive 方法相當,並且能夠同時考慮到不同位置的詞之間的關聯從而達到很好的生成質量。
搜索算法:Cascaded Decoding
我們用 Conditional Random Field (CRF)來描述文本生成模型
,其中 是第 個單詞, 是句子的長度。一個 m 階 CRF 模型為:上式中的 是帶有參數的 log potential,它可以建模相鄰 個單詞之間的關係。當 ,得到了一個 non-autoregressive 模型,而當 ,得到 fully autoregressive 模型。生成文本時,我們需要找到概率最高的句子 。我們可以使用動態算法計算,但時間複雜度是 ,即使 都很不現實,因為 一般是幾萬量級的。常用的做法是用 beam search 去找到近似的最優解,但 beam search 無法並行,而人們還很少考慮能替代 beam search 的算法。我們提出的 cascaded decoding 的思路與 beam search 的從左到右不同,是基於對整個解空間 的不斷過濾。我們考慮每個位置的可能的 n-gram,把不太可能的 n-gram 過濾掉,從而保留 個最可能的 n-gram。首先,我們用一個 0 階模型去過濾掉每個位置不太可能的 unigram,然後用一個 1 階模型過濾掉每個位置不太可能的 bigram,再用一個 2 階模型過濾掉每個位置不太可能的 trigram,直到最後得到一個高階模型,並使用動態算法去找出過濾後的空間裡的最優解。為了便於理解,下圖中展示了一個序列長度 並且過濾 3 次的例子,這裡我們使用 ,也就是每次保留前 10 個 n-gram。首先,我們使用一個 0 階的 Markov 模型 (non-autoregressive 模型)去過濾掉每個位置不太可能的 unigram,每個位置只保留最可能的 10 個 unigram,所以過濾後的解空間如下圖所示是一個 10 10 10 的立方體(過濾前的解空間大小為 V V V)。這裡的坐標軸分別代表 , 和 。在此基礎上,我們使用一個 1 階的 Markov 模型 去過濾掉每個位置不太可能的 bigram,所以解空間進一步縮小為每個位置只考慮 10 個最可能的 bigram,可以從下圖中的水平平面投影和左面垂直平面的投影中看出:每個平面上恰好有 10 個陰影小方塊,代表 10 個被保留的 bigram(10 個可能的 和 10 個可能的 )。
最後,我們使用一個 2 階的 Markov 模型 去過濾掉每個位置不太可能的 trigram,使得解空間縮小為每個位置只考慮 10 個最可能的 trigram,可以從下圖中的 10 個小立方看出。
我們重複上述過程的次數越多,就能使用越高階的模型從而更接近 fully autoregressive 模型。在最後的縮小的解空間裡,我們可以使用動態算法去找出最可能的一句話。
在上面的過程中,我們用到了每個位置「最可能」的 n-gram,這個「最可能」的評判方式有很多,比如每個 n-gram 的 marginal probability,但我們實際使用的是 max-marginal(Weiss et al 2020),具體細節參見我們的論文。
變長生成
目前為止我們假定已知生成序列的長度,但實際應用中我們很難準確預測生成序列的長度,因此我們提出一個可以同時考慮不同可能長度的算法。我們先估計一個最大長度,然後在搜索中考慮所有比這個最大長度短的序列。
這種變長搜索在 CRF 中的實現非常簡單:我們只需要在詞表中引入一個佔位字符 pad,同時改寫 log potential 使得句尾 eos 和 pad 的下一個詞必須為 pad,那麼我們在生成時只需要使用一個最大長度,就可以同時考慮不同長度的句子:不同長度的句子只是句尾 pad 的個數不同而已,但 pad 的存在不會影響分數。
下表中我們展示一個 cascaded decoding 和變長生成的例子,這裡我們考慮最大長度 8,並使用 ,也就是只保留最可能的 5 個 unigram,birgram,trigram,在每個 table 的 5 行中按分數由大到小排序。首先,我們使用一個 0 階模型,並在下表中展示出每個位置最可能的 unigram。如果我們只使用一個 0 階模型(non-autoregressive),那麼得到的解將會是「an amzing woman woman eos」(第一行),重複了單詞「woman」,這也是 non-autoregressive 模型的常見問題。
在我們的算法中,之後引入的高階模型可以修正這個問題。這裡一個小細節是我們限制最後一個單詞為佔位字符 pad,以確保每句話都有結束符 eos(end-of-sentence)。下一步,我們使用一個 1 階模型,並在下表中展示出每個位置最可能的 bigram。現在已經修正了之前的重複問題:按照第一行的最可能 bigram,最可能的解已經是「an amazing women . eos pad pad pad」。
同時注意到由於佔位字符 pad 的存在,我們可以考慮長度小於最大長度 8 的句子,這在很多其他的 non-autoregressive 工作中是很難做到的。
然後,我們使用一個 2 階模型,並在下表中展示出每個位置最可能的 trigram。
我們可以重複上述過程來引入越來越高階的模型,最後使用動態算法得到最可能的解。
並行化
計算不同位置 log potential 的過程是互相獨立的,因此我們可以使用 GPU 並行計算所有位置的 log potential。除了 log potential 外,另一個問題是如何並行計算我們使用的過濾 n-gram 的指標 max-marginal。
實際上,Rush et al 2020 中已經指出 CRF 中的 max-marginal 和 marginal 都可以使用並行的動態算法計算,核心思路是建一個以句子的每個位置為葉子節點的二叉樹並從下向上再從上到下計算,而不像傳統的動態算法那樣從句子的最左到最右再從右至左。這個算法已經在 torch-struct [6] 包裡實現。我們前面使用了很多不同階的 Markov 模型,然而實際上我們可以修改 transformer 的訓練過程,使一個 transformer 可以被當做不同階的 Markov 模型使用,即 Markov transformer。
這裡的核心思路是:如果在訓練時每 M 個單詞就重置 transformer 的 hidden state,並隨機選擇第一個重置位置,那麼 transformer 就可以在測試中被當做任何小於 階的模型,如下圖所示( )。
在上圖中,綠色分割線代表重置 hidden state(在 transformer 中我們只需要要求灰色線條表示的 self attention 不穿過分割線即可,同時我們使用空白字符 去重置分割線後一個位置的 state)。第 1、4、7 個位置的輸出沒有使用任何其他單詞的信息,因此相當於使用了 0 階模型;第 2、5、8 個位置的輸出使用了前一個單詞,因此相當於使用了 1 階模型;第 3、6、9 個位置的輸出使用了前兩個單詞,因此相當於使用了 2 階模型。
綜上,這個模型在測試時可以在任何位置被當做 0 階、1 階或者 2 階模型使用。(我們需要隨機選擇第一個重置位置,否則比如上圖中第 3 個位置無法被用作 0 階或者 1 階模型)。實驗結果與分析
使用 knowledge distillation,我們在 WMT14 En-De 上可以達到常規的 fully autoregressive transformer 速度的 2.4 倍,BLEU 只低 0.5。在 IWSLT14 De-En 上,我們的速度是 transformer 的 5.88 倍速度,BLEU 只損失 0.54。這個 BLEU 分數比去年的 FlowSeq(Ma et al 2019)高 6 分。
與 beam search 相比,cascaded decoding的另一個優勢是在搜索過程中考慮了非常多的序列。雖然每個位置只考慮了 個 n-gram,但考慮的序列個數最多是以序列長度的指數增長的:比如如果每個位置只考慮 個 unigram,那麼對於長度為 的序列就考慮了 個可能的序列。下圖中我們用一個 box plot 展示實際能夠考慮的序列個數。上圖中的 , , 是指 cascaded decoding 最終使用 2 階、3 階、4 階 CRF 的結果, 而 展示的是 beam search 的結果。由此可見,即使我們使用 4 階 CRF,依然可以比 beam search 考慮多一個量級的序列個數。Beam search 在文本生成中的地位幾十年來未被撼動。我們提出一種新的文本生成搜索算法 cascaded decoding,不僅形式簡潔優美,而且性能優異。Cascaded decoding 可以衍生出很多新的研究方向,比如我們可以進行長文本生成,或者引入 latent variable 去考慮全局信息以彌補目前算法只能考慮局部關聯的不足。
此外,我們提出的 Markov transformer 的思路可以被用來學習任何結構的概率圖模型。最後,我們這裡使用了一個 locally normalized 的語言模型作為 log potentials,實際上我們可以用更強大的 globally normalized 模型(Deng et al 2019)。[1] Gu et al 2017:https://arxiv.org/pdf/1711.02281.pdf
[2] Rush et al 2020:https://arxiv.org/pdf/2002.00876.pdf[3] Weiss et al 2010:http://proceedings.mlr.press/v9/weiss10a/weiss10a.pdf[4] Ma et al 2019:https://arxiv.org/pdf/1909.02480.pdf)[5] Deng et al 2019:https://openreview.net/pdf?id=B1l4SgHKDH[6] https://github.com/harvardnlp/pytorch-struct如何才能讓更多的優質內容以更短路逕到達讀者群體,縮短讀者尋找優質內容的成本呢?答案就是:你不認識的人。
總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。
PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀 ,也可以是學習心得 或技術乾貨 。我們的目的只有一個,讓知識真正流動起來。
📝 來稿標準:
• 稿件確係個人原創作品 ,來稿需註明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向)
• 如果文章並非首發,請在投稿時提醒並附上所有已發布連結
• PaperWeekly 默認每篇文章都是首發,均會添加「原創」標誌
📬 投稿郵箱:
• 投稿郵箱:hr@paperweekly.site
• 所有文章配圖,請單獨在附件中發送
• 請留下即時聯繫方式(微信或手機),以便我們在編輯發布時和作者溝通
🔍
現在,在「知乎」 也能找到我們了
進入知乎首頁搜索「PaperWeekly」
點擊「關注」 訂閱我們的專欄吧
關於PaperWeekly
PaperWeekly 是一個推薦、解讀、討論、報導人工智慧前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號後臺點擊「交流群」 ,小助手將把你帶入 PaperWeekly 的交流群裡。