[論文メモ] Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning
NeurIPS2021
単純にデータを学習して予測をするのではなく、学習データそのもの(全体)を使って予測をするNon-Parametric Transformers (NPTs)を提案。
手法
Non-Parametric Transformers(NPTs)について。
NPTsは予測を改善するためにデータポイント間の関係を暗に学習する。
NPTsの要点は3つ。
1) データセットのデータ全てを入力として予測を行う。テスト時はテストデータも全て。
2) self-attention(SA)を使ってデータポイント間の関係をモデル化する。
3) 学習はBERTのようにstochastic maskingで欠損したデータを再構成する
Datasets as Inputs
NPTsはデータセット を入力として受け取る。
各データは行でスタックされ、
で行、すなわち各データポイントを参照し、
で列、すなわち全てのデータポイントのある要素を参照する。
また、データポイントの最後にはラベルデータが入っている(の要素。なので一般的な入力に使われる特徴量自体は次元)。
が各特徴量の値で画像だとピクセル値とかになる。
これを使って学習を行うが、学習はmasked language modeling。
を用意し、マスクした入力からマスク部分を予測するタスクで学習する。つまりを予測する。
NPT Architecture
アーキテクチャは図2を参照
図2(a)で表される入力をデータポイントごとに線形写像してを得る(図2(b))。
reshapして次元の行列にし、データポイントのシーケンス(次元特徴が個のシーケンス)とみなしてmulti-head self-attention(MHSA)を適用。
そして再度reshape(を)し今度は要素方向のシーケンス(次元特徴が個のシーケンス)とみなしてMHSAを適用。
最後に、線形変換をして出力を得る。
MHSAでデータポイント間の関係を学習できる。
またデータポイントのスタックの順番は予測に影響しないことに注意。
この1つ目のMHSA(データポイント間のSA)をAttention Between Datapoints (ABD)、
2つ目のMHSA(特徴間のSA)をAttention Between Attributes (ABA)と呼ぶ。
実験
詳細は論文参照。
一般的な教師あり学習の結果
NPTsがデータポイント間でのインタラクションをしているかの実験結果。
予測したいデータの欠損なしのデータとラベル部分をマスクしたデータを入力してラベルを予測する。
予測したいラベルは入力されているのでそこを参照すれば100%答えられるはず、といった問題設定。
NPTsがテストデータの予測のために他のデータポイントを利用しているかの実験。
予測したいデータポイント以外のデータポイントについて、属性を他のデータポイントとをランダムに入れ替え -> 予測 を何度か繰り返して精度を測る。
他のデータポイントに依存していば元の状態で測った精度より低下するはず。
属性レベルではなく、データポイントレベルでデータを参照していることを確かめるため?