seabornによる統計データ可視化(ポケモン種族値を例に)(2)

前記事の続きです

mizti.hatenablog.com

3. Regression plots

2つの系列のデータを受け取り、回帰や残差を可視化します。回帰を可視化するregplotと残差を可視化するresiplotがあります。
regplot
plt.figure(figsize=(9, 9))
ax = sns.regplot(x=(pkmn["Total"] - pkmn["Sp. Atk"]), y=(pkmn["Total"] - pkmn["Attack"]))
d = ax.set(xlabel='Physical-Attacker Status', ylabel='Sp-Attacker Status') # d is just for supressing output
f:id:mizti:20171118193647p:plain 
regplotでは与えられた2つのデータから散布図を描き、回帰線を引きます。 散布プロットと回帰線の有無はそれぞれscatter / fit_regパラメータで切り替え可能です。
上記ではステータス合計値 "Total" から特殊攻撃力/攻撃力を引いた値をそれぞれX軸/Y軸とし、物理攻撃アタッカー性能/特殊アタッカー性能としてプロットしてみました。
デフォルトでは回帰の方法は線形回帰(order=1の多項式回帰)となっていますが、多項式回帰の次数を増やすことができます。また、yに真偽値を入れてlogistic=Trueを指定することでロジスティック回帰を用いることも可能です。
plt.figure(figsize=(9, 9))
ax = sns.regplot(x=pkmn["Total"] , y=pkmn["Legendary"], logistic=True)
 f:id:mizti:20171118193424p:plain
伝説のポケモン(y=1.0)とそれ以外(y=0.0)で分類し、ロジスティック回帰を行いました。
回帰線の周囲にある薄い色付きの領域は信頼区間で、デフォルトでは95%です。ciパラメータで%を指定することができます。今回のデータの場合は全てのサンプルが網羅されているのであまり意味がありませんが、母集団からの標本によって作ったデータの場合は、100回のサンプリング群のうち95回はこのピンク色の範囲に回帰線が含まれるだろう、と見ることができます。 
residplot
plt.figure(figsize=(9, 9))
ax = sns.residplot(x=(pkmn["Total"] - pkmn["Sp. Atk"]), y=(pkmn["Total"] - pkmn["Attack"]))
d = ax.set(xlabel='Physical-Attacker Status', ylabel="distance from regression line") # d is just for supressing output
f:id:mizti:20171118193306p:plain 
residplotは残差を散布図としてプロットします。(残差、とは予測値とデータの離れ具合を指す統計上の用語です)
すこし分かり辛いですが、y=0の線が線形回帰線にあたり、そこからどのくらい離れているかがy軸になっています。 何らかの回帰を行ったあとで、回帰線からの離れ方がx軸の系列とどのような関係にあるか見るのに適しています。
ここでは、上記のregplotの最初で書いた線形回帰と同じデータの残差を可視化してみました。
線形回帰だとregplotを見れば普通は十分ですが、多項式回帰などで回帰線が複雑な曲線を描くような場合には残差の分布がわかりづらくなるため、有用に使えると思います。 

4. Matrix plots

Matrix plotsは、複数のカテゴリデータ同士の関係を可視化します。
heatmap
 
s2 = pkmn["Type 2"]
labels2, levels2 = pd.factorize(s2)
labels1 = levels2.get_indexer(pkmn["Type 1"])
s = np.zeros*1
for i in range(len(labels1)):
    if(labels2[i]<0): # assign "N/A" as 18th in table
        labels2[i] = 18
    s[labels1[i] ,labels2[i]] += 1
 
plt.figure(figsize=(11,9))
ax = sns.heatmap( s, vmin=0, vmax=7, annot=True, cmap="magma", xticklabels=levels2, yticklabels=levels2)
d = ax.set(xlabel='Type 2', ylabel='Type 1') # d is just for supressing output 

f:id:mizti:20171118193140p:plain

heatmapはその名の通り受け取った配列をヒートマップで可視化します。
今回はタイプ同士の組み合わせ分布を可視化してみました。
データ解析周りでheatmapを一番良く使うのは、分類問題で予測ラベル/正解ラベルの組み合わせを可視化するケースだと思います。(この利用方法の場合、confusion matrixなどとも呼ばれます)
予測が正解に完全に一致していれば左上から右下に綺麗に赤いが一直線に並び、そこからはみ出した部分は予想誤りということになります。
色温度の組み合わせは用途に合わせて https://matplotlib.org/examples/color/colormaps_reference.html から自由に選択してください。
clustermap
import random
st = pkmn.loc[:, ['HP', 'Defense', 'Sp. Def']].sample(30, random_state=179)
plt.figure(figsize=(11,9))
ax = sns.clustermap(st, metric='euclidean', method='average') 

f:id:mizti:20171118193102p:plain

clustermapは、指定されたデータセットについて類似度によってクラスタリングを行った結果をデンドログラムで表示します。(各データについてはヒートマップのように色付けを行いながら)
例ではポケモンのHP / 防御力 / 特殊防御力を基準として類似度の近い者同士が近くなるように描画しています。 あまり数や次元が多いと計算に時間がかかるため、ランダムに抜き出した30種類に絞っています。
metricオプションによって、各データをベクトルと見做したときの、二つのデータ間の類似度の測り方を指定できます。例えば、 metric='euclidean' であればユークリッド距離、 metric='cosine'であればコサイン類似度によって、といった具合です。
また、methodオプションではクラスター同士をどのように結合していくかの規則を指定できます。 例えば、method='single'であれば、あるクラスターを結合する先を選ぶ際に、両方の含まれる最も近い点同士の距離が最も小さいクラスターと、method='average'であればクラスターに含まれる全てのデータ同士の距離の平均が最も小さいクラスターと結合されます。
実際に活用する場合では、素のデータをそのままクラスタリングするのではなく、何らかの形で特徴量抽出を行った上でクラスタリングすることが多いと思います。
 

5. Timeseries plots

tsplotの一種類のみです。時系列データを描画します。
tsplot
v = ["HP","Attack","Sp. Atk","Defense", "Sp. Def", "Speed"]
mean = pkmn.groupby(["Generation"])["HP","Attack","Sp. Atk","Defense", "Sp. Def", "Speed"].mean()
#std = pkmn.groupby(["Generation"])["HP","Attack","Sp. Atk","Defense", "Sp. Def", "Speed"].std()
plt.figure(figsize=(13,9))
ax = sns.tsplot(mean["HP"], color="Green")
ax = sns.tsplot(mean["Attack"], color="Red")
ax = sns.tsplot(mean["Sp. Atk"], color="Purple")
ax = sns.tsplot(mean["Defense"], color="Grey")
ax = sns.tsplot(mean["Sp. Def"], color="Blue")
ax = sns.tsplot(mean["Speed"], color="Orange") 

