機器之心編輯部 注意力是 Transformer 架構的關鍵部分,負責將每個序列元素轉換為值的加權和。將查詢與所有鍵進行點積,然后通過 softmax 函數歸一化,會得到每個鍵對應的注意力權重。 盡管 SoftmaxAttn 中的 softmax 具有廣泛的用途和有效性,但它并非沒有局限性。例如,softmax 函數有時會導致注意力集中在少數幾個特征,而忽略了其他信息。 近來,一些研究探索了 Transformer 中 softmax 注意力的替代方案,例如 ReLU 和 sigmoid 激活函數。最近,來自蘋果的研究者重新審視了 sigmoid 注意力并進行了深入的理論和實驗分析。 該研究證明:從理論上講,與 softmax 注意力相比,具有 sigmoid 注意力的 Transformer 是通用函數逼近器,并且受益于改進的正則化。
該研究還提出了一種硬件感知且內存高效的 sigmoid 注意力實現 ——FLASHSIGMOID。FLASHSIGMOID 在 H100 GPU 上的推理內核速度比 FLASHATTENTION2 提高了 17%。 跨語言、視覺和語音的實驗表明,合理歸一化的 sigmoid 注意力與 softmax 注意力在廣泛的領域和規模上性能相當,而之前的 sigmoid 注意力嘗試無法實現這一點。 此外,該研究還用 sigmoid 內核擴展了 FLASHATTENTION2,將內核推理掛鐘時間減少了 17%,將現實世界推理時間減少了 8%。 論文作者 Jason Ramapuram 表示:如果想讓注意力快 18% 左右,你不妨試試 Sigmoid 注意力機制。他們用 Sigmoid 和基于序列長度的常量標量偏置取代了注意力機制中的傳統 softmax。 Sigmoid 注意力 假設 根據先前的研究,自注意力可以簡寫為: 其中 Softmax 函數將輸入矩陣的每一行進行了歸一化。該研究將 Softmax 做了以下替換: 實際上,將 將多個 SigmoidAttn 輸出進行組合,得到多個頭的形式,如下所示: Sigmoid 注意力理論基礎 該研究對 SigmoidAttn 進行了分析,分析的目的主要有兩個:(1)證明當 SigmoidAttn 取代 SoftmaxAttn 時,Transformer 架構仍然是一個通用函數逼近器;(2)通過計算 SigmoidAttn 的 Lipschitz 常數來恢復其規律性。 具有 Sigmoid 注意力的 Transformer 是通用逼近器嗎? 經典 Transformer 可以將連續的序列到序列函數近似到任意精度,這一特性稱為通用近似特性 (UAP,Universal Approximation Property)。UAP 非常受歡迎,因為它證明了架構的通用性和表示能力。由于 SigmoidAttn 修改了 Transformer 架構,因此從理論上保證這種修改不會影響表示能力并保留 UAP 的性能至關重要。該研究通過以下定理提供此保證。 結果表明,即使使用 SigmoidAttn,一系列 transformer 塊也可以實現上下文映射。 Sigmoid 注意力的正則性 與神經網絡中的任何層一樣,SigmoidAttn 的正則性值得研究,因為它可以深入了解相應網絡的魯棒性及其優化的難易程度。 SigmoidAttn 正則性定理為: 結果證明,SigmoidAttn 的局部 Lipschitz 常數遠低于 SoftmaxAttn 的最差局部 Lipschitz 常數。 FLASHSIGMOID:硬件感知實現 現代架構上的注意力計算往往會受到內存訪問 IO 的限制。FLASHATTENTION 和 FLASHATTENTION2 通過優化 GPU 內存層次結構利用率來加速注意力計算。得益于這些方法提供的速度提升,該研究開發了 SigmoidAttn 的硬件感知實現 ——FLASHSIGMOID,采用了三個核心思路:
實驗 為了實驗驗證 SigmoidAttn,該研究在多個領域進行了評估:使用視覺 transformer 進行監督圖像分類、使用 SimCLR 進行自監督圖像表示學習、BYOL(Bootstrap Your Own Latent)和掩碼自動編碼器 (MAE) 以及自動語音識別 (ASR) 和自回歸語言建模 (LM)。 該研究還在 TED-LIUM v3 上驗證了 ASR 的序列長度泛化,在所有這些領域和算法中,該研究證明 SigmoidAttn 的性能與 SoftmaxAttn 相當(圖 2 和 21),同時提供訓練和推理加速。 該研究得出以下觀察結果: SigmoidAttn 對于沒有偏置的視覺任務是有效的(MAE 除外),但依賴于 LayerScale 以無超參數的方式匹配基線 SoftmaxAttn(圖 9-a)的性能。除非另有說明,否則為 SoftmaxAttn 呈現的所有結果也公平地添加了 LayerScale。 LM 和 ASR 對初始范數 |
|