chainer: Extensionを自作してディープラーニングの訓練に独自処理を挟み込む
なぜExtensionを自作するのか
Chainerのモデルのトレーニング中に、
- たまにモデルの出力をダンプさせたい(特に生成系で必要になる)
- 1エポック毎に学習率を手動で変更したい
など、独自の処理を定期的に挟みたくなることがあるかと思います。 このような願いを叶えるのがchainerのextensionという仕組みです。 extensionはtrainerに仕掛けておくことで指定した間隔ごとに独自に定義した処理を実行してくれます。
extensionはtrainerにひも付き、一つのtrainerにいくつでもextensionを設定することができます。
Extensionの自作方法
モデルを定期的に評価するEvaluatorもこのextensionを使って実装されています。 モデルの評価用途にはこのEvaluatorを拡張したほうが早いでしょう。 ここでは評価以外のより一般的にextensionを自作する方法にフォーカスします
例えば、公式ではこのようなExtensionが用意されています。よく使うのはこのあたりでしょうか。
Extension名 | 概要 |
---|---|
extensions.Evaluator | モデルの損失や正解率を評価する |
extensions.snapshot | モデルの重みをファイルに保存する |
extensions.LogReport | Evaluator等で測定した各種の評価値をログに出力する |
extensions.PrintReport | 指定した評価値を標準出力やファイルに出力する |
これらの公式Extensionはチュートリアル等でも言及されていますが、 extensionは簡単に自作することも可能です。
簡単なextensionの例として、「"hoge"とpirntするだけ」のextensionを作ってみましょう。
main.py | + lib -- print_hoge.py
このようなディレクトリ構成で、 定義側(print_hoge.py):
print_hoge.py def print_hoge(): @training.make_extension(trigger=(1, 'epoch')) def _print_hoge(trainer): print("hoge") return _print_hoge
呼び出し側 (main.py):
from lib.print_hoge import * (略) trainer = training.Trainer(updater, (30, 'epoch'), out=args.output) trainer.extend(print_hoge())
たったこれだけの記述で、1epochに一回、hogeと標準出力に出力されるようになります。
Extensionのカスタマイズ
動作する間隔を変えたければ
- trigger=(1, 'iteration'): 1イテレーションに1回
- trigger=(3, 'epoch'): 3エポックに一回
などtriggerを変更すればokです。
また、extensionの中でモデルを参照したければtrainer.update.model
とういうように、
trainerに登録されたupdater経由でアクセス可能です。
下記のように実行時に引数を与えることも可能です。
print_hoge.py def print_hoge(message): @training.make_extension(trigger=(1, 'epoch')) def _print_hoge(trainer): print(message) return _print_hoge 呼び出し側 (main.py) from lib.print_hoge import * (略) trainer = training.Trainer(updater, (30, 'epoch'), out=args.output) trainer.extend(print_hoge("fuga"))
後はextension名と中身の処理を好きなように書き換えればokです。
以下はもう少し詳しく知りたい方向け
書き方のポイントは、関数print_hoge()
が関数「_print_hoge
」自体を返り値に返却していることです。
実際、「関数を返却する関数」を定義しなくても
print_hoge.py @training.make_extension(trigger=(1, 'epoch')) def print_hoge(trainer): print("hoge")
と定義して、呼び出し側(main.py)で
from lib.print_hoge import * (略) trainer = training.Trainer(updater, (30, 'epoch'), out=args.output) trainer.extend(print_hoge)
とtrainer.extend
に関数を直接渡せば最初の例と同じように動作します。
(trainer.extend(print_hoge())
ではなくtrainer.extend(print_hoge)
となっていることに注意してください)
上記のようにも書けるのですが、chainerの公式コードでは一度「関数を返す関数(ないしクラス)」を使う実装がなされています。
これは外側の関数(print_hoge
)を関数と内側の関数(_print_hoge
)の間で変数を定義することで_print_hoge
で参照する変数のクロージャを作ることができ、色々便利だからかと思います。
また、
@training.make_extension(trigger=(1, 'epoch')) def print_hoge(trainer): print("hoge") @training.make_extension(trigger=(1, 'epoch'))
という処理も一見何か難しいことをしているように見えますが、
make_extension
関数内で行われていることは下記とほぼ同義です。
def sample_recog(trainer): print("hoge") sample_recog.trigger=(1, 'iteration')
つまり、極端な話
def sample_recog(): def _sample_recog(trainer): print("hoge") _sample_recog.trigger=(1, 'iteration') return _sample_recog
このように定義してもmake_extensionを使ったときと同じように動きます。