f:id:mizti:20171118193017p:plain

tsplotは文字通り時系列を可視化するためのplotです。
あまり時系列になりそうなものがなかったのですが、一応世代を時系列とみなして各ステータスの平均の推移をプロットしてみました。 これだけではあんまりなので、もう一つ、データを生成してプロットしてみます。
x=np.linspace(0,16,32)
sin = np.sin(x) + np.random.normal(0, 1.2,  (16,32))
cos = np.cos(x) + np.random.rand(16, 32) + np.random.randn(16, 1)
 
plt.figure(figsize=(13,9))
 
ax = sns.tsplot(sin,color='Green')
ax = sns.tsplot(cos,color='Reds')

f:id:mizti:20171118192947p:plain

それぞれsin, cosに対して正規分布や一様ランダムを加えて動きを与えたプロットです。 今回の場合で先程と異なるのは、各時系列に対して各系列16ずつの値のばらつきを持っていることです。このような場合、実線は各時の平均となり、信頼区間は色付きで表されます。(デフォルトでは68%信頼区間)

6. Miscellaneous plots

文字通り、「その他」のプロット。palplot一つしかありません。 
palplot
plt.figure(figsize=(15, 8))
ax = sns.palplot(sns.color_palette(n_colors=24), 2)
ax = sns.palplot(sns.color_palette("Set1", 24))
ax = sns.palplot(sns.color_palette("terrain", 5))
ax = sns.palplot(sns.color_palette("terrain", 8))
ax = sns.palplot(sns.color_palette("terrain", 15))
ax = sns.palplot(sns.color_palette("inferno", 15))

f:id:mizti:20171118192902p:plain

f:id:mizti:20171118192851p:plain

f:id:mizti:20171118192822p:plain 

f:id:mizti:20171118192752p:plain

f:id:mizti:20171118192747p:plain

f:id:mizti:20171118192734p:plain

palplotはデータ可視化用のプロットではなく、現在設定されているパレットの内容を確認するためのプロットです。 特に他者にデータを可視化する場合、見やすかったりデータのイメージに合った色を選ぶのが重要になるので色の調整等に 使いましょう
 

7. Axis grids

最後はAxis gridsカテゴリです。
このカテゴリは若干特殊で、複数の変数の組み合わせに対して今までにあげたplotを描くための枠を用意し、一度にグラフを描けるseabornを"便利に"使うためのツールがまとまっています。 このため、このカテゴリは上記で見てきたplotをオプションで指定して利用するものが多く含まれています。 
Facet Grid
plt.figure(figsize=(16,16))
g = sns.FacetGrid(pkmn, col="Generation")
 f:id:mizti:20171118192614p:plain FacetGridは、指定した要素ごとにプロット領域を定義するクラスです。これ単体ではプロットは行われません。 (matplotlibに馴染みのある人であればsubplotをまとめて定義してくれる、と表現すればわかりやすいかもしれません) FacetGridは後述のfactorplot, lmplotで継承されて利用されているようです。
rowパラメータにより横軸に展開する要素を、colパラメータで縦軸に展開する要素を指定できます。
In [32]:
plt.figure(figsize=(16,16))
g = sns.FacetGrid(pkmn, col="Generation", row="Legendary")
g = g.map(plt.hist, "HP")

f:id:mizti:20171118192539p:plain

FacetGridを代入した変数(上記ではg)にmapを適用することによって、matplotlib等でグラフを描画することもできます。
上の例では上段が非伝説、下段が伝説で、左から順にGenerationごとのHPの分布を可視化しています。 
factorplot
plt.figure(figsize=(16,16))
g = sns.factorplot(x="Generation", y="Speed", hue="Legendary", data=pkmn, kind="box")
 f:id:mizti:20171118192327p:plain
