[論文メモ] Sparse is Enough in Scaling Transformers

arxiv.org

NeurIPS2021

あくまでメモ。ちょっとわからないところがあり間違っているかもしれない。


でかいTransformerがいろんなタスクで性能を発揮しているが、学習に時間がかかりfine-tuningも時間がかかる。実際に使用するときも遅く実用的でない。
そこでTransformerをスパースにすることで性能をそれほど落とさず高速化させる。


BERTの学習コストは$4k~$12kで一人の4時間のフライトと同量のCO2を排出し、GPT-3はサンフランシスコ-ニューヨークの往復の3倍のCO2を排出するらしい。
これらのモデルのpopularityや今後もモデルが巨大化していくことを考えるとモデルの効率化は非常に重要な問題。

Transformerのパラメータ数を d_{\verb|model|} d_{\verb|ff|}で表す。
 d_{\verb|model|}はself-attention(SA)のパラメータ数。 d_{\verb|model|}=512が一般的。
BERTは d_{\verb|model|}=768、large BERTは d_{\verb|model|}=1024、GPT-3に至っては d_{\verb|model|}=12288
 d_{\verb|ff|}はfeedforward block(FFB)のパラメータで d_{\verb|ff|}= 4d_{\verb|model|} が一般的。

SAではqueryやkey等で d_{\verb|model|} \times d_{\verb|model|}の計算が4回必要になる。
またdecoder部分ではqueryの計算とattentionの計算で d_{\verb|model|} \times d_{\verb|model|}が2回必要。
FFBでは2回の d_{\verb|model|} \times d_{\verb|ff|}の計算が必要になる。
トータルで 4d^2_{\verb|model|} + 2d^2_{\verb|model|} + 2d_{\verb|model|}d_{\verb|ff|}の計算コストがかる。

つまり d^2_{\verb|model|}の計算コストと学習パラメータが必要になる。

提案手法ではこれを 8d^{1.5}_{\verb|model|} + 4d^{1.5}_{\verb|model|} + 4d^{1.5}_{\verb|model|}まで減らす。

手法

Sparse is Enough
Feedforward Layer(FFL)とSAをそれぞれスパースにしていく。

Sparse Feedforward Layer

FFLは2つのfully-connected(FC) layerとReLUからなる。そして大体その間の次元数は d_{\verb|ff|} = (4 or 8)d_{\verb|model|}
ReLUにより2つのFC layerの中間は0が多くなる。そこで中間表現を N個のブロックに分け、各ブロックから1つだけが非ゼロになるようにする。
既存手法として枝刈りがあるが、ここでは枝刈りせず、すべてのweight matrixを学習し、利用する中間表現を動的に決定する(dynamic sparsity)。

決定方法は別途Controllerを学習する。ControllerはFF layerと同じ入力を受け取る非常に小さい2 layerのMLP。出力は入力と同じサイズでそれをマスクとして利用する。式にすると下記。
f:id:Ninhydrin:20211201100050p:plain

まずControllerを計算しバイナリマスクを生成し、N個のブロックでマスクされていない箇所だけを計算をする(図2を参考)。
式は下記。
f:id:Ninhydrin:20211201101836p:plain
f:id:Ninhydrin:20211201100433p:plain

実際に利用するときはargmaxでマスクを決定するが、学習時はsoftmaxしてサンプリングする。微分可能にするために Gumbel-Softmax trikを利用する。
T5-largeで試した結果が表2と図3。 N=64のときbaselineと変わらないレベルの性能で高速化。
f:id:Ninhydrin:20211201101756p:plain

Sparse QKV Layer

次にSAで使われるFC layerの部分(query、key、value、FC)を調整する。これらをまとめてQKVと呼ぶ
QKVはすべて d^2_{\verb|model|}のパラメータと計算コストが必要。ただFFLと違ってReLUが無いので先程のは使えない。

