[論文メモ] AdaViT: Adaptive Tokens for Efficient Vision Transformer

arxiv.org

CVPR2022 Oral
NVIDIA

あくまでメモ。間違っているかも。

ViTでトークンにhalting scoreを導入し予測時に各レイヤーでスコアに従い間引くことで僅かな精度低下で速度を大幅に向上させた。

手法

 l番目のレイヤーでの k番目のトークンを t^l_k \in \mathbb{R}^{E}とする。 Eは次元数。シーケンス長は Kとする。
トークンからhalting score  h^l_kを算出するhalting module  Hを導入。

halting moduleは以下の式。

 t^l_{k,e}トークンのe次元目の特徴量、 \sigmaシグモイド関数 \beta, \gammaは学習可能なスカラ変数ですべてのレイヤーとトークンで共有される。またeは0を採用(トークンは十分にスパースで能力もあるので適当な1次元をhalting用に割いた感じ)。

halting scoreは全レイヤーで積み上げで、しきい値 1 - \epsilonを超えたレイヤー  N_k以降ではトークンは0にされattentionからも除外される。

最終レイヤーではクラストークン以外はすべて削除される。


AdaViTの例が下の図2。下の緑のトークンは途中で削除されattentionを通した影響(破線)を与えなくなる。

レイヤーをまたいでhalting scoreの経過を追跡するためにhalting scoreの残り  r_kを計算する(式(6))

 N_kのときに1を初めて超えるので r_k < 1

早期のhaltingを促す ponder lossは r_kを用いて

となる。


また、クラストーク t_cにhalting probabilityを導入し、その重み付き平均でクラス分類を行う。
halting probabilityは r_kを用いて

 l > N_kのときはもう停止しているので0。
この確率の意味がいまいちわからない。そのレイヤーが与える影響度的な?

lossは以下。 \mathcal{C}はクラストークンをクラス分類にするためのpost process

レイヤーをまたいだhalting scoreの分布  \mathcal{H}についても制限を加える。

この \mathcal{H}が事前に決めた分布  \mathcal{H}^{target}になるようにKLダイバージェンスでlossをとる。

ここでは \mathcal{H}^{target}をガウシアンとする(ベル状の形なので、ネットワークの中間あたりでhalting socreが大きくなる)。
最終的なlossは

 \alphaはlossのバランスを取るための重み。

実験・結果

例によって省略


所感

確かに画像の端の方のトークンとかはクラス分類にあまり影響がなさそうなので省いても良さそうではあるが、ヒューリスティックは厳しそうなので、ネットワーク自身に選択させるのは良さそう。
以前に画像の中心がなくてもある程度予測出来る的な論文もあった(はず)が、そういった問題を解決することにも繋がりそう。
多少の精度劣化があるのは残念。また、test時の工夫による速度改善なので、実装に依存になるのも残念。大規模なシステムでこれらを許容できるなら選択の余地はあるかも。
読んでいて、halting scoreとhalting probabilityの使い分けが少々おかしいところがあった(気がした)。
そしてhalting probabilityの意味がよくわからない(クラストークンのときだけしか出てこないし)。