代码拉取完成,页面将自动刷新
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torch
import ModelDef, FedUtils
import mysql.connector
learning_rate = 1e-3
epochs = 60
batch_size = 64
if __name__ == "__main__":
trainset = datasets.MNIST(root=r'./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0, 1)]))
testset = datasets.MNIST(root=r'./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0, 1)]))
train_dataloader = DataLoader(trainset, batch_size=batch_size)
test_dataloader = DataLoader(testset, batch_size=batch_size)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = ModelDef.LeNet([str(e) for e in range(10)]).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for i in range(epochs):
print(f"===Epoch {i}===")
cor_train, trl = model.model_train(device, train_dataloader, loss_fn, optimizer)
cor_test, tsl, confusions = model.model_test(device, test_dataloader, loss_fn)
cnx = mysql.connector.connect(
user="sgxclientuser",
password="1234aA!!",
host="127.0.0.1",
database="sgxclient"
)
FedUtils.db_save_model(cnx, "LeNet", 0, model.state_dict(), model.label_names)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。