[論文メモ] SELF-ATTENTION DOES NOT NEED O(n^2) MEMORY

arxiv.org

self-attentionの計算にメモリ O(n^2)は必要ない

self-attentionはクエリ q \in \mathbb{R}^d、長さ nのキーとバリューをそれぞれ k_1,...k_n v_1,...v_n \in \mathbb{R}^dとして次の式で表せる(ただしクエリが1つのとき)。
f:id:Ninhydrin:20211216092016p:plain

普通に実装すると s_iの計算・保存ために O(n)の計算量とメモリが必要。そしてself-attentionは O(n^2)必要。
これを改善し、

attentionの計算について。
まずsoftmax(式の s'_iの部分)で \Sigma_{j} e^{s_j}をattentionの最後に移動する。
f:id:Ninhydrin:20211216092728p:plain

これは定数メモリで計算出来る。
attentionの除算部分のために v^* \in \mathbb{R}^d s^* \in \mathbb{R}を用意する。
キーとバリューのペア k_i, v_iを取り出したら、 s_iを計算し v^* s^*を蓄積していく( v^* \leftarrow v^* + v_ie^{s_i} s^* \leftarrow s^* + e^{s_i})。
そして最後に \frac{v^*}{s^*}で割ればよい。

これにより O(1)のメモリ消費。
これをself-attentionに拡張するにはクエリにインデックスを1つ追加するだけで、それでも O(\log n)のメモリに抑えられる。
ただ、入出力は O(n)必要でこれは仕方なし。

所感

JAXの実装が乗っているけどJAXをあまり使わないので悲しい。
JAXを使う人は使ってみてもいいのかもしれない。