[論文メモ] Prioritized Training on Points that are learnable, Worth Learning, and Not Yet Learnt

arxiv.org
ICML2022
間違っているかもしれないので注意。

巨大なデータセットに関する学習の高速化方法を提案

最近はWebで集めた巨大なデータセットで学習するモデルが増えている(GPT-3やCLIPなど)。
データセットが大きい分、学習にも一ヶ月やそれ以上の時間がかかる。
学習の高速化のためデータセットをフィルタリングする方法も提案されているが、Webで集めたデータはノイズだらけでラベルも正しいか怪しかったり似たサンプルが多くなったりする。

こういった巨大なデータセットで学習する際にはサンプルを選択し学習の効率化を図る。
簡単なサンプルから学習し収束を早めるカリキュラム学習なども提案されていてるが、すでに学習し終わった(冗長)サンプルをスキップできない。
そういった冗長なサンプルを選択しないためlossの大きい難しいサンプルを選択する方法もあるが、調べてみるとlossの大きいサンプルはたいていラベルが間違っている等ラベルに問題があるためあまり有効ではない。
そしてわずかに外れ値のようなデータも含まれる。外れ値なデータはテストデータに出現する可能性が低いので学ぶ価値はない(ホンマか?)

これらを解決し、有効なサンプルを選択したいというお気持ち


手法

reducible holdout loss selection (RHO-LOSS)を提案。

学習を行うとき、パラメータ \thetaを持つモデル p(y|x; \theta)をデータ \mathcal{D} = \{(x_i, y_i)\}^n_{i=1}を使ってSGDで学習する。
各学習ステップ tではバッチサイズ n_bのミニバッチ b_tを使って学習をする。

オンラインでのバッチ選択ではより大きなバッチサイズ n_B > n_bのミニバッチ B_tからサンプルを選択し n_bサイズのバッチを生成する。

このような既存のオンラインでは学習データセット全体に対するlossを下げるサンプルを選択するが、提案手法ではホールドアウトセットに対してのlossを小さくする。
これを毎回測定するのは当然コストが高いので、実際にそれらのサンプルを学習せずにホールドアウトセットに対してのlossを最も減らせるサンプルを探す。

簡潔化のために各学習ステップtにおけるある一つのサンプル  (x, y) \in B_tについて説明する。
現在までの学習してきたデータを \mathcal{D}_tとして、現在のモデルを p(y'|x'; \mathcal{D}_t)とする。
そして学習データと同じ分布を持つホールドアウトセットを \mathcal{D}_{ho}=\{(x^{ho}_i, y^{ho}_i)\}^{n^{ho}}_{i=1} とする(簡略化のために \textbf{x}^{ho}, \textbf{y}^{ho}とする)。

目的は大きめに選んだバッチの中からホールドアウトセットのlossを最小にするサンプルを選択すること。つまり以下の式のサンプルを見つけること。


Deriving a tractable selection function

モデルのパラメータを事前分布 p(\theta)を持つ確率変数として扱い、学修済みデータ \mathcal{D}_tを使って事後確率 p(\theta|\mathcal{D}_t)を推論する。
モデルは p(y|x, \mathcal{D}_t) = \int_{\theta}p(y|x, \theta)p(\theta|\mathcal{D_t})d\thetaとなる。
ベイズの定理と条件独立から以下の様に式変形できる(符号は反転した)。

 L[\cdot]はクロスエントロピー loss。
 \mathcal{D}_t, \mathcal{D}_{ho}を条件としたニューラルネットに対してベイズ推定は困難なのでSGDで代用(近似1)。
一項目の L[y|x, \mathcal{D}_t]は \mathcal{D}_tで学習した現在のモデルの (x, y)に対してのloss。
二項目の L[y|x; \mathcal{D}_{ho}, \mathcal{D}_{t}]は \mathcal{D}_t \mathcal{D}_{ho}で学習したモデルによるloss。

式(2)は扱い易いが学習の度に更新されるので計算コストが少々高い。そこで二項目をホールドアウトセットのみで学習したモデルで置き換える(近似2)。
これができると学習開始前に各サンプルについて一度だけ計算しておくだけで済む。

 (x, y)における L[y|x; \mathcal{D}_{ho}]をここではirreducible holdout loss(IL)と呼ぶ(ホールドアウトセット学習後の残ったlossだから)。
式(2)のlossをreducible holdout lossと呼ぶ。

