代码拉取完成,页面将自动刷新
同步操作将从 myl135/python_restnet50 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from Resnext import resnet50
#from Resnext import eca_resnet50
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
#创建一个数据预处理的字典
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
"""
/home/linux/my/datasets/archive/
flowers
train
val
"""
image_path = os.path.join("/home/linux/my/datasets", "archive")
assert os.path.exists(image_path), '{} path does not exist.'.format(image_path)
#创建数据集 路径root包括train和val,数据预处理格式
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
#获取长度和分类字典 3887 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
train_num = len(train_dataset)
flower_list= train_dataset.class_to_idx
cla_dict = dict( (val,key) for val,key in flower_list.items())
print(cla_dict==flower_list)
#字典写成json文件 indent缩进级别
json_str = json.dumps(cla_dict,indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
#加载数据集
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=1)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path,"val"),
transform=data_transform["val"])
val_num =len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,
shuffle=False,
num_workers=1)
#加载网络
net = resnet50()
#加载预训练模型
model_weight_path ="./resnet50.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path,map_location='cpu'))
#修改fc层的参数 5分类
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel,5)
net.to(device)
#定义损失函数
loss_function = nn.CrossEntropyLoss()
# 定义优化器
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
#训练
epochs = 3
best_acc = 0.0
save_path = './resNet50-{}.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
# labels tensor([2, 1, 1, 4, 4, 2, 2, 1, 3, 1, 0, 2, 2, 1, 4, 2])
# 梯度参数清零
optimizer.zero_grad()
# 前向传播 得到图片的预测值
logits = net(images.to(device))
# 得到损失函数
loss = loss_function(logits, labels.to(device))
#反向传播得到的梯度值,和偏差
loss.backward()
#更新参数
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
#验证 不需要计算梯度
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
# 预测准确的值相加,并转化为python的标量值
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
#显示进度
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path.format(epoch+1))
print('Finished Training')
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。