chainerのupdaterを自作して複雑なネットワークを訓練する

なぜupdaterの自作が必要か

chainerで様々なニューラルネットを試していると、どこかで複数のモデルが組み合つもの、 複数の出力を持つものなど、込み入ったネットワークを訓練したいことがあると思います。

よくあるmnistのサンプルなどでは

optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
(略)
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

という感じで単一のモデルを対象にして訓練ループを進めているのですが、このままでは 複数モデルや複数出力が絡むネットワークの訓練は扱えません。

自分でゴリゴリと更新やらトレーニングやら評価のプロセスを全て自分で書くことも可能は可能ですが、 せっかくならtrainerやoptimizerなどchainerの作った枠組みをどうせなら活かしたいです そこで必要になるのが、それぞれのネットワークの訓練に適した形のupdaterの自作です。

updaterの役割

updaterはあるネットワークと訓練データが与えられた時に、

  • 訓練データから切り出したミニバッチからの損失の計算方法
  • 計算された損失からどのようにネットワークの重みを更新するのか

を定義するモジュールです。

updaterはtrainerに対して一つだけ定義され、一つのupdaterは

  • 最適化対象となるモデルを含むoptimizer(1つ以上、複数の場合もあり)
  • datasetを含み、epochやバッチサイズを定義したiterator

を受け取ります。(下図)

f:id:mizti:20170923205123p:plain]

updaterの実装方法

updaterの作成方法は色々あるのかと思いますが、以下では最も基本的な StandardUpdaterを継承/オーバーライドする形でのupdater定義方法を記載します。

StandardUpdaterをオーバーライドして使う場合に、最低限変更が必要になるのが

  • __init__
  • update_core

です。

initの実装

__init__は言うまでもなくコンストラクタで、受け取る引数を定義します。 必ず受け取らないとエラーになるものはありませんが、その後のupdate_coreで重み更新を行うために必要な 以下のような引数は設定したほうがよいと思います。

  • 訓練データを含むiterator
  • 損失関数の計算に必要なモデル(群)
  • 重み変更対象になるモデルをsetupしたoptimizer(群)
  • iteratorをミニバッチに変換するためのconverter
  • 各種処理を行うためのデバイス指定(cpuなら-1、gpuなら0以上の整数)

また、__init__ 内で

  • self._iterator (iteratorをまとめたした辞書)
  • self._optimizers (optimizerをまとめた辞書)
  • self.iteration=0

の3点はtrainerやupdater_core外のupdater処理で参照されているのでこの名前で 宣言しておくのが良いと思います(今回の手法では)。

update_coreの実装

update_coreでは、updaterの中心的な役割となる

  1. iteratorから入力データ/ラベルの取り出し
  2. 入力データ/ラベルからの損失計算
  3. 損失からネットワークの更新

を定義します。

  1. iteratorから入力データ/ラベルの取り出し: 後で損失計算できればなんでもよいですが、 chainerのDatasetをセットしたiteratorを渡している場合、next()メソッドを使うと [ (入力データ, 教師ラベル), (入力データ, 教師ラベル).. ] というlistが取れるため、convert.concat_examplesを使うと ( [batch_size分の入力データ] , [batch_size分の教師ラベル] ) というtupleに簡単に変換できます。

  2. 入力データ/ラベルからの損失計算: これは目的に応じて。基本的に入力データと教師ラベルをモデルのcallや然るべきメソッドに渡せば 算出されるように作っていると思います。

  3. 損失からネットワークの更新: Chainerでネットワークの重み更新を行う場合、基本的には ① 各更新対象モデルのcleargrad()による勾配の初期化 ② lossからのbackward() による誤差逆伝搬 ③ optimizer.update()による各モデルの勾配更新 の流れとなります。 (このあたりはChainerのチュートリアルやプレイグラウンド(ここここ)あたりに目を通すと良いと思います)

以下、update_coreを実装する上で幾つか気をつけるべき点です

  • モデルの持つ重みを個別にcleargrad()しても良いですが、Chainクラスの場合cleargrads()メソッドにより重みを一度に初期化できます
  • .barckward()は次元を持たない(つまりスカラーな)Variableに対してしか実行できません
  • backward()は指定したlossの計算過程をChainインスタンスを跨って伝搬します
  • optimizer.update()はoptimizerに設定したモデルのみが重み更新対象になります。(逆に重み更新したくないモデルはupdate()しなければ良いです)

実装例

以下、参考になるかわかりませんが私の作ったUpdater例です。 このUpdaterは下図のような畳み込み層に対して複数のLinear層がそれぞれ個別に 入力を受け取り、別々に損失を計算するようなネットワークを訓練するために作ったものです。

f:id:mizti:20170917000954p:plain

