Born Again Neural Networks

[1805.04770] Born Again Neural Networks

本来Knowledge distillation (KD) はモデルの圧縮に使われているが、そうではなくKDによって教師モデルより高いパフォーマンスの生徒モデルを作ることを試みたのがこの論文。
自分の記憶だとKD的なものは2014年あたりからちらほら見かけるようになった。
[1312.6184] Do Deep Nets Really Need to be Deep?
深い(Deep)ネットワークではなく層の数的に薄い(Shallow)ネットワークじゃだめなのか?ということをKDを用いて調べたものや
[1412.6550] FitNets: Hints for Thin Deep Nets
KDを用いてよりDeepで細いネットワークを生成するものなどが記憶にある。これらはネットワークの圧縮関係で教師モデルと同じ精度を保ったまま高速・軽量を目指したものだった。それに対してBorn Again Neural Networksは生徒モデルとして教師モデルと同じ、もしくはより複雑なモデルにすることで精度向上を図っている。手法はこうである。出来上がった教師モデルを用いて生徒モデルを学習する。これによってできた生徒モデルを新たな教師モデルとし再度ランダムに初期化された生徒モデルを学習する。これを繰り返す。これだけである。しかしこれで精度の向上が見られたそうな。なお、途中の生徒モデルを用いてアンサンブルもできる。これに関してはなるほどとなる。
f:id:Ninhydrin:20180608002800p:plain
本論文で提案するBorn Again Networks(BANs)について。正直図の通り。一般的にSGDなどを用いてDNNのような関数f(x, θ)ラベル付きデータ(x, y)に対するlossを最小化するパラメータ*θを求める。経験的にこの*θは最適解に近く、損失関数の変更によって改善できる可能性があるのではという考えに基づいている。よくある損失関数の変更はL2正則化とかだ。BANsではKDを用いて情報量豊富な出力を教師ラベルとすることでより良いモデルを得る。


DKの成功は本当に正解ラベル以外の情報のおかげなの?それとも単純に重み付けのおかげなの?これに関しても検証してる。
実験はCifar-10, 100とPenn Tree Bankの3つ。MNISTは必要ないがImageNetかMS COCOは欲しくないか?BANsは時間がかかるから仕方ないかな。
f:id:Ninhydrin:20180612232602p:plain
まずはCifar-10について。モデルが複雑になると効果が薄いように感じられる。まあCifar-10だしという気持ちもある。

f:id:Ninhydrin:20180612232940p:plain
table 2はCifar-100に対してBANを行った結果。Cifar-10のときと同様に複雑になるほど効果が薄い。アンサンブルに関してはかなりの効果が見込める。ここでCWTMとDKPPという二つの項目がある。CWTM(Confidence Weighted by Teacher Max)とDKPP(DK with Permuted Predictions)はDKの効果についての検証を行ったものである。CWTMは教師モデルの最大値のみを用いてラベルを重み付けし、他のラベルについての情報は扱わない。DKPPはCWTMと同様に教師モデルの結果を使うが、最大値以外の出力についてはその重みを入れ替える。結果としてCWTMとDKPPともに精度向上に効果があり、DKPPのほうが効果が高い。しかし、どちらも単純なDKには及ばなかった(つまりDKは効果があったということ)。

実装する際に特別なことも必要なく精度向上が見込めるという意味では使い勝手はありそう。ただ本当に効果が得られるのか、どれくらいでBorn Againすればいいのかなどは難しそう(というか難しかった)。

Learn to Pay Attention

https://arxiv.org/abs/1804.02391
ICLR2018。 CNNにおけるチャンネル方向ではなく空間方向へのAttentionモデルの提案。タイトルがいいよね。
図を見たほうが早い。
f:id:Ninhydrin:20180516223808j:plain
決められたそれぞれの特徴量マップに対してAttentionを行う。画像のどこに注目するかという感じかな。
これによって得られた特徴量を入力としてクラス識別を行う。
まずレイヤー{ \displaystyle s \in \{1, ... S\}} から特徴集合{\mathcal{L}^{s} = \{\mathcal{l}^{s}_{1}, \mathcal{l}^{s}_{2},...,\mathcal{l}^{s}_{n} \}}を抽出する。この\mathcal{l}^{s}_{i}は空間方向の特徴。ネットワークの出力のクラス識別の前の部分をグローバル特徴mathcal{g}とする。図で言うとFC-1の出力部分。\mathcal{l}^{s}_{i}mathcal{g}と同じ次元に線形写像したベクトルの集合\hat{\mathcal{L}^{s}}\mathcal{g}をcompatibility function \mathcal{C}に入力。\mathcal{C}(\hat{\mathcal{L}^{s}}, \mathcal{g}) = \{c^{s}_{1}, c^{s}_{2}, ... , c^{s}_{n}\}。compatibility scoreをsoftmax関数で正規化

a^{s}_{i} = \frac{exp(c^{s}_{i})}{\Sigma^{n}_{j}(c^{s}_{j})}, i \in \{1, ... ,n\}
正規化されたcompatibility score\mathcal{A}^{s} = \{a^{s}_{1}, a^{s}_{2}, ..., a^{s}_{n}\}を用いてレイヤー\mathcal{s}ごとに一つの特徴ベクトル\mathcal{g}^{s}_{a} = \Sigma^{n}_{i=1}{a^{s}_{i}\cdot \mathcal{l}^{s}_{i}}を得る。この\mathcal{g}^{s}_{a}を用いてクラス識別を行うが\mathcal{S} > 1のときは\mathcal{g}^{s}_{a}が複数になる。そのときはつなげて一つのベクトルとして識別するもよし、それぞれの\mathcal{g}^{s}_{a}を用いてそれぞれ識別を行い、その平均を使うもよし。

compatibility関数Cについては線形モデル
 c^s_i = \langle \mathcal{u}, \mathcal{l}^s_i + \mathcal{g}\rangle   i \in \{1, ... ,n\}
もしくはドット積
 c^s_i =  \langle \mathcal{l}^s_i, \mathcal{g}\rangle    i \in \{1, ... ,n\}
を用いる。

実験結果
f:id:Ninhydrin:20180516231842p:plain

attXのXの部分に最後からいくつのレイヤーをattentionするか、dpはglobal featureをドット積で、pcは線形モデル。concatはベクトルをすべて結合してFCに入力、indepはそれぞれでクラス予測を行いその平均を結果とする。基本的に比較はVGG-att2-concat-pcで、元のVGGに対して改善が見られるようだ。Table2のRN(ResNet)は他の論文からのものでImageNetでpre-trainしているらしい。それに対してVGG-att2-concat-pcはCifar-100でpre-trainしている。
このあとの実験はconcat-pcのみ
次はAdversarial Attack。今後はこういうものも検証したほうがいいのかな。
f:id:Ninhydrin:20180530004652p:plain
VGGより5%も低いとのことだがここはしっかり調べてない。ノイズが知覚できるレベル(Fig5の5列目とか)になるとVgg-att2-concat-pcのほうが大きくなる。
f:id:Ninhydrin:20180530231137p:plain
f:id:Ninhydrin:20180530232029p:plain
Table 7はクロスドメイン。Fine turningってことでいいのかな?Cifar 10/ 100のそれぞれで学習したモデルを特徴量抽出器とし、その出力をSVMに入力し汎化性能を測っている。どのデータセットに対しても平均6%ほどの改善。Cifarは32x32と解像度が低いが600x600サイズのデータセットでも向上が見られる。STLデータセットはCifar10とクラスが同じらしい。

f:id:Ninhydrin:20180530232108p:plain
最後は弱教師あり学習によるセグメンテーション。他のAttentionモデルよりも良い結果。saliency-basedに対しても車だけは良い結果となった。

End to EndでAttentinoを学習できるCNNというのは面白かった。grad-camなどCNNがどこに注目しているかを調べる手法があったがそれと似たようなことができるのも良い。実装も簡単だしいろいろと応用ができそう。GANなどの生成系にもつかえないかな。それと付録にglobal featureをクエリとすることができるとあった。
f:id:Ninhydrin:20180530233448p:plain

global feature ってクエリじゃね?ってことでまず、ある物体が顕著な画像をクエリ画像、その物体が含まれているが他の物体も多い複雑な画像をターゲット画像と呼ぶ。
Fig 10の1列目がクエリ画像、2列目がターゲット画像。ターゲット画像から得られた10層目の特徴マップに対してターゲット画像で得られたglobal featureを適用したのが3列目と6列目。この二つはcompatibility関数がドット積かパラメータを用いた関数かの違い。4列目と7列目はターゲット画像で得られた特徴マップに対してクエリ画像で得られたglobal featureを適用したもの。5列目と8列目は元のglobal featureを用いた特徴マップとの変化の割合を表している。ドット積を用いたほうがターゲットとなる物体の付近での変化が激しい。それに対して線形モデルの方は変化がほとんどない。これはパラメータ\mathcal{u}がそのあたりの特徴を獲得しているかららしい。そのためテストデータのような未知の画像に対しても適用できる。

