我转onnx的代码如下:
onnx版本是1.6.0
import torch
import os
import numpy as np
from model.mobilenext import mnext
def main():
initial_checkpoint = 'mobiilenext_se.pth'
network = mnext(num_classes=2)
network.eval()
state = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
network.load_state_dict({k[7:]: v for k, v in state.items()})
dummy_input1 = torch.randn(1, 3, 48, 48)
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(network, dummy_input1, "mobilenext_se.onnx", verbose=True, input_names=input_names,
output_names=output_names)
if name == 'main':
main()
我转onnx的代码如下:
onnx版本是1.6.0
import torch
import os
import numpy as np
from model.mobilenext import mnext
def main():
initial_checkpoint = 'mobiilenext_se.pth'
network = mnext(num_classes=2)
network.eval()
state = torch.load(initial_checkpoint, map_location=lambda storage, loc: storage)
network.load_state_dict({k[7:]: v for k, v in state.items()})
dummy_input1 = torch.randn(1, 3, 48, 48)
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(network, dummy_input1, "mobilenext_se.onnx", verbose=True, input_names=input_names,
output_names=output_names)
if name == 'main':
main()