import six
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda, training, reporter
from chainer.datasets import get_mnist
from chainer.training import trainer, extensions
from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
from chainer.datasets import get_mnist
from chainer import optimizer as optimizer_module

class MyUpdater(training.StandardUpdater):
    def __init__(self, iterator, base_cnn, classifiers, base_cnn_optimizer, cl_optimizers, converter=convert.concat_examples, device=None):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main':iterator}
        self._iterators = iterator
        self.base_cnn = base_cnn
        self.classifiers = classifiers

        self._optimizers = {}
        self._optimizers['base_cnn_opt'] = base_cnn_optimizer
        for i in range(0, len(cl_optimizers)):
            self._optimizers[str(i)] = cl_optimizers[i]

        self.converter = convert.concat_examples
        self.device = device
        self.iteration = 0

    def update_core(self):
        iterator = self._iterators['main'].next()
        in_arrays = self.converter(iterator, self.device)

        xp = np if int(self.device) == -1 else cuda.cupy
        x_batch = xp.array(in_arrays[0])
        t_batch = xp.array(in_arrays[1])
        y = self.base_cnn(x_batch)

        loss_dic = {}
        for i, classifier in enumerate(self.classifiers):
            loss = classifier(y, t_batch[:,i])
            loss_dic[str(i)] = loss

        for name, optimizer in six.iteritems(self._optimizers):
            optimizer.target.cleargrads()

        for name, loss in six.iteritems(loss_dic):
            loss.backward()

        for name, optimizer in six.iteritems(self._optimizers):
            optimizer.update()

少しわかりづらいですが、"base_cnn"が左側の畳み込み層のモデルを表し、 classifiersが右側のLinearモデルのリストになっています。(optimizerも同様)

そしてそれぞれの出力から誤差を逆伝搬し、畳み込み層は全ての出力からの誤差を蓄積した上で updateされるようになっています。

スレッドとキューとキューランナーとコーディネータの関係

TensorflowのDeveloper’s Guideのスレッドとキュー解説をかいつまんでまとめてみる。

Threading and Queues  |  TensorFlow

  • キュー

文字通り、データのキュー。FIFOQueueなど。 q = tf.FIFOQueue(3, "float") として定義した後、 x = q.dequeue()でデータのポップ、 q_inc = q.enqueue([y]) でデータのプッシュができる

  • キューランナー

キューに対して、繰り返し同じ操作でenqueueオペレーションを実行するスレッドを指定数だけ作成する

queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
sess = tf.Session()
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)

上記のコードで、queueに対して、exampleという操作(op)でエンキューを行うスレッドを4つ作成していることになる。 注意点は、あくまでエンキューを行うスレッドであるということ。 また、キューランナーでスレッドを作成するときには、そのスレッドたちの「監督役」になるCoordinaorを指定する必要がある。

  • コーディネータ

前述のキューランナーでスレッドを作成する際に、スレッドの監督役になるオブジェクト。 coord = tf.train.Coordinator() で作成したあとキューランナーに渡すと、 coord.should_stop()でスレッドを殺すべきタイミングを教えてくれたり、 coord.request_stop()でスレッドの停止を命令したり coord.join(threads)で指定したスレッド(群)が停止するまで待機したりできる。

Tensorflow + Jupyterのsave & restore時のトラブルとその回避方法

Tensorflowでモデルを保存しようとする場合にsaveしたモデルをrestoreすることができないトラブルに遭遇した。

保存側:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

saver = tf.train.Saver()

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
for _ in range(3):
    print(_)
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    save_path = saver.save(sess, "tmp/model.ckpt")
    print(save_path)

print(sess.run(b))
print(sess.run(W))

読込側:

import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
saver = tf.train.Saver()

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "tmp/model.ckpt")
  print("Model restored.")
  print(sess.run(b))
  print(sess.run(W))

上記のコードはmnist sampleにsaveとrestoreを追加したもの。シェル上では正しく動作するのだが、jupyter上で動作させると NotFoundError: Key Variable_43 not found in checkpoint というエラーが発生する。(43の部分は実行毎に異なる)

StackOverflowには同様の事象の報告が上がっている。Googleのエンジニアからの回答によると、 ・Jupyterがコードの再実行時のパフォーマンスを上げるために勝手に変数をコピーする=> ・その結果、Tensorflowが名前の重複を避けるために新しい名前をアサインして別Variablesが増える ということが起きているようです。

