[論文メモ] NOT ALL PATCHES ARE WHAT YOU NEED: EXPEDITING VISION TRANSFORMERS VIA TOKEN REORGANIZATIONS

arxiv.org
github.com

ICLR2022

Vision Transformer(ViT)においてすべてのパッチ(トークン)は必要ないので注意の少ないトークンをマージすることで精度を保ちつつ高速化する。

f:id:Ninhydrin:20220228091813p:plain
図1(a)のようにランダムにパッチをマスクしてもViTの予測に影響しないが、図2(b)のようにメインとなるオブジェクト領域をマスクすると予測が狂う。
またViTは画像を切り出したパッチをSelf-Attention(SA)で処理するので固定長の必要がない、というのがお気持ち

手法

Attentionはいつもの通り
f:id:Ninhydrin:20220228093301p:plain

ViTは画像を切り出したパッチからのトークンと最後にクラス予測に使うCLSトークンからなる。
CLSトークンと他トークンとの相互作用は以下の式。
f:id:Ninhydrin:20220228093634p:plain
 \boldsymbol{q}_{class}はCLSトークンのクエリベクトルで、Softmaxの部分をまとめ \boldsymbol{a}とするとCLSトークンの出力 \boldsymbol{x}_{class} \boldsymbol{a}係数としたバリューベクトルの線形結合と見ることができる。

CLSトークンによるAttentionはクラス予測に影響を与えるトークンほど大きくなるという既存研究の報告がある(自己教師あり学習のDINOの論文の図1を参照)。
arxiv.org

なのでCLSトークンの注意の大きさを元に不要なトークンを削除すればよい、というわけにはいかなくて、削除すると精度が落ちる(表1)
f:id:Ninhydrin:20220228095532p:plain

ということで単純に削除するのではなくマージして、トークンの再構成を行う。
Multi-Head SAの各ヘッド h = [1,...,H]のCLSトークンによるAttentionベクトル \boldsymbol{a}^{(h)}の平均 \bar{a} = \Sigma^{H}_{h=1}\boldsymbol{a}^{(a)}/Hを使ってtop-k個のトークンを決め、それ以外のトークンをマージする(図2参照)。
正確にはマージではなく、Attention Weightを使った加重平均( \boldsymbol{x}_{fused} = \Sigma_{i \in \mathcal{N}}a_i \boldsymbol{x}_i)
f:id:Ninhydrin:20220228095726p:plain

一定の割合のトークンを削除する手法だと、背景に対応するトークンは情報は少ないので削除しても影響は少ないが、画像全体に対してオブジェクトが占める割合が大きいと画像に対してはオブジェクトに関わるトークンを削除することになりパフォーマンスが落ちる。
提案手法のようにマージすれば情報の少ないトークンでも有効活用できる。

本手法を可視化したのが図3。
f:id:Ninhydrin:20220228101123p:plain

実験・結果

ImageNetで学習・検証。
提案手法名はEViT。比較対象はDeiTとLV-ViT。
表2のEViT with inattentive token fusionが提案手法で、EViT without inattentive token fusionは不要なトークンを削除するタイプ。
提案手法の方が精度低下が少ない、と言いたいところだが表2を見ると削除もマージも正直あまり差がない(一応標準偏差は小さい)。
f:id:Ninhydrin:20220301085848p:plain

f:id:Ninhydrin:20220301090040p:plain

学習済みDeiT-S/Bをオラクルにした(CLSトークンが学習済みだと大事なトークンがわかった状態である)。
f:id:Ninhydrin:20220301090807p:plain

f:id:Ninhydrin:20220301091152p:plain

なおmultiply-accumulate computations(MACs)メトリックはtorchprofileを使って測ったらしい(知らなかった)。

所感

CLSトークンがトークンの重要度と関係するという既存研究結果を知らなかった。画像から作ったトークンと違って目的が目的ゆえ、たしかに言われるとそうかという気持ち。
そしてそれを使って加重平均をとってトークンをマージして減らすというのは良い発想だが、削除して減らしてもあまり結果に変化がなさそうなのが少々残念。
削除はnetwork pruningに似ておりViTに限らずMLPでも削除ではなくマージする戦略があったのか気になる。