加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 6.20 KB
一键复制 编辑 原始数据 按行查看 历史
hanpc1125 提交于 2022-01-26 11:29 . first commit
from __future__ import print_function
from argparse import ArgumentParser
import cv2
import csv
import os.path
import numpy as np
import torch
from torch.optim import Adam, lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from criterion import CrossEntropyLoss2d
from datasets import CD2014
from datasets import levir
import sys
#sys.path.append("./correlation_package/build/lib.linux-x86_64-3.5")
import cscdnet
import utils.transforms as trans
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
def colormap():
cmap=np.zeros([2, 3]).astype(np.uint8)
cmap[0,:] = np.array([0, 0, 0])
cmap[1,:] = np.array([255, 255, 255])
return cmap
class Colorization:
def __init__(self, n=2):
self.cmap = colormap()
self.cmap = torch.from_numpy(np.array(self.cmap[:n]))
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = gray_image[0] == label
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image
class Training:
def __init__(self, arguments):
self.args = arguments
self.icount = 0
self.dn_save = os.path.join(self.args.checkpointdir,'cdnet','checkpointdir','set{}'.format(self.args.cvset))
def train(self):
self.color_transform = Colorization(2)
dataset_type = self.args.dataset
cdnet_path = self.args.datadir
# Dataset loader for train and test
if dataset_type == "cdnet":
train_transform_det = trans.Compose([
trans.Scale((512,768)),
])
cdnet_TRAIN_DATA_PATH = os.path.join(cdnet_path, "dataset")
cdnet_TRAIN_LABEL_PATH = os.path.join(cdnet_path, "dataset")
cdnet_TRAIN_TXT_PATH = cdnet_path + "/dataset/supply.txt"
dataset_train = DataLoader(
CD2014.Dataset(cdnet_TRAIN_DATA_PATH,cdnet_TRAIN_LABEL_PATH,
cdnet_TRAIN_TXT_PATH,'train',transform=True,transform_med = train_transform_det),
num_workers=self.args.num_workers, batch_size=self.args.batch_size, shuffle=True)
if dataset_type == "levir":
train_transform_det = trans.Compose([
trans.Scale((512, 512)),
])
cdnet_TRAIN_DATA_PATH = os.path.join(cdnet_path, "dataset")
cdnet_TRAIN_LABEL_PATH = os.path.join(cdnet_path, "dataset")
cdnet_TRAIN_TXT_PATH = cdnet_path + "/dataset/train.txt"
dataset_train = DataLoader(
levir.Dataset(cdnet_TRAIN_DATA_PATH, cdnet_TRAIN_LABEL_PATH,
cdnet_TRAIN_TXT_PATH, 'train', transform=True, transform_med=train_transform_det),
num_workers=self.args.num_workers, batch_size=self.args.batch_size, shuffle=True)
self.test_path = os.path.join(self.dn_save, 'test')
if not os.path.exists(self.test_path):
os.makedirs(self.test_path)
# Set loss function, optimizer and learning rate
weight = torch.ones(2)
criterion = CrossEntropyLoss2d(weight.cuda())
optimizer = Adam(self.model.parameters(), lr=0.0001, betas=(0.5, 0.999))
lambda1 = lambda icount: (float)(self.args.max_iteration - icount) / (float)(self.args.max_iteration)
model_lr_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
fn_loss = os.path.join(self.dn_save,'loss.csv')
f_loss = open(fn_loss, 'w')
writer = csv.writer(f_loss)
self.writers= SummaryWriter(os.path.join(self.dn_save, 'log'))
# Training loop
icount_loss = []
while self.icount < self.args.max_iteration:
model_lr_scheduler.step()
for step, (inputs_train, mask_train) in enumerate(dataset_train):
inputs_train = inputs_train.cuda()
mask_train = mask_train.cuda()
inputs_train = Variable(inputs_train)
mask_train = Variable(mask_train)
outputs_train, feature_maps = self.model(inputs_train)
optimizer.zero_grad()
self.loss = criterion(outputs_train, mask_train[:, 0])
self.loss.backward()
optimizer.step()
self.icount += 1
icount_loss.append(self.loss.item())
print("self.icount: ", self.icount)
writer.writerow([self.icount, self.loss.item()])
if self.args.icount_save > 0 and self.icount % self.args.icount_save == 0:
self.checkpoint()
f_loss.close()
# Output results for tensorboard
def log_tbx(self, image):
writer = self.writers
writer.add_scalar('data/loss', self.loss.item(), self.icount)
writer.add_image('change detection', image, self.icount)
def checkpoint(self):
if self.args.use_corr:
filename = 'cscdnet-{0:08d}.pth'.format(self.icount)
else:
filename = 'cdnet-{0:08d}.pth'.format(self.icount)
torch.save(self.model.state_dict(), os.path.join(self.dn_save, filename))
print('save: {0} (iteration: {1})'.format(filename, self.icount))
def run(self):
self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True)
self.model = self.model.cuda()
self.train()
if __name__ == '__main__':
parser = ArgumentParser(description='Start training ...')
parser.add_argument('--checkpointdir', required=True)
parser.add_argument('--datadir', required=True)
parser.add_argument('--dataset', required=True)
parser.add_argument('--use-corr', action='store_true', help='using correlation layer')
parser.add_argument('--max-iteration', type=int, default=50000)
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--cvset', type=int, default=0)
parser.add_argument('--icount-save', type=int, default=10)
training = Training(parser.parse_args())
training.run()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化