[論文メモ] Efficient Training of Visual Transformers with Small Datasets

arxiv.org
github.com

少量データでVision Transformer(ViT)を学習するときにサブタスクとしてパッチ間の距離を学習することで精度を向上させる。

ViTは大量のデータセットで学習することで高いパフォーマンスを発揮するが、逆にデータセットが小さいと精度が出にくい。
最近のViT(第二世代)はそれを克服するためconvolutionなどを利用し、データがそこまで大きくなくても(言うてImageNet1Kクラス)ResNet等を凌ぐ精度を出している。
しかし、寄り小さいデータセットのときにどうなのかについては明らかではない。その辺りに焦点を当てた論文。
実際、小さいデータセットだと精度がいまいち。

そこでパッチ間の距離を予測する自己教師ありのサブタスクを導入することでブーストさせる。

手法


一般的にViTは入力を K \times Kのパッチにして扱う。一部のViTではconvolutionやpoolingを導入してパッチのマージを行い出力のパッチ数が入力と異なることがある。そこで出力は k \times k、ただし基本 k \leq Kとする。
入力画像を xとして、その最終的な埋め込みを G_x = \{ \textbf{e}_{i,j}\}_{1 \leq i,j \leq k} \textbf{e}_{i,j} \in \mathcal{R}^dとする( dは埋め込みの次元サイズ)。この G_xからランダムに2つの埋め込み( \textbf{e}_{i,j}, \textbf{e}_{p,h})を選択し距離 (t_u, t_v)^{T}を計算する。

次に \textbf{e}_{i,j} \textbf{e}_{p,h}MLP(f)に通して距離を予測させる( (d_u, d_v)^T = f(\textbf{e}_{i,j}, \textbf{e}_{p, h})^T)。
この予測した距離が先程の距離と等しくなるようにL1 lossで学習をする。

 Bはミニバッチ。
最終的なlossは元のクラス分類のlossと合わせて \mathcal{L}_{total} = \mathcal{L}_{ce} + \lambda \mathcal{L}_{drloc}

なお、元のアーキテクチャと比較のため基本的に対象のViTのアーキテクチャは変えず、追加のMLPをくっつけるだけにする。
Swin Transformerは出力が 7 \times 7だが、T2TとCvTは[14 \times 14]なので 2 \times 2のAvgPoolをして出力を揃える。これはあくまで距離予測タスクにのみ使用し、クラス分類に影響はしない。

実験結果

データセットのサイズ等は表1を参照。

距離学習に使うパッチペアのサンプル数を mとして比較した結果が表2。距離学習は有効そう。

クラッチからの学習でepoch数による比較結果が表3,4。小さいデータセットのときはそこそこ効果が高そう。
またSwin-Tに関して、Swin-Tにはrelative positional embeddingが入っていて距離学習は必要ないのではないかと思うが距離学習を行うことで精度が向上している。

fine-tuningでも効果あり(表5)

所感

パッチ間の距離を追加で予測するだけで精度が向上するとのこと。
positional encodingを使っていてもブーストされるのは少し意外(まあ最終的な埋め込みまで位置情報が残っているとは思わないが)。
IN-100は試行回数が5回なのでそこそこ信用できそう。
他は書いていないということは試行回数が1回とかかもしれないが、それでもSwin-Tの精度向上を見ると効果はありそう。
ただなんでそんなに効果があるのか少々不思議。距離を予測するにはセマンティックな情報を扱う必要があり、より良い情報が埋め込まれるから?ほんとに?