[論文メモ] Sparse is Enough in Scaling Transformers
NeurIPS2021
あくまでメモ。ちょっとわからないところがあり間違っているかもしれない。
でかいTransformerがいろんなタスクで性能を発揮しているが、学習に時間がかかりfine-tuningも時間がかかる。実際に使用するときも遅く実用的でない。
そこでTransformerをスパースにすることで性能をそれほど落とさず高速化させる。
BERTの学習コストは$4k~$12kで一人の4時間のフライトと同量のCO2を排出し、GPT-3はサンフランシスコ-ニューヨークの往復の3倍のCO2を排出するらしい。
これらのモデルのpopularityや今後もモデルが巨大化していくことを考えるとモデルの効率化は非常に重要な問題。
Transformerのパラメータ数をとで表す。
はself-attention(SA)のパラメータ数。が一般的。
BERTは、large BERTは、GPT-3に至っては。
はfeedforward block(FFB)のパラメータで が一般的。
SAではqueryやkey等での計算が4回必要になる。
またdecoder部分ではqueryの計算とattentionの計算でが2回必要。
FFBでは2回のの計算が必要になる。
トータルでの計算コストがかる。
つまりの計算コストと学習パラメータが必要になる。
提案手法ではこれをまで減らす。
手法
Sparse is Enough
Feedforward Layer(FFL)とSAをそれぞれスパースにしていく。
Sparse Feedforward Layer
FFLは2つのfully-connected(FC) layerとReLUからなる。そして大体その間の次元数は。
ReLUにより2つのFC layerの中間は0が多くなる。そこで中間表現を個のブロックに分け、各ブロックから1つだけが非ゼロになるようにする。
既存手法として枝刈りがあるが、ここでは枝刈りせず、すべてのweight matrixを学習し、利用する中間表現を動的に決定する(dynamic sparsity)。
決定方法は別途Controllerを学習する。ControllerはFF layerと同じ入力を受け取る非常に小さい2 layerのMLP。出力は入力と同じサイズでそれをマスクとして利用する。式にすると下記。
まずControllerを計算しバイナリマスクを生成し、N個のブロックでマスクされていない箇所だけを計算をする(図2を参考)。
式は下記。
実際に利用するときはargmaxでマスクを決定するが、学習時はsoftmaxしてサンプリングする。微分可能にするために Gumbel-Softmax trikを利用する。
T5-largeで試した結果が表2と図3。のときbaselineと変わらないレベルの性能で高速化。
Sparse QKV Layer
次にSAで使われるFC layerの部分(query、key、value、FC)を調整する。これらをまとめてQKVと呼ぶ
QKVはすべてのパラメータと計算コストが必要。ただFFLと違ってReLUが無いので先程のは使えない。
ここでも入力をサイズの個のブロックに分ける()。ただこのまま処理すると他のブロックに干渉できないので multiplicative dense layer(mult)を導入する。
multはのパラメータを使う。入力をとして以下の式。
ただし、、。基本的にを利用。図4を見ると式のイメージがつかめる。
multで2次元()のマップが得られ、シーケンス長やミニバッチを考えるとの形をしている。lengthとを画像で言うheightとwidth、をチャンネルとしてconvolutionを行う。
この2つを組み合わせ、FC layerを置き換える。
入力のQ、K、Vに関してはmultを共有する。つまり。そして出力部分のFC layerに関しては取り除く。図4(b)の一番下の図。
にすることで計算コストをからのオーダーに減らせる。
実際QKVを置き換えた結果が図5。FFも置き換えた結果が表3。
最終的な出力はボキャブラリー数でFC layerで出力するが、そこに関してもmultiplicative dense layerに置き換える。速度は速くなるが少々劣化する。
Sparsity for Long Sequences
これらの工夫によりdense layerは効率化されたが、実際に利用するときに長いシーケンスが来るとattention部分の計算量が支配的になりこれまでの高速化が意味をなさない。
そこで長いシーケンスに対応出来るように調整する。
具体的にはReformerで提案されたLSH(Locality-Sentsitive Hashing)に注目。これを組み込んだアーキテクチャ、Terraformerとする。
arxiv.org
Architecture for Long Sequences
Transformerのdecoderは準最適で改良の余地がある。
decoderのSAと encoder-decoder attentionに分かれているが、その必要がないのでencoder-decoder attentionを取り除く。
そしてdecoderの入力にencoderの出力をシーケンス方向でconcatして同時に入れてしまう。これによりTransformerと同じ位の速度でより良いパフォーマンスが得られる。アーキテクチャは図8を参照。
これがどうしてLong Sequenceに効果的なのかちょっとわからない。
Reversibility for Memory Efficiency
巨大なバッチサイズで巨大なモデルを一つのマシンで学習するためにReformerの手法を取り入れる。
Reformerはfeedforwardとattentionが1:1だったが、Terraformerは2つのattentionがあるのでswapが3回になる。
また、連続関数でないと効果が発揮できない。スパースにするためにマスクを作っているがこれをbackward時に再計算するのではなく保存しておく必要がある。
Recurrence for Generalization
FF Blockに速度を落とさないようにrecurrentを導入する。そのためにsimple recurrent units(SRUs)を利用。
SRUsの次元は小さくても機能することが実験でわかった(32次元)。
SRUsを導入したことでより長いシーケンスを扱えるようになった。
実験・結果
細かい内容は論文とそのAppendixを参照。
所感
Terraformerという強い名前。簡単に調べたがまだ使われてはなさそうだった。Transformerに語感は似てるし良いネーミングセンス。
スパース化・高速化するときに単純な行列積ではなくアダマール積と組み合わせていたが、行列積二回で途中で次元を下げるのとどちらが良いのか気になる。
読んでいてちょっとわからなかったのがlong sequenceの扱いの部分。
encoder自体は残っていて、図8はあくまでdecoderのアーキテクチャで正しい?それとも図8のencoder embeddingはencoderへの入力トークン列で図8がTerraformerの全体を示している?(Appendixには「Figure 8 shows the whole architecture of Terraformer model」と書いてあるし)
実際このencoderもdecoderもわけないというのは使い勝手がありそう。
Recurrentを導入したのは少々ずるい気もする。