[論文メモ] Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

arxiv.org

NeurIPS2021

単純にデータを学習して予測をするのではなく、学習データそのもの(全体)を使って予測をするNon-Parametric Transformers (NPTs)を提案。

手法

Non-Parametric Transformers(NPTs)について。
NPTsは予測を改善するためにデータポイント間の関係を暗に学習する。
NPTsの要点は3つ。
1) データセットのデータ全てを入力として予測を行う。テスト時はテストデータも全て。
2) self-attention(SA)を使ってデータポイント間の関係をモデル化する。
3) 学習はBERTのようにstochastic maskingで欠損したデータを再構成する

Datasets as Inputs

NPTsはデータセット  \textbf{X} \in \mathbb{R}^{n \times d}を入力として受け取る。
各データは行でスタックされ、
 \{\textbf{X}_{i,:} \in \mathbb{R}^d|i \in 1...n \}で行、すなわち各データポイントを参照し、
 \{\textbf{X}_{:,j} \in \mathbb{R}^n|j \in 1...d \}で列、すなわち全てのデータポイントのある要素を参照する。
また、データポイントの最後にはラベルデータが入っている( \textbf{X}_{:,d}の要素。なので一般的な入力に使われる特徴量自体は d - 1次元)。
\{X_{:,j}|j \neq d \}が各特徴量の値で画像だとピクセル値とかになる。
これを使って学習を行うが、学習はmasked language modeling。
 \textbf{M} \in \mathbb{R}^{n \times d}を用意し、マスクした入力 \textbf{X}^M=\{\textbf{X}_{i,j} | \textbf{M}_{i,j}=1\}からマスク部分 \textbf{X}^O=\{\textbf{X}_{i,j} | \textbf{M}_{i,j}=0\}を予測するタスクで学習する。つまり p(\textbf{X}^M|\textbf{X}^O)を予測する。

NPT Architecture

アーキテクチャは図2を参照
f:id:Ninhydrin:20211102093455p:plain
図2(a)で表される入力をデータポイントごとに線形写像して \textbf{H}^{(0)} \in \mathbb{R}^{n \times d \times e}を得る(図2(b))。
reshapして n \times d\cdot e次元の行列にし、データポイントのシーケンス( d次元特徴が n個のシーケンス)とみなしてmulti-head self-attention(MHSA)を適用。
そして再度reshape( n \times d\cdot e n \times d \times e)し今度は要素方向のシーケンス( e次元特徴が d個のシーケンス)とみなしてMHSAを適用。
最後に、線形変換をして出力 \hat{\textbf{X}} \in \mathbb{R}^{n \times d}を得る。
MHSAでデータポイント間の関係を学習できる。
またデータポイントのスタックの順番は予測に影響しないことに注意。

この1つ目のMHSA(データポイント間のSA)をAttention Between Datapoints (ABD)、
2つ目のMHSA(特徴間のSA)をAttention Between Attributes (ABA)と呼ぶ。

Masking and Optimization

マスクは特徴量部分とラベル部分、別々の確率で行う。
特徴量部分( \textbf{X}_{i,j}, j \neq d)は p_{\verb|feature|}、ラベル部分( \textbf{X}_{:,d})は p_{\verb|target|}

目的関数はラベル部分と特徴量部分に対してのNLLでそれぞれ \mathcal{L}^{\verb|Targets|} \mathcal{L}^{\verb|Features|}
NPT全体のlossは
 \mathcal{L}^{\verb|NPT|} = (1 - \lambda)\mathcal{L}^{\verb|Targets|} + \lambda \mathcal{L}^{\verb|Features|}

確率的なマスクを使った学習はマスクされていない部分からマスク部分を予測する。
なので一般的なモデルのようにモデルパラメータ \mathbb{\theta}で入力から出力へのマッピングを学習するのではなく、学習データをどう参考にするのかについてをパラメータ全てを使って学習することになる。

ただし、データセットが大きいとGPUメモリとかが厳しい。そこでデータセットをランダムサンプリングしてミニバッチ化する。

実験

詳細は論文参照。

一般的な教師あり学習の結果
f:id:Ninhydrin:20211104091441p:plain


NPTsがデータポイント間でのインタラクションをしているかの実験結果。
予測したいデータの欠損なしのデータとラベル部分をマスクしたデータを入力してラベルを予測する。
予測したいラベルは入力されているのでそこを参照すれば100%答えられるはず、といった問題設定。
f:id:Ninhydrin:20211104092451p:plain

NPTsがテストデータの予測のために他のデータポイントを利用しているかの実験。
予測したいデータポイント以外のデータポイントについて、属性を他のデータポイントとをランダムに入れ替え -> 予測 を何度か繰り返して精度を測る。
他のデータポイントに依存していば元の状態で測った精度より低下するはず。
属性レベルではなく、データポイントレベルでデータを参照していることを確かめるため?
f:id:Ninhydrin:20211104093907p:plain

所感

データセット全てを参照して予測を行うというのは面白い発想。Attentionを使えば確かに問題はなさそう。
ただ、データセットが大きすぎると厳しくて、few-shotであったり、データ数が少ないときに有効な手段だと思われ。
画像系のデータセットサイズではちょっと無理そう(ImageNet-1Kとか)。