これを避けるために、3つの方法が提示されています。

  1. モデルの構築を始める前にtf.reset_default_graph()を呼び出す。(defaultとして設定されているGraphが消える点に注意)

  2. tf.train.Saver()を呼び出す際に、保存するVariableを明示的に指定する。例としては saver = tf.train.Saver(var_list={"b1": b1, "W1": W1, "b2": b2, "W2": W2}) が紹介されていますが、こちらのように、Variableの宣言時にtrainable=Trueを宣言し、tf.trainable_variables()Saverに渡す方法もあるようです

  3. Graphのスコープをwith tf.Graph().as_default():で明示的に指定し、その中でモデルの宣言とrestoreを行う

副作用が少なそうで、行儀が良さそうなのは2かな..という気がするのでこれでやっています。 Tensorflowのsaveとrestoreについては、Saverコンストラクタの呼び出しタイミングがセンシティブだったりするので色々注意が必要ですね。


補足: エラー詳細

INFO:tensorflow:Restoring parameters from tmp/model.ckpt

---------------------------------------------------------------------------
NotFoundError                             Traceback (most recent call last)
<ipython-input-36-1322488c39ce> in <module>()
     12 with tf.Session() as sess:
     13   # Restore variables from disk.
---> 14   saver.restore(sess, "tmp/model.ckpt")
     15   print("Model restored.")
     16 

/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc in restore(self, sess, save_path)
   1455     logging.info("Restoring parameters from %s", save_path)
   1456     sess.run(self.saver_def.restore_op_name,
-> 1457              {self.saver_def.filename_tensor_name: save_path})
   1458 
   1459   @staticmethod

/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_call(self, fn, *args)
   1050         except KeyError:
   1051           pass
-> 1052       raise type(e)(node_def, op, message)
   1053 
   1054   def _extend_graph(self):

NotFoundError: Key Variable_43 not found in checkpoint
         [[Node: save_21/RestoreV2_38 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_21/Const_0, save_21/RestoreV2_38/tensor_names, save_21/RestoreV2_38/shape_and_slices)]]

Caused by op u'save_21/RestoreV2_38', defined at:
  File "/Users/miz/.pyenv/versions/2.7.11/lib/python2.7/runpy.py", line 162, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/Users/miz/.pyenv/versions/2.7.11/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-36-1322488c39ce>", line 10, in <module>
    saver = tf.train.Saver()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1056, in __init__
    self.build()
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1086, in build
    restore_sequentially=self._restore_sequentially)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 691, in build
    restore_sequentially, reshape)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 407, in _AddRestoreOps
    tensors = self.restore_op(filename_tensor, saveable, preferred_shard)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 247, in restore_op
    [spec.tensor.dtype])[0])
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 669, in restore_v2
    dtypes=dtypes, name=name)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 768, in apply_op
    op_def=op_def)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/miz/.pyenv/versions/2.7.11/envs/tensorflow_test/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
    self._traceback = _extract_stack()

NotFoundError (see above for traceback): Key Variable_43 not found in checkpoint
         [[Node: save_21/RestoreV2_38 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_21/Const_0, save_21/RestoreV2_38/tensor_names, save_21/RestoreV2_38/shape_and_slices)]]

動画紹介: 確率的ニューラルネットワークについて

昨年末VAEの勉強をしているときに、以下のようなことがわからなくて時間を喰っていた。

  • なんで確率的に毎回出力の異なるモデルが必要になるのか
  • reparameterization trickがなぜ必要になるのか

いろんな人に質問したり、関連サイトを読み漁ってなんとかなったのだが、 PFN得居さんのPyCON2016の講演動画ががとてもわかりやすいので、VAEみたいな確率的NNを 触る前に見ることをお勧め。


5_05 [招待講演 / Invited Talk] 確率的ニューラルネットの学習と Chainer による実装

以下、講演内容前半のメモ。

確率的NNとは何か

出力までの過程に確率分布からのサンプリングが含まれており、そのサンプリングの結果によって出力が変動するような ニューラルネットワークのこと

確率的NNをなぜ使うのか

分類問題などでは、出力として決定的な結果がほしいが、問題によっては出力として非決定的な結果がほしいことがある

例: * 実際に確率分布そのものがほしい場合 * 生成モデル(変分AEなど)

確率的NNを使う上で難しい点

途中に確率的なunitが入っている場合、何を最適化するのか? 確率的なunitが入っている場合、同じ入力でもNNを通すたびに出力が変わる。 出力をたくさん得ると、「出力の分布」が得られる。その「出力の分布の良さ」=確率分布の良さを測りたい

このようなことをする場合、各試行に対してloss値が定まるように設計することが多い。 そして出力をたくさん取って、各試行に対してlossを計算

=>その平均を取ることで「lossの期待値」が取れるため、これの最小化を最適化の目標とすることが多い。

確率分布自体は連続なものだが、 試行にかけられる時間が有限であるため、サンプル値は離散的にしか取ることができない。したがって近似が必要。

これに対処するための代表的な手法が下記2つ。

