Learn to Pay Attention

https://arxiv.org/abs/1804.02391
ICLR2018。 CNNにおけるチャンネル方向ではなく空間方向へのAttentionモデルの提案。タイトルがいいよね。
図を見たほうが早い。
f:id:Ninhydrin:20180516223808j:plain
決められたそれぞれの特徴量マップに対してAttentionを行う。画像のどこに注目するかという感じかな。
これによって得られた特徴量を入力としてクラス識別を行う。
まずレイヤー{ \displaystyle s \in \{1, ... S\}} から特徴集合{\mathcal{L}^{s} = \{\mathcal{l}^{s}_{1}, \mathcal{l}^{s}_{2},...,\mathcal{l}^{s}_{n} \}}を抽出する。この\mathcal{l}^{s}_{i}は空間方向の特徴。ネットワークの出力のクラス識別の前の部分をグローバル特徴mathcal{g}とする。図で言うとFC-1の出力部分。\mathcal{l}^{s}_{i}mathcal{g}と同じ次元に線形写像したベクトルの集合\hat{\mathcal{L}^{s}}\mathcal{g}をcompatibility function \mathcal{C}に入力。\mathcal{C}(\hat{\mathcal{L}^{s}}, \mathcal{g}) = \{c^{s}_{1}, c^{s}_{2}, ... , c^{s}_{n}\}。compatibility scoreをsoftmax関数で正規化

a^{s}_{i} = \frac{exp(c^{s}_{i})}{\Sigma^{n}_{j}(c^{s}_{j})}, i \in \{1, ... ,n\}
正規化されたcompatibility score\mathcal{A}^{s} = \{a^{s}_{1}, a^{s}_{2}, ..., a^{s}_{n}\}を用いてレイヤー\mathcal{s}ごとに一つの特徴ベクトル\mathcal{g}^{s}_{a} = \Sigma^{n}_{i=1}{a^{s}_{i}\cdot \mathcal{l}^{s}_{i}}を得る。この\mathcal{g}^{s}_{a}を用いてクラス識別を行うが\mathcal{S} > 1のときは\mathcal{g}^{s}_{a}が複数になる。そのときはつなげて一つのベクトルとして識別するもよし、それぞれの\mathcal{g}^{s}_{a}を用いてそれぞれ識別を行い、その平均を使うもよし。

compatibility関数Cについては線形モデル
 c^s_i = \langle \mathcal{u}, \mathcal{l}^s_i + \mathcal{g}\rangle   i \in \{1, ... ,n\}
もしくはドット積
 c^s_i =  \langle \mathcal{l}^s_i, \mathcal{g}\rangle    i \in \{1, ... ,n\}
を用いる。

実験結果
f:id:Ninhydrin:20180516231842p:plain

attXのXの部分に最後からいくつのレイヤーをattentionするか、dpはglobal featureをドット積で、pcは線形モデル。concatはベクトルをすべて結合してFCに入力、indepはそれぞれでクラス予測を行いその平均を結果とする。基本的に比較はVGG-att2-concat-pcで、元のVGGに対して改善が見られるようだ。Table2のRN(ResNet)は他の論文からのものでImageNetでpre-trainしているらしい。それに対してVGG-att2-concat-pcはCifar-100でpre-trainしている。
このあとの実験はconcat-pcのみ
次はAdversarial Attack。今後はこういうものも検証したほうがいいのかな。
f:id:Ninhydrin:20180530004652p:plain
VGGより5%も低いとのことだがここはしっかり調べてない。ノイズが知覚できるレベル(Fig5の5列目とか)になるとVgg-att2-concat-pcのほうが大きくなる。
f:id:Ninhydrin:20180530231137p:plain
f:id:Ninhydrin:20180530232029p:plain
Table 7はクロスドメイン。Fine turningってことでいいのかな?Cifar 10/ 100のそれぞれで学習したモデルを特徴量抽出器とし、その出力をSVMに入力し汎化性能を測っている。どのデータセットに対しても平均6%ほどの改善。Cifarは32x32と解像度が低いが600x600サイズのデータセットでも向上が見られる。STLデータセットはCifar10とクラスが同じらしい。

f:id:Ninhydrin:20180530232108p:plain
最後は弱教師あり学習によるセグメンテーション。他のAttentionモデルよりも良い結果。saliency-basedに対しても車だけは良い結果となった。

End to EndでAttentinoを学習できるCNNというのは面白かった。grad-camなどCNNがどこに注目しているかを調べる手法があったがそれと似たようなことができるのも良い。実装も簡単だしいろいろと応用ができそう。GANなどの生成系にもつかえないかな。それと付録にglobal featureをクエリとすることができるとあった。
f:id:Ninhydrin:20180530233448p:plain

global feature ってクエリじゃね?ってことでまず、ある物体が顕著な画像をクエリ画像、その物体が含まれているが他の物体も多い複雑な画像をターゲット画像と呼ぶ。
Fig 10の1列目がクエリ画像、2列目がターゲット画像。ターゲット画像から得られた10層目の特徴マップに対してターゲット画像で得られたglobal featureを適用したのが3列目と6列目。この二つはcompatibility関数がドット積かパラメータを用いた関数かの違い。4列目と7列目はターゲット画像で得られた特徴マップに対してクエリ画像で得られたglobal featureを適用したもの。5列目と8列目は元のglobal featureを用いた特徴マップとの変化の割合を表している。ドット積を用いたほうがターゲットとなる物体の付近での変化が激しい。それに対して線形モデルの方は変化がほとんどない。これはパラメータ\mathcal{u}がそのあたりの特徴を獲得しているかららしい。そのためテストデータのような未知の画像に対しても適用できる。

読んだ感じこんなだった。間違いなどあればガンガン指摘してくださるとありがたいです。