Where can and should I save a trained model?

Kirill Meisser 1 Reputation point
2021-11-15T15:13:07.06+00:00

Hi,

I am new to Azure and am currently having a problem with saving a CNN model that is created by my training script. I am currently trying to save it in my project folder under the 'outputs' directory yet the model refuses to be saved. The size is over 300 MB so that may be a further problem. I have looked around for a solution but have not seen much that could help me. I am using PyTorch and am trying to save my model using the torch.save(...) function into a .tar file. Now, my question is how and where should one save the models generated by a training script and how would one ultimately load that same model to then use it for inference? Any help is greatly appreciated :)

Azure Machine Learning
Azure Machine Learning
An Azure machine learning service for building and deploying models.
3,128 questions
0 comments No comments
{count} votes

1 answer

Sort by: Most helpful
  1. GiftA-MSFT 11,171 Reputation points
    2021-11-15T19:05:25.077+00:00

    Hi, here's an example training script showing how to save trained pytorch model. You can then register, deploy, and consume the model. For more details on training pytorch models, please visit this document. Also, feel free to review where to save and write files for experiments. If you're still unable to save the trained model, please share the error message or steps to reproduce the issue. Hope this helps!

    def main():  
        print("Torch version:", torch.__version__)  
    
        \# get command-line arguments  
        parser = argparse.ArgumentParser()  
        parser.add_argument('--num_epochs', type=int, default=25,  
                            help='number of epochs to train')  
        parser.add_argument('--output_dir', type=str, help='output directory')  
        parser.add_argument('--learning_rate', type=float,  
                            default=0.001, help='learning rate')  
        parser.add_argument('--momentum', type=float, default=0.9, help='momentum')  
        args = parser.parse_args()  
    
        data_dir = download_data()  
        print("data directory is: " + data_dir)  
        model = fine_tune_model(args.num_epochs, data_dir,  
                                args.learning_rate, args.momentum)  
        os.makedirs(args.output_dir, exist_ok=True)  
        torch.save(model, os.path.join(args.output_dir, 'model.pt'))  
    

    --- *Kindly Accept Answer if the information helps. Thanks.*

    0 comments No comments

Your answer

Answers can be marked as Accepted Answers by the question author, which helps users to know the answer solved the author's problem.