ただ、はじめにILを計算するためのモデルをホールドアウトセットを使って学習する必要がある。
このコストを下げるため、IL用には精度の低い小さいモデルで学習を行う(近似3)。

当初の目的であるバッチ内からサンプルを選択するためのホールドアウトセットのloss(式(1))は以下の様に計算できる。

Understanding reducible loss

RHO-LOSSを使うことで、なぜ冗長・ノイズ・タスクあまり関係ないサンプルを見つけられるのか。

冗長サンプル

ここでの冗長サンプルはすでに学習済みでlossを減らせないサンプルのこと。
冗長サンプルは学習lossが小さいので式(3)も小さくなり選択されなくなる。

ノイズサンプル

曖昧・間違ったラベルでlossが大きいサンプルのこと。
こういったサンプルはホールドアウトセットで学習したモデルでもうまく推測できず、結果的にILが大きくなる。
学習lossが大きくてもILが相殺するため結果的に式(3)も小さくなり選択されなくなる。

タスクあまり関係ないサンプル

外れ値的なサンプルのこと。既存のlossベースなサンプル選択だとこのサンプルはlossが大きくなるため選択されやすい。
このサンプルは学習したいデータの分布の密度の小さいところのサンプルで、このサンプルを学習してもホールドアウトセットのlossを下げる影響は小さく、外れ値でないサンプルの学習を優先すべき。
このサンプルもノイズサンプル同様にILが大きくなるため選択されにくい。

実験・結果

データセットはQMNISTやCIFAR、Clothing-1Mなど7つ。
学習データセットをホールドアウトセットに分けて利用。

まず近似1~3について、これが問題ないかの確認。データセットはQMNISTで10%のノイズとダブリを加えることで現実のスクレイピングデータに近くした。
それでも式(2)は扱いにくいのでコストが高いがより正確な近似を基準にする。

5つのニューラルネットでアンサンブルを行い、それらをtステップ毎に収束するように学習を行う(近似0)。

提案手法ではベイズ部分をSGDで近似したが、それに対して近似0のがコストは高いが元の式に近い。
まずこの近似0について、アンサンブルではなく1つのモデルにしたときと近似0とのスピアマンの相関係数を計算すると0.75で似たサンプルが選択されている(表1のNon-Bayesian)。
次に近似2( L[y|x; \mathcal{D}_{ho}, \mathcal{D}_{t}]を L[y|x; \mathcal{D}_{ho}]に)について。これによって選ばれるサンプルも相関係数が0.63と高い(表1のNot updating IL model)
近似3(ILモデルを本来のモデルより小さいモデルで置き換え)については表1 Small IL modelの0.51とこれも高い。

RHO-LOSSの特性についての調査。
結果は図2で元の学習に対してどれくらい速度が向上したか(同じ精度になるまでの時間)。
ターゲットモデルはResNet-18。
RHO-LOSSは本来ノイズの多い巨大なデータセットに効果的だが、ここではクリーンで中規模なデータセットにした。
1行目(Default)はILモデルもResNet-18としてRHO-LOSSを適用した結果。
2行目はILモデルをより小さいモデルにした
3行目はホールドアウトセットなしでデータセットを半分にし、それぞれで作ったILモデルをILモデルの学習してない半分で学習したもの。
4行目はiLモデルはそのままに、ターゲットモデルを変更したもの。GoogleNetやMobileNet-v2等、7つのアーキテクチャで検証。
5行目は同じILモデルでターゲットモデルの学習時のバッチサイズや学習率等のハイパラについてグリッドサーチした結果。


ノイズサンプル等が実際どれくらい選択されているのかを確認するために10%のデータを破損させてそのサンプルを追跡(図3)。
ノイズサンプルや関係のないサンプルは勾配ベースやlossベースでは選択されやすい。
冗長サンプルについては勾配ベース等でも選択されない。


学習速度の向上について比較。
各データセットについてターゲットの精度に到達するまでのepoch数を計測(表2)。
提案手法が速い。

所感

結構学習が高速化できそうなので良さそうだが、ホールドアウトセットによる事前学習が少々面倒そう。
はじめのお試しの学習でモデルの精度が足りなそうとかだったらそれをILモデルとして使ってもいいのかも?
回帰問題等ではどうなのか気になる。
式的には比較対象はlossなのでクラス分類でも回帰問題でも問題なさそうではあるが。