1. likelihood-ratio method (LR)

よく使われる手法。強化学習分野で「reinforce」と呼ばれていたものとほぼ一緒。 まず普通にforwardする(この過程でサンプリングが入る) => lossが高かったら今の選択が行われにくいように下げる / lossが低かったらあまり下げない。

出てきたlossが大きいか小さいか自体の判断が本来は難しいが、systematicにやるために、出てきたlossに比例する割合でその確率を下げる。

 
\Delta \mu := f(z) \nabla_\mu \log p(z|\mu)  \qquad (z \sim p(z|\mu))

式の意味

z : NNの出力。zは確率分布p(z|μ)に従う

f(z) :サンプルされたlossの値(fがloss関数)

\nabla_\mu \log p(z|\mu) : 平均値μの条件下でzを取る確率の勾配。この方向に進めば、zを取る確率が下がる。

これは1回だけだとzを取る確率が下がるだけだが、 f(z)が小さい方向には小さく f(z)が大きい方向には大きく 確率が下がるため、これを繰り返すことで相対的にf(z)が小さい確率が生き残って、高いままになる。

この式は真の期待値の勾配の不偏推定になっており、無限回サンプルすることで本物のgradientに一致する。

(ただし、これ自体は最小化したかったもの自体の勾配ではないのでもうひと工夫必要)

LRの問題点

勾配の大きさのブレが大きい。(サンプリングによりlossが何千倍、というオーダーでブレることもある) このため、varianceを下げるためのテクニックが必要になる。 よく使われるのが"baseline"という手法。 相対的にlossが小さい方に向けて移動させていくという動きは、定数を引いても挙動は変わらない(不偏推定性が保たれる)

うまい定数を引くことでvarianceが小さくなることが知られており、何らかの形で定数を決定し、 f(z)の代わりにf(z) - bを目標にする (bにはf(z)の平均などを用いる)

勾配の方向を一次元的に調節するだけなので、varianceの減らし方としては弱いが、手軽。

2. reparameterization trick

最近、もっと良いvarianceの少ない手法が開発された。それがreparameterization trick (ただし、正規分布のような連続な分布にしか使えない手法)

f:id:mizti:20170130230237p:plain

サンプリングを含むユニットを書き換えて、サンプリングを入力μとσ2と「平均0、分散1の、パラメータとは全く関係のないノイズ」の入力を掛け合わせるユニットとみなす。

このようにみなすことで偶然性(stochasticity)のない決定的なNNであるとみなせる。 => 普通にbackpropすれば良くなる。 しかも実装がかんたん。 ただし、ガウシアンでしか使えない。

離散のユニットについてはやっぱりLRを使わなければならない。

そうは言っても、予測したい問題が離散的な場合、離散的なモデルを使ったほうが良いことはあり、 学習しやすい方法を見つけ出すのは重要な問題として残っている。

(また、得居さんが離散的な値の場合について分散が大きくならない手法を研究しており、そのうち発表するとのこと。 (スライド上ではReLEGという名前がついていてた))

Chainer でのcoding方法

以下、Chainerでどうやって書くの?という話。 要点のみ。

  • VAEの場合の例 => ガウシアンは連続値なのでreparameterization trick

  • sigmoid belief network(SBN)の例 => 離散値なのでLRを使う (そのうちSBNの勉強をする機会があったら見返してみよう)

注意点: backpropのルートから外したい部分は

  • chainerのChain宣言 __init__内のsuper.__init__(..){ ... }から外すこと
  • ChainerのVarianceではなく、Variance内のnumpy / cupyの値を直接操作すること

着彩済イラストから綺麗に線画を抽出する方法

機械学習のテーマの一つとして自動着彩があります。この中で、特にイラストの自動着彩を考えると 未着彩と着彩済みのペアが学習用サンプルとして大量に必要となりますが、まとまった量を入手するのはなかなか難しいという問題があります。

すると、カラーイラストから線画を抽出することを考えたくなるのですが、 一般的な輪郭検出を用いると「輪郭線自体の輪郭」が抽出されてしまい、線がぼやけてしまうという問題があります。

例えば f:id:mizti:20170121224135j:plain に対して輪郭検出を実施すると、 f:id:mizti:20170121224155j:plain となります。 (拡大) f:id:mizti:20170121224451j:plain

右頬の輪郭線に対して、肌側、背景側それぞれの境界が検出されてしまい、線が2本引かれてしまっていることがわかります。

で、綺麗な輪郭抽出ができず困っていたのですが、ピーFN(一体何FNなんだ...)のtaizanさんが投稿されたこちらのエントリ

qiita.com

では非常に綺麗に線画抽出ができており、どのようにやっているか気になっていたところ

