[論文メモ] AdaViT: Adaptive Tokens for Efficient Vision Transformer
CVPR2022 Oral
NVIDIA
あくまでメモ。間違っているかも。
ViTでトークンにhalting scoreを導入し予測時に各レイヤーでスコアに従い間引くことで僅かな精度低下で速度を大幅に向上させた。
手法
番目のレイヤーでの番目のトークンをとする。は次元数。シーケンス長はとする。
各トークンからhalting score を算出するhalting module を導入。
halting moduleは以下の式。
はトークンのe次元目の特徴量、はシグモイド関数、は学習可能なスカラ変数ですべてのレイヤーとトークンで共有される。またeは0を採用(トークンは十分にスパースで能力もあるので適当な1次元をhalting用に割いた感じ)。
halting scoreは全レイヤーで積み上げで、しきい値を超えたレイヤー 以降ではトークンは0にされattentionからも除外される。
最終レイヤーではクラストークン以外はすべて削除される。
AdaViTの例が下の図2。下の緑のトークンは途中で削除されattentionを通した影響(破線)を与えなくなる。
レイヤーをまたいでhalting scoreの経過を追跡するためにhalting scoreの残り を計算する(式(6))
のときに1を初めて超えるので。
早期のhaltingを促す ponder lossはを用いて
となる。
また、クラストークン にhalting probabilityを導入し、その重み付き平均でクラス分類を行う。
halting probabilityはを用いて
のときはもう停止しているので0。
この確率の意味がいまいちわからない。そのレイヤーが与える影響度的な?
lossは以下。はクラストークンをクラス分類にするためのpost process
レイヤーをまたいだhalting scoreの分布 についても制限を加える。
このが事前に決めた分布 になるようにKLダイバージェンスでlossをとる。
ここではをガウシアンとする(ベル状の形なので、ネットワークの中間あたりでhalting socreが大きくなる)。
最終的なlossは
各はlossのバランスを取るための重み。
実験・結果
例によって省略
所感
確かに画像の端の方のトークンとかはクラス分類にあまり影響がなさそうなので省いても良さそうではあるが、ヒューリスティックは厳しそうなので、ネットワーク自身に選択させるのは良さそう。
以前に画像の中心がなくてもある程度予測出来る的な論文もあった(はず)が、そういった問題を解決することにも繋がりそう。
多少の精度劣化があるのは残念。また、test時の工夫による速度改善なので、実装に依存になるのも残念。大規模なシステムでこれらを許容できるなら選択の余地はあるかも。
読んでいて、halting scoreとhalting probabilityの使い分けが少々おかしいところがあった(気がした)。
そしてhalting probabilityの意味がよくわからない(クラストークンのときだけしか出てこないし)。