代码拉取完成,页面将自动刷新
import torch
from torchvision import datasets, transforms
from model import MyLeNet5
import numpy as np
import tvm
from tvm import relay
from tvm.contrib import graph_executor
from torch.autograd import Variable
# Compose():将多个transforms的操作整合在一起
data_transform = transforms.Compose([
# ToTensor():数据转化为Tensor格式
transforms.ToTensor()
])
# 加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transform, download=False)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
val_data_iter = iter(test_dataloader)
val_image, val_label = val_data_iter._next_data()
input_data=torch.rand(1, 1, 28, 28) ##
# int(val_label[1])
input_name = "input0"
shape_list = [(input_name, val_image.shape)]
# 模型实例化,将模型转到device
model = MyLeNet5()
# 加载train.py里训练好的模型
model.load_state_dict(torch.load("./save_model/my_model_dict.pth", weights_only=True))
scripted_model=torch.jit.trace(model, input_data).eval()
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
target = "llvm"
target_host = "llvm"
ctx = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, target_host=target_host, params=params)
'''
######################################################
with open("./save_model/my_model.json","w") as f:
f.write(str(lib.graph_json))
lib.export_library("./save_model/my_model.o")
with open("./save_model/my_model.params", "wb") as f:
f.write(relay.save_param_dict(params))
######################################################
with open("./save_model/my_model.json", "r") as fi:
loaded_json = fi.read()
loaded_lib = tvm.runtime.load_module("./save_model/my_model.o")
loaded_params = bytearray(open("./save_model/my_model.params", "rb").read())
ctx = tvm.cpu()
module = graph_executor.create(loaded_json, loaded_lib, ctx)
module.load_params(loaded_params)
'''
module = graph_executor.GraphModule(lib["default"](ctx))
import time
start_time = time.time() # 记录开始时间
for i in range(10):
x, y = test_dataset[i][0], test_dataset[i][1]
x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False)
module.set_input(input_name, x)
module.run()
tvm_output = module.get_output(0)
top1_tvm = np.argmax(tvm_output.asnumpy()[0])
# print("predict=", top1_tvm, "ref=", y)
end_time = time.time() # 记录结束时间
elapsed_time = end_time - start_time # 计算运行时间
print("程序运行时间为:", elapsed_time, "秒")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。