f:id:mizti:20170121224813p:plain

との情報が。ということでやってみました。

f:id:mizti:20170121230350j:plain 拡大 f:id:mizti:20170121230510j:plain

線がだぶることなく、綺麗に抽出できているようです。すごい!

(ここまでの絵は村田蓮爾氏のものを引用させていただいています)

手順詳細

以下、手法の詳細についてです。

以降の絵はpixivで見つけたLpipさんイラスト を例にさせていただいてます。

今回使ったのはcv2のpythonライブラリです。

画像を開く

I = cv2.imread('data/before.png')

dilationする

kernel = np.ones((5,5), np.uint8)
dilation = cv2.dilate(I, kernel, iterations = 1)

f:id:mizti:20170121232739p:plain

元画像とのdiffを取る

diff = cv2.subtract(I, dilation)

f:id:mizti:20170121233257p:plain

白黒反転する

diff_inv = 255 - diff

f:id:mizti:20170121232842p:plain

グレースケール化して書き出し

diff_inv_binarized = cv2.threshold(diff_inv, 100, 255, cv2.THRESH_BINARY)
cv2.imwrite('after.png', diff_inv)

f:id:mizti:20170122103923p:plain

まとめると

I = cv2.imread('data/before.png')
kernel = np.ones((5,5), np.uint8)
dilation = cv2.dilate(I, kernel, iterations = 1)

diff = cv2.subtract(I, dilation)
diff_inv = 255 - diff
diff_inv_binarized = cv2.threshold(diff_inv, 100, 255, cv2.THRESH_BINARY)
cv2.imwrite('after.png', diff_inv)

です

chainer: 独自datasetを定義する方法

f:id:mizti:20170113130134p:plain

f:id:mizti:20171013202945p:plain

chainerで独自データセットクラスを作るための方法を明示的に示したドキュメントが 見当たらなかったので、備忘録をかねて書く。実はとっても簡単。

  1. データセットにするクラスは chainer.dataset.DatasetMixinを継承する
  2. 内部に持っているデータの数を返却する __len__(self) メソッドを実装する。このメソッドは整数でデータ数を返却する
  3. i番目のデータを取得する get_example(self, i) メソッドを実装する。このメソッドは、
    • 画像配列などのデータ
    • ラベル

の2つを返却する(return image_array, label みたいな感じで)

本当に必要なことはこのたった3つです。 Datasetのクラスを定義するタイミングで画像を全部読み込んでもよいですが、 get_exampleを呼び出すタイミングで実際の読み込みを 行うのでも構いません。

実例:

import sys
import random
import numpy as np
from PIL import Image
import csv
import chainer
from chainer import datasets

class ImageDataset(chainer.dataset.DatasetMixin):
    def __init__(self, normalize=True, flatten=True, train=True, max_size=200):
        self._normalize = normalize
        self._flatten = flatten
        self._train = train
        self._max_size = max_size
        pairs = []
        with open('data/filename_label_list.tsv', newline='') as f:
            tsv = csv.reader(f, delimiter='\t')
            for row in tsv:
                if 'jpg' in row[0]:
                    pairs.append(row)

        self._pairs = pairs

    def __len__(self):
        return len(self._pairs)

    def get_image(self, filename):
        image = Image.open('data/' + filename)
        new_w = self._max_size + 1
        new_h = self._max_size
        image = image.resize((new_w, new_h), Image.BICUBIC)
        image_array = np.asarray(image)
        return image_array

        # type cast
        image_array = image_array.astype('float32')
        label = np.int32(label)
        return image_array, label

    def get_example(self, i):
        filename = self._pairs[i][0]
        image_array = self.get_image(filename)
        if self._normalize:
            image_array = image_array / 255
        if self._flatten:
            image_array = image_array.flatten()
        else:
            if image_array.ndim == 2:
                mage_array = image_array[np.newaxis,:]
        image_array = image_array.astype('float32')
        image_array = image_array.transpose(2, 0, 1) # order of rgb / h / w
        label = np.int32(self._pairs[i][1])
        return image_array, label
  • __init__で、ファイル名とラベルのリストを読み込んでいます。ここで self._pairsにリストの各行を入れていますが、画像データはまだ読み込んでいません
  • __len__self._pairsの項目数を返却するだけ
  • get_exampleで実際の画像読み込み=> 配列化とlabelの返却を行っています(画像読み込みは、エポック数が多い場合は __init__内で先に全部読み込んでおいたほうが早い場合もあるかもしれません)
  • __len__が返却する値がdatasetのサイズとみなされます。その結果、
    • len / minibatch_sizeがepoch内で学習されるminibatchの個数となる
    • Iteratorは0番目からlen番目までの要素をget_example(i)で取得するようになる

