読者です 読者をやめる 読者になる 読者になる

生成モデルpix2pixを動かしてみる

pix2pix is 何

  • 2016年11月に発表された、任意の画像を入力にして、それを何らかの形で加工して出力する、というある種の条件付きGAN。
  • GANって何: 画像等のデータ入力を真似て偽造する生成器と、そのデータが生成されたものか本物かを識別する鑑別器を互いに競わせるように訓練することで、本物によく似たなデータを作り出せるようにする「生成モデルおよびその訓練手法」

    • 2014年に公表された当初は訓練の不安定性の問題が大きかったが、Batch Normalizationの導入や条件付けなど安定を増すノウハウが蓄積され、ここ2年注目を浴びている。
    • 先週GAN / DCGANについてまとめてみたエントリを書いたので参考になるかも
  • 「画像を入力にして画像を出力にする」ようなタスクは世の中に無数に存在していて、その潜在的な適用範囲の広さが特徴。(着彩、塗り分け、単純化etc..)

  • 下記はpix2pixサイトのサンプル。 f:id:mizti:20161217233325p:plain

  • より詳細なサンプル例は下記で見られる

pix2pixの構成

  • 大きく3つのネットワークから構成される:

    • Encoder: 画像の畳み込みにより、入力画像の特徴量を圧縮する。
    • Decoder: Encoderで圧縮された特徴量を逆畳み込み(転置畳み込み)x6層したあと畳み込みx1層により画像に変換する
    • Discriminator: 「入力画像」と「真の出力サンプルまたはDecoderの出力」2つの画像入力を行う、「本物らしさ」を出力する

    • (追記) Decoder  n - i 層は、直前の層の他に  i 層のEncoderの出力も同時に受け取る(U-Netというらしい)

    • (追記) Discriminatorは、"Patch GAN"と名付けられているようだけどちゃんと読み解けてないです...
  • 3つのネットワークの損失関数:

    • Encoder: 「生成された偽画像と真の画像の差異」と「Enc->Decが出力した画像がDiscriminatorに偽物と思われた度合いのsoftplus->バッチ・画素数平均」の重み付き和
    • Decoder: Encoderの損失関数と同じ
    • Discriminator: 本物の画像(t_out)を偽物と判定した度合い(softplus->バッチ画素平均)と偽物の画像(x_out)を本物と判定した度合いの和

動かしてみる

例によってMattyaさんが例によって神速でChainer実装しているので、今日はそのまま動かしてみる。

github.com

  1. git cloneする
  2. 訓練用データとなるFacadeという様々な建物の前面部分写真/構成情報のデータを落としておく。 CMP Facade Database
  3. データの場所、出力の場所、処理するGPU番号等をオプションで指定し起動する

と、たったこれだけ。 (前提環境を整えるのは先回のエントリの記載内容。ただし、chainerは1.19以上が必要)

  • 訓練条件は、デフォルトのまま
    • ミニバッチサイズ1
    • 300枚の画像セットについてランダムな順番で300イテレーションで1エポック
    • エポック数200で全訓練完了

結果

入力画像 / 生成画像 / 正解画像

f:id:mizti:20161218015605p:plain:w180 f:id:mizti:20161218015414p:plain:w180 f:id:mizti:20161218015423p:plain:w180

入力画像 / 生成画像 / 正解画像

f:id:mizti:20161218015658p:plain:w180 f:id:mizti:20161218015712p:plain:w180 f:id:mizti:20161218015706p:plain:w180

入力画像 / 生成画像 / 正解画像

f:id:mizti:20161218015817p:plain:w180 f:id:mizti:20161218015825p:plain:w180 f:id:mizti:20161218015838p:plain:w180

入力画像 / 生成画像 / 正解画像

f:id:mizti:20161218015923p:plain:w180 f:id:mizti:20161218015933p:plain:w180 f:id:mizti:20161218015936p:plain:w180

入力画像 / 生成画像 / 正解画像

f:id:mizti:20161218020413p:plain:w180 f:id:mizti:20161218020017p:plain:w180 f:id:mizti:20161218020019p:plain:w180

(生成し損ねた..)

  • 損失関数の推移

f:id:mizti:20161218113548p:plain

思ったこととか

  • 1万回目くらいで既に建物にしか見えない画像を生成できるようになってる。まずこれがすごい

  • Enc / Decの損失関数って、「正解画像との一致度」「Discriminator騙せた度合い」両方損失として使えそうだけど どうやるんだろう... > からの「重みづけして足すだけ」という分かりやすさ。この重み変えると学習結果にどう影響するのか興味ある。 きっと重要なハイパーパラメータ。

  • 回が進むにつれて、画像による明度差が生まれやすくなっている。

    • 色相の差が生まれづらいのは、入力画像と建物色の間にはっきりした相関関係が無いから?
      • じゃ何で明度差を生み出せたんだろう(?)
  • まっすぐあるべき線がまっすぐにならないのは他の生成手法でもよくある。(ニューラルネットワークって「まっすぐな線」引くの苦手ですよね...)生成画像の黒ずみと相俟ってスラムみたいな印象に...

  • 損失関数の推移だけみると、まだ学習が収束してないのでepoch数を増やしても良さそう

  • Mattyaさんの実装で使われていた「CBR層」(畳み込み/BN/ReLU/Dropoutをセットにした層)が便利。

    • up / down指定するだけで畳み込み / 逆畳み込み両方に使える(!)

新しいタスクへの適用

  • 新しいデータセットにしたいと思って、pixivからイラスト拾い集めてみている(やっと500枚集まった)。うまくいけば着彩タスクを作れるかもしれない
    • 未着彩の線画提供 / 着彩させてblogに載せさせてやってもいいぜっていう絵師さんいませんか