chainer: トレーニングモジュールの拡張方法まとめ

chainerのチュートリアルをこなしたあと、

  • mnistなどライブラリが標準で用意してる以外の独自データセットをどうやって作ったら良いんだろう?
  • 複数のモデルが絡むネットワークを同時並行に訓練/評価するにはどうすれば?

といった点で困ることが多くありました。

結果的にはChainerのトレーニング関係のコンポーネントを拡張/自作するのが良さそうなのですが、 それぞれにちょっとずつポイントが有ります。ここでは、トレーニング関連のオブジェクトを概観しながら 今までに書いてきたエントリをまとめます。(今後も増えるかも..)

chainerのトレーニング関連コンポーネント概観

f:id:mizti:20171025204322p:plain

Trainer:

登録されたUpdaterを使って指定したepoch数やイテレーション数分だけ訓練を実行します。

Extensions:

Trainerが実行される過程で付属的に実行される処理を規定します。

mizti.hatenablog.com

モデルの評価を行うEvaluatorもExtensionの一種として定義されます。

mizti.hatenablog.com

Updater:

Datasetを含むIterator、訓練対象のモデルオブジェクトと勾配更新方法を規定するOptimizerを受け取って、 具体的にどのようにモデル群を更新していくかを規定します。

mizti.hatenablog.com

Optimizer:

SGD, Adamなど勾配計算後のパラメータ更新方法を規定します。

Model (Link / Chain)

ニューラルネットワーク(NN)の本体です。基本的にはLinkやFunctionを積み重ねてChainでラッピングしNNを構成していきます。 (Chain自体もLinkを継承して作られたLinkオブジェクトの一種なため、ChainとChainを繋いだりChainとLinkを繋いだりもできます)

Datasetと合わせてLinkやChainで作られるChainerのNNの概要を図にまとめてみました。

f:id:mizti:20171025224939p:plain

Iterator:

下記のデータセットを受け取り、ミニバッチサイズや取り出し順のポリシーを設定します。

Dataset:

画像などのデータとラベルの対を多数保持します。下記のエントリーでまとめたように、 Datasetクラスを継承することで極めて簡単かつ柔軟に定義できます

mizti.hatenablog.com