[論文メモ] Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?

arxiv.org
あくまで個人的メモレベル

概要

ViTのように近年ではTransformerのCV分野への応用が活発だが、画像認識においてself-attentionが高パフォーマンスを得るための鍵なのか調査し、最近また再燃しているMLPモデルを改良したattention-freeなsMLPNetを提案。

導入

ViTやDeiTなど、適切な事前学習をしたtransformer系のビジョンモデルは画像認識でSOTAを達成した。
DeiTなどの畳み込みのないViTは2つの固定観念的なものを生み出した
1) グローバルな依存関係のモデリングが重要で、畳み込みのようなローカルな依存関係を置き換えられる
2) self-attentionが重要

1番目について、多くの研究者はローカルな依存関係のモデリングをなくし、わざわざself-attentionにローカルな依存関係を学習させようとしてきた。
確かにself-attetnionは強力だが計算コストが高く高解像度画像への適用の欠点とされている。Swin-Transformerはピラミッド構造・ローカルな依存関係が明示的に構造に入っており高いパフォーマンスを示している。

2番目について、近年ではself-attention freeなMLPモデルで画像認識をしようとしている研究者が出てきている。MLP-MixerはViTのように画像をパッチに分けて入力し、spatial mixingとchannel mixingをするというのが特徴。ただ、パラメータが多い分過学習しやすい。
MLP-MixerとSOTAモデルとの精度の差は大きいがこれがself-attentionの有無ということにはならない。

提案するsMLPNetはこの2つを否定する。

手法

CNNの重要なデザインアイディアを保ちつつ、Transformerから着想を得た新しいモジュールを導入する。
ただし、
1) ViTやMLP-Mixerと公平な比較のため、それらと似た構造にする
2) 明示的にローカルバイアスを導入する
3) self-attentionを使用しないでもグローバルな依存関係を得られるか調査
4) ピラミッド構造によるマルチステージの処理

全体像

f:id:Ninhydrin:20210915094335p:plain
ViTやMLP-Mixerとかと同様に画像をパッチに分割して入力する。パッチサイズは4x4と小さめ(MLP-Mixerは16x16)。
同じ画像サイズのとき計算コストはMLP-Mixerの16倍であまりに入力が多いとMLP-Mixerでは学習が破綻する。
4つのステージに別れ、1ステージ目はパッチの埋め込みになっているが、他ではパッチを統合し解像度を1/2倍にしチャンネルを2倍するSwinのブロックを導入。
実装は周辺の2x2のバッチをconcatしてLinear Layerに通す。
図2(b)はtoken-mixing moduleでdepth-wise convによりローカルバイアスを導入。パラメータも少ないし低コストで提案するsMLPにも導入。
channel-mixing moduleはMLP-Mxierと同じfeed-forward network。

Sparse MLP(sMLP)

MLPの2つ問題点、1)パラメータが多く過学習しやすい、 2)計算コストが高い(特に入力が高次元)を解決する。
そのためにsMLPでは重みの共有を行う(図1)。全てのトークンではなく、同じ行と列のトークンのみと相互作用する。
f:id:Ninhydrin:20210915101031p:plain
仕組みとしてはH x W x Cを(HC) x W or (WC) x HにしてFCレイヤーに入れるだけ。それぞれで得られた特徴を X_H X_Wとする。もとのHxWxCの入力を XとしてsMLPの出力[X^{out}]は[X^{out}=FC(concat(X_H, X_W, X))]となる(図3)。pytorchの擬似コードもある。
f:id:Ninhydrin:20210915101620p:plain
f:id:Ninhydrin:20210915101659p:plain

パラメータは H^2+W^2+3C^2MLP 2\alpha (HW)^2 \alphaは拡大率で4が多い。入力が224x224、 C=80とすると3000xほどのパラメータ削減になる。
計算コストは
sMLP)  \Omega(sMLP)=HWC(H+W)+3HWC^2
MLP-Mixer)  \Omega(MLP)=2\alpha (HW)^2C

アーキテクチャの設定

サイズ違いで3種類のモデルを用意
f:id:Ninhydrin:20210915102158p:plain

実験

実験は一部省略

Ablation Study

Local and global modeling