読んだ感じこんなだった。間違いなどあればガンガン指摘してくださるとありがたいです。

Raspberry Piとpyenv

当然Pythonを入れねばならない。Pythonならpyenvである。

git clone https://github.com/pyenv/pyenv.git $HOME/.pyenv

して.zshrcとかに

export PYENV_ROOT=${HOME}/.pyenv
if [ -d ${PYENV_ROOT} ]; then
    export PATH=${PYENV_ROOT}/bin:$PATH
    export PYTHONPATH=./pyenv/python:$PYTHONPATH
    eval "$(pyenv init -)"
fi

書いて

source ~/.zshrc

して

pyenv install 3.6.5

して終わり。とか思っていたが最後のインストールで躓いた。なぜだ。どうやらopenSSLがないかららしい。
必要なもの諸々を入れる。

sudo apt-get install -y git openssl libssl-dev libbz2-dev libreadline-dev libsqlite3-dev

これで問題なくインストールできる。
ちなみにbz2とreadlineがないとインストール後にWarningを食らう。こんな感じ↓

$ pyenv install 3.6.5
Downloading Python-3.6.5.tar.xz...
-> https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tar.xz
Installing Python-3.6.5...
WARNING: The Python bz2 extension was not compiled. Missing the bzip2 lib?
WARNING: The Python readline extension was not compiled. Missing the GNU readline lib?

sqliteがないとipythonとかのhistoryが保存できない。ひとまずこれでPythonの環境が作れた。

Raspberry Piの設定の備忘録

Rapberry Piを買った。単純にRaspberry Pi が欲しかったのとSwitch Botのためだ。ninhydrin.hatenablog.com

デスクトップ用途ではなくサーバー用途なのでそのために幾つか設定を行った。思っていた以上に良かったので今後も購入する可能性がある。そのときのために設定を書き残しておく。

ユーザーの追加

sudo adduser hoge
sudo gpasswd -a hoge sudo
sudo gpasswd -d pi sudo
rm /etc/sudoers.d/010_pi-nopasswd # パスワード無しでのsudoをさせない

sudo passwd pi # pi ユーザーのパスワードを変更
# su hoge ログイン

piユーザーのパスワードを変更しているがpi ユーザーの名前を変更したほうがいいかもしれない。削除はしない。
rootのパスワードについては設定しなければログインそのものができないそうなのであえて設定はしていない。

公開鍵認証でSSH接続

まずは鍵を作成。

ssh-keygen -t rsa
chmod 600 id_rsa

パーミッションは忘れてはいけない。
次に公開鍵(pubの方)をRasPiに渡す。

cat id_rsa.pub >> authorized_keys
chmod 600 authorized_keys

ここでもパーミッションを忘れない。
次に何らかのエディタでsshd_configを編集

sudo emacs /etc/ssh/sshd_config

幾つか変更

Port 11011 # 22以外に適当に
PermitRootLogin no
RSAAuthentication yes
PubkeyAuthentication yes
AuthorizedKeysFile .ssh/authorized_keys
PasswordAuthentication no # パスワードログイン禁止

sshサーバーを再起動

sudo /etc/init.d/ssh restart

起動時にSSHを有効にする

sudo touch /boot/ssh

VNCで接続できるようにする

こちらを参照
Raspberry Piの設定【VNCサーバ(tightVNC)の設定】 - Aldebaranな人のブログ
ありがとうございました。
ただこちらの方法だと起動時にvncサーバーが立ち上がらなかった。シバンの

#! /bin/sh

をファイルの一番上に持ってくることで起動するようになった。

HDMIをオフにする

VNCで接続できるようになったのでHDMI出力はいらない

tvservice -o # HDMIオフ
tvservice -p # HDMIオン

このオフの設定を/etc/rc.localに書き込む

sudo emacs /etc/rc.local # root権限が必要

linuxの起動時に自動的に実行するコマンドを書き込むファイルらしい。

