Compartir a través de


Uso de Graphviz para la visualización

CNTK proporciona una manera sencilla de visualizar el gráfico computacional subyacente de un modelo mediante Graphviz, un software de visualización de grafos de código abierto.

Para ilustrar un caso de uso, primero vamos a crear una red convolucional sencilla mediante la biblioteca de capas de CNTK.

import cntk as C

def create_model(x, num_classes):
  with C.layers.default_options(init=C.glorot_uniform(), activation=C.relu):
    model = C.layers.Sequential([
      C.layers.For(range(3), lambda i: [
        C.layers.Convolution((5,5), [32,32,64][i], pad=True, name=['conv1', 'conv2', 'conv3'][i]),
        C.layers.MaxPooling((3,3), strides=(2,2), name=['pool1', 'pool2', 'pool3'][i])
        ]),
      C.layers.Dense(64, name='fc1'),
      C.layers.Dense(num_classes, activation=None, name='classify')
    ])
  return model(x)

Ahora suponiendo que estamos entrenando en el conjunto de datos CIFAR-10, que consta de imágenes de 32 x 32 en 10 clases, podemos asignar la forma de entrada correspondientemente. Consulte el tutorial de CNTK 201A para obtener instrucciones sobre cómo descargar y preparar el conjunto de datos CIFAR-10 para su uso en CNTK.

input_var = C.input_variable((3,32,32))
z = create_model(input_var)

Para obtener la descripción subyacente del gráfico computacional, CNTK proporciona una plot función en el cntk.logging.graph módulo. La plot(root, filename=None) función devuelve una descripción de red del grafo a partir del root nodo proporcionado. Además, si filename se especifica , el método genera un archivo DOT, PNG, PDF o SVG (correspondiente al sufijo de nombre de archivo).

Para generar la salida de DOT, deberá instalar pydot-ng (pip install pydot_ng). Y si desea una salida PNG, PDF o SVG, necesitará Graphviz además pydot-ngde . Una vez que haya instalado Graphviz, asegúrese de que los archivos binarios de Graphviz están en la variable de entorno PATH.

import pydot_ng

graph_description = C.logging.graph.plot(z, "graph.png")
print(graph_description)
Convolution(Parameter5302, Parameter5303, Input5301) -> Block6011_Output_0;

MaxPooling(Block6011_Output_0) -> Block6023_Output_0;

Convolution(Parameter5332, Parameter5333, Block6023_Output_0) -> Block6039_Output_0;

MaxPooling(Block6039_Output_0) -> Block6051_Output_0;

Convolution(Parameter5362, Parameter5363, Block6051_Output_0) -> Block6067_Output_0;

MaxPooling(Block6067_Output_0) -> Block6079_Output_0;

Dense(Parameter5710, Parameter5711, Block6079_Output_0) -> Block6095_Output_0;

Dense(Parameter5730, Parameter5731, Block6095_Output_0) -> Block6110_Output_0;

Si usa un Jupyter Notebook, puede mostrar la salida de visualización de grafos insertada:

from IPython.display import Image

display(Image(filename="graph.png"))

Graphviz output

Para obtener un ejemplo más detallado de visualización mediante logging.graph.plot, consulte el manual de depuración .