Where can and should I save a trained model?

Kirill Meisser 1 Reputation point


I am new to Azure and am currently having a problem with saving a CNN model that is created by my training script. I am currently trying to save it in my project folder under the 'outputs' directory yet the model refuses to be saved. The size is over 300 MB so that may be a further problem. I have looked around for a solution but have not seen much that could help me. I am using PyTorch and am trying to save my model using the torch.save(...) function into a .tar file. Now, my question is how and where should one save the models generated by a training script and how would one ultimately load that same model to then use it for inference? Any help is greatly appreciated :)

Azure Machine Learning
Azure Machine Learning
An Azure machine learning service for building and deploying models.
1,649 questions
No comments
{count} votes

1 answer

Sort by: Most helpful
  1. GiftA-MSFT 11,096 Reputation points

    Hi, here's an example training script showing how to save trained pytorch model. You can then register, deploy, and consume the model. For more details on training pytorch models, please visit this document. Also, feel free to review where to save and write files for experiments. If you're still unable to save the trained model, please share the error message or steps to reproduce the issue. Hope this helps!

    def main():  
        print("Torch version:", torch.__version__)  
        \# get command-line arguments  
        parser = argparse.ArgumentParser()  
        parser.add_argument('--num_epochs', type=int, default=25,  
                            help='number of epochs to train')  
        parser.add_argument('--output_dir', type=str, help='output directory')  
        parser.add_argument('--learning_rate', type=float,  
                            default=0.001, help='learning rate')  
        parser.add_argument('--momentum', type=float, default=0.9, help='momentum')  
        args = parser.parse_args()  
        data_dir = download_data()  
        print("data directory is: " + data_dir)  
        model = fine_tune_model(args.num_epochs, data_dir,  
                                args.learning_rate, args.momentum)  
        os.makedirs(args.output_dir, exist_ok=True)  
        torch.save(model, os.path.join(args.output_dir, 'model.pt'))  

    --- *Kindly Accept Answer if the information helps. Thanks.*

    No comments