[論文メモ] Class Adaptive Network Calibration

https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_Class_Adaptive_Network_Calibration_CVPR_2023_paper.pdf

CVPR2023

クラス不均衡なデータを効率的に学習する手法を提案。

クラスの分布が不均衡・裾が長い場合にDNNは自身過剰な予測を出すことがある。これを調整することをここではキャリブレーションと呼ぶ。
このキャリブレーション方法としては主に2種類ある。

1つ目は事後処理でキャリブレーションする方法で、検証セットを使ってlogitを調整するパラメータを設定する。低コストで効果的だが学習したモデルや検証セットの影響を受けやすい。

2つ目は学習中に同時にキャリブレーションを行う方法。メインの目的関数に加えてキャリブレーションの目的関数も追加するというもの。Label SmoothingやFocal lossなんかがこれに該当する。
これらはlogit距離を0に近づけるペナルティ項として定式化できる() nOTE

ただこれにも問題があり、
1) 各クラスの重みが同じで難しいクラスに対応できない
2) 重みの調整は適応的ではなく事前に行われるため、最適な結果が得られない

これらをなんとかするための拡張ラグランジュ乗数アルゴリズムに基づいたlabel smoothing手法、CALS-ALMを提案。

手法

 Nをサンプル数、サンプルを \mathcal{x}、ラベルを yとするとデータセット
 \mathcal{D} = \{(\boldsymbol{x}^{(i)}, y^{(i)})\}^N_{i=1}

になる。なおクラス数は K個。

パラメータ \thetaを持つDNNを F_{\theta}とするとlogitは \mathcal{l} = F_{\theta}(\mathcal{x})
softmaxで確率にすると \mathcal{s} = \text{softmax}(\mathcal{l}) = \frac{\exp \boldsymbol{l}}{\sum \exp \boldsymbol{l}}

クロスエントロピーlossは
 \mathcal{L}_{\text{CE}} (\mathcal{x}, y) = - \Sigma^{K}_{k=1}y_k \log s_k

となる。なお、基本的に \boldsymbol{y}はone-hotエンコーディングなことに注意。

既存のMargin-based Label Smoothing (MbLS)

Margin-basedな手法のlossは

 \min_{\theta} \sum^N_{i=1}\mathcal{L}_{\text{CE}}(\boldsymbol{x}^{(i)}, y^{(i)}) + \lambda \sum^N_{i=1} \sum^K_{j=1} \max\{  0, \text{max}_k \{  l^{(i)}_k  \} - l^{(i)}_j - m  \}
の形式。なお \lambda \in \mathbb{R}_+
CE lossに追加しマージンの制約を設けた感じ。マージンの制約は各サンプルについてのlogitの各値と最大logitとのマージンがm以下になるようにする。各クラスのlogit同士にあまり大きなマージンができないようにするという感じ?

これも非常に強力なキャリブレーションだが、すべてのサンプル・クラスに対して均一のペナルティを与えることになり最適ではない。
最適にするなら以下の式のように \lambdaをサンプル・クラスについて分ける必要がある

 \min_{\theta} \sum^N_{i=1}\mathcal{L}_{\text{CE}}(\boldsymbol{x}^{(i)}, y^{(i)}) + \sum^N_{i=1} \sum^K_{j=1} \Lambda_{ij} \max\{  0, \text{max}_k \{  l^{(i)}_k  \} - l^{(i)}_j - m  \}

ただし\Lambda \in \mathbb{R}^{N \times K}_+
最適化の観点からは \Lambda^*ラグランジュ定数で、最適なパラメータ \theta ^*とのペア、 (\theta^*, \Lambda^*)が存在する。

当然だがImageNetのようなサンプルもクラスも巨大なデータセットや、ピクセルにクラスを割り当てるセグメンテーション問題を考えればこの最適化は現実的ではない。

そこでサンプルレベルのペナルティを緩和し、クラスレベルとする。

 \min_{\theta} \sum^N_{i=1}\mathcal{L}_{\text{CE}}(\boldsymbol{x}^{(i)}, y^{(i)}) + \sum^N_{i=1} \sum^K_{j=1} \lambda_{j} \max\{  0, \text{max}_k \{  l^{(i)}_k  \} - l^{(i)}_j - m  \}

ただし、 (\lambda_j)_{1 \leq j \leq K} \in \mathbb{R}^K_+
それでもImageNetでは K=1000で少々複雑。

Class Adaptive Network Calibration

 Kが大きいときでも適応できるためにAugmented Lagrangian Multiplier(ALM)法 (拡張ラグランジュ乗数法)を利用する。

一般的なALM法はラグランジュ関数にペナルティ項を追加したもので、最適化の条件を満たすまで最適化とラグランジュ乗数とペナルティ項の係数を更新を繰り返すアルゴリズム。ここでは詳しくは省略。

j回目のラグランジュ関数は以下の式。

 \min_x \mathcal{L}^{(j)}(x) = f(x) + \sum^n_i=1P(h_i(x), \rho^{(j)}_i, \lambda^{(j)}_i)

 h(x)は制約。
ペナルティ項と制約を合わせたペナルティ関数をここでは Pとし以下を満たす。

 \forall z \in \mathbb{R}, P'(z, \rho, \lambda) = \frac{\partial}{\partial z} P(z, \rho, \lambda) \geq 0
 P'(0, \rho, \lambda) = \lambda

 Pと各パラメータの関係を図にしたのが図2

ALM法は凸最適化で非凸のときは保証が無いが、非凸でも効果的なのがよく知られている。にも関わらずDNNの文脈ではほとんど検討されない。

そこでALM法をキャリブレーションに利用する。

 \min_{\theta} \sum^N_{i=1}\mathcal{L}_{\text{CE}} (\boldsymbol{x}^{(i)}, y^{(i)}) + \sum^K_{k=1}P(d^{(i)}_k - m, \rho_k, \lambda_k)

なお d^{(i)}_k = \max \{ \boldsymbol{l}^{(i)} \} - l^{(i)}_k \in \mathbb{R_+}


ペナルティ項を足し合わせるのではなく平均化し、マージン m > 0で制約を正規化し、最終的なloss関数を得る。
 \min_{\theta} \sum^N_{i=1}\mathcal{L}_{\text{CE}} (\boldsymbol{x}^{(i)}, y^{(i)}) + \frac{1}{K}\sum^K_{k=1}P(\frac{d^{(i)}_k}{m} - 1, \rho_k, \lambda_k)

ただ、すぐに過適合してしまうので各エポック毎検証データを使ってラグランジュ乗数を更新する。

 \lambda^{(j+1)}_k = \frac{1}{|\mathcal{D}_{\text{val}}|} \sum_{(\boldsymbol{x}, y) \in \mathcal{D}_{\text{val}}} P'(\frac{d_k}{m} -1 , \rho^{(j)}_j, \lambda^{(j)}_k)

 \rhoについては制約が満たされておらず、かつ制約項の値が現象していない場合に \gamma倍する。

既存研究と実際に実験した結果から、 Pとして以下のPHR関数を利用する。

{
\begin{eqnarray}
\text{PHR}(z, \rho, \lambda) = \left\{ \begin{array}{ll} 
  \lambda z  + \frac{1}{2} \rho z^2 & (\lambda + \rho z \geq 0)  \\
  - \frac{\lambda^2}{2 \rho} & \text{otherwise} 
\end{array}\right.
\end{eqnarray}
}

アルゴリズムアルゴリズム2を参照。

実験・結果

データセットはTiny-ImageNetとImageNetとImageNetLT。
ImageNet-LTは裾が長い分布をしている。
またセグメンテーションタスクとしてPASCAL VOC2012、NLPタスクとして 20 Newsgroups。
評価指標としてよく使われているExpected Calibration Error (ECE)を採用。

 ECE = \sum^M_{m=1}\frac{|B_m|}{N} |A_m - C_m|

 Mはbinの数、 Nはテストサンプル数(ここでは15に固定)、 B_m m番目のbinの予測確信度、 A_m m番目のサンプルのacuuracy、 C_mm番目のサンプルの平均確信度。
またAdaptive ECEも。

先程提案した手法をCALS-ALM、 \lambdaの更新を以下のヒューリスティックにしたものをCALS-HRとする。

画像分類の結果。提案手法がキャリブレーション指標では優れている。

ablationとして学習中の \lambdaの変化とペナルティ関数Pとマージン mの影響をグラフ化(図3)。

はじめは精度を上げるためにECEも \lambdaも上昇するが、途中で調整が始まりECE、 \lambdaともに減少し始める。
ペナルティ関数とマージンについてはPHRが最もよく、 m \approx 10が良さそう。

セグメンテーションとNLPの結果は表2、3を参照。


ただ、この手法の制限として学習データセットと同じ分布の検証データセットが必要なこと。
検証データがi.i.dのときについては今後検証予定だそう。

所感

DNNに拡張ラグランジュ乗数法を適用したものだが、シンプルで悪くなさそう。
現状ではvalidationのデータセットの分布が学習データセット同じである必要があるのがどれくらい厳しい制約なのか気になる。
また、キャリブレーションができたとして実際どういう結果になっているのかも気になるところ。
Githubにpytorchのコードが公開されているのですぐに試すこともできる。
github.com