DWConvがローカルの情報、sMLPがグローバルの情報を担当しているがこれらが必要なのかを確認(表1)。
まずはsMLPのみ。DWConvは軽量のため取り除いても0.1MB程度しかモデルサイズは変わらないそう(FLOPsは0.1B程度)。しかしaccuracyは0.7%程度も変わる。
次はDWConvのみ。sMLPは重めのモジュールなのでモデルサイズ、FLOPsをベースモデルに揃えるためにチャンネルサイズを80から112にしたがsMLPのみとほとんど変わらないaccuracy。
f:id:Ninhydrin:20210916090022p:plain

次にステージ毎のsMLPモジュールの役割について調査。ベースはsMLPNet-B。結果は表2を参照。1行目がベースモデルそのまま。
sMLPを取り除くごとに徐々にaccuracyが落ちていくが、特にステージ3からを取り除いたときに大きく落ち込んだ。
ステージ3は一番レイヤー数が多く、FCのサイズも大きいためパラメータも多いし計算コストも高い。
ステージ1~3のsMLPを取り除いたモデルのモデルサイズ的にsMLPNet-S(accuracyは83.1%)と同じ。しかしaccuracyでは劣るので入力の早い段階でグローバル情報を扱うのが重要。
f:id:Ninhydrin:20210916091048p:plain

Fusion in sMLP

sMLPは特徴をFCに入れてsumしているが他の軽量な方法ではどうなのか調査。
sum) 単純に足し合わせる。
weighted sum) 学習可能パラメータで乗算して足し合わせる
比較対象のベースモデルはsMLPNet-S。結果が表3。
smuもweighted sumもパラメータ数、計算コストともに下がったがaccuracyも低下した。より軽量でconcat + FCのsMLPNet-Tと比べてもaccuracyが落ちている。
パラメータ数・計算コストとaccuracyのトレードオフを考えたとき、concat + FCの方が優れていそう
f:id:Ninhydrin:20210916092551p:plain

Branches in sMLP

sMLPモジュールでは3つのブランチをconcatしているが他についても調査。2つのパターンの組み合わせ。

  • ベースと同じ並列処理と順次処理( X^{out}=FC(X_H) or  X^{out} = FC(X_W)的な。どちらもほぼ同じaccuracyだったらしい)
  • Identity mapping(入力を足し合わせる)

これらのある無しで4通りを実験。結果が表4。同じ行で比較(並列 or 順次)だと並列のが良く、同じ列で比較(identityありなし)だとありが良い。
f:id:Ninhydrin:20210916093733p:plain

Multi-stage processing in pyramid structure

ローカル情報とグローバル情報を組み合わせるときにピラミッド構造は有効そうだが、実際にピラミッド構造が有効なのか調査の必要がある。
MLPのモデルでシングルステージとマルチステージの比較をした。ただし、ベースはsMLPNetでsMLPブロックを調整する。
小さいsMLPNetを構築し、ステージ1のsMLPブロックをDWConvに置き換えステージ2~4のsMLPブロックをMLPブロックに置き換えた。これがマルチステージのMLPモデル。
結果が表5。マルチステージのほうが少ないパラメータでより良いaccuracyとなった。
f:id:Ninhydrin:20210916094735p:plain

Comparison with state-of-the-art

SOTAモデルとの比較。データセットはImageNet-1K。入力は224x224で結果は表6。
注目はsMLPNet-BでSwin-Bに対してより小さいパラメータ数、FLOPsで同じaccuracyを達成した。過学習も見られない。
attention-freeでもSOTAと並ぶパフォーマンスを出すことができ、attentionが最高パフォーマンスを出すのに必ずしも必要では無いことが示せた。
なおtransformer-basedなモデルでCSwinがあるそうでこちらは84.2%のtop-1を叩き出しているが、あくまでattention-freeでSOTAレベルのパフォーマンスを示すのが目的なので、CSwinの結果がこの結果に影響はしないとのこと。
f:id:Ninhydrin:20210916095507p:plain

所感

attention-freeでもSOTAレベルのaccuracyを獲得できたのはすごいと思うが、Convolutionのようなバイアスや制約をモデルに課す必要がありうーんといった感じ。
学術的には面白いと思うが、実際に適用するときにMLPモデルを採用するかは微妙。
ただ、解きたい問題・ドメインにあった帰納バイアスを持つアーキテクチャを選ぶ・設計することの重要性を再認識させられた。
MLPという自由度の高いモデルで如何に精度を出すかはこの手の改良が必要で非常に参考になる。
そういう意味でも、今後もMLP系のモデルの発展に注目していきたい。