また、注意しておいたほうが良いことが数点だけあります。

  • ラベルを整数で返却する場合、 np.int32( label ) という感じで np.int32にキャストして返却すること. 普通のintでも回せますが、GPUを使わず、CPUのみで実行しようとするとき、labelがint型だとCUDA environment is not correctly set upという関連の分かり辛いエラーで怒られてしまいます
  • 画像を普通に読み込むと、0 ~ 255 の整数データになるため、0.0 ~ 1.0に正規化すること(私は'float32'型を指定しています)
  • 画像を返却する際は(色次元数, h, w) という順番に軸変換を行っておくこと。普通にPIL等でイメージを読み込むと、( h, w, 色次元数 ) という順になるため、 image_array = image_array.transpose(2, 0, 1)などで変換が必要です
  • datasetクラスをinitializeする際に全てのデータを読み込んだり生成したりするのはGPUメモリ容量的に得策ではありません。initialize時にはデータのリストだけ作り、get_example内でデータの実体を読み込む/生成するようにしたほうが良いと思います。(特にデータ数が多い場合)

いろんなデータセットで楽しむきっかけになれば幸いです。

AWS Step FunctionsとLambdaでディープラーニングの訓練を全自動化する

動機とやったことの概要

詳細

Lambdaに付与する権限

たぶん以下くらいの権限がLambda実行時に必要。

AWSLambdaAMIExecutionRole
AmazonS3FullAccess
AmazonEC2SpotFleetRole
AWSLambdaBasicExecutionRole
AmazonSNSFullAccess
EC2ReadOnly ("ec2:DescribeSpotInstanceRequests"リソースへのアクセスを追加)

Step Functionsの入力

{
    "exec_name": "pix2pix-20161231",
    "repository_url": "https://github.com/mattya/chainer-pix2pix.git",
    "repository_name": "chainer-pix2pix",
    "data_dir": "/home/ubuntu/data",
    "output_dir": "/home/ubuntu/result",
    "data_get_command": "/home/ubuntu/.pyenv/shims/aws s3 cp s3://pix2pixfacade/ /home/ubuntu/data --recursive",
    "exec_command": "/home/ubuntu/.pyenv/shims/python /home/ubuntu/chainer-pix2pix/train_facade.py -g 0 -e 100 -i /home/ubuntu/data --out /home/ubuntu/result --snapshot_interval 10000"
}
変数名 説明
exec_name この実行の名前。バケット名にもなるため、アンダースコアを使わずkebab-case推奨
repository_url git cloneする対象のリポジトリURL
repository_name git cloneしたあと取得できるリポジトリ
data_dir データを格納するディレクト
output_dir 訓練結果等を格納するディレクト
data_get_command データを取得するなど、訓練開始前に実施する
exec_command 訓練実施コマンド

やってることの中身

Step Functionの定義

{
  "Comment" : "Machine learning execution with spot instance",
  "StartAt" : "CreateS3Bucket",
  "States"  : {
    "CreateS3Bucket": {
      "Type"      : "Task",
      "Resource"  : "arn:aws:lambda:ap-northeast-1:999999999999:function:create_s3_bucket",
      "Next"      : "RequestSpotInstance"
    },
    "RequestSpotInstance": {
      "Type"      : "Task",
      "Resource"  : "arn:aws:lambda:ap-northeast-1:999999999999:function:request_spot_instance",
      "Next"      : "WaitBidding"
    },
    "WaitBidding": {
      "Type"      : "Wait",
      "Seconds"   : 30,
      "Next"      : "CheckBiddingResult"
    },
    "CheckBiddingResult": {
      "Type"      : "Task",
      "Resource"  : "arn:aws:lambda:ap-northeast-1:999999999999:function:check_bidding_result",
      "Next": "ChoiceBiddingResult"
    },
    "ChoiceBiddingResult": {
      "Type" : "Choice",
      "Choices": [
        {
          "Variable": "$.request_result",
          "BooleanEquals": true,
          "Next": "NotifyRequestSuccess"
        },
        {
          "Variable": "$.request_result",
          "BooleanEquals": false,
          "Next": "NotifyRequestFailed"
        }
      ],
      "Default": "NotifyRequestFailed"
    },
    "NotifyRequestFailed": {
      "Type" : "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message",
      "Next": "SpotRequestFailed"
    },
    "SpotRequestFailed": {
          "Type": "Fail",
          "Error": "SpotRequestError",
          "Cause": "Spot price bidding too low"
    },
    "NotifyRequestSuccess": {
      "Type" : "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message",
      "Next": "WaitTaskComplete"
    },
    "WaitTaskComplete": {
      "Type"      : "Wait",
      "Seconds"   : 10,
      "Next"      : "CheckTaskCompleted"
    },
    "CheckTaskCompleted": {
      "Type" : "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:check_task_completed",
      "Next": "ChoiceTaskCompleted"
    },
    "ChoiceTaskCompleted": {
      "Type" : "Choice",
      "Choices": [
        {
          "Variable": "$.task_completed",
          "BooleanEquals": true,
          "Next": "NotifyTaskCompleted"
        },
        {
          "Variable": "$.task_completed",
          "BooleanEquals": false,
          "Next": "WaitTaskComplete"
        }
      ],
      "Default": "WaitTaskComplete"
    },
    "NotifyTaskCompleted":{
      "Type": "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:send_sms_message",
      "Next": "WaitInstanceDelete"
    },
    "WaitInstanceDelete": {
      "Type"      : "Wait",
      "Seconds"   : 1800,
      "Next"      : "DeleteSpotInstance"
    },
    "DeleteSpotInstance": {
      "Type": "Task",
      "Resource": "arn:aws:lambda:ap-northeast-1:999999999999:function:delete_ec2_instance",
      "End": true
    }
  }
}
  • 判断分岐以外は直列に流してるだけ。
  • 処理途中に生成されるID類はeventに追加しながら下流に流す
  • S3作成とスポットインスタンスリクエストはParallelにしても良いかも(面倒くさいのでやってない..)
  • 訓練完了から30分は削除せずに待つ。サーバに未練があればこの間に実行を停止する。