上の方の例でviolineplotやpointplotで、複数系列を同時に描画した例がありました。factorplotは、「複数の系列を同時に描画する」という目的でseabornの各種プロットを整理して呼び出せるようにしたもの、と考えられます。
対象にしたい連続値が一つで、多数のカテゴリに分類して状況を把握したい場合はこのプロットを使うと効率的に可視化できます。
plt.figure(figsize=(16,16))
g = sns.factorplot(x="Generation", y="Sp. Atk", hue="Legendary", data=pkmn, kind="lv", col="Type 1", col_wrap=3 

f:id:mizti:20171118192216p:plain

 
タイプ毎に分けたプロット領域に、世代/伝説であるかどうかを分けてグラフを一度に描画しています。
colに"Type 1"を指定してタイプごとに分かれたプロット領域を一直線に並べ、 col_wrapを4に指定することで、4つごとに改行させています。
プロットのタイプにはpoint, bar, count, box, violin, strip, lvといった、Categorical Plotのメソッドを選択することができます。
 
lmplot
plt.figure(figsize=(16,16))
g = sns.lmplot(x="Attack", y="Sp. Atk", data=pkmn, col="Generation", col_wrap=3)
f:id:mizti:20171118192117p:plain 
先程のfactorplotがCategorical plotsを一度に展開するものだったとすれば、lmplotはRegression plotsを指定した要素分類ごとに一度に描画するもの、と考えて良いと思います。
上記では世代毎のAttackとSp. Atkの回帰線付きの散布図にしました。col, row, col_wrapあたりの使い方はfactorplotと共通しています。また、回帰アルゴリズムの選択方法等はregplotと同様です。
Pairgrid & pairplot
実質的にはPairgridを継承したpairplotを使うことが多いと思われるため、一緒に記載します。 Facet Gridと同様、Pairgridで描画したsubplotにmatplotlibを使って直接様々なグラフを描画することも可能です
plt.figure(figsize=(16,16))
#d = pkmn.loc[pkmn.columns.isin(["HP", "Attack", "Sp. Atk", "Defense", "Sp. Def", "Speed"])]
d = pkmn[["HP", "Attack", "Sp. Atk", "Defense", "Sp. Def", "Speed"]]
g = sns.pairplot(data=d)
 

f:id:mizti:20171118192044p:plain

pairplotは与えられたデータの変数の要素の交差表を自動的に作成し、
  • 同一データ同士の交差する点(対角線)についてはヒストグラム
  • 異なるデータ同士が交差する点については散布図
を描画します。
上記では6種のステータスを取り出し、pairplotをしています。 異なるステータスが重なっている箇所では散布図が、同じステータス同士が重なっている箇所ではヒストグラムが描画されているのがわかると思います。
また、ヒストグラムと散布図が描画されている箇所はそれぞれkde、regressionで置き換えることも可能です。
pairplotは素性の分からない多数の種類のデータが与えられた時に、それぞれのデータ系列同士の相関性などを確認する上でとても有用になります。
In [37]:
plt.figure(figsize=(16,16))
d = pkmn[["HP", "Attack", "Sp. Atk", "Defense", "Sp. Def", "Speed"]]
g = sns.pairplot(data=d, kind="reg", diag_kind="kde")
 f:id:mizti:20171118191824p:plain
 
JointGrid & jointplot
JointGridとjointplotもPairGrid&pairplotと同様、一緒に記載します。 FacetGridやPairGirdと同様、JointGridで作ったsubplotにmatplotlibでグラフを描画することも可能ですが、本稿はseabornに限定するため省略します。
plt.figure(figsize=(16,16))
g = sns.jointplot(x = pkmn["HP"], y = pkmn["Attack"]) 

f:id:mizti:20171118191947p:plain

jointplotは2種類の連続値を受け取り、それぞれ片方ずつのヒストグラムと、両方を組み合わせた散布図を描画します。
散布図が描かれている部分については、kindによって各種散布図(“scatter” | “reg” | “resid” | “kde” | “hex”)を パターン指定できます。
plt.figure(figsize=(16,16))
g = sns.jointplot(x = pkmn["Defense"], y = pkmn["Sp. Def"], kind="hex") .set_axis_labels("Defense", "Sp.Defense")
 

f:id:mizti:20171118192008p:plain

 

*1:19,19

seabornによる統計データ可視化(ポケモン種族値を例に)(1)

データの可視化をまとめて学んでおこうと思って書きました。

はじめに

データ分析はデータの可視化から

機械学習や統計分析をするに当たって、データの可視化は
  • 対象のデータに対して洞察を深める
  • 処理の結果を評価する
  • 成果を分かりやすく他人に説明する
など、様々な局面で重要になります。
KaggleのKenel (分析/処理の過程をまとめたもの) をみても対象のデータに対する洞察を行う過程が全体の半分以上を占めていることが少なくありません。データを正しく可視化することは、データ分析や機械学習全般の土台にあたる作業です。
今回は、データの統計的可視化でよく使われるライブラリ "Seaborn" を用いてよく使う可視化パターンについてまとめてみます。

環境とデータ

実行環境にはKaglleのKernelを使いますが、オープンソースライブラリJupyterを使えばほぼおなじことが可能です。 また、ちょうど最近GoogleのG-Suiteで 公開されたColaboratory というツールでも同じように動くと思います。
また、データにはkaggleで公開されている
を使います。
これはポケモンのステータスのデータセットです。今回の場合、データの洞察そのものが目的ではなく可視化方法の整理が目的なので、前処理無しで使えてできるだけ平易なデータを用いました。(あと、みんなもうアヤメの花びらの長さや幅は飽きてると思うので) 

準備

必要となるライブラリの読み込みます。
  • numpy: 行列式を扱うためのライブラリ
  • pandas: csv形式のような表データを扱うためのライブラリ
  • matplotlib: グラフを描画する基本となるライブラリ(seabornはmatplotlibのラッパーとして動作)
  • seaborn: 今回のメインとなる統計データをグラフ化するライブラリ
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns 
データを読み込んで利用準備。
pkmn = pd.read_csv('../input/Pokemon.csv') 
データ最初の何例かを取り出して様子を眺める。
pkmn.head()
 
 
#
Name
Type 1
Type 2
Total
HP
Attack
Defense
Sp. Atk
Sp. Def
Speed
Generation
Legendary
0
1
Bulbasaur
Grass
Poison
318
45
49
49
65
65
45
1
False
1
2
Ivysaur
Grass
Poison
405
60
62
63
80
80
60
1
False
2
3
Venusaur
Grass
Poison
525
80
82
83
100
100
80
1
False
3
3
VenusaurMega Venusaur
Grass
Poison
625
80
100
123
122
120
80
1
False
4
4
Charmander
Fire
NaN
309
39
52
43
60
50
65
1
False
 
ID、名前、タイプ1、タイプ2、ステータス合計、HP、攻撃力、防御力、特攻、特防、素早さ、初登場世代、伝説のポケモンであるか否かという構成になってることがわかります。
以下、Seabornでデータを可視化していきます。 

Seabornのメソッドやクラスの分類について

SeabornはMatplotlibのラッパーライブラリで、Matplotlibに比べて直感的にグラフを描くことができます。
上記を見ると、大まかに「〜Grid」というクラスと、「〜plot」というメソッドに分かれているのがわかると思います。 〜plotは変数やオプションを与えて簡易にグラフを描画することができます。さらにグラフを精密に作り込みたいときは〜Gridクラスを使って枠を作り、.plotメソッドでデータ描画していくという生のmatplotlibに近い使い方が出来るようになっています。
また、各メソッドは大まかな適用対象/目的ごとに大分類されているようです。 以下では、グラフ描画メソッドを大分類ごとに使っていきます。(説明の流れ上、Axis Gridは一番最後に回しています)

1. Categorical plots

対象のデータが二つの変数を持ち、片方がカテゴリ別でもう片方が連続値を持つ場合の描画に用います。(countplotを除く) 以下では、各世代(タイトルナンバリング))ごとにHPの分布を可視化していきます。 
stripplot
sns.set_style("whitegrid")
sns.set_palette("husl")
sns.set_context("notebook")
plt.figure(figsize=(15, 8))
ax = sns.stripplot(x=pkmn["Generation"], y =pkmn["HP"])

f:id:mizti:20171118185619p:plain

 
stripplotは素朴にstrip(細い端切れ)のように一直線上にデータをプロットします。
 
少数のサンプルに対して素朴なカテゴリごとのデータ分布を見るには良いですが、上でいう50 ~ 100の間など沢山の点が重なっている部分ではどのくらいの重なりがあるのか分かり辛いことが多いです。また、中央値等の所在も分かり辛いです。
これ単体で積極的に使うことは無いですが、他のグラフと組み合わせて使うことがあります。
第一世代と第二世代でHPが外れ値のように大きくなってるのは
pkmn[pkmn["HP"]>200]
 
#
Name
Type 1
Type 2
Total
HP
Attack
Defense
Sp. Atk
Sp. Def
Speed
Generation
Legendary
121
113
Chansey
Normal
NaN
450
250
5
5
35
105
50
1
False
261
242
Blissey
Normal
NaN
540
255
10
10
75
135
55
2
False
 
ということでChansey(和名: ラッキー) とBlissey(和名: ハピナス)のようです。実際のデータ分析でも、このようにグラフから外れ値の有無/内容を確認していく作業は重要になります。
 
swarmplot
plt.figure(figsize=(15, 8))
ax = sns.swarmplot(x=pkmn["Generation"], y =pkmn["HP"])
 

