Learning Confidence for Out-of-Distribution Detection in Neural Networks

クラス分類などで実際にニューラルネットを実際に使うとわかるが、ある入力が特定のクラスである確率が9割を超えていても間違えている事がある。
Adversarial Attackなどを考えればよく分かる。
この論文では、入力に対してクラスラベルの推定と共にその予測分布に対しての確信度も同時に推定する手法を提案している。
入力に対して、クラスAである確率が90%という出力が得られたとしても、確信度が低いものは間違っている可能性がある。
要するにグレーゾーンというわけだ。実システムで使う場合は、確信度の低いものは人が確認するといった手続きになるだろう。でも、そのグレーゾーンがわかることは非常に役立つ。
提案手法は実装もしやすい。

p, c = f(x, \theta)
p_i, c \in [0, 1], \sum^{M}_{i=1}p_i =1

入力 xに対しての出力はクラス確率 pと確信度cを得る。 cはsigmoidを通している。
そして、クラス確率 p_iと入力 xの実際のラベル分布y_iを確信度でブレンドした

p_{i}'=cp + (1 - c)y_i

をネットワークの出力として、この p_{i}' y_iとのクロスエントロピーを誤差として学習を行う。

 \mathcal{L}_t = - \sum^{M}_{i=1}\log(p_{i}')y_i

しかし、このlossだけだとネットワークはc=0とすることで常に正解ラベルを参照でき、lossを0にできてしまう。そこで-\log(c)のペナルティを与える。

 \mathcal{L}_c = -\log(c)

結果として全体のloss関数は

 \mathcal{L} = \mathcal{L}_t + \lambda\mathcal{L}_c

となる。\lambdaは確信度のlossを調整するハイパーパラメータで、確信度のlossの大きさによって学習中に逐一調整を行う。

クラス確率の予測が正確になればなるほど正解ラベルに頼る必要がなくなり確信度が1に張り付いてしまう。そうならないように\lambdaを小さくすることで正解ラベルを参照するコストを小さくする。逆に予測がうまくできない初期状態などは、正解ラベルを参照しすぎ、学習をしなくなるので\lambdaを大きくし、正解ラベル参照コストを大きくする。

つまり、予測に自信が持てないときはコスト( L_c)を払うことで、正解ラベルを使うことができるわけだ。ただ、正解ラベルを使うのにはコストが必要で、これが予測を間違えるよりも安くなければならない。予測を間違えるより、正解ラベルを使うコストほうが高ければ当然正解ラベルを使わなくなり、普通のクロスエントロピー誤差を使った学習と同じになってしまう。


chainerで書くとしたら大体以下のようになるだろう。

# 正解ラベルtはone-hot エンコーディング
y = F.softmax(h_class)
confidence = F.sigmoid(h_confidence)
c = F.broadcast_to(confidence.reshape(batch_size, class_num), y.shape)
y_prime = y * c + t * (1 - c)
loss_classify = - F.mean(F.log(F.select_item(y_prime, t.argmax(1))))
loss_confidence = - F.mean(F.log(confidence)) * lm

h_class、h_confidenceはニューラルネットのそれぞれの出力、tはone-hot形式の正解ラベル、batch_sizeはバッチサイズ、class_numはクラスの数、lmは係数ラムダを表す。
注意点としては2つ。確信度自体の出力はスカラーなのでクラス予測と同じ形のベクトルに変換することと、y_primeはすでに分布になっているのでF.softmax_cross_entropyを使うのではなく、単純にクロスエントロピーをとること。なお、正解ラベルは出力した分布とブレンドするのでone-hotエンコーディングにしておくor 変換する必要がある。

係数ラムダの調整はchainerのExtensionsでlossを見て調整するか、ここで直接調整すればよいだろう。

lm += 0.01 if loss_confidence > 0.3 else -0.01

みたいな感じで。
論文では、確信度誤差を0.3に保つように\lambdaを調整するとあった。e^{-0.3} \fallingdotseq 0.741なので学習データの中に難しいサンプルが25%程度含まれているという事前知識がある気がする。この難しいサンプルの割合が事前にわかっているなら、それに合わせて\lambdaの調整も変えたほうがいいだろう。

実際に使用してみるとなかなか良い。自分がやった実験ではラムダは0.55前後で安定した。
また、面白いことに学習後のモデルのvalidation accuracyが僅かだが上昇した(正解ラベルを混ぜたわけではない)。
予想だが、確信度の低いものは学習に悪影響を与えるサンプルで、それを学習に使用しなかった(確信度を下げ、正解ラベルのみ使う)ことで、より良い特徴を捕らえられたのでは。データセットクリーニングに近いのかもしれない。

あいかわらずひどい日本語だ...