S3バケット作成

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import boto3
import json
import os

def lookup(s3, bucket_name):
  try:
    s3.meta.client.head_bucket(Bucket=bucket_name)
  except botocore.exceptions.ClientError as e:
    error_code = int(e.response['Error']['Code'])

    if error_code == 404:
      return False

    return True

def create_bucket(bucket_name):
    s3 = boto3.resource('s3')
    response = ''
    if not lookup(s3, bucket_name):
       response = s3.create_bucket(Bucket=bucket_name)

    return response

def lambda_handler(event, context):
    response = create_bucket(event['exec_name'])
    return event
  • eventからexec_nameを取り出してバケット名に
  • その名前のバケットがなければ作る

スポットインスタンスのリクエスト

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import boto3
import json
import logging
import base64
import os

SPOT_PRICE = '0.8'
REGION = 'ap-northeast-1'
AMI_ID = 'ami-9999999f'
KEY_NAME = 'your_keyname'
INSTANCE_TYPE = 'g2.2xlarge'
SECURITY_GRUOP_ID = ['sg-9999999']

def request_spot_instance(user_data):
    ec2_client = boto3.client('ec2',
        region_name = REGION
    )
    response = ec2_client.request_spot_instances(
        SpotPrice = SPOT_PRICE,
        Type = 'one-time',
        LaunchSpecification = {
            'ImageId': AMI_ID,
            'KeyName': KEY_NAME,
            'InstanceType': INSTANCE_TYPE,
            'UserData': user_data,
            'Placement':{},
            'SecurityGroupIds': SECURITY_GRUOP_ID
        }
    )
    return response

