次の方法で共有


視覚化に Graphviz を使用する

CNTKは、オープンソースのグラフ視覚化ソフトウェアである Graphviz を使用して、モデルの基になる計算グラフを簡単に視覚化する方法を提供します。

ユース ケースを説明するために、まず、CNTK Layers ライブラリを使用して単純な畳み込みネットワークを構築してみましょう。

import cntk as C

def create_model(x, num_classes):
  with C.layers.default_options(init=C.glorot_uniform(), activation=C.relu):
    model = C.layers.Sequential([
      C.layers.For(range(3), lambda i: [
        C.layers.Convolution((5,5), [32,32,64][i], pad=True, name=['conv1', 'conv2', 'conv3'][i]),
        C.layers.MaxPooling((3,3), strides=(2,2), name=['pool1', 'pool2', 'pool3'][i])
        ]),
      C.layers.Dense(64, name='fc1'),
      C.layers.Dense(num_classes, activation=None, name='classify')
    ])
  return model(x)

10 クラスの 32 x 32 イメージで構成される CIFAR-10 データセットに関するトレーニングを行うと仮定すると、それに対応して入力図形を割り当てることができます。 CNTKで使用するために CIFAR-10 データセットをダウンロードして準備する手順については、CNTK 201A チュートリアルを参照してください。

input_var = C.input_variable((3,32,32))
z = create_model(input_var)

計算グラフの基になる説明を取得するには、CNTKモジュールに関数をplotcntk.logging.graph提供します。 この関数は plot(root, filename=None) 、指定されたノードから始まるグラフのネットワーク記述を root 返します。 さらに、指定した場合 filename 、メソッドは DOT、PNG、PDF、または SVG ファイル (ファイル名のサフィックスに対応) を出力します。

DOT 出力を出力するには、(pip install pydot_ng) をインストールpydot-ngする必要があります。 PNG、PDF、SVG 出力が必要な場合は、 Graphviz に加えて pydot-ng. Graphviz をインストールしたら、Graphviz バイナリが PATH 環境変数内にあることを確認します。

import pydot_ng

graph_description = C.logging.graph.plot(z, "graph.png")
print(graph_description)
Convolution(Parameter5302, Parameter5303, Input5301) -> Block6011_Output_0;

MaxPooling(Block6011_Output_0) -> Block6023_Output_0;

Convolution(Parameter5332, Parameter5333, Block6023_Output_0) -> Block6039_Output_0;

MaxPooling(Block6039_Output_0) -> Block6051_Output_0;

Convolution(Parameter5362, Parameter5363, Block6051_Output_0) -> Block6067_Output_0;

MaxPooling(Block6067_Output_0) -> Block6079_Output_0;

Dense(Parameter5710, Parameter5711, Block6079_Output_0) -> Block6095_Output_0;

Dense(Parameter5730, Parameter5731, Block6095_Output_0) -> Block6110_Output_0;

Jupyter Notebookを使用している場合は、グラフの視覚化出力をインラインで表示できます。

from IPython.display import Image

display(Image(filename="graph.png"))

Graphviz output

使用する logging.graph.plot視覚化の詳細な例については、「 デバッグ方法 」のマニュアルを参照してください。