Skip to content

Instantly share code, notes, and snippets.

@jakepoz
Created June 26, 2020 23:47
Show Gist options
  • Select an option

  • Save jakepoz/eb36163814a8f1b6ceb31e8addbba270 to your computer and use it in GitHub Desktop.

Select an option

Save jakepoz/eb36163814a8f1b6ceb31e8addbba270 to your computer and use it in GitHub Desktop.
Exports a YoloV5 model as torchscript
"""Exports a pytorch *.pt model to *.onnx format
Usage:
$ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
"""
import argparse
import onnx
from models.common import *
from utils import google_utils
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
opt = parser.parse_args()
print(opt)
# Parameters
f = opt.weights.replace('.pt', '.torchscript') # onnx filename
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size, (1, 3, 320, 192) iDetection
# Load pytorch model
google_utils.attempt_download(opt.weights)
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
model.eval()
# Don't fuse layers
#model.fuse()
# Export to onnx
model.model[-1].export = True # set Detect() layer export=True
_ = model(img) # dry run
traced_script_module = torch.jit.trace(model, img)
traced_script_module.save(f)
@govindamagrawal
Copy link

govindamagrawal commented Jul 27, 2020

Hi jakepoz,
Thanks for the conversion script from torch to torchscript. I could not infer the output of the torchscript file. Here is the code to reproduce the result.

#Using torch
import torch
import torchvision
weights = './yolov5s.pt'
img = torch.zeros((1, 3, 224, 224)) # image size(1,3,320,192)
model = torch.load(weights, map_location=torch.device('cpu'))['model'].float()
y = model(img)

#Using torchscript
import torch
import torchvision
img = torch.zeros((1, 3, 224, 224)) # image size(1,3,320,192)
script_model = torch.jit.load('yolov5s.torchscript.pt')
out = script_model(img)

The shape of y[0] is torch.Size([1, 3087, 85]), whereas the shape of out[0] is torch.Size([1, 3, 7, 7, 85]), which I am not able to understand. Please help in this regard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment