Share via


Serialization

This article goes over saving and loading models and checkpointing during training.

Saving and loading models

Save model

To save a model to file, use the save() function and specify a filepath for the saved model.

import cntk as C

x = C.input_variable(<input shape>)
z = create_model(x) #user-defined 
z.save("myModel.model")

Note that a model saved in this way using theCNTK Library API will have the model-v2 format. The model-v2 format is a Protobuf-based model serialization format, introduced in CNTK v2. (For more information, refer to CNTK model format.) While any file extension can be used (and you may see the use of various file extensions such as .model, .dnn, .cmf across CNTK documentation), we recommend that you stick to the convention of using .model for your CNTK models.

Load model

To load a model from file into CNTK:

from cntk.ops.functions import load_model

z = load_model("myModel.model")

Alternatively, you can also load your model using load():

z = C.Function.load("myModel.model")

The following examples include some commonly performed tasks involving the saving and loading of trained models.

ONNX format

CNTK also supports the saving and loading of models in the ONNX format, which allows for interoperability among other frameworks, including Caffe2, PyTorch and MXNet.

To save a model to the ONNX format, simply specify the format parameter:

z.save("myModel.onnx", format=C.ModelFormat.ONNX)

And to load a model from ONNX:

z = C.Function.load("myModel.onnx", format=C.ModelFormat.ONNX)

You can find more ONNX-specific tutorials here.

Checkpointing during training

During the training of deep neural networks, the practice of checkpointing allows the user to take snapshots of the model state and weights across regular intervals. Since the training of deep learning models can be extremely time-consuming, checkpointing ensures a level of fault tolerance in the event of hardware or software failures. Moreover, checkpointing enables the user to resume training from the last saved checkpoint, and to retain the best-performing models (i.e. during hyperparameter tuning).

In CNTK, one way to save the model state is by using the aformentioned save() method. However, this method only saves the model state, and not other stateful entities of the training script like the current state of the minibatch source and the trainer, which are also needed in order to restore a state to resume training. To this end, CNTK provides checkpointing API for the two ways of model training: 1) the low-level Trainer.train_minibatch API and 2) the high-level Function.train (training_session) API. For more information on the different training methods, see the manual 'Train model using declarative and imperative API.'

Low-level checkpointing (Trainer.train_minibatch)

In the Trainer.train_minibatch/test_minibatch method of training, the user has full control over each minibatch update and data is fed to the trainer through an explicit loop, minibatch by minibatch. In this case, to perform checkpointing during training, the user has to manually checkpoint the current state of both the minibatch source and the trainer.

Save checkpoint

Once your Trainer object is instantiated, during training you can checkpoint the model and trainer state by calling the save_checkpoint method while inside your training loop:

# get the checkpoint state of the minibatch source (mb_source)
mb_source_state = mb_source.get_checkpoint_state()

# pass the minibatch source state to the external_state parameter of save_checkpoint()
checkpoint = "myModel.dnn" #filename to store the checkpoint
trainer.save_checkpoint(checkpoint, mb_source_state)

Restore from checkpoint

To resume training from a checkpoint file, restore the minibatch source and trainer state using the corresponding restore_from_checkpoint methods:

checkpoint = "myModel.dnn"
mb_source_state = trainer.restore_from_checkpoint(checkpoint)
mb_source.restore_from_checkpoint(mb_source_state)

For a complete example of manual checkpointing during training, refer here.

High-level checkpointing (Function.train)

Instead of explicitly writing the training loop, the user can use the Function.train/test methods, which take care of the different aspects of a training session, including data sources, checkpointing, and progress printing.

In order to enable checkpointing, the user must provide a checkpoint configuration callback by instantiating the CheckpointConfig class. The callback then takes care of consistent checkpointing with the specified frequency during training. To restore from the last available checkpoint before the start of training, the restore parameter is default set to True. If you would like to save all the checkpoints from training, set preserve_all to True (default is False). The checkpoint filenames are then saved like so: e.g. if filename="myModel.dnn", the checkpoints will be myModel.dnn.0, myModel.dnn.1, myModel.dnn.2, and so on.

Once you've defined your minibatch source (mb_source), model (z), criterion function (criterion), learner and input map (if mb_source is a data reader), you can configure the train method to take in a checkpoint callback:

checkpoint = "myModel.dnn"
checkpoint_frequency = 100
checkpoint_config = C.CheckpointConfig(checkpoint, frequency=checkpoint_frequency, preserve_all=True)
criterion.train(mb_source, model_inputs_to_streams = input_map, parameter_learners=[learner], callbacks=[checkpoint_config]) 

Note that Function.train has additional parameters that can be specified.

More detailed examples can be found here (in section 2) and in the CNTK 200 Tutorial (under the "Advanced Training Example").

Note on training_session

Function.train is a convenience wrapper around cntk.train.trainer.Trainer and cntk.train.training_session.TrainingSession, which encapsulates a typical training loop. You may sometimes come across examples that explicitly instantiate a TrainingSession object using training_session. In that case, the above Function.train training code can equivalently be expressed as such:

trainer = C.Trainer(z, criterion, [learner])
checkpoint_frequency = 100
checkpoint_config = C.CheckpointConfig(checkpoint, frequency=checkpoint_frequency, preserve_all=True)

C.train.training_session(
  trainer=trainer,
  mb_source=mb_source,
  mb_size=mb_size,
  model_inputs_to_streams=input_map,
  checkpoint_config=checkpoint_config
).train()

Distributed training

During distributed training in CNTK, if the user is using save to save the model, a common mistake is that he or she will end up calling save for all the processes, which can lead to race conditions when writing the model to file. To lower the likelihood of error, consider instead checkpointing with the save_checkpoint/restore_from_checkpoint (or CheckpointConfig for Function.train) methods, which guard against such race hazards by explicitly saving the entire distributed training snapshot from only the rank-0 worker.

Section 1.2 ("Distributed manual loop") of the "Train model using declarative and imperative API" manual has a full example of checkpointing with save_checkpoint/restore_from_checkpoint during distributed training when using the low-level Trainer.train_minibatch API.