加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
ResNet_train.py 5.57 KB
一键复制 编辑 原始数据 按行查看 历史
myl135 提交于 2023-07-10 10:20 . first
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from Resnext import resnet50
#from Resnext import eca_resnet50
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
#创建一个数据预处理的字典
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
"""
/home/linux/my/datasets/archive/
flowers
train
val
"""
image_path = os.path.join("/home/linux/my/datasets", "archive")
assert os.path.exists(image_path), '{} path does not exist.'.format(image_path)
#创建数据集 路径root包括train和val,数据预处理格式
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
#获取长度和分类字典 3887 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
train_num = len(train_dataset)
flower_list= train_dataset.class_to_idx
cla_dict = dict( (val,key) for val,key in flower_list.items())
print(cla_dict==flower_list)
#字典写成json文件 indent缩进级别
json_str = json.dumps(cla_dict,indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
#加载数据集
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=1)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path,"val"),
transform=data_transform["val"])
val_num =len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,
shuffle=False,
num_workers=1)
#加载网络
net = resnet50()
#加载预训练模型
model_weight_path ="./resnet50.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path,map_location='cpu'))
#修改fc层的参数 5分类
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel,5)
net.to(device)
#定义损失函数
loss_function = nn.CrossEntropyLoss()
# 定义优化器
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
#训练
epochs = 3
best_acc = 0.0
save_path = './resNet50-{}.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
# labels tensor([2, 1, 1, 4, 4, 2, 2, 1, 3, 1, 0, 2, 2, 1, 4, 2])
# 梯度参数清零
optimizer.zero_grad()
# 前向传播 得到图片的预测值
logits = net(images.to(device))
# 得到损失函数
loss = loss_function(logits, labels.to(device))
#反向传播得到的梯度值,和偏差
loss.backward()
#更新参数
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
#验证 不需要计算梯度
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
# 预测准确的值相加,并转化为python的标量值
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
#显示进度
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path.format(epoch+1))
print('Finished Training')
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化