[論文メモ] BatchFormer: Learning to Explore Sample Relationships for Robust Representation Learning

arxiv.org

CVPR2022

サンプル間の関係をネットワーク内部で学習するフレームワークを提案。

サンプル間の関係を調査するフレームワークは色々あるが、基本的に入力や出力時点で行う。
ミニバッチの中でのインタラクションはテスト時等を考えると適用は難しい(Batch Normもドメインシフトの問題とか)。
学習時はサンプル間のインタラクションを使い、テスト時にはインタラクション不要なモジュールが好ましい。

手法

ネットワーク内部で自発的にミニバッチ内のサンプル間関係を学習させる。

backboneネットワークは個々のサンプルの特徴を学習する。このときはサンプル間のインタラクションはない。
この特徴に対してサンプル間の関係を学習するために、バッチ方向についてAttentionを行うBatch Transformer(BatchFormer)モジュールを導入する

BatchFormerは式(1), (2)で表されるPost Norm LNのTransformer Blockからなる。

バッチ方向でAttentionするためCross-Attentionと見れなくもない。
BatchFormerの出力をClassifierに入力しクラス予測を行う。

しかし、BatchFormerはバッチ内の統計情報を使うのでテスト時には適さない。そこでテスト時はBatchFormerモジュールを取り除く。
しかし、BatchFormerモジュールを取り除くと予測ができなくなるので、Classifierをbackboneの出力を入力としても学習する。
図2にはClassifierが2つあるが、2つはパラメータ等を共有した同じネットワークで、テスト時は入力画像 -> backbone -> Classifierという流れで予測を行う。

このようにBatchFormerモジュールは取り外し可能なモジュールで、学習もEnd2Endで行える。


図3はBatchFormerの有無での勾配の伝播を図示したもの。

普通の学習だと N個のサンプル X_nが与えられたとき、それぞれのloss  L_nについての勾配(図の実線)が流れるが、BatchFormerでは \frac{\partial L_i}{\partial X_j}, i \neq jも流れる。

実験・結果

Long-Tailed RecognitionやZero-shotについて実験。
詳細等は省略。






所感

バッチ方向でのAttentionとテスト時の実行方法の提案。
BatchFormer自体は特徴量を入力として特徴量を出力する追加モジュールなので、既存のResNetでもViTでも追加できるのはよい。とりあえず試してみるか的な使い方ができそう。コードとしての分離もやりやすそう。
ただ、ミニバッチサイズの大きさがある程度確保できないと効果は薄そう。
また、ミニバッチ内のサンプル間でのAttentionがどういった影響があるのか少々懐疑的。
StyleGAN2のDiscriminatorで採用していたminibatch standard deviation layerと同じように、Discriminatorに使うとmode collapseを抑えられるかも?(そもそも学習ができなくなりそうだが)