[論文メモ] Swin Transformer V2: Scaling Up Capacity and Resolution

arxiv.org

Swin Transformerの改良。著者はSwin Transformerとだいたい同じ。
f:id:Ninhydrin:20211209093632p:plain

言語モデルは大量のパラメータ(530billion)で高いパフォーマンスを出しており、パラメータ数が多いと大体パフォーマンスも改善するのはわかっているが、画像系モデルに関してはせいぜい1billion程度にとどまっている。

Swin Transformerの問題点

モデルを大きくすると不安定になる

元のSwin Transformerを大きくしたときの各層の活性化後の値をプロットしたのが図2。*-Preが元のSwin Transformerで総が深くなるにつれ分散が大きくなり、は最初の層と最後の層で差が大きいものは 10^4近い差がある(縦軸は対数スケール)。

f:id:Ninhydrin:20211209092441p:plain

更にスケールアップ(658million param)にしたら学習ができなかった(図3)。
f:id:Ninhydrin:20211209092857p:plain

window sizeを変更すると劣化する

ImageNet1Kを入力サイズ256x256、window size 8x8で学習したモデルをより大きな入力サイズ・window sizeでテストしてみるとaccuracyが大きく低下。
relative position biasを再考する必要がありそう。
表1のParameterized position biasがそれ。
f:id:Ninhydrin:20211209100854p:plain

Swin Transformer V2

Scaling Up Model Capacity

Post normalization

Swin Transformerは普通のTransformerを踏襲しており図1の上の通りpre-normを採用している。
pre-normではactivation後の値がresidual connectionを通して利用されるため層が深くなるにつれ値が大きくなっていく。
そこで図1下のようにpost-normに変更する。これによりnormalizeされた値がresidual connectionで利用されるため値が大きくなっていくのを防げる。
実際試した結果が図2の*-Postのグラフ。*-Preに比べ分散が小さいことがわかる。

Scaled cosine attention

self-attentionを調べるとあるピクセルペアが支配的になっている。特にpost-normのセッティングで顕著。
これを緩和するためにscaled cosine attentionを行う。
f:id:Ninhydrin:20211209094956p:plain

 B_{ij}ピクセル i, jのrelative position bias、 \tauは学習可能なスケール(ヘッドやレイヤーで共有しない)。
cosineによりnormalizeされるのでattentionがマイルドになる。

Scaling Up Window Resolution

windowの解像度変更に対応するために

Continuous relative position bias

Swin Transformerではbiasの値そのものを最適化してきたが、連続的なposition biasにするために小さいネットワーク(2-layer MLP)を使ってbiasを生成する。
f:id:Ninhydrin:20211209095556p:plain
 \mathcal{G}MLP。relative coordinateを入力してbiasを出力する。
これによりfine-tuning等でwindow sizeが変化しても自然に対応できる。
また実際に利用するときには事前に計算しておけば良いのでparameterizedなbiasと比較してinference timeに違いは無い。

Log-spaced coordinates

window sizeが大きく変化する場合、学習で利用したことのない外挿のrelative positionが発生する。
これを緩和するために線形の座標からlogスケール座標で座標を扱う。
f:id:Ninhydrin:20211209100207p:plain
 \Delta{x}, \Delta{y}が線形の座標、 \hat{\Delta{x}}, \hat{\Delta{y}}が変換後のlogスケールでの座標。
元が8x8のwindow sizeのものを16x16のwindow sizeにしてfine-tuningした場合、入力の座標範囲は[-7, 7] x [-7, 7]から[-15, 15] x [-15 x 15]に変化し、範囲が8/7 = 1.14倍になる。
それに対してlogスケールだと[-2.079, 2.079] x [-2.079, 2.079]から[−2.773, 2.773] x [−2.773, 2.773]への変化になり、元のときの拡大率の0.33倍程度で済む。
表1のLinear-Spaced CPB とLog-Spaced CPBが実験結果(CPBはcontinuous position bias)。

その他工夫

GPUのメモリ消費を抑えるテクニックや精度を上げるための工夫。

Zero-Redundancy Optimizer (ZeRO)

よくあるdata-parallelの実装はモデルや最適化の状態をすべてのGPUにコピーするためGPUメモリの消費が激しい。
ZeROはモデル等を分割して各GPUに分配するのでGPUメモリの消費を大きく抑えられる。
ここではZeRO stage-1を使ったDeepSpeedフレームワークを利用。

Activation check-pointing

Transformerの特徴マップのメモリ消費を抑える。ただ最大30%程度速度が低下。

Sequential self-attention computation

画像サイズ1536x1536、window size 32 x 32とかだと上記の最適化を施してもまだ40GBぐらい消費する。
これはself-attentionがボトルネックになっており、これをバッチ処理ではなくシーケンシャルに処理するように実装しはじめの2ステージに適用。
少し速度に影響。

これらによりNvidia A100-40GでCOCO object detectionの入力画像サイズ1536 x 1536の3 billionのモデルを学習可能にした。

Joining with a self-supervised approach

大きなモデルはデータも大量に必要で、一般的にはJFT-3Bのような巨大なデータセットで学習したりself-supervisedで事前学習たりする。
ここでは両方を採用。
ImageNet22Kを5倍程度に拡大し、ラベルにノイズのある7000万枚のデータセットを作成。JFT-3Bには程遠い量だが、更にこれをself-supervisedで学習する。
これらにより3 billionの巨大なSwin TransformerをSOTA レベルにした。

Model configurations

例の如くサイズ違いの複数のモデルを用意。 Cはチャンネル数。
f:id:Ninhydrin:20211210093249p:plain

更に巨大なモデル(658millionパラメータと3billionパラメータ)を用意。それぞれ以下。
f:id:Ninhydrin:20211210093315p:plain

実験・結果

ImageNet-1KのV1、V2 val、object detection COCO、semantic segmentation ADE20K、video action classification Kinetics-400。
詳しい設定等は論文を参照。

f:id:Ninhydrin:20211210093829p:plain
f:id:Ninhydrin:20211210093851p:plain
f:id:Ninhydrin:20211210093909p:plain
f:id:Ninhydrin:20211210093926p:plain

post-normとcosine attentionのablation。僅かだが改善。このaccuracyでの0.1%は大きい?
f:id:Ninhydrin:20211210094040p:plain

Log CPBのablation。読んだ当時、図7を参照している部分が存在しない...。
f:id:Ninhydrin:20211210094140p:plain

所感

最近良くみかけるpost-norm。pre-normと比べ何が良いのかよくわからなかったが、residual connectionで伝播していく値をnormすることで深い層での分散を小さくするということがわかったのは収穫。確かにそのとおり。情けなし。
relative PEをlogスケールにするのはfine-tuning等でwindow sizeを変更するという少々稀?なタスク用で、普通に使う分にはあまり必要はなさそう(表1を見ても)。
メモリ消費を抑えたとはいえ、一番大きなモデルはまだ少々厳しい。メモリ消費を抑えるテクニックは参考にはなった。