デフォルトのログインユーザーを変更

sudo emacs /etc/lightdm/lightdm.conf

autologin-user=piとなっているところを任意のユーザーに変更する。


とりあえずこんなところかな。何かあったら追記していく。

Raspberry PiとGoogle HomeとSwitch Botと

Switch Botに興味があったので買ってしまった。2個で7000円なので一つあたり3500円。これぐらいの値段なら買ってもいいかなと思える。

Google Homeを持っておりGoogle HomeからSwitch Botを動かすのが本命。Google Homeと連携しようと思うとSwitch Linkが必要になるがこれはいい値段するので買うのを躊躇してしいた。色々調べてみるとRaspberry Piから操作できるようなのでRaspberry Piで操作する(汎用性も高いしね)。エンジニアなのに恥ずかしながらRaspberry Piを触ったことがないのでいい機会だ。個人的なRasPiの設定も後日書く予定


Raspberry PiGoogle Homeの連携について参考にしたのは以下
IFTTTとBeebotteを使ってGoogleHomeからRaspberryPiを操作する - Qiita

おかげさまでGoogle Homeに話しかけて何かしらのレスポンスを受け取れるようになった(Beebottle便利ですね)。ここまでくればこっちのものであとはSwitch Bot にかぎらずpythonでなんでもできる。今後も何かやっていこう。
Switch Botpythonで動かすコードはこちら。このSwitch Botを動かす処理をon_messageの中に書けば終わり。これでGoogle HomeからSwitchLinkを使わずにSwitch Botを動かせるようになった。

ただ困ったところが2つある。
一つ目は反応が少し遅い。p = Peripheral("ff:ff:ff:ff:ff:ff", "random")で接続した時に失敗することがある(失敗しても1回ぐらい)。これはfor文tryすることでなんとかなる。
二つ目の問題は重大でセキュリティに関して。SwitchBotは専用アプリがあって一つ一つのSwitch Botにパスワードを設定できる。パスワードを設定することで他人がSwitch Bot動かせないようになっており安心。なのだがパスワードを設定すると上記pythonスクリプトが動かない。つまり現状Google HomeからセキュアにSwitch Botを動かすにはSwitchLinkしか選択肢がない。うーむ、どうしたもんか。
なおソースコードはこちら。まあ殆ど上記の記事のまんまである。

import json
import logging
import time

import paho.mqtt.client as mqtt
import binascii
from bluepy.btle import Peripheral


TOKEN = "token_***********"  # beebottleのチャンネルトークン
HOSTNAME = "mqtt.beebotte.com"
PORT = 8883
TOPIC = "channel_name/resource_name"
CACERT = "mqtt.beebotte.com.pem" # 証明書 https://beebotte.com/certs/mqtt.beebotte.com.pem

SWITCH_MAC = "ff:ff:ff:ff:ff:ff"
SWITCH_NAME = "switch_name"

_LOGGER = logging.getLogger(__name__)

class SwitchBot:
    def __init__(self, name, mac):
        self._mac = mac
        self._name = name

    def turn_on(self):
        for connection in range(1,6):
            try:
                p = Peripheral(self._mac, "random")
            except:
                _LOGGER.error(f'Connection attempt failed after {connection} tries')
                time.sleep(1)
                continue
            break
        else:
            _LOGGER.error('Connection to Switchbot failed')

        try:
            hand_service = p.getServiceByUUID("cba20d00-224d-11e6-9fb8-0002a5d5c51b")
            hand = hand_service.getCharacteristics("cba20002-224d-11e6-9fb8-0002a5d5c51b")[0]
            hand.write(binascii.a2b_hex("570100"))
            p.disconnect()
        except:
            _LOGGER.error("Cannot connect to switchbot.")

switch_bot = SwitchBot(SWITCH_NAME, SWITCH_MAC)


def on_connect(client, userdata, flags, respons_code):
    print('status {0}'.format(respons_code))
    client.subscribe(TOPIC)

def on_message(client, userdata, msg):
    data = json.loads(msg.payload.decode("utf-8"))["data"][0]
    data = {key:value.strip() for key, value in data.items()}
    #  この辺で条件分岐してスイッチを動かす
    swich_bot.turn_on()

