加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test.py 1.52 KB
一键复制 编辑 原始数据 按行查看 历史
Shivelino 提交于 2023-12-26 16:48 . chore: 少量修改
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file test.py
@brief
@details
@author Shivelino
@date 2023-12-23 19:10
@version 0.0.1
@par Copyright(c):
@par todo:
@par history:
"""
import torch
import argparse
import cv2
import matplotlib.pyplot as plt
import numpy as np
from nets import get_model
from utils import get_device, get_dataloader_mnist
def test(opt):
# get dataloader
_, testloader = get_dataloader_mnist(opt.data_dir, opt.batch_size)
# init model
device = get_device()
print(f"Current Model: {opt.model}")
model = get_model(opt.model).to(device)
model.load_state_dict(torch.load(f'model/model_{opt.model}.pth'))
# 将模型设置为评估模式
model.eval()
correct = 0
total = 0
with torch.no_grad():
for i, data in enumerate(testloader):
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f'Accuracy on the test set: {accuracy * 100:.2f}%')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="lenet", help='model')
parser.add_argument('--data_dir', type=str, default="data", help='data directory')
parser.add_argument('--batch_size', type=int, default=128, help='size of the batches')
test(parser.parse_args())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化