Born Again Neural Networks

[1805.04770] Born Again Neural Networks

本来Knowledge distillation (KD) はモデルの圧縮に使われているが、そうではなくKDによって教師モデルより高いパフォーマンスの生徒モデルを作ることを試みたのがこの論文。
自分の記憶だとKD的なものは2014年あたりからちらほら見かけるようになった。
[1312.6184] Do Deep Nets Really Need to be Deep?
深い(Deep)ネットワークではなく層の数的に薄い(Shallow)ネットワークじゃだめなのか?ということをKDを用いて調べたものや
[1412.6550] FitNets: Hints for Thin Deep Nets
KDを用いてよりDeepで細いネットワークを生成するものなどが記憶にある。これらはネットワークの圧縮関係で教師モデルと同じ精度を保ったまま高速・軽量を目指したものだった。それに対してBorn Again Neural Networksは生徒モデルとして教師モデルと同じ、もしくはより複雑なモデルにすることで精度向上を図っている。手法はこうである。出来上がった教師モデルを用いて生徒モデルを学習する。これによってできた生徒モデルを新たな教師モデルとし再度ランダムに初期化された生徒モデルを学習する。これを繰り返す。これだけである。しかしこれで精度の向上が見られたそうな。なお、途中の生徒モデルを用いてアンサンブルもできる。これに関してはなるほどとなる。
f:id:Ninhydrin:20180608002800p:plain
本論文で提案するBorn Again Networks(BANs)について。正直図の通り。一般的にSGDなどを用いてDNNのような関数f(x, θ)ラベル付きデータ(x, y)に対するlossを最小化するパラメータ*θを求める。経験的にこの*θは最適解に近く、損失関数の変更によって改善できる可能性があるのではという考えに基づいている。よくある損失関数の変更はL2正則化とかだ。BANsではKDを用いて情報量豊富な出力を教師ラベルとすることでより良いモデルを得る。


DKの成功は本当に正解ラベル以外の情報のおかげなの?それとも単純に重み付けのおかげなの?これに関しても検証してる。
実験はCifar-10, 100とPenn Tree Bankの3つ。MNISTは必要ないがImageNetかMS COCOは欲しくないか?BANsは時間がかかるから仕方ないかな。
f:id:Ninhydrin:20180612232602p:plain
まずはCifar-10について。モデルが複雑になると効果が薄いように感じられる。まあCifar-10だしという気持ちもある。

f:id:Ninhydrin:20180612232940p:plain
table 2はCifar-100に対してBANを行った結果。Cifar-10のときと同様に複雑になるほど効果が薄い。アンサンブルに関してはかなりの効果が見込める。ここでCWTMとDKPPという二つの項目がある。CWTM(Confidence Weighted by Teacher Max)とDKPP(DK with Permuted Predictions)はDKの効果についての検証を行ったものである。CWTMは教師モデルの最大値のみを用いてラベルを重み付けし、他のラベルについての情報は扱わない。DKPPはCWTMと同様に教師モデルの結果を使うが、最大値以外の出力についてはその重みを入れ替える。結果としてCWTMとDKPPともに精度向上に効果があり、DKPPのほうが効果が高い。しかし、どちらも単純なDKには及ばなかった(つまりDKは効果があったということ)。

実装する際に特別なことも必要なく精度向上が見込めるという意味では使い勝手はありそう。ただ本当に効果が得られるのか、どれくらいでBorn Againすればいいのかなどは難しそう(というか難しかった)。