加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
single_train.py 1.46 KB
一键复制 编辑 原始数据 按行查看 历史
electronick_pro 提交于 2023-08-21 11:21 . predict
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import torch
import ModelDef, FedUtils
import mysql.connector
learning_rate = 1e-3
epochs = 60
batch_size = 64
if __name__ == "__main__":
trainset = datasets.MNIST(root=r'./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0, 1)]))
testset = datasets.MNIST(root=r'./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0, 1)]))
train_dataloader = DataLoader(trainset, batch_size=batch_size)
test_dataloader = DataLoader(testset, batch_size=batch_size)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = ModelDef.LeNet([str(e) for e in range(10)]).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for i in range(epochs):
print(f"===Epoch {i}===")
cor_train, trl = model.model_train(device, train_dataloader, loss_fn, optimizer)
cor_test, tsl, confusions = model.model_test(device, test_dataloader, loss_fn)
cnx = mysql.connector.connect(
user="sgxclientuser",
password="1234aA!!",
host="127.0.0.1",
database="sgxclient"
)
FedUtils.db_save_model(cnx, "LeNet", 0, model.state_dict(), model.label_names)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化