[論文メモ] AugMax: Adversarial Composition of Random Augmentations for Robust Training

arxiv.org
github.com


NeurIPS2021

AugMixをベースとして、データの多様性と難しさの両方を持ったデータ拡張フレームワークAugMaxを提案
またAugMaxだと学習が困難なのでそれを解決するDual-Batch-and-Instance Normalization(DbBIN)を提案。

概要

データセットは限られていて、どうしても実世界の問題をカバーしきれない(Out-of-distribution (OOD))。
この問題を解決するためにデータ拡張があるが、データ拡張には2つの種類がある

多様性

データに多くの変形を加え学習することでモデルに多様性を持たせる。画像タスクだと回転やフリップなどがそう。
その中でもAugMixは複数の変形を加え統合することで非常に多様な画像を生成できる。
AugMixについては以下が詳しい。
hotcocoastudy.hatenablog.jp

難しさ

敵対的なデータ拡張によってモデルを騙すような難しいサンプル(hard case)を生成し、難しいサンプルを増強する。
これはモデルにとって弱いところの修正になり、モデルをロバストにするのに役立つ。
モデルの弱い部分を積極的に修正し堅牢性を上げられる。
しかしそういったサンプルの作成はコストが掛かり、より多く学習時間が必要になる。

AugMaxはこの2つを統合したフレームワーク

手法

AugMax: Augmented Training with Unified Diversity and Hardness

学習データとそのラベルをそれぞれ \boldsymbol{x}, \boldsymbol{y}として、学習のゴールはパラメータ \boldsymbol{\theta}を持つ識別モデル fを堅牢にすること。
 \mathcal{L}(クロスエントロピーとか)を損失関数として損失を最小化する。
f:id:Ninhydrin:20211028095748p:plain

AugMaxはadversarial mixing parameter  m^* \boldsymbol{w}^*を持った関数 g(\cdot)で表現される。
このパラメータはAugMixと同じ役割だが、AugMixが乱数で与えられるのに対してAugMaxは学習で更新される。
元の画像 \boldsymbol{x}_{orig}にAugMaxを適用した画像を \boldsymbol{x}^*=g(\boldsymbol{x}_{orig}; m^*, \boldsymbol{w}^*)となる

 m^* \boldsymbol{w}^*は以下の式3の最適化で決定する。
f:id:Ninhydrin:20211029090922p:plain
意味としてはAugMaxで拡張した画像をモデルに入力したときにlossが大きくなる(正解を間違える)様に m^* \boldsymbol{w}^*を学習する。

 \boldsymbol{w}=\sigma(\boldsymbol{p})(\sigmaはsoftmax)として以下の式4に書き直せる。
f:id:Ninhydrin:20211029091219p:plain

AugMaxによる学習全体は以下の最適化になる。 \mathcal{L}_cはAugMixでも使われていたconsistency lossでデータ拡張後の画像が拡張前と大きく変化しないようにする正則化項。
JS(\cdot)はJensen-Shannon ダイバージェンス \tilde{\boldsymbol{x}}はAugMix後の画像。
f:id:Ninhydrin:20211029091242p:plain
f:id:Ninhydrin:20211029091302p:plain

Adversarial attackとしても式4ぐらいならそこまで計算コストがかからない。

可視化

AugMaxによるデータ拡張後の中間特徴を可視化・比較した結果が図1。
特徴抽出に使ったモデルはResNeXt29でCIFAR-10を普通のデータ拡張(ランダムフリップ、変形)で学習したものでAugMixやAugMax等の拡張は使っていない(未知)。
3つのクラスについて100サンプル。
f:id:Ninhydrin:20211029092306p:plain
図1(a)が元のデータで各クラスタ間に大きな空白地帯ができ分離しているのがわかる。この空白地帯に属するテストデータが来たときにモデルの出力は不確かなものになる。
AugMix(図1(b))はデータを拡散しサンプルの多様性を高めているのがわかる。しかし、決定境界付近の難しいサンプルはほとんど生成できない。
また、PGD attack(図1(c))は難しいサンプルを生成できるが、多様性を失い中央に集まってしまうのが確認できる。
それらに対してAugMix(図1(d))はデータの多様性を保ちつつ、決定境界付近の難しいサンプルも生成できていることが確認できる。

DuBIN: Disentangled Normalization for Heterogeneous Features

AugMaxは非常に多くの範囲をカバーするデータを生成できモデルをロバストにできるが単純にAugMaxを適用しただけではAugMixに比べ僅かな改善しか得られない。
これはAugMaxが非常に高次で不均一なサンプルを生成できるためで、モデルにもそれ相応の能力が必要になるから。
そこでDual Batch-and-Instance Normalization (DuBIN)を提案。

DuBINの全体像は図3を参照。
DuBINはDual Batch Normalization(BuBN)とInstance Normalization(IN)の並列した2つのパスから成る正規化レイヤー。

DuBNは更にクリーンなデータ用とadversarial attack用の2つのBNからなる。
この2つのドメインのデータの統計量は異なり、これらを混ぜたとき統計量が不適切になる。それをを防ぐ役割。
主にadversarial trainingで使われ、2つのドメインのグループレベルの統計量を分離する。
以下の論文で提案されている。
arxiv.org

INはデータ拡張による多様性をインスタンスレベルで分離する。これはデータ拡張による多様なサンプルに対応するため。
入力をチャンネル方向で分割し、それぞれDuBNとINに通して元に戻す。
AugMaxにおけるDuBINとDuBNの効果を比較する。
 \bar{\sigma}^2_c \bar{\sigma}^2_aはそれぞれクリーンなデータとAugMaxしたデータに対するBNの分散の平均。
DuBINを使ったモデルのほうがDuBNに比べ分散が小さいことがわかる。これはINがインスタンスレベルの多様性の統計量を扱ってくれるおかげでBNが多様性について扱う負担が減ったため。

実験

データセットはCIFAR10、CIFAR100、ImageNet、Tiny ImageNet(TIN)。
モデルアーキテクチャはCIFAR系に対してはResNet18、WRN40-2、ResNeXt29、ImageNet系にはResNet18。
またcommon natural corruptions(カメラブラー、ノイズや雨など)に対してのロバスト性を確かめるために、CIFAR10-C, CIFAR100-C, ImageNet-C and Tiny ImageNet-C (TIN-C)を利用。


評価指標は*-Cデータセットに対してのrobustness accuracy(RA)と普通のテストデータに対してのクラス分類の精度standard accuracy(SA)。
f:id:Ninhydrin:20211101095557p:plain
f:id:Ninhydrin:20211101100150p:plain
f:id:Ninhydrin:20211101100206p:plain
f:id:Ninhydrin:20211101100230p:plain

詳しい実験結果やablation studiesは論文を参照。

所感

AugMixのパラメータを学習可能にするというシンプルな手法。
サンプリングと違って同じようなサンプルが生成されそうで多様なデータを生成できなそうな気もするけど図1を見るとカバー範囲は広そう。
adversarial trainingでモデルの弱いところ(弱いサンプル)を学習に探して網羅していくのかな?
学習中のAugMaxのパラメータの変化とか見てみたい。
ターゲットとなるモデルアーキテクチャがResNetベースの少し古いものが多いのが気になる。
画像タスクでも最近ではViTとかTransformer系が出てきているが、それらについて実験していないとは考えにくい。
ただResNet152とかも比較に無いので、ある程度伸びしろの大きいアーキテクチャを意図的に選んでいるのでは?ViT系は逆効果 or 効果なしだった?とか勘ぐってしまう。