client = mqtt.Client()
client.username_pw_set(f"token:{TOKEN}")
client.on_connect = on_connect
client.on_message = on_message
client.tls_set(CACERT)
client.connect(HOSTNAME, port=PORT, keepalive=60)
client.loop_forever()

パスワードを設定すると動かないのは私だけなのか?でもSwitch Botを動かすスクリプトを見た感じパスワードを送信するようなところはないのでMacアドレスさえわかれば近くのSwitch Botを動かせることになる。多分公開してないパスワード関連のAPIがあるのだろう。早く公開してほしいものである。もしもパスワード設定してもpythonから動かせたよという人がいたらぜひ教えて頂きたい。

pythonでターミナル上のカーソルを上に移動する

プログレスバーなどを作りたいときにsys.stdout.write("~\r") のような感じでキャリッジリターンを使っていたがこれだとカーソルを行頭には戻せるがn行上には戻せない。なので全ての情報を1行に記述するしかなく表示する情報に限界がある。カーソルの移動くらい方法があるだろうと調べてエスケープシーケンスにたどり着いた。ターミナル上のカーソルの移動や文字の削除などができる特殊文字列のようだ。キャリッジリターンもそうだったんだな。知らんかった。

使い方はESC[**といった感じでこれをディスプレイに送る。ESCは\033もしくは\x1bのどっちでもいい。やりたいことに応じて**の部分を変更する。例えばカーソルをn行上に移動するには\033[nAと書く。なのでカーソル3行上に移動させるには

print("\033[3A")

とすればいい。カーソルを上の行に移動できるのでprintでも問題ない。他にも\aでアラート音を出したり、\033[y;xHでコンソールのy行x列にカーソルを移動できたり、print("hoge\033[33mhoge\033[0mhoge")で真ん中のhogeだけ黄色くしたりできる。詳しくはエスケープシーケンス一覧などで検索すればいくらでも出てくるだろう。量が多すぎるのでここには書かない。まだまだ知らないことが多い

ちなみにプログレスバーtqdmを使っている。

numpyのRandomStateとmultiprocessing使用時のseed値の罠

numpyとmultiprocessingでシミュレーション的なことをしていた。思ったより良い感じにならなかったしなんとなくおかしいなと思って調査。
numpy.randomとmultiprocessingを使って並列に乱数生成する。適当にその乱数を出力してみる。

コード

import os
import time
from multiprocessing import Pool

import numpy as np


def func(x: int) -> None:
    print(f"x={x}, np.random={np.random.rand()}, pid={os.getpid()}")

if __name__ == '__main__':
    with Pool(10) as p:
        p.map(func, range(10))

結果

python test.py
x=0, np.random=0.5962941676890798, pid=29550
x=1, np.random=0.5962941676890798, pid=29551
x=2, np.random=0.5962941676890798, pid=29552
x=3, np.random=0.5962941676890798, pid=29553
x=4, np.random=0.5962941676890798, pid=29554
x=5, np.random=0.5962941676890798, pid=29555
x=6, np.random=0.5962941676890798, pid=29556
x=9, np.random=0.030654881568257686, pid=29550
x=7, np.random=0.5962941676890798, pid=29557
x=8, np.random=0.5962941676890798, pid=29558

ほとんど同じやんけ。
調べてみるとnumpyのRandomStateは生成時に現在時刻からseed値を決定するらしい(しっかり調べてない)。なんとなくわかった。
回避策としては親プロセスが子プロセスにseed値を与えるぐらい?

変更後

def func(y: tuple) -> None:
    x, seed = y
    np.random.seed(seed)
    print(f"x={x}, np.random={np.random.rand()}, pid={os.getpid()}")

if __name__ == '__main__':
    with Pool(10) as p:
        p.map(func, zip(range(10), np.random.randint(0, 2 ** 32 -1, 10)))

変更後結果

python test.py
x=0, np.random=0.9123107234323904, pid=30126
x=1, np.random=0.01700757843088274, pid=30127
x=2, np.random=0.15402676571119933, pid=30128
x=3, np.random=0.8384796480668341, pid=30129
x=4, np.random=0.03180545379613742, pid=30130
x=5, np.random=0.9923826065352217, pid=30131
x=6, np.random=0.8970540102911799, pid=30132
x=7, np.random=0.9861043353661219, pid=30133
x=8, np.random=0.6360555488868728, pid=30134
x=9, np.random=0.9287125485831335, pid=30135

もっとエレガントな方法はないのか。