代码拉取完成,页面将自动刷新
同步操作将从 小乐major/VoiceprintRecognition-Pytorch 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import re
import shutil
import time
from datetime import datetime, timedelta
import argparse
import functools
import numpy as np
import torch
from torch.nn import DataParallel
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchsummary import summary
from utils.reader import CustomDataset
from utils.metrics import ArcNet
from utils.resnet import resnet34
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('gpus', str, '0', '训练使用的GPU序号,使用英文逗号,隔开,如:0,1')
add_arg('batch_size', int, 32, '训练的批量大小')
add_arg('num_workers', int, 4, '读取数据的线程数量')
add_arg('num_epoch', int, 50, '训练的轮数')
add_arg('num_classes', int, 3242, '分类的类别数量')
add_arg('learning_rate', float, 1e-3, '初始学习率的大小')
add_arg('weight_decay', float, 5e-4, 'weight_decay的大小')
add_arg('lr_step', int, 10, '学习率衰减步数')
add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状')
add_arg('train_list_path', str, 'dataset/train_list.txt', '训练数据的数据列表路径')
add_arg('test_list_path', str, 'dataset/test_list.txt', '测试数据的数据列表路径')
add_arg('save_model', str, 'models/', '模型保存的路径')
add_arg('resume', str, None, '恢复训练,当为None则不使用恢复模型')
add_arg('pretrained_model', str, None, '预训练模型的路径,当为None则不使用预训练模型')
args = parser.parse_args()
# 评估模型
@torch.no_grad()
def test(model, metric_fc, test_loader, device):
accuracies = []
for batch_id, (spec_mag, label) in enumerate(test_loader):
spec_mag = spec_mag.to(device)
label = label.to(device).long()
feature = model(spec_mag)
output = metric_fc(feature, label)
output = output.data.cpu().numpy()
output = np.argmax(output, axis=1)
label = label.data.cpu().numpy()
acc = np.mean((output == label).astype(int))
accuracies.append(acc.item())
return float(sum(accuracies) / len(accuracies))
# 保存模型
def save_model(args, model, metric_fc, optimizer, epoch_id):
model_params_path = os.path.join(args.save_model, 'epoch_%d' % epoch_id)
if not os.path.exists(model_params_path):
os.makedirs(model_params_path)
# 保存模型参数和优化方法参数
torch.save(model.state_dict(), os.path.join(model_params_path, 'model_params.pth'))
torch.save(metric_fc.state_dict(), os.path.join(model_params_path, 'metric_fc_params.pth'))
torch.save(optimizer.state_dict(), os.path.join(model_params_path, 'optimizer.pth'))
# 删除旧的模型
old_model_path = os.path.join(args.save_model, 'epoch_%d' % (epoch_id - 3))
if os.path.exists(old_model_path):
shutil.rmtree(old_model_path)
# 保存整个模型和参数
all_model_path = os.path.join(args.save_model, 'resnet34.pth')
if not os.path.exists(os.path.dirname(all_model_path)):
os.makedirs(os.path.dirname(all_model_path))
torch.jit.save(torch.jit.script(model), all_model_path)
def train():
device_ids = [int(i) for i in args.gpus.split(',')]
# 数据输入的形状
input_shape = eval(args.input_shape)
# 获取数据
train_dataset = CustomDataset(args.train_list_path, model='train', spec_len=input_shape[2])
train_loader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size * len(device_ids),
shuffle=True,
num_workers=args.num_workers)
test_dataset = CustomDataset(args.test_list_path, model='test', spec_len=input_shape[2])
test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
device = torch.device("cuda")
# 获取模型
model = resnet34()
metric_fc = ArcNet(512, args.num_classes, scale=64, margin=0.5)
if len(args.gpus.split(',')) > 1:
model = DataParallel(model, device_ids=device_ids, output_device=device_ids[0])
metric_fc = DataParallel(metric_fc, device_ids=device_ids, output_device=device_ids[0])
model.to(device)
metric_fc.to(device)
if len(args.gpus.split(',')) > 1:
summary(model.module, input_shape)
else:
summary(model, input_shape)
# 初始化epoch数
last_epoch = 0
# 获取优化方法
optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
# 获取学习率衰减函数
scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1, verbose=True)
# 获取损失函数
criterion = torch.nn.CrossEntropyLoss()
# 加载模型参数和优化方法参数
if args.resume:
optimizer_state = torch.load(os.path.join(args.resume, 'optimizer.pth'))
optimizer.load_state_dict(optimizer_state)
# 获取预训练的epoch数
last_epoch = optimizer_state['state'][0]['step']
if len(device_ids) > 1:
model.module.load_state_dict(torch.load(os.path.join(args.resume, 'model_params.pth')))
metric_fc.module.load_state_dict(torch.load(os.path.join(args.resume, 'metric_fc_params.pth')))
else:
model.load_state_dict(torch.load(os.path.join(args.resume, 'model_params.pth')))
metric_fc.load_state_dict(torch.load(os.path.join(args.resume, 'metric_fc_params.pth')))
print('成功加载模型参数和优化方法参数')
# 开始训练
sum_batch = len(train_loader) * (args.num_epoch - last_epoch)
for epoch_id in range(last_epoch, args.num_epoch):
for batch_id, data in enumerate(train_loader):
start = time.time()
data_input, label = data
data_input = data_input.to(device)
label = label.to(device).long()
feature = model(data_input)
output = metric_fc(feature, label)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_id % 100 == 0:
output = output.data.cpu().numpy()
output = np.argmax(output, axis=1)
label = label.data.cpu().numpy()
acc = np.mean((output == label).astype(int))
eta_sec = ((time.time() - start) * 1000) * (sum_batch - (epoch_id - last_epoch) * len(train_loader) - batch_id)
eta_str = str(timedelta(seconds=int(eta_sec / 1000)))
print('[%s] Train epoch %d, batch: %d/%d, loss: %f, accuracy: %f, eta: %s' % (
datetime.now(), epoch_id, batch_id, len(train_loader), loss.item(), acc.item(), eta_str))
scheduler.step()
# 开始评估
model.eval()
print('='*70)
accuracy = test(model, metric_fc, test_loader, device)
model.train()
print('[{}] Test epoch {} Accuracy {:.5}'.format(datetime.now(), epoch_id, accuracy))
print('='*70)
# 保存模型
if len(device_ids) > 1:
save_model(args, model.module, metric_fc.module, optimizer, epoch_id)
else:
save_model(args, model, metric_fc, optimizer, epoch_id)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
print_arguments(args)
train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。