Nota:
El acceso a esta página requiere autorización. Puede intentar iniciar sesión o cambiar directorios.
El acceso a esta página requiere autorización. Puede intentar cambiar los directorios.
En este artículo se sobre cómo guardar y cargar modelos y puntos de comprobación durante el entrenamiento.
Guardar y cargar modelos
Guardar modelo
Para guardar un modelo en el archivo, use la save() función y especifique una ruta de acceso de archivo para el modelo guardado.
import cntk as C
x = C.input_variable(<input shape>)
z = create_model(x) #user-defined
z.save("myModel.model")
Tenga en cuenta que un modelo guardado de esta manera mediante la API de bibliotecade CNTK tendrá el formato model-v2. El formato model-v2 es un formato de serialización de modelos basado en Protobuf, introducido en CNTK v2. (Para obtener más información, consulte formato de modelo CNTK). Aunque se puede usar cualquier extensión de archivo (y es posible que vea el uso de varias extensiones de archivo como .model, .dnn, .cmf en la documentación de CNTK), le recomendamos que se mantenga en la convención de uso .model para los modelos de CNTK.
Modelo de carga
Para cargar un modelo desde el archivo en CNTK:
from cntk.ops.functions import load_model
z = load_model("myModel.model")
Como alternativa, también puede cargar el modelo mediante load():
z = C.Function.load("myModel.model")
Recursos relacionados
Entre los ejemplos siguientes se incluyen algunas tareas realizadas habitualmente que implican guardar y cargar modelos entrenados.
- Evaluación de un modelo entrenado guardado
- Acceso a los parámetros de un modelo entrenado guardado
- Evaluar y escribir capas específicas de un modelo entrenado guardado
Formato ONNX
CNTK también admite el ahorro y la carga de modelos en formato ONNX , lo que permite la interoperabilidad entre otros marcos, como Caffe2, PyTorch y MXNet.
Para guardar un modelo en el formato ONNX, basta con especificar el parámetro format:
z.save("myModel.onnx", format=C.ModelFormat.ONNX)
Y para cargar un modelo desde ONNX:
z = C.Function.load("myModel.onnx", format=C.ModelFormat.ONNX)
Puede encontrar más tutoriales específicos de ONNX aquí.
Puntos de comprobación durante el entrenamiento
Durante el entrenamiento de redes neuronales profundas, la práctica de la creación de puntos de comprobación permite al usuario tomar instantáneas del estado del modelo y ponderaciones en intervalos regulares. Dado que el entrenamiento de modelos de aprendizaje profundo puede llevar mucho tiempo, los puntos de control garantizan un nivel de tolerancia a errores en caso de errores de hardware o software. Además, la creación de puntos de control permite al usuario reanudar el entrenamiento desde el último punto de control guardado y conservar los modelos de mejor rendimiento (es decir, durante el ajuste de hiperparámetros).
En CNTK, una manera de guardar el estado del modelo es mediante el método aformentioned save() . Sin embargo, este método solo guarda el estado del modelo y no otras entidades con estado del script de entrenamiento, como el estado actual del origen de minibatch y el instructor, que también son necesarios para restaurar un estado para reanudar el entrenamiento. Para ello, CNTK proporciona la API de punto de control para las dos maneras de entrenamiento del modelo: 1) la API de bajo nivel Trainer.train_minibatch y 2) la API de alto nivel Function.train (training_session). Para obtener más información sobre los diferentes métodos de entrenamiento, consulte el manual "Entrenamiento del modelo mediante api declarativa e imperativa".
Punto de comprobación de bajo nivel (Trainer.train_minibatch)
En el método Trainer.train_minibatch/test_minibatch de entrenamiento, el usuario tiene control total sobre cada actualización de minibatch y los datos se alimentan al entrenador a través de un bucle explícito, minibatch por minibatch. En este caso, para realizar puntos de control durante el entrenamiento, el usuario tiene que controlar manualmente el estado actual del origen del minibatch y el instructor.
Guardar punto de control
Una vez creado una instancia del objeto Trainer, durante el entrenamiento puede controlar el estado del modelo y del instructor llamando al save_checkpoint método mientras se encuentra dentro del bucle de entrenamiento:
# get the checkpoint state of the minibatch source (mb_source)
mb_source_state = mb_source.get_checkpoint_state()
# pass the minibatch source state to the external_state parameter of save_checkpoint()
checkpoint = "myModel.dnn" #filename to store the checkpoint
trainer.save_checkpoint(checkpoint, mb_source_state)
Restauración desde el punto de control
Para reanudar el entrenamiento desde un archivo de punto de control, restaure el estado de origen y entrenador de minibatch mediante los métodos correspondientes restore_from_checkpoint :
checkpoint = "myModel.dnn"
mb_source_state = trainer.restore_from_checkpoint(checkpoint)
mb_source.restore_from_checkpoint(mb_source_state)
Para obtener un ejemplo completo de puntos de comprobación manuales durante el entrenamiento, consulte aquí.
Puntos de comprobación de alto nivel (Function.train)
En lugar de escribir explícitamente el bucle de entrenamiento, el usuario puede usar los métodos Function.train/test , que se encargan de los distintos aspectos de una sesión de entrenamiento, incluidos los orígenes de datos, los puntos de comprobación y la impresión de progreso.
Para habilitar los puntos de control, el usuario debe proporcionar una devolución de llamada de configuración de punto de comprobación creando una instancia de la CheckpointConfig clase . A continuación, la devolución de llamada se encarga de los puntos de comprobación coherentes con la frecuencia especificada durante el entrenamiento. Para restaurar desde el último punto de control disponible antes del inicio del entrenamiento, el restore parámetro se establece de forma predeterminada en True. Si desea guardar todos los puntos de control del entrenamiento, establezca en preserve_allTrue (el valor predeterminado es False). Los nombres de archivo del punto de control se guardan de la siguiente manera: por ejemplo, si filename="myModel.dnn", los puntos de control serán myModel.dnn.0, myModel.dnn.1myModel.dnn.2, , etc.
Una vez que haya definido el origen de minibatch (), el modelo (mb_source), la función de criterio (zcriterion), el aprendiz y el mapa de entrada (si mb_source es un lector de datos), puede configurar el método de entrenamiento para que realice una devolución de llamada de punto de comprobación:
checkpoint = "myModel.dnn"
checkpoint_frequency = 100
checkpoint_config = C.CheckpointConfig(checkpoint, frequency=checkpoint_frequency, preserve_all=True)
criterion.train(mb_source, model_inputs_to_streams = input_map, parameter_learners=[learner], callbacks=[checkpoint_config])
Tenga en cuenta que Function.train tiene parámetros adicionales que se pueden especificar.
Puede encontrar ejemplos más detallados aquí (en la sección 2) y en el Tutorial de CNTK 200 (en el "Ejemplo de entrenamiento avanzado").
Nota sobre training_session
Function.train es un contenedor cómodo alrededor cntk.train.trainer.Trainer de y cntk.train.training_session.TrainingSession, que encapsula un bucle de entrenamiento típico. A veces puede encontrar ejemplos que crean instancias explícitas de un TrainingSession objeto mediante training_session. En ese caso, el código de entrenamiento anterior Function.train se puede expresar de forma equivalente como tal:
trainer = C.Trainer(z, criterion, [learner])
checkpoint_frequency = 100
checkpoint_config = C.CheckpointConfig(checkpoint, frequency=checkpoint_frequency, preserve_all=True)
C.train.training_session(
trainer=trainer,
mb_source=mb_source,
mb_size=mb_size,
model_inputs_to_streams=input_map,
checkpoint_config=checkpoint_config
).train()
Entrenamiento distribuido
Durante el entrenamiento distribuido en CNTK, si el usuario usa save para guardar el modelo, un error común es que acabará llamando save a todos los procesos, lo que puede provocar condiciones de carrera al escribir el modelo en el archivo. Para reducir la probabilidad de error, considere la posibilidad de establecer puntos de control con los save_checkpoint/restore_from_checkpoint métodos (o CheckpointConfig para Function.train), que protegen contra estos peligros de carrera al guardar explícitamente toda la instantánea de entrenamiento distribuida solo del trabajador de rango 0.
La sección 1.2 ("Bucle manual distribuido") del manual "Entrenar modelo mediante API declarativa e imperativa" tiene un ejemplo completo de puntos de comprobación con save_checkpoint/restore_from_checkpoint durante el entrenamiento distribuido al usar el nivel Trainer.train_minibatch APIbajo.