[論文メモ] Hydra Attention: Efficient Attention with Many Heads

arxiv.org

CADL2022

効率的なmulti-head attentionの提案

transformerのattentionはトークンの数の2乗オーダーの計算コストを必要とする。
そのためVision Transformer(ViT)などで高解像度の画像を扱うとトークン数が膨大になり、計算のほとんどをattention matrixの生成と適用に費やすことになる。
これをなんとかしたいというお気持ち。

提案手法

一般的なsoftmax self-attentionは以下の式(1)

トークン数を T、特徴量の次元数を Dとすると計算量は O(T^2D)
softmaxを Q Kの類似度を図る関数 simとして一般化したのが以下の式

 A(Q, K, V) = sim(Q, K)V

非線形関数 \phi(\cdot) simを分解する。

 A(Q, K, V; \phi) = (\phi(Q) \phi(K))V = \phi(Q)(\phi(K))V

 \phi(K)^T Vを先に計算することで、計算量は O(TD^2)になる。

これが1つのヘッドに相当。計算量はTについて線形になったが一般に D \geq 768なのでまだ高コスト。

基本self-attentionはmulti-headで扱われる(MSA)。
ヘッド数を Hは大体6~16で、 Q, Kの特徴量を D/Hに分割して行う。

 A(Q_h, K_h, V_h) = softmax(\frac{Q_h K^{T}_{h}}{\sqrt{D}})V_h \quad \forall_h \in \{1,...,H\}

 Q_h,K_h,V_h \in \mathbb{R}^{T \times \frac{D}{H}}

コレがmulti-head linear attention(MLA)
multi-headはもとのattentionと計算量は変わらないが、先程のように非線形関数で分解することで計算量を O(HT(D/H)^2) = O(TD^2/H)に抑えられる。

 A(Q_h, K_h, V_h; \phi) = \phi(Q_h)(\phi(K_h)^T V_h) \quad \forall_h \in \{1,...,H\}

 O(TD^2/H)なのでヘッド数を増やすと高速化できるが精度とのトレードオフで、実際いくつくらいまで増やしていいのか。

調査のためImageNet-1kをDeiT-Bで学習した結果が以下の図2。横軸が H

MSAは H \gt 96で、MLAは H \lt 3でメモリ不足。

MLAは H=768でもある程度精度を保っているが、これは H = Dでただのスカラ特徴。

類似度関数としてsoftmaxを使わなければ Hをスケールアップできそう(ここではcosine similarityを採用)。


そこで H = Dとした hydra trick を導入する。

 A(Q_h, K_h, V_h; \phi) = \phi(Q_h)(\phi(K_h)^T V_h) \quad \forall_h \in \{1,...,D\}
なお、 Q_h,K_h,V_h \in \mathbb{R}^{T \times 1}


 \odotアダマール積として
 Hydra(Q, K, V; \phi) = \phi(Q) \odot \Sigma^{T}_{t=1} \phi(K)^t \odot V^t



 \phi Q, K全体に適用することに注意( Q_h,K_hは列ベクトルなので)。

HydraはMSAとは全く異なる動作で、すべてのトークンを集約したグローバルな特徴ベクトル \Sigma^{T}_{t=1} \phi(K)^t \odot V^tに対して \phi(Q)ゲーティングしている。

計算量は O(TD(D/H)) = O(TD)

その他の O(TD)の手法でAttention-Free TransformerやPloyNLなどがあるが、Hydra Attentionはこれらの一般化と捉えることができる(論文参照)。

実験・結果

アーキテクチャは基本的にDeiT-B、データセットはImageNet-1k。
 sim(\cdot, \cdot)としてcosine similarityを採用( \phiはL2 normになる)。


cosine similarity以外について調査した結果が表1。

cosine similarityが最もよく、MSAのそもそもの性質を変化させてるのが原因と考えられる。
MSAは重みの和が1になるようになっているがそれらがそもそも望ましい性質では無いのかもしれない。

Hydra Attentionでの置き換え位置の調査。
すべてを置き換えるのではなく一部を置き換えた方がいいのではというお気持ち(よくあるグローバルを扱うAttention系は後半の層を置き換えると良い的なのが多い)。
Hydra Attentionはグローバルな情報を扱うためその可能性は高い。
実験結果が図4。

はじめの層の置き換え(forward)は精度が低下しているが、後ろの層の置き換え(Backward)は精度が1%近い改善もあった。

既存手法との比較

他の O(TD)の手法に比べ精度低下が少ない。
後半2層の置き換えは速度向上は少ないが精度が向上。

所感

ヒドラという名前がセンスがある。
すべてを置き換えると精度に影響が出るが、少し利用するなら精度・速度両方に恩恵があるそうなので良さそう。
実装も手軽そうなのも良い。
ただViTによる画像分類の結果だけなので他のタスクでうまくいくのかは気になるところ。