[論文メモ] QUADTREE ATTENTION FOR VISION TRANSFORMERS

arxiv.org
github.com

ICLR2022
Vision Transformer(ViT)のAttentionに四分木を導入して計算コストを下げた。


手法

ピラミッド構造にし、予測に影響しない(attention scoreが低い)箇所はそのまま、予測に影響する部分のみ深堀りしていくことで情報のロスの抑えつつ効率化を図る

特徴マップをダウンサンプリング(avarage pooling)することで、 L層のピラミッド構造を用意する。
一番粗い l=1層から処理をしていく。
ある層 lに対してattentionを行いattention scoreを計算する。
そしてattention scoreにが高い上位 K個のトークンを選択し、次の l + 1層にでは l層で選ばれたトークンをそれぞれ4分割し再度そのトークン間でattentionを計算する。
図2は L=3 K=2のときで、層1で選択された右2つのトークンが4分割され、層2でattentionに利用されている。
この処理を一番細かい層(元のトークンサイズ)になるまで繰り返す。
こうして出来上がった各層の粗さの異なるトークンを集約する必要があり、その手法としてQuadTree-AとQuadTree-Bの2つを提案。

QuadTree-A

すべての層から集約するパターン。
一番細かい層のあるあるクエリトークンを \textbf{q}_iに対応する出力 \textbf{m}_i(ここではメッセージと呼ぶ)について考える(図2(a)のQのlevel 3にある)。
QuadTree-Aでは \textbf{m}_iを各層から計算される部分メッセージ \textbf{m}^l_iの和とする。


 l -1層の上位 K個のトークンが所属する領域を \Gamma^l_iとする(図2(c)を参照)。ただし \Gamma^1_iは画像全体。
そして \Omega^l_i=\Gamma^l_i -\Gamma^{l+1}_iとする(図2(b)を参照)。

部分メッセージ \textbf{m}^l_iは領域 \Gamma^l_i内のトークン間で計算される



 \textbf{v}^l_jはバリュー、 s^l_{ij}=s^l_{ij}t^l_{ij} s^{l-1}_{ij}は対応する親のクエリとキーのスコア(attention の \textbf{QK}^{\mathsf{T}}的な)で  s^1_{ij}=1
 t^l_{ij}は分割された2x2のトークンによるattention後の値。

図2の(b)がメッセージ \textbf{m}_iの例で、緑の領域は \textbf{q}_iに関係ある部分で細かく、赤い箇所は関係ない部分で粗くなっている。

QuadTree-Aではクエリ・キー・バリューはaverage poolingでダウンサンプリングする。

QuadTree-B

QuadTree-Aではattention score  s^l_{ij}再帰的に計算されるため、細かい層の影響が小さくなる。そこで学習可能な重み w^l_iを導入する。

部分メッセージは以下で計算。

実験・結果

例の如く省略。
Image classificationだけ。
意外と精度は高いが、計算効率的に微妙な...

所感

四分木にして計算コストを抑えつつ、attention scoreで必要なトークンを選択するというのは面白いアイディア。
でもその構造を構築するための計算がそこそこコストになっていそうな気がするのは気のせい?(ピラミッド構造を作って内部で何度かattentionしてとか)。
思っていたよりも効率化されてない印象。
また実装も少々面倒そうな気がする。attention socoreでトークンの選択をするとなるとバッチ処理のときにきれいに書きにくそう。
地味に説明のない文字が数式にあらわれていて迷う。