f:id:mizti:20171118190101p:plain

 
同じデータを可視化していきます。swarmplotはswarm(群れ)というだけあって、まるでバッファローの群れのようにパラメータの同じサンプルの分布具合が見て取れます。直感的に分布を見て取れるのが大きな利点です。
boxplot
plt.figure(figsize=(15, 8))
ax = sns.boxplot(x=pkmn["Generation"], y =pkmn["HP"])
 

f:id:mizti:20171118190201p:plain

boxplotはswarmplotほど直感的ではないですが、重要な統計情報がまとまっています。

f:id:mizti:20171118190630p:plain

中心線は平均ではなく中央値です。
それぞれのカテゴリ毎の絶対数に興味がなく、分布の偏りに対する興味が強い場合に使うと良さそうです。
violinlot
plt.figure(figsize=(15, 8))
ax = sns.violinplot(x=pkmn["Generation"], y =pkmn["HP"]) 

f:id:mizti:20171118190758p:plain

一見、変な形のグラフに見えますがよく見るとswarmplotの直感的な部分とboxplotの分布が定量的にが読みやすい長所を兼ね備えた優れたグラフです。名前の通りバイオリンのような形をしてます。
ただし、平滑化の過程で本来存在しない値のサンプルが存在するように見えてしまうことがあるので注意が必要です。 (例えば、[9,10,10,11,99,100,101]のような極端なデータを普通にviolinplotすると-50近辺までのサンプルが存在するかのようなグラフが描かれてしまいます)
また、Categorical plots全般ですが、hueパラメータに特定のフラグを与えることで、さらに系列を分割して比較することができます。
plt.figure(figsize=(15, 8))
ax = sns.violinplot(x=pkmn["Generation"], y =pkmn["HP"], hue=pkmn["Legendary"], split=True) 

f:id:mizti:20171118190826p:plain

例として、普通のポケモンのHP分布(左側)と伝説のポケモンのHP分布(右側)を比べてみました。伝説のポケモンは基本的にHPが高く設定されているようです。 
また、vionlinplotに限らずSeabornでは描くグラフは任意の他のグラフと同時に描画することができます。
plt.figure(figsize=(15, 8))
ax = sns.violinplot(x=pkmn["Generation"], y =pkmn["HP"], inner=None, color="0.95", linewidth=0.3)
ax = sns.swarmplot(x=pkmn["Generation"], y =pkmn["HP"]) 
 
swarmplotとviolinplotを組み合わせました。見やすくするためにviolinplotの色や線の太さを調整しています。 

f:id:mizti:20171118190936p:plain

lvplot
plt.figure(figsize=(15, 8))
ax = sns.lvplot(x=pkmn["Generation"], y =pkmn["HP"]) 

f:id:mizti:20171118191005p:plain

