Преобразование модели обучения PyTorch в формат ONNX

На предыдущем этапе работы с этим учебником мы использовали PyTorch для создания модели машинного обучения. Однако эта модель является файлом .pth. Чтобы иметь возможность интегрировать этот файл с приложением Windows ML, вам понадобится преобразовать модель в формат ONNX.

Экспорт модели.

Для экспорта модели нужно использовать функцию torch.onnx.export(). Эта функция выполняет модель и записывает трассировку того, какие операторы используются для расчета выходных данных.

  1. Скопируйте следующий код в файл PyTorchTraining.py в Visual Studio и вставьте его над функцией main.
import torch.onnx 

#Function to Convert to ONNX 
def Convert_ONNX(): 

    # set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor  
    dummy_input = torch.randn(1, input_size, requires_grad=True)  

    # Export the model   
    torch.onnx.export(model,         # model being run 
         dummy_input,       # model input (or a tuple for multiple inputs) 
         "ImageClassifier.onnx",       # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,  # whether to execute constant folding for optimization 
         input_names = ['modelInput'],   # the model's input names 
         output_names = ['modelOutput'], # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

Прежде чем экспортировать модель, нужно вызвать model.eval() или model.train(False), поскольку эти переключатели позволяют задать для модели режим вывода. Такое действие необходимо, поскольку операторы dropout или batchnorm работают по-разному в режиме вывода и обучения.

  1. Чтобы выполнить преобразование в ONNX, добавьте вызов функции преобразования в функцию main. Заново обучать модель не нужно, поэтому мы закомментируем некоторые функции, которые нам больше не понадобится выполнять. Функция main будет выглядеть следующим образом.
if __name__ == "__main__": 

    # Let's build our model 
    #train(5) 
    #print('Finished Training') 

    # Test which classes performed well 
    #testAccuracy() 

    # Let's load the model we just created and test the accuracy per label 
    model = Network() 
    path = "myFirstModel.pth" 
    model.load_state_dict(torch.load(path)) 

    # Test with batch of images 
    #testBatch() 
    # Test how the classes performed 
    #testClassess() 
 
    # Conversion to ONNX 
    Convert_ONNX() 
  1. Запустите проект еще раз, нажав кнопку Start Debugging на панели инструментов или клавишу F5. Обучать модель снова не понадобится. Все, что нужно сделать — просто загрузить существующую модель из папки проекта.

Выходные данные должны выглядеть следующим образом.

ONNX conversion process

Перейдите к расположению проекта и найдите модель ONNX рядом с моделью .pth.

Примечание.

Хотите узнать больше? Ознакомьтесь с руководством PyTorch по экспорту модели.

Обзор модели.

  1. Откройте файл модели ImageClassifier.onnx с помощью Neutron.

  2. Выберите узел data, чтобы открыть свойства модели.

ONNX model properties

Как мы видим, в качестве входных данных для модели нужно использовать 32-разрядный свободно перемещаемый объект (многомерный массив) тензор, а в качестве выходных данных возвращается число с плавающей точкой тензора. Массив выходных данных будет содержать вероятность для каждой метки. При построении модели метки обозначаются 10 числами, каждое из которых представляет десять классов объектов.

Метка 0 Метка 1 Метка 2 Метка 3 Метка 4 Метка 5 Метка 6 Метка 7 Метка 8 Метка 9
0 1 2 3 4 5 6 7 8 9
управления автомобиль птица cat deer собака frog лошадь ship truck

Вам нужно будет извлечь эти значения, чтобы отобразить правильный прогноз в приложении Windows ML.

Дальнейшие действия

Наша модель готова к использованию. Затем, что касается главного события, давайте создадим приложение Windows и запустим его локально на устройстве Windows.