Dela via


Konvertera din PyTorch-träningsmodell till ONNX

Anmärkning

För större funktioner kan PyTorch också användas med DirectML i Windows.

I föregående steg i den här självstudien använde vi PyTorch för att skapa vår maskininlärningsmodell. Den modellen är dock en .pth fil. För att kunna integrera den med Windows ML-appen måste du konvertera modellen till ONNX-format.

Exportera modellen

Om du vill exportera en modell använder torch.onnx.export() du funktionen. Den här funktionen kör modellen och registrerar en spårning av vilka operatorer som används för att beräkna utdata.

  1. Kopiera följande kod till PyTorchTraining.py filen i Visual Studio ovanför huvudfunktionen.
import torch.onnx 

#Function to Convert to ONNX 
def Convert_ONNX(): 

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn(1, input_size, requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "ImageClassifier.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

Det är viktigt att anropa model.eval() eller model.train(False) innan du exporterar modellen, eftersom detta ställer in modellen till slutsatsdragningsläge. Detta behövs eftersom operatorer som dropout eller batchnorm beter sig annorlunda i inferens- och träningsläge.

  1. Om du vill köra konverteringen till ONNX lägger du till ett anrop till konverteringsfunktionen i huvudfunktionen. Du behöver inte träna modellen igen, så vi kommenterar ut några funktioner som vi inte längre behöver köra. Huvudfunktionen är följande.
if __name__ == "__main__": 

    # Let's build our model 
    #train(5) 
    #print('Finished Training') 

    # Test which classes performed well 
    #testAccuracy() 

    # Let's load the model we just created and test the accuracy per label 
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 

    # Test with batch of images 
    #testBatch() 
    # Test how the classes performed 
    #testClassess() 
 
    # Conversion to ONNX 
    Convert_ONNX() 
  1. Kör projektet igen genom att välja knappen Start Debugging i verktygsfältet eller trycka på F5. Du behöver inte träna modellen igen. Läs bara in den befintliga modellen från projektmappen.

Utdata kommer att vara följande.

ONNX-konverteringsprocess

Gå till projektplatsen och leta reda på ONNX-modellen bredvid .pth modellen.

Anmärkning

Vill du veta mer? Granska PyTorch-självstudien om hur du exporterar en modell.

Utforska din modell.

  1. ImageClassifier.onnx Öppna modellfilen med Netron.

  2. Välj datanoden för att öppna modellegenskaperna.

Egenskaper för ONNX-modell

Som du ser kräver modellen ett 32-bitars tensor-objekt (flerdimensionell matris) som indata och returnerar en Tensor-flyttal som utdata. Utdatamatrisen inkluderar sannolikheten för varje etikett. Som du skapade modellen representeras etiketterna av 10 tal och varje tal representerar de tio objektklasserna.

Etikett 0 Etikett 1 Etikett 2 Etikett 3 Etikett 4 Etikett 5 Etikett 6 Etikett 7 Etikett 8 Etikett 9
0 1 2 3 4 5 6 7 8 9
flygplan bil fågel katt hjort hund groda häst skepp lastbil

Du måste extrahera dessa värden för att visa rätt förutsägelse med Windows ML-appen.

Nästa steg

Vår modell är redo att distribueras. För huvudhändelsen ska vi sedan skapa ett Windows-program och köra det lokalt på din Windows-enhet.