代码拉取完成,页面将自动刷新
import torch
from torch import Tensor, nn, tensor
from torchvision.transforms.functional import normalize
from tqdm import tqdm
import numpy as np
import os
class RBFLayer(nn.Linear):
def forward(self, input) -> Tensor:
return torch.cat([torch.sum((input-w) ** 2, 1).reshape(-1,1) for w in self.weight], 1)
class LeNet(nn.Module):
A = 1.7159
S = 2/3
def __init__(self, label_names=[str(i) for i in range(10)], *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.label_names = label_names
self.c1 = nn.Conv2d(1,6,5,1,2)
self.s2 = nn.AvgPool2d(2,2)
self.c3s = nn.ModuleList([])
for i in range(6):
self.c3s.append(nn.Conv2d(3,1,5))
for i in range(9):
self.c3s.append(nn.Conv2d(4,1,5))
self.c3s.append(nn.Conv2d(6,1,5))
self.s4 = nn.AvgPool2d(2,2)
self.c5 = nn.Conv2d(16,120,5)
self.f6 = nn.Linear(120, 84)
self.out = RBFLayer(84, 10, bias=False)
def forward(self, x):
x = self.A * nn.functional.tanh(self.S * self.c1(x))
x = self.A * nn.functional.tanh(self.s2(x))
xs = [None] * 16
xs[0] = self.c3s[0](x[:,:3,:,:])
xs[1] = self.c3s[1](x[:,1:4,:,:])
xs[2] = self.c3s[2](x[:,2:5,:,:])
xs[3] = self.c3s[3](x[:,3:6,:,:])
xs[4] = self.c3s[4](x[:,[4,5,0],:,:])
xs[5] = self.c3s[5](x[:,[5,0,1],:,:])
xs[6] = self.c3s[6](x[:,:4,:,:])
xs[7] = self.c3s[7](x[:,1:5,:,:])
xs[8] = self.c3s[8](x[:,2:6,:,:])
xs[9] = self.c3s[9](x[:,[3,4,5,0],:,:])
xs[10] = self.c3s[10](x[:,[4,5,0,1],:,:])
xs[11] = self.c3s[11](x[:,[5,0,1,2],:,:])
xs[12] = self.c3s[12](x[:,[0,1,3,4],:,:])
xs[13] = self.c3s[13](x[:,[1,2,4,5],:,:])
xs[14] = self.c3s[14](x[:,[2,3,5,0],:,:])
xs[15] = self.c3s[15](x)
x = self.A * nn.functional.tanh(self.S * torch.cat(xs, 1))
x = self.A * nn.functional.tanh(self.S * self.s4(x))
x = self.A * nn.functional.tanh(self.S * self.c5(x))
x = nn.Flatten()(x)
x = self.A * nn.functional.tanh(self.S * self.f6(x))
x = self.out(x)
return x
def model_train(self, device, dataloader, loss_fn, optimizer):
size = len(dataloader.dataset)
self.train()
train_loss, correct = 0, 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = self(X)
loss = loss_fn(-pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
correct += (pred.argmin(1) == y).type(torch.float).sum().item()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss /= (batch + 1)
correct /= size
print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f}")
return correct, train_loss
def model_test(self, device, dataloader, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
self.eval()
test_loss, correct = 0, 0
confusions = {}
for label in self.label_names:
confusions[label] = {
"TP": 0,
"FP": 0,
"FN": 0,
}
with torch.no_grad():
for X, y in tqdm(dataloader, desc="Test"):
X, y = X.to(device), y.to(device)
pred_val = self(X)
pred_res = pred_val.argmin(1)
test_loss += loss_fn(-pred_val, y).item()
correct += (pred_res == y).type(torch.float).sum().item()
for i, label in enumerate(self.label_names):
confusions[label]["TP"] += torch.logical_and(pred_res == i, y == i).type(torch.float).sum().item()
confusions[label]["FP"] += torch.logical_and(pred_res == i, y != i).type(torch.float).sum().item()
confusions[label]["FN"] += torch.logical_and(pred_res != i, y == i).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
return correct, test_loss, confusions
def model_detect(self, device, dataset, batch_size):
self.eval()
path_batches = []
im_batches = []
num = 0
print("Making batches...")
for i, (path, im, _, _, _) in enumerate(tqdm(dataset)):
if i % batch_size == 0:
path_batches.append([])
im_batches.append([])
path_batches[-1].append(path)
im_batches[-1].append(im[:1,...])
res = {}
with torch.no_grad():
for path_batch, im_batch in tqdm(zip(path_batches, im_batches)):
X = tensor(np.stack(im_batch, 0), dtype=torch.float, device=device)
pred_val = self(X)
pred_res = pred_val.argmin(1)
for path, pr in zip(path_batch, list(pred_res)):
res[os.path.basename(path)] = self.label_names[int(pr.item())]
return res
class ResUnit(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.c1 = nn.Conv2d(in_channels, out_channels, 3, 1 if in_channels==out_channels else 2, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.c2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
if in_channels != out_channels:
self.c0 = nn.Conv2d(in_channels, out_channels, 1, 2, bias=False)
self.bn0 = nn.BatchNorm2d(out_channels)
def forward(self, x):
y = nn.functional.relu(self.bn1(self.c1(x)))
y = self.bn2(self.c2(y))
if self.in_channels != self.out_channels:
x = self.bn0(self.c0(x))
return nn.functional.relu(x + y)
class ResNet(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
self.bn = nn.BatchNorm2d(64)
self.res = nn.Sequential(
ResUnit(64, 64),
ResUnit(64, 64),
ResUnit(64, 128),
ResUnit(128, 128),
ResUnit(128, 256),
ResUnit(256, 256),
ResUnit(256, 512),
ResUnit(512, 512)
)
self.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.conv(x)
x = self.res(x)
x = torch.mean(x, (2,3))
x = self.fc(x)
return x
def model_train(self, device, dataloader, loss_fn, optimizer):
size = len(dataloader.dataset)
self.train()
train_loss, correct = 0, 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = self(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss /= (batch + 1)
correct /= size
print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f}")
return correct, train_loss
def model_test(self, device, dataloader, loss_fn, labels_idx):
size = len(dataloader.dataset)
num_batches = len(dataloader)
self.eval()
test_loss, correct = 0, 0
confusions = {}
for label in labels_idx:
confusions[label] = {
"TP": 0,
"FP": 0,
"FN": 0,
}
with torch.no_grad():
for X, y in tqdm(dataloader, desc="Test"):
X, y = X.to(device), y.to(device)
pred_val = self(X)
pred_res = pred_val.argmax(1)
test_loss += loss_fn(pred_val, y).item()
correct += (pred_res == y).type(torch.float).sum().item()
for label in labels_idx:
confusions[label]["TP"] += torch.logical_and(pred_res == label, y == label).type(torch.float).sum().item()
confusions[label]["FP"] += torch.logical_and(pred_res == label, y != label).type(torch.float).sum().item()
confusions[label]["FN"] += torch.logical_and(pred_res != label, y == label).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
return correct, test_loss, confusions
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。