代码拉取完成,页面将自动刷新
import torch
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, ToTensor, Normalize
import numpy as np
import ModelDef, config, FedUtils
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
DLG_epoches = 600
FedUtils.alpha = 0.9
def fetch_gradients(optimizer: torch.optim.Optimizer, is_paritial_leaked):
grads = {}
for i, param in enumerate(optimizer.param_groups[0]["params"]):
grad = param.grad.cpu().numpy()
if is_paritial_leaked:
mask = FedUtils.filter_param_mask(grad, FedUtils.alpha)
grads[str(i)] = torch.tensor(grad*np.logical_not(mask), dtype=param.dtype, device=param.device)
else:
grads[str(i)] = param.grad.detach().clone()
return grads
def grad_sum_dist(grads, target_grads):
ret = torch.tensor(0, dtype=grads[0].dtype, device=grads[0].device)
for i, grad in enumerate(grads):
dif = grad - target_grads[str(i)]
ret += torch.sqrt(torch.sum(dif ** 2))
return ret
def pic_avg_dist(X_dum, X):
dif = X_dum - X
dist = torch.sqrt(torch.sum(dif ** 2, (1,2,3)))
ret = torch.sum(dist) / torch.tensor(dist.size(), dtype=dist.dtype, device=X_dum.device)
return ret.item()
def feed_to_model(model, loss_fn, X_dum, pred_dum, target_grads):
pred = model(X_dum)
y_dum = torch.nn.functional.softmax(pred_dum, 1)
loss = loss_fn(-pred, y_dum)
grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
return grad_sum_dist(grads, target_grads)
def fig_results(X_dum_history, X):
fig = plt.figure(tight_layout=True)
gs = gridspec.GridSpec(X.shape[0], len(X_dum_history) + 1)
for i in range(X.shape[0]):
for j, X_dum in enumerate(X_dum_history):
ax = fig.add_subplot(gs[i,j])
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(X_dum[i,0].detach().cpu().numpy(), cmap="binary")
ax = fig.add_subplot(gs[i,len(X_dum_history)])
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(X[i,0].detach().cpu().numpy(), cmap="binary")
plt.show()
class CrossEntropyOneHot(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self, input, target):
P = torch.nn.functional.softmax(input, 1)
H = target * torch.log(P)
return torch.mean(torch.sum(H, 1), 0)
if __name__ == "__main__":
conf = config.get_dyna_config()
conf["batch_size"] = 2
trainset = datasets.MNIST(root=r'./data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0, 1)]))
testset = datasets.MNIST(root=r'./data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0, 1)]))
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
client_trainset = random_split(trainset, [0.5,0.5], torch.Generator().manual_seed(42))[0]
train_dataloader = DataLoader(client_trainset, batch_size=conf["batch_size"])
test_dataloader = DataLoader(testset, batch_size=conf["batch_size"])
model = ModelDef.LeNet().to(device)
# model.load_state_dict(torch.load(config.init_model_path))
model = model.to(device)
loss_fn = CrossEntropyOneHot()
optimizer = torch.optim.SGD(model.parameters(), lr=float(conf["learning_rate"]))
model.train()
# A data batch
X, y = next(iter(train_dataloader))
X, y = X.to(device), torch.nn.functional.one_hot(y, 10).to(device)
pred = model(X)
loss = loss_fn(-pred, y)
optimizer.zero_grad()
loss.backward()
# Gradients leaked
target_grads = fetch_gradients(optimizer, True)
optimizer.step()
# Initialize random dummy data and new model
X_dum, pred_dum = torch.normal(0, 1, size=X.shape, device=device, requires_grad=True), torch.normal(0, 1, size=(conf["batch_size"], 10), device=device, requires_grad=True)
optim_DLG = torch.optim.SGD([X_dum, pred_dum], lr=1)
X_dum_history = []
X_dum_history.append(X_dum.detach().clone())
grad_dist_min = np.inf
pic_dist_min = np.inf
grad_dist_min_t = -1
pic_dist_min_t = -1
for t in range(DLG_epoches):
def closure():
optim_DLG.zero_grad()
DLG_loss = feed_to_model(model, loss_fn, X_dum, pred_dum, target_grads)
DLG_loss.backward()
return DLG_loss
grad_dist = optim_DLG.step(closure)
pic_dist = pic_avg_dist(X_dum, X)
print(f"Epoch {t}/{DLG_epoches} grad dist: {grad_dist:>7f}, pic dist: {pic_dist:>7f}")
if grad_dist < grad_dist_min:
grad_dist_min = grad_dist
grad_dist_min_t = t
if pic_dist < pic_dist_min:
pic_dist_min = pic_dist
pic_dist_min_t = t
if (t + 1) % 60 == 0:
X_dum_history.append(X_dum.detach().clone())
X_dum_history.append(X_dum.detach().clone())
print(f"Min grad dist @Epoch {grad_dist_min_t}, min pic dist @Epoch {pic_dist_min_t}")
fig_results(X_dum_history, X)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。