ここでも入力をサイズ M S個のブロックに分ける( M = d_{\verb|model|} / S)。ただこのまま処理すると他のブロックに干渉できないので multiplicative dense layer(mult)を導入する。
multは d^2_{\verb|model|}/S + d_{\verb|model|}Sのパラメータを使う。入力を \textbf{x} \in \mathbb{R}^{d_{\verb|model|}}として以下の式。
f:id:Ninhydrin:20211202091542p:plain
ただし \textbf{y} \in \mathbb{R}^{S \times M} D \in \mathbb{R}^{d_{\verb|model|} \: \times S} M \in \mathbb{R}^{d_{\verb|model|} \: \times M}。基本的に S=16を利用。図4を見ると式のイメージがつかめる。
f:id:Ninhydrin:20211202091854p:plain

multで2次元( S \times M)のマップが得られ、シーケンス長やミニバッチを考えると \mathbb{R}^{\verb|batch| \times \verb|length| \times S \times M}の形をしている。lengthと Sを画像で言うheightとwidth、 Mをチャンネルとしてconvolutionを行う。

この2つを組み合わせ、FC layerを置き換える。
入力のQ、K、Vに関してはmultを共有する。つまり Q = conv_Q(mult(x)), K = conv_K(mult(x)), V= conv_V(mult(x))。そして出力部分のFC layerに関しては取り除く。図4(b)の一番下の図。
 S = \sqrt{d_{\verb|model|}}にすることで計算コストを d^2_{\verb|model|}から d^{1.5}_{\verb|model|}のオーダーに減らせる。

実際QKVを置き換えた結果が図5。FFも置き換えた結果が表3。
f:id:Ninhydrin:20211202093619p:plain


最終的な出力はボキャブラリー数でFC layerで出力するが、そこに関してもmultiplicative dense layerに置き換える。速度は速くなるが少々劣化する。

Sparsity for Long Sequences

これらの工夫によりdense layerは効率化されたが、実際に利用するときに長いシーケンスが来るとattention部分の計算量が支配的になりこれまでの高速化が意味をなさない。
そこで長いシーケンスに対応出来るように調整する。
具体的にはReformerで提案されたLSH(Locality-Sentsitive Hashing)に注目。これを組み込んだアーキテクチャ、Terraformerとする。
arxiv.org

Architecture for Long Sequences

Transformerのdecoderは準最適で改良の余地がある。
decoderのSAと encoder-decoder attentionに分かれているが、その必要がないのでencoder-decoder attentionを取り除く。
そしてdecoderの入力にencoderの出力をシーケンス方向でconcatして同時に入れてしまう。これによりTransformerと同じ位の速度でより良いパフォーマンスが得られる。アーキテクチャは図8を参照。
f:id:Ninhydrin:20211202094552p:plain

これがどうしてLong Sequenceに効果的なのかちょっとわからない。

Reversibility for Memory Efficiency

巨大なバッチサイズで巨大なモデルを一つのマシンで学習するためにReformerの手法を取り入れる。
Reformerはfeedforwardとattentionが1:1だったが、Terraformerは2つのattentionがあるのでswapが3回になる。
また、連続関数でないと効果が発揮できない。スパースにするためにマスクを作っているがこれをbackward時に再計算するのではなく保存しておく必要がある。

Recurrence for Generalization

FF Blockに速度を落とさないようにrecurrentを導入する。そのためにsimple recurrent units(SRUs)を利用。
SRUsの次元は小さくても機能することが実験でわかった(32次元)。
SRUsを導入したことでより長いシーケンスを扱えるようになった。

実験・結果

細かい内容は論文とそのAppendixを参照。
f:id:Ninhydrin:20211203090150p:plain

所感

Terraformerという強い名前。簡単に調べたがまだ使われてはなさそうだった。Transformerに語感は似てるし良いネーミングセンス。
スパース化・高速化するときに単純な行列積ではなくアダマール積と組み合わせていたが、行列積二回で途中で次元を下げるのとどちらが良いのか気になる。
読んでいてちょっとわからなかったのがlong sequenceの扱いの部分。
encoder自体は残っていて、図8はあくまでdecoderのアーキテクチャで正しい?それとも図8のencoder embeddingはencoderへの入力トークン列で図8がTerraformerの全体を示している?(Appendixには「Figure 8 shows the whole architecture of Terraformer model」と書いてあるし)

実際このencoderもdecoderもわけないというのは使い勝手がありそう。
Recurrentを導入したのは少々ずるい気もする。