箱ひげ図と似た感じのletter value plotという方式のプロットを行います。 "letter value plot"は割りと最近になって提唱されたboxplotの改良版にあたるグラフ描画手法です。 (http://vita.had.co.nz/papers/letter-value-plot.html)
boxplotが比較的小規模なデータに対して手書きすることを前提に設計されており、大規模なデータセットに対して沢山の情報が抜け落ちてしまうという欠点を持っているのに対して、lvplotではより多くのletter value (要約値)を図に反映することができます。
実際、第3/1四分位数より上/下についても、細かに段階が付けられており大量のサンプルからなるデータの分布の特徴をboxplot以上に精密に読み取ることができます。
 
pointplot
plt.figure(figsize=(15, 8))
ax = sns.pointplot(x=pkmn["Generation"], y =pkmn["HP"]) 

f:id:mizti:20171118191117p:plain

pointplotは極めてシンプルな図です。 点が打たれている箇所が平均値(mean)で、線が引かれているのが95%信頼区間です。 (boxplotやlvplotは中央値(median)と第x分位数です)
95%信頼区間の95という数字や平均値の求め方については引数により制御することができます。
このグラフは情報がシンプルなので、複数の系列を比較するのに向いています。 例えば、
plt.figure(figsize=(15, 8))
ax = sns.pointplot(x=pkmn["Generation"], y =pkmn["HP"], hue=( pkmn["Type 1"].isin(["Grass"]) | pkmn["Type 2"].isin(["Grass"])), dodge=True)
 

f:id:mizti:20171118191346p:plain

草タイプ以外(左)と草タイプ(右)のHPの分布を比較できました。シリーズ全般を通じて、草タイプのポケモンのHPの平均は低めという傾向が読み取れます。
平均や信頼区間を簡単に比較できるので、pointplotの系列比較はわりとよく使われている気がします。 
barplot
plt.figure(figsize=(15, 8))
ax = sns.barplot(x=pkmn["Generation"], y =pkmn["HP"],  capsize=.2)
 
barplotは文字通り棒グラフです。
棒グラフの値は各データの平均値で、縦線が信頼区間なので、可視化されている情報としてはpointplotと代わりありません。ただ、見ての通り0を起点として描かれるので値全体に占める信頼区間の大きさやカテゴリ毎の平均差異が全体のうちでどのくらいなのか、という部分はpointplotより見やすいかと思います。
(capsizeを指定すると信頼区間に横棒が生えます。見やすくなるのでオススメ)
このグラフももちろんhueを指定してさらに系列分割できます。
countplot
plt.figure(figsize=(15, 8))
ax = sns.countplot(x=pkmn["Generation"])
 
 
countplotはここまでのCategorical Plotと異なり、単純に指定されたカテゴリに含まれるデータがいくつあるかをカウントします。(yを代わりに指定することもできますが、横棒グラフになるだけです。xと同時には指定できません)
上記ではGenerationを指定したので、世代毎のデータ数(つまりポケモン数)が棒グラフになっています。 このプロットは下記のようにタイプ別の数え分けに使うことが多いと思います。
plt.figure(figsize=(15, 8))
ax = sns.countplot(x=pkmn["Generation"], hue=pkmn["Type 1"], palette="Set1" )
 
"Type 1" で分類してカウントしました。色は見分けが付きやすいパレットを選択しています。 (ポケモンは属性を二つ持つので、実際には"Type 2"も勘案する必要がありますがここでは気にしない)

2. Distribution plots

カテゴリラベルではなく何らかの数量であるデータ(1変数または2変数)を可視化するために用います。 APIリファレンスと順番が異なりますが、説明のわかりやすさのためkdeplot, rugplotを先に記載し、distplotを後に記載します。 
kdeplot
plt.figure(figsize=(15, 8))
ax = sns.kdeplot(pkmn[pkmn["Legendary"]==False]["Total"])
 

f:id:mizti:20171118194411p:plain

変数列を1つ、または2つ受け取り、対象のデータを元にカーネル密度推定(KDE: Kernel Density Estimation)を行い、その結果をプロットします。(つまり、平たく言えば「変数の分布を確率分布に変換」します) 結果として確率密度関数が描かれることになるので、全体を積分した値は常に1となります。
ここでは伝説のポケモンを除くポケモンの、パラメータ値合計("Total")でplotしてみました。例えば、伝説でないポケモン全部から一体をランダムに選ぶと、このKDEで変換した確率密度関数で計算するとHP種族値が500ぴったりのポケモンを引く確率は約3.8%程度となります。
プロットの使い途としては、サンプルデータを確率密度関数に変換する時に分布の様子をみたり下記のようなパラメータを調整したりするために使うことが多いと思います。
KDEは以下のような式で算出される関数です。

f:id:mizti:20171118194330p:plain 

ただし、n: データ数, h: バンド幅, K: カーネル関数
バンド幅hhを小さい値にするほど分布の細かい特徴が反映されやすく、逆に大きくするほど平滑になります。(標準では自動的にバンド幅を選択するアルゴリズムが選択されています)
カーネル関数Kには標準ではガウスカーネルが選択されていますが、別のカーネル関数を指定することもできます。
# バンド幅を10と指定した場合
plt.figure(figsize=(15, 8))
ax = sns.kdeplot(pkmn[pkmn["Legendary"]==False]["Total"], bw="10")
f:id:mizti:20171118194220p:plain 
kdeplotでは、1変数ではなく2変数を指定することもできます。
plt.figure(figsize=(10, 10))
ax = sns.kdeplot(pkmn["Attack"], pkmn["Sp. Atk"], shade=True) 

f:id:mizti:20171118194112p:plain

攻撃力と特殊攻撃力を2変数としてみました。色の濃い所が確率密度の高いところです。(shadeを指定しないと等高線表示になります)
この場合、値が大きい方に行くに従ってSpかAtkのどちらかに分かれていき(相関が弱くなっていき)、全体が三角形に近い形になっている様子を見て取ることができます。
 
rugplot
plt.figure(figsize=(15, 8))
ax = sns.rugplot(pkmn[pkmn["Legendary"]==False]["Total"])
f:id:mizti:20171118194009p:plain 
rugplotはカテゴリが一つしかないときのstripplotと同じようなもので、各データを単に一本一本の棒で描画していきます。
kdeの最初の例と同じく伝説以外のポケモンのステータス合計値をデータとしました。
rugplotは単体でつかうことはあまりないと思います。後述のdistplotなどと合わせて補助的にデータの分布を可視化するたに使うことが多いと思います。
distplot
plt.figure(figsize=(15, 8))
ax = sns.distplot(pkmn[pkmn["Legendary"]==False]["Total"])

f:id:mizti:20171118193917p:plain

distplotは指定した1変数の分布を可視化します。ここでは伝説のポケモンを除くポケモンの、パラメータ値合計("Total")でplotしてみました。
distplotはごくシンプルに指定された変数のヒストグラムを描画し、合わせてkdeやrugを一緒に描画してくれます。 kdeやrug, ヒストグラムはそれぞれオプションでon / offできます。
plt.figure(figsize=(15, 8))
ax = sns.distplot(pkmn[pkmn["Legendary"]==False]["Total"], kde=False, rug=True)
 
f:id:mizti:20171118193834p:plain

つづきはこちらです

mizti.hatenablog.com

chainer: トレーニングモジュールの拡張方法まとめ

chainerのチュートリアルをこなしたあと、

  • mnistなどライブラリが標準で用意してる以外の独自データセットをどうやって作ったら良いんだろう?
  • 複数のモデルが絡むネットワークを同時並行に訓練/評価するにはどうすれば?

といった点で困ることが多くありました。

結果的にはChainerのトレーニング関係のコンポーネントを拡張/自作するのが良さそうなのですが、 それぞれにちょっとずつポイントが有ります。ここでは、トレーニング関連のオブジェクトを概観しながら 今までに書いてきたエントリをまとめます。(今後も増えるかも..)

chainerのトレーニング関連コンポーネント概観

f:id:mizti:20171025204322p:plain

Trainer:

登録されたUpdaterを使って指定したepoch数やイテレーション数分だけ訓練を実行します。

Extensions:

Trainerが実行される過程で付属的に実行される処理を規定します。

mizti.hatenablog.com

モデルの評価を行うEvaluatorもExtensionの一種として定義されます。

mizti.hatenablog.com

Updater:

Datasetを含むIterator、訓練対象のモデルオブジェクトと勾配更新方法を規定するOptimizerを受け取って、 具体的にどのようにモデル群を更新していくかを規定します。

mizti.hatenablog.com

Optimizer:

SGD, Adamなど勾配計算後のパラメータ更新方法を規定します。

Model (Link / Chain)

ニューラルネットワーク(NN)の本体です。基本的にはLinkやFunctionを積み重ねてChainでラッピングしNNを構成していきます。 (Chain自体もLinkを継承して作られたLinkオブジェクトの一種なため、ChainとChainを繋いだりChainとLinkを繋いだりもできます)

Datasetと合わせてLinkやChainで作られるChainerのNNの概要を図にまとめてみました。

f:id:mizti:20171025224939p:plain

Iterator:

下記のデータセットを受け取り、ミニバッチサイズや取り出し順のポリシーを設定します。

Dataset:

画像などのデータとラベルの対を多数保持します。下記のエントリーでまとめたように、 Datasetクラスを継承することで極めて簡単かつ柔軟に定義できます

mizti.hatenablog.com

chainer: Evaluatorを自作してトレーニング中のモデルの評価を柔軟に行う

Evaluatorとは

f:id:mizti:20171021163105p:plain

DNNの訓練を行う中でモデルの訓練が意図通り進んでいるかを評価したくなることが多いと思います。

Chainerでは定義したモデルの訓練を行う際にそのモデルの評価を行うための仕組みとしてEvaluatorという 仕組みを持っています。

このEvaluatorは 以前解説したExtensionの一種として作られています。

本質的にはExtensionを自分で自作すればモデルの評価ももちろん可能なのですが、 Evaluatorを継承してカスタマイズすることでaccuracyやlossの計上、イテレータの状態管理などをスマートに行うことができます。

chainer 標準のEvaluator

まず、継承元になるEvaluatorの作りを確認してみましょう。 Evaluatorについては、chainer標準のextensions.Evaluatorがかなり汎用的に作られており、自作せずに済むのならそれに越したことはないので。

Trainer extensions — Chainer 2.0.0 documentation

Evaluatorは、まず下記のようなオブジェクトを受け取ります。

  • iterator: 評価用のデータセット、ミニバッチサイズ等が設定されたイテレータオブジェクト
  • target: 評価対象となるモデル、もしくはモデルの列挙されたdict。
  • converter: イテレータから取り出した(データ, ラベル)のタプルを訓練用のミニバッチに変換する関数
  • device: 評価計算を行うために利用するGPU番号
  • eval_hook: 評価前に実行される関数(なくてもok)
  • eval_func: 評価を行うために呼び出される関数 (指定されない場合、targetに渡したモデルのcallが代わりに利用される)

chainer.training.extensions.evaluator — Chainer 2.0.0 documentation

で、それぞれ主要なメソッドの動作をざっくり確認すると

  • __init__ :
    • 渡された引数をインスタンス変数として格納する。
    • 特に、targetにモデルのdictではなく単一のLinkが渡された場合、そのlinkを"main"という名前で辞書登録しなおす
  • __call__:
    • Reporter objectを作成して、targetとして渡された各リンクのを監視対象に指定する。
    • evaluateメソッドを呼び出し、その結果をreporterを使ってreportする。 (__call__の戻り値は参照されておらず、reporter_module.reportに渡したdictが印字対象な点に注意)
  • evaluate:
    • 渡されたiterator("main")からbatchを取り出す
    • 渡されたモデル("main")、もしくはeval_funcにbatchから取り出したデータ/ラベルを入力する(ここでobserverにaccuracy/lossが記録される)
    • observerに書かれたaccuracy/lossをsummaryに蓄積。
      • summaryはDictSummaryのインスタンス。DictSummaryはキー毎に投入された値の回数や平均値、二乗和を蓄積でき、平均や分散などの統計値を取り出せます
    • 最後にsummaryに蓄積された値を平均して呼び出し元(__call__)に返却

ということをしています。

つまり、標準のEvaluatorではこういうことが可能です。

  • 単一のモデルと単一イテレータの評価。
  • モデルのcallを評価対象にしても良いし、eval_funcで関数を渡して評価に使っても良い

標準のEvaluatorは複数のモデルをtargetに辞書として受け取ることはできますが、 evaluateメソッドが'main'のみを用いて評価するようになっているため、複数のモデルを評価に用いることはできません。

Evaluatorを自作する

逆に標準のEvaluatorではできないこと、例えば

  • 複数のモデルやイテレータを使った評価(たとえばGANのGeneratorとDiscriminatorなど)
  • accuracyやloss以外の指標値の出力(一応、eval_funcを使えば可能ですが)

などがしたい場合にはEvaluatorを自作すると良いかと思います。 Updaterの自作をした場合には対応するEvaluatorを作りたいことが多いかと思います。

拡張例

様々な拡張方法があると思いますので、やりたいことベースでchainer標準のEvaluatorを継承して独自のEvaluatorを定義する幾つか例を挙げていきます。

①とにかく指定した値をログやレポートに表示させたい

印字させたい項目を項目名をkey、スカラー値をvalueに持つ辞書をreporter_module.reportに渡せば、 とりあえず指定した値をログやレポートに表示させられます。

from chainer import reporter as reporter_module
from chainer.training import extensions

class MyEvaluator(extensions.Evaluator):
    def __call__(self, trainer=None):
        result = {"hoge": 4, "piyo": 88}
        reporter_module.report(result)
        return None

出力:

    {
        (略)
        "hoge": 4.0,
        "piyo": 88.0
    }

chainerのReportで受け取る辞書(dict)の各値はスカラー値であることが必須です(文字列やリストは渡せません)。

②複数のモデルを用いて評価を行う

影響しあう複数のモデルを並列に訓練しているなどで評価を行いたい場合、 evaluateでtargetに指定するモデルをself._targetsから取り出す際に指定するorループで順に呼び出すなどすると良いと思います。

呼び出し元:

trainer.extend(MyEvaluator(test_iter, {"model1": model, "model2":model2}, device=args.gpu))

Evaluator側:

class MyEvaluator(extensions.Evaluator):
    default_name="myval"
    def evaluate(self):
        #target = self._targets['main']

        summary = reporter_module.DictSummary()
        for name, target in six.iteritems(self._targets):
            iterator = self._iterators['main']
            #target = self._targets['main']
            eval_func = self.eval_func or target

            if self.eval_hook:
                self.eval_hook(self)

            if hasattr(iterator, 'reset'):
                iterator.reset()
                it = iterator
            else:
                it = copy.copy(iterator)

            #summary = reporter_module.DictSummary()
            for batch in it:
                observation = {}
                with reporter_module.report_scope(observation):
                    in_arrays = self.converter(batch, self.device)
                    with function.no_backprop_mode():
                        if isinstance(in_arrays, tuple):
                            eval_func(*in_arrays)
                        elif isinstance(in_arrays, dict):
                            eval_func(**in_arrays)
                        else:
                            eval_func(in_arrays)

                summary.add(observation)
        return summary.compute_mean()

色々書いてあるように見えますが、元のevaluate()からの変更点は

  1. クラス変数default_nameを指定している(下記のログ出力のようにReporterがログ項目の接頭辞にしてくれます)
  2. target = self.targets['main']ではなくself.targetsからループで取り出すようにしている
  3. summaryの宣言をそのループの外側に書いた

だけです。

self._targetにはEvaluator定義時に指定したモデルが入っているのですが、

  • 単一のモデル渡す(そのモデルが"main"という名前で__init__内で辞書登録される)
  • モデルを辞書で渡す

のどちらでも良いようになっています。

このようにすることで

    {
         (略)
        "myval/model1/loss": 0.07286249771073926,
        "myval/model1/accuracy": 0.9748000055551529,
        "myval/model2/accuracy": 0.0888000001013279,
        "myval/model2/loss": 2.3258586740493774
    }

のように各モデルに対する評価を出力できます。

③モデルの評価中に独自指標値を出力

Inceptionのように一つのモデルから複数の出力がある場合など、accuracyとloss以外の指標を計測してログに出力したいことも多いと思います。そのような場合には

    def evaluate(self):
        iterator = self._iterators['main']
        target = self._targets['main']
        eval_func = self.eval_func or target

        if self.eval_hook:
            self.eval_hook(self)

        if hasattr(iterator, 'reset'):
            iterator.reset()
            it = iterator
        else:
            it = copy.copy(iterator)

        summary = reporter_module.DictSummary()

        for batch in it:
            observation = {}
            with reporter_module.report_scope(observation):
                in_arrays = self.converter(batch, self.device)
                with function.no_backprop_mode():
                    if isinstance(in_arrays, tuple):
                        eval_func(*in_arrays)
                    elif isinstance(in_arrays, dict):
                        eval_func(**in_arrays)
                    else:
                        eval_func(in_arrays)

            summary.add({MyEvaluator.default_name + '/currenttime': int(time.time())})
            print(observation)
            summary.add(observation)

        return summary.compute_mean()
  • summary.add({MyEvaluator.default_name + '/currenttime': int(time.time())})を足しています

これはreporter_module.report_scope(observation)のスコープ内でchainer.reporter.report(dict)が呼び出されると、 observationにdictが追加されるという仕組みを用いています

例ではUnixtimeをログに出していますが、モデルや出力に関する適切な数値を渡すことで 評価中のモデルについて都合の良い指標を出力できます。

(私の場合だと例えば、文字列を認識するモデルに対して正解文字列までの編集距離を出力するのに使っていました)

④GANのUpdater / Evaluator

下記の記事がGAN用のUpdater / Evaluatorの対の実装例になっているので、GANを実装したい方は参考に できると思います。

qiita.com


chainerで少し複雑なモデルを初めて扱うことになると評価をどうしようか迷うと思いますが(実際迷いました)、このようにEvaluatorを拡張することで柔軟に対処できるようになるかと思います。

chainer: Extensionを自作してディープラーニングの訓練に独自処理を挟み込む

なぜExtensionを自作するのか

Chainerのモデルのトレーニング中に、

  • たまにモデルの出力をダンプさせたい(特に生成系で必要になる)
  • 1エポック毎に学習率を手動で変更したい

など、独自の処理を定期的に挟みたくなることがあるかと思います。 このような願いを叶えるのがchainerのextensionという仕組みです。 extensionはtrainerに仕掛けておくことで指定した間隔ごとに独自に定義した処理を実行してくれます。

extensionはtrainerにひも付き、一つのtrainerにいくつでもextensionを設定することができます。

f:id:mizti:20170923205550p:plain

Extensionの自作方法

モデルを定期的に評価するEvaluatorもこのextensionを使って実装されています。 モデルの評価用途にはこのEvaluatorを拡張したほうが早いでしょう。 ここでは評価以外のより一般的にextensionを自作する方法にフォーカスします

例えば、公式ではこのようなExtensionが用意されています。よく使うのはこのあたりでしょうか。

Extension名 概要
extensions.Evaluator モデルの損失や正解率を評価する
extensions.snapshot モデルの重みをファイルに保存する
extensions.LogReport Evaluator等で測定した各種の評価値をログに出力する
extensions.PrintReport 指定した評価値を標準出力やファイルに出力する

これらの公式Extensionはチュートリアル等でも言及されていますが、 extensionは簡単に自作することも可能です。

簡単なextensionの例として、「"hoge"とpirntするだけ」のextensionを作ってみましょう。

main.py
   |
   + lib -- print_hoge.py

このようなディレクトリ構成で、 定義側(print_hoge.py):

print_hoge.py
def print_hoge():
    @training.make_extension(trigger=(1, 'epoch'))
    def _print_hoge(trainer):
        print("hoge")
    return _print_hoge

呼び出し側 (main.py):

from lib.print_hoge import *
(略)
trainer = training.Trainer(updater, (30, 'epoch'), out=args.output)
trainer.extend(print_hoge())

たったこれだけの記述で、1epochに一回、hogeと標準出力に出力されるようになります。

Extensionのカスタマイズ

動作する間隔を変えたければ

 - trigger=(1, 'iteration'): 1イテレーションに1回
 - trigger=(3, 'epoch'): 3エポックに一回

などtriggerを変更すればokです。 また、extensionの中でモデルを参照したければtrainer.update.modelとういうように、 trainerに登録されたupdater経由でアクセス可能です。

下記のように実行時に引数を与えることも可能です。

print_hoge.py
def print_hoge(message):
    @training.make_extension(trigger=(1, 'epoch'))
    def _print_hoge(trainer):
        print(message)
    return _print_hoge
呼び出し側 (main.py)
from lib.print_hoge import *
(略)
trainer = training.Trainer(updater, (30, 'epoch'), out=args.output)
trainer.extend(print_hoge("fuga"))

後はextension名と中身の処理を好きなように書き換えればokです。


以下はもう少し詳しく知りたい方向け

書き方のポイントは、関数print_hoge()が関数「_print_hoge」自体を返り値に返却していることです。 実際、「関数を返却する関数」を定義しなくても

print_hoge.py
@training.make_extension(trigger=(1, 'epoch'))
def print_hoge(trainer):
    print("hoge")

と定義して、呼び出し側(main.py)で

from lib.print_hoge import *
(略)
trainer = training.Trainer(updater, (30, 'epoch'), out=args.output)
trainer.extend(print_hoge)

trainer.extendに関数を直接渡せば最初の例と同じように動作します。 (trainer.extend(print_hoge())ではなくtrainer.extend(print_hoge)となっていることに注意してください)

上記のようにも書けるのですが、chainerの公式コードでは一度「関数を返す関数(ないしクラス)」を使う実装がなされています。 これは外側の関数(print_hoge)を関数と内側の関数(_print_hoge)の間で変数を定義することで_print_hogeで参照する変数のクロージャを作ることができ、色々便利だからかと思います。

また、

@training.make_extension(trigger=(1, 'epoch'))
def print_hoge(trainer):
    print("hoge")
@training.make_extension(trigger=(1, 'epoch'))

という処理も一見何か難しいことをしているように見えますが、 make_extension関数内で行われていることは下記とほぼ同義です。

def sample_recog(trainer):
    print("hoge")

sample_recog.trigger=(1, 'iteration')

つまり、極端な話

def sample_recog():
    def _sample_recog(trainer):
        print("hoge")
    _sample_recog.trigger=(1, 'iteration')
    return _sample_recog

このように定義してもmake_extensionを使ったときと同じように動きます。

chainerのupdaterを自作して複雑なネットワークを訓練する

なぜupdaterの自作が必要か

chainerで様々なニューラルネットを試していると、どこかで複数のモデルが組み合つもの、 複数の出力を持つものなど、込み入ったネットワークを訓練したいことがあると思います。

よくあるmnistのサンプルなどでは

optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
(略)
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

という感じで単一のモデルを対象にして訓練ループを進めているのですが、このままでは 複数モデルや複数出力が絡むネットワークの訓練は扱えません。

自分でゴリゴリと更新やらトレーニングやら評価のプロセスを全て自分で書くことも可能は可能ですが、 せっかくならtrainerやoptimizerなどchainerの作った枠組みをどうせなら活かしたいです そこで必要になるのが、それぞれのネットワークの訓練に適した形のupdaterの自作です。

updaterの役割

updaterはあるネットワークと訓練データが与えられた時に、

  • 訓練データから切り出したミニバッチからの損失の計算方法
  • 計算された損失からどのようにネットワークの重みを更新するのか

を定義するモジュールです。

updaterはtrainerに対して一つだけ定義され、一つのupdaterは

  • 最適化対象となるモデルを含むoptimizer(1つ以上、複数の場合もあり)
  • datasetを含み、epochやバッチサイズを定義したiterator

を受け取ります。(下図)

f:id:mizti:20170923205123p:plain]

updaterの実装方法

updaterの作成方法は色々あるのかと思いますが、以下では最も基本的な StandardUpdaterを継承/オーバーライドする形でのupdater定義方法を記載します。

StandardUpdaterをオーバーライドして使う場合に、最低限変更が必要になるのが

  • __init__
  • update_core

です。

initの実装

__init__は言うまでもなくコンストラクタで、受け取る引数を定義します。 必ず受け取らないとエラーになるものはありませんが、その後のupdate_coreで重み更新を行うために必要な 以下のような引数は設定したほうがよいと思います。

  • 訓練データを含むiterator
  • 損失関数の計算に必要なモデル(群)
  • 重み変更対象になるモデルをsetupしたoptimizer(群)
  • iteratorをミニバッチに変換するためのconverter
  • 各種処理を行うためのデバイス指定(cpuなら-1、gpuなら0以上の整数)

また、__init__ 内で

  • self._iterator (iteratorをまとめたした辞書)
  • self._optimizers (optimizerをまとめた辞書)
  • self.iteration=0

の3点はtrainerやupdater_core外のupdater処理で参照されているのでこの名前で 宣言しておくのが良いと思います(今回の手法では)。

update_coreの実装

update_coreでは、updaterの中心的な役割となる

  1. iteratorから入力データ/ラベルの取り出し
  2. 入力データ/ラベルからの損失計算
  3. 損失からネットワークの更新