def lambda_handler(event, context):
    REPOSITORY_URL  = event["repository_url"]
    REPOSITORY_NAME = event["repository_name"]
    BUCKET_NAME = event["exec_name"]

    shell='''#!/bin/sh
    sudo -s ubuntu
    cd /home/ubuntu
    sudo -u ubuntu mkdir /home/ubuntu/.aws
    sudo -u ubuntu mkdir /home/ubuntu/completed
    sudo -u ubuntu git clone {5}
    sudo -u ubuntu mkdir {0}
    sudo -u ubuntu mkdir {1}

    sudo -u ubuntu echo "[default]" >> /home/ubuntu/.aws/credentials
    sudo -u ubuntu echo "aws_access_key_id={2}" >> /home/ubuntu/.aws/credentials
    sudo -u ubuntu echo "aws_secret_access_key={3}" >> /home/ubuntu/.aws/credentials

    sudo -u ubuntu echo "*/5 * * * * /home/ubuntu/.pyenv/shims/aws s3 sync {1} s3://{4} > /dev/null 2>&1" >> mycron
    sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 cp {1}/log s3://{4} > /dev/null 2>&1" >> mycron
    sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 cp /home/ubuntu/trace.log s3://{4} > /dev/null 2>&1" >> mycron
    sudo -u ubuntu echo "*/1 * * * * /home/ubuntu/.pyenv/shims/aws s3 sync /home/ubuntu/completed s3://{4} > /dev/null 2>&1" >> mycron

    sudo -u ubuntu /usr/bin/crontab mycron
    sudo -u ubuntu /bin/rm /home/ubuntu/mycron

    PATH="/usr/local/cuda/bin:$PATH"
    LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"

    sudo -u ubuntu cd /home/ubuntu/{6}

    sudo -u ubuntu touch trace.log
    sudo -u ubuntu echo `pwd` >> trace.log  2>&1
    sudo -u ubuntu echo `which python` >> trace.log  2>&1
    sudo -u ubuntu echo 'repository_name: {6}' >> trace.log 2>&1
    sudo -u ubuntu echo 'dataget_command: {7}' >> trace.log 2>&1
    sudo -u ubuntu echo 'exec_command: {8}' >> trace.log 2>&1
    sudo -u ubuntu {7}  > /dev/null 2>> trace.log
    sudo -u ubuntu echo `ls /home/ubuntu/data | wc` >> trace.log

    PATH="/usr/local/cuda/bin:$PATH"
    LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"
    sudo -u ubuntu -i {8}  >> trace.log 2>&1
    sudo -u ubuntu touch /home/ubuntu/completed/completed.log
    '''

    shell_code = shell.format(
        event["data_dir"],
        event["output_dir"],
        os.environ.get('S3_ACCESS_KEY_ID'),
        os.environ.get('S3_SECRET_ACCESS_KEY'),
        event["exec_name"],
        event["repository_url"],
        event["repository_name"],
        event["data_get_command"],
        event["exec_command"]
        )
    user_data = base64.encodestring(shell_code.encode('utf-8')).decode('ascii')
    response = request_spot_instance(user_data)
    event["spot_instance_request_id"] = response["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
    return event
  • インスタンスタイプや入札価格は定数にして、StepFunction実行時の入力(event)からは引かないようにしている(eventはコードの実行条件のみにし、環境調達条件はLambda側に持たせるポリシーのつもり)
  • AMIは、chainer、CUDA等はインストール完了いているものがある前提
  • インスタンスをリクエストしたあとuser_dataをシェルスクリプトにして流し込んでる
  • 大体の汚い処理はここのシェルスクリプトに凝縮されている
    • S3へのupload系タスクはcronに登録
    • その後、パスを通して訓練の開始
  • S3_ACCESS_KEY_ID / S3_SECRET_ACCESS_KEYはIAMのwrite権限のある鍵をLambda Functionの環境変数に登録しておく。
  • 実行時のログはtrace.logに出力 > これもS3に随時Up
  • 実行完了後に、completed.logを出力。これがS3のバケットに入ると、StepFunctions側でタスク完了とみなされる

入札結果確認

def check_bidding_result(spot_instance_request_id):
    ec2_client = boto3.client('ec2',
        region_name = REGION
    )
    response = ec2_client.describe_spot_instance_requests(
      SpotInstanceRequestIds = [spot_instance_request_id]
    )
    return response

def lambda_handler(event, context):
    response = check_bidding_result(event["spot_instance_request_id"])
    event["request_result"] = (response['SpotInstanceRequests'][0]['Status']['Code']==u'fulfilled')

    if event["request_result"]:
        event["instance_id"] = response['SpotInstanceRequests'][0]['InstanceId']

    return event
  • スポットインスタンスリクエスト時に取得した'SpotInstanceRequests'から、入札の結果を確認する

通知

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import boto3
import json
import os

TOPIC_ARN = 'arn:aws:sns:ap-northeast-1:9999999999:training_end_notification_mail' # Mail
REGION = 'ap-northeast-1'


def send_sms_message(event, context):
    sns = boto3.client('sns',
        region_name = REGION
    )

    message = ''
    subject = ''
    if "completed" in event:
        subject = 'Training ended'
        message = '''task completed!
        result: https://console.aws.amazon.com/s3/home?bucket={0}

        -----
        {1}
        '''.format(event["exec_name"], event)
    else:
        if event["request_result"]:
            subject = 'request fulfilled'
            message = '''
            Spot Request Fulfilled! {0}
            '''.format(event["exec_name"])
        else:
            subject = 'request failed'
            message = '''
            Spot Request Fails! {0}
            '''.format(event["exec_name"])

        response = sns.publish(
            TopicArn = TOPIC_ARN,
            Subject = subject,
            Message = message
        )

    return response

def lambda_handler(event, context):
    response = send_sms_message(event, context)
    return event
  • 通知の宛先、通知手段は、事前にSNS側に登録し、Topic ARNを発行しておく
  • 作成されたインスタンスのIDはeventに追加して下流に流す

その他のLambda

  • あとは特別なことはしていないリポジトリをご参照ください

リポジトリ

github.com

改良案とか、◯◯をXXでやらないなんて有りえない!とかあればお気軽に @mizti までコメントください (AWS今までちゃんと触ってこなかった勢なので話せれば嬉しいです)

f:id:mizti:20170101205025p:plain