[論文読み] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

arxiv.org
github.com

NeurIPS2021
f:id:Ninhydrin:20211116090537p:plain

データが少ないときのadversarial trainingではDiscriminator(D)の過適合がGenerator(G)の学習を妨げる。
少量データでもDとGの競争をより安定させるためのAdaptive Pseudo Augmentation (APA) を提案。

手法

GはDを騙すように学習するが、データが少ないとDがすべてを記憶し間違えなくなる。そうするとGへのフィードバックが無意味になってしまう。
図2はFFHQを普通(70k)と少量データ(7k)で学習した学習曲線。
少量データの時、Dはすぐに過学習し最終的にはreal/fakeをほぼ100%見分けられるようになっている。またFIDも発散していく。
f:id:Ninhydrin:20211116091236p:plain

Adaptive Pseudo Augmentation

GにはDを騙す能力自体はあるのでそれを引き出す。
手法としてはGで生成したサンプルで疑似real集合を作り、それをrealサンプルとして適応的にDに見せることでDが自信過剰になるのを抑制する。

ただ、単純にfakeサンプルをrealサンプルとして見せるとDが劣化するのであくまで適応的に行う。
そのために確率 p  \in [0, 1)を導入する
real画像をサンプリングする際に確率 pで疑似relaの画像から、確率 1-pで本当のreal画像から選択する。

Dの過学習は学習中にダイナミックなので、その度合いに合わせて pもダイナミックに変化してほしい。
そこでADA(Adaptive discriminator augmentation)に倣いDの過学習の度合い \lambdaを導入し3つに拡張する。
f:id:Ninhydrin:20211116093302p:plain
ただし
f:id:Ninhydrin:20211116093322p:plain
 \lambda_rはreal画像の一部からDのreal画像に対する予測、 \lambda_fはfakeに対して。
 \lambda_{rf}はrealとfakeの距離的なもの。これらはすべて、0の時は過学習していない、1のときは過学習といった指標になる。
実験では \lambda_rを採用し、ほかはablation studyで。

この指標 \lambdaの使い方は、まずしきい値 tを設定し( t=0.6)、 p=0で初期化。 tに対して \lambdaが大きいは pを増やし、\lambdaが小さいときは pを減らす。つまり過学習してるときは疑似realを利用してDの自信過剰状態を抑える。
これによりDの過学習を適応的に抑える。

Theoretical Analysis

 pは動的に決定するので分析のためそれを近似した \alphaで置き換える。 p \in [0, 1)なので 0 \leq \alpha < p_{max} < 1。ただし、 p_{max}は学習を通しての pの最大値)
APAのtwo-player minimaxゲームとして価値関数は式(3)のようになる。
f:id:Ninhydrin:20211116095000p:plain

Gを固定したときの最適なDは式(4)になる(証明は省略。論文にある)。
f:id:Ninhydrin:20211116095855p:plain

Dは価値関数V(G,D)の最大化だが、Gは逆の最小化。
Dの目的は条件付き確率 P(Y=y|x)の対数尤度の最大化になる( xがreal画像のとき y=1でfakeのとき y=0)。
Gの目的関数を書き換えると
f:id:Ninhydrin:20211116101144p:plain
になる。APAで学習したときの C(G)のglobal minimumについて考えると p_{g}=p_{data}のときで、そのとき C(G)=-\verb|log|4
 C(G)=-\verb|log|4になるのは[p_g =p_{data}]のときに D^*_G(x)=\frac{1}{2}となるので、これを式6に代入すればよい。
[p_g =p_{data}]のときにglobal minimumになることについての証明は論文を参照。

APAを使っても十分なモデルと時間があれば p_g p_{data}に収束する。

実験・結果

評価指標はFIDとIS。ベースモデルはStyleGAN2(SG2)。
少量データのときは圧倒的にFID、ISが良い。
また学習曲線を見ても過学習が抑えられているのがわかる(図6)。
f:id:Ninhydrin:20211117091124p:plain
f:id:Ninhydrin:20211117091444p:plain

他手法との比較。
FFHQ-5kではAPA単体だとADA単体に劣る。これはデータセット固有の多様性の問題と考えられる(Gも多様な画像を生成できない?)。
f:id:Ninhydrin:20211117091310p:plain
f:id:Ninhydrin:20211117092107p:plain
f:id:Ninhydrin:20211117092122p:plain

学習コストについて。彼らの環境(NVIDIA Tesla V100 x 8)で256x256の画像を使うときベースのStyleGAN2だと(4.740 ± 0.100) sec/kimg、APAは(4.789 ± 0.078) sec/kimgと無視できるレベルの増加。ADAは(5.327 ± 0.116) sec/kimgでAPAより大きい。

ablation study。過学習の度合いを測る \lambdaしきい値 tについて。
f:id:Ninhydrin:20211117092011p:plain


所感

Dに対してrealとしてfake画像を見せる、ただし適応的にというシンプルな手法で少量データでの学習を改善。
簡単過ぎて少し拍子抜けだった。
いくつかのデータセットをサイズも変えて他手法と比較していて、その辺りが重要だなと感じた。
見落としてるかもかもしれないが、疑似real画像のライフタイム、どれくらいで入れ替わるのか(集合サイズ)などが気になる。