を定義します。

  1. iteratorから入力データ/ラベルの取り出し: 後で損失計算できればなんでもよいですが、 chainerのDatasetをセットしたiteratorを渡している場合、next()メソッドを使うと [ (入力データ, 教師ラベル), (入力データ, 教師ラベル).. ] というlistが取れるため、convert.concat_examplesを使うと ( [batch_size分の入力データ] , [batch_size分の教師ラベル] ) というtupleに簡単に変換できます。

  2. 入力データ/ラベルからの損失計算: これは目的に応じて。基本的に入力データと教師ラベルをモデルのcallや然るべきメソッドに渡せば 算出されるように作っていると思います。

  3. 損失からネットワークの更新: Chainerでネットワークの重み更新を行う場合、基本的には ① 各更新対象モデルのcleargrad()による勾配の初期化 ② lossからのbackward() による誤差逆伝搬 ③ optimizer.update()による各モデルの勾配更新 の流れとなります。 (このあたりはChainerのチュートリアルやプレイグラウンド(ここここ)あたりに目を通すと良いと思います)

以下、update_coreを実装する上で幾つか気をつけるべき点です

  • モデルの持つ重みを個別にcleargrad()しても良いですが、Chainクラスの場合cleargrads()メソッドにより重みを一度に初期化できます
  • .barckward()は次元を持たない(つまりスカラーな)Variableに対してしか実行できません
  • backward()は指定したlossの計算過程をChainインスタンスを跨って伝搬します
  • optimizer.update()はoptimizerに設定したモデルのみが重み更新対象になります。(逆に重み更新したくないモデルはupdate()しなければ良いです)

実装例

以下、参考になるかわかりませんが私の作ったUpdater例です。 このUpdaterは下図のような畳み込み層に対して複数のLinear層がそれぞれ個別に 入力を受け取り、別々に損失を計算するようなネットワークを訓練するために作ったものです。

f:id:mizti:20170917000954p:plain

import six
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda, training, reporter
from chainer.datasets import get_mnist
from chainer.training import trainer, extensions
from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
from chainer.datasets import get_mnist
from chainer import optimizer as optimizer_module

class MyUpdater(training.StandardUpdater):
    def __init__(self, iterator, base_cnn, classifiers, base_cnn_optimizer, cl_optimizers, converter=convert.concat_examples, device=None):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main':iterator}
        self._iterators = iterator
        self.base_cnn = base_cnn
        self.classifiers = classifiers

        self._optimizers = {}
        self._optimizers['base_cnn_opt'] = base_cnn_optimizer
        for i in range(0, len(cl_optimizers)):
            self._optimizers[str(i)] = cl_optimizers[i]

        self.converter = convert.concat_examples
        self.device = device
        self.iteration = 0

    def update_core(self):
        iterator = self._iterators['main'].next()
        in_arrays = self.converter(iterator, self.device)

        xp = np if int(self.device) == -1 else cuda.cupy
        x_batch = xp.array(in_arrays[0])
        t_batch = xp.array(in_arrays[1])
        y = self.base_cnn(x_batch)

        loss_dic = {}
        for i, classifier in enumerate(self.classifiers):
            loss = classifier(y, t_batch[:,i])
            loss_dic[str(i)] = loss

        for name, optimizer in six.iteritems(self._optimizers):
            optimizer.target.cleargrads()

        for name, loss in six.iteritems(loss_dic):
            loss.backward()

        for name, optimizer in six.iteritems(self._optimizers):
            optimizer.update()

少しわかりづらいですが、"base_cnn"が左側の畳み込み層のモデルを表し、 classifiersが右側のLinearモデルのリストになっています。(optimizerも同様)

そしてそれぞれの出力から誤差を逆伝搬し、畳み込み層は全ての出力からの誤差を蓄積した上で updateされるようになっています。

スレッドとキューとキューランナーとコーディネータの関係

TensorflowのDeveloper’s Guideのスレッドとキュー解説をかいつまんでまとめてみる。

Threading and Queues  |  TensorFlow

  • キュー

文字通り、データのキュー。FIFOQueueなど。 q = tf.FIFOQueue(3, "float") として定義した後、 x = q.dequeue()でデータのポップ、 q_inc = q.enqueue([y]) でデータのプッシュができる

  • キューランナー

キューに対して、繰り返し同じ操作でenqueueオペレーションを実行するスレッドを指定数だけ作成する

queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
sess = tf.Session()
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)

上記のコードで、queueに対して、exampleという操作(op)でエンキューを行うスレッドを4つ作成していることになる。 注意点は、あくまでエンキューを行うスレッドであるということ。 また、キューランナーでスレッドを作成するときには、そのスレッドたちの「監督役」になるCoordinaorを指定する必要がある。

  • コーディネータ

前述のキューランナーでスレッドを作成する際に、スレッドの監督役になるオブジェクト。 coord = tf.train.Coordinator() で作成したあとキューランナーに渡すと、 coord.should_stop()でスレッドを殺すべきタイミングを教えてくれたり、 coord.request_stop()でスレッドの停止を命令したり coord.join(threads)で指定したスレッド(群)が停止するまで待機したりできる。