代码拉取完成,页面将自动刷新
import numpy as np
import os, io, tarfile, shutil
from Crypto.Cipher import AES
from datetime import datetime, timezone
from tqdm import tqdm
from torch import tensor, Generator
from torch import load as tor_load, save as tor_save
from torch.utils.data import Subset, random_split
import json
import random
random.seed(42)
import config, Comm
alpha = 0.2
aggregation_weight = 1
sgx_split_size_threshold = 300
coordinator_ip = None
clients_info = None
self_ip = None
# 测存取,随便写个pytorch模型,调用写数据库后调用读取,测试是否一致
def db_save_model(cnx, model_name, jobid, state_dict, label_names):
cursor = cnx.cursor()
cur_dt = datetime.now(tz=timezone.utc)
cursor.execute(
"INSERT INTO nnmodel (name, label_names, id_job, create_time, update_time) VALUES (%s, %s, %s, %s, %s)",
(model_name, json.dumps(label_names), jobid, cur_dt, cur_dt)
)
model_id = cursor.lastrowid
exc_state = "INSERT INTO nnmodel_content (id_nnmodel, content, length, sequence, create_time, update_time) VALUES (%s, %s, %s, %s, %s, %s)"
buf = io.BytesIO()
tor_save(state_dict, buf)
arr = buf.getvalue()
buf.close()
for seq, beg in enumerate(tqdm(range(0, len(arr), 128))):
cur_dt = datetime.now(tz=timezone.utc)
block = arr[beg : (beg + 128)]
exc_val = (model_id, block, len(block), seq, cur_dt, cur_dt)
cursor.execute(exc_state, exc_val)
cnx.commit()
cursor.close()
return model_id
def db_load_state_dict(cnx, model_id):
cursor = cnx.cursor()
cursor.execute(
"SELECT content, length FROM nnmodel_content WHERE id_nnmodel=%s ORDER BY sequence",
(model_id,)
)
buf = io.BytesIO()
for (content, length) in tqdm(cursor):
buf.write(content[:length])
cursor.close()
arr = buf.getvalue()
buf.close()
buf = io.BytesIO(arr)
ret = tor_load(buf)
buf.close()
return ret
# 测试返回mask中1的比例是否小于topk_frac
def filter_param_mask(param, topk_frac):
topk_num = np.floor(param.size * (1.0 - topk_frac)).astype(np.int64)
return (np.argsort(np.abs(param), axis=None) > topk_num).reshape(param.shape)
# 测试返回数组中所有项之和是否等于total_size
def split_block(total_size, split_size):
block_num = total_size // split_size
block_rem = total_size % split_size
num_arr = [split_size] * block_num
if block_rem > 0: num_arr.append(block_rem)
ret = []
sum = 0
for num in num_arr:
sum += num
ret.append(sum)
return ret
def find_ref(model, param_name):
obj = model
for member_name in param_name.split("."):
try:
obj = getattr(obj, member_name)
except AttributeError:
obj = obj[int(member_name)]
return obj
def save_sgx_encrytped_param(filename, arr):
buf = io.BytesIO()
buf.write(arr.reshape(-1).astype(np.float64).tobytes())
cipher_bytes = AES.new(config.key, AES.MODE_CTR, initial_value=0, nonce=bytes([0])).encrypt(buf.getvalue())
buf.close()
with open(filename, "wb") as fo:
fo.write(cipher_bytes)
def load_sgx_encrypted_flat_param(filename):
buf = io.BytesIO()
with open(filename, "rb") as fi:
buf.write(fi.read())
origin_bytes = AES.new(config.key, AES.MODE_CTR, initial_value=0, nonce=bytes([0])).decrypt(buf.getvalue())
buf.close()
buf = io.BytesIO(origin_bytes)
ret = np.frombuffer(buf.getvalue(), dtype=np.float64)
buf.close()
return ret
def aggregate_masks(params):
mask_file_ls = []
if not os.path.exists(config.client_mask_dir):
os.makedirs(config.client_mask_dir)
print("Saving masks...")
with tqdm(total=len(params), bar_format="|{bar}| {n_fmt}/{total_fmt} [{postfix}]") as pb:
for name, param in params.items():
pb.set_postfix({"param": name})
mask = filter_param_mask(param, alpha)
np.save(f"{config.client_mask_dir}/{name}.npy", mask)
mask_file_ls.append(f"{name}.npy")
pb.update(1)
with open(f"{config.client_mask_dir}/NameList", "w") as fo:
fo.writelines([(ln+"\n") for ln in mask_file_ls])
with tarfile.open(config.client_mask_tar_path, "w") as tar:
tar.add(config.client_mask_dir, "")
shutil.rmtree(config.client_mask_dir)
print("Sending masks to coordinator...")
Comm.send_file(config.client_mask_tar_path, coordinator_ip, clients_info[self_ip]["mask_send_port"])
os.remove(config.client_mask_tar_path)
print("Receiving aggregated masks from coordinator...")
# handle no dir
Comm.recv_file(config.client_mask_tar_path, clients_info[self_ip]["mask_recv_port"])
with tarfile.open(config.client_mask_tar_path, "r") as tar:
tar.extractall(config.client_mask_dir)
os.remove(config.client_mask_tar_path)
def save_params(param_dict):
params = {}
for name, torchparam in param_dict.items():
params[name] = torchparam.cpu().numpy()
aggregate_masks(params)
with open(f"{config.client_mask_dir}/serverip", "r") as fi:
conf = config.get_dyna_config()
conf["server_ip"] = fi.readline().strip()
config.set_dyna_config(conf)
param_file_ls = []
masks = {}
if not os.path.exists(config.client_param_dir):
os.makedirs(config.client_param_dir)
print("Saving parameters...")
with tqdm(total=len(params), bar_format="|{bar}| {n_fmt}/{total_fmt} [{postfix}]") as pb:
for name, param in params.items():
mask = np.load(f"{config.client_mask_dir}/{name}.npy")
masks[name] = mask
cla_param = param[mask.nonzero()].reshape(-1)
uc_param = param * np.logical_not(mask)
if cla_param.size > sgx_split_size_threshold:
for i, arr in enumerate(np.split(cla_param, split_block(cla_param.size, sgx_split_size_threshold)[:-1])):
pb.set_postfix({"param": f"{name}-{i}"})
save_sgx_encrytped_param(f"{config.client_param_dir}/cla{name}-part{i}", arr)
param_file_ls.append(f"cla{name}-part{i}")
else:
pb.set_postfix({"param": name})
save_sgx_encrytped_param(f"{config.client_param_dir}/cla{name}", cla_param)
param_file_ls.append(f"cla{name}")
np.save(f"{config.client_param_dir}/uc{name}.npy", uc_param)
param_file_ls.append(f"uc{name}.npy")
pb.update(1)
with open(f"{config.client_param_dir}/NameList", "w") as fo:
fo.writelines([(ln+"\n") for ln in param_file_ls])
np.save(f"{config.client_param_dir}/AggregationWeight.npy", np.array([aggregation_weight], dtype=np.float64))
with tarfile.open(config.client_param_tar_path, "w") as tar:
tar.add(config.client_param_dir, "")
shutil.rmtree(config.client_param_dir)
shutil.rmtree(config.client_mask_dir)
return masks
def load_params(param_dict, masks, device):
"""Read parameters from aggregation, the function will overwrite on the input and return the result.
Args:
param_dict: A dict whose parameter value must be initialized as the Tensor with the correct shape
device: The device where the parameters tensors will be stored
"""
with tarfile.open(config.client_param_tar_path, "r") as tar:
tar.extractall(config.client_param_dir)
os.remove(config.client_param_tar_path)
print("Loading parameters...")
with tqdm(total=len(param_dict), bar_format="|{bar}| {n_fmt}/{total_fmt} [{postfix}]") as pb:
for name in param_dict.keys():
filename = f"{config.client_param_dir}/cla{name}"
if os.path.exists(filename):
pb.set_postfix({"param": name})
cla_param = load_sgx_encrypted_flat_param(filename)
else:
arr_ls = []
i = 0
while os.path.exists(f"{filename}-part{i}"):
pb.set_postfix({"param": f"{name}-{i}"})
arr_ls.append(load_sgx_encrypted_flat_param(f"{filename}-part{i}"))
i += 1
cla_param = np.concatenate(arr_ls)
mask = masks[name].astype(cla_param.dtype)
mask[mask.nonzero()] = cla_param
cla_param = mask.reshape(tuple(param_dict[name].shape))
uc_param = np.load(f"{config.client_param_dir}/uc{name}.npy")
param = cla_param + uc_param
param_dict[name] = tensor(param, dtype=param_dict[name].dtype, device=device)
pb.update(1)
shutil.rmtree(config.client_param_dir)
return param_dict
def biased_MNIST(ds):
percentage = 0.95
ret_indices = [[],[]]
label_indices = {}
labels = [int(e) for e in ds.targets]
for i, label in enumerate(labels):
if label not in label_indices:
label_indices[label] = []
label_indices[label].append(i)
for label, indices in label_indices.items():
spl = int(np.floor(percentage * len(indices)))
ret_indices[0] += indices[:spl]
ret_indices[1] += indices[spl:]
percentage = 1 - percentage
return [Subset(ds, random.sample(indices, len(indices))) for indices in ret_indices]
def random_CIFAR(ds):
return random_split(ds, [0.25,0.25,0.25,0.25], Generator().manual_seed(42))
def random_CIFAR2(ds):
return random_split(ds, [0.5, 0.5], Generator().manual_seed(42))
def biased_CIFAR(ds):
a,b=0.8,0.07
ret_indices = [[],[],[],[]]
label_indices = {}
labels = [int(e) for e in ds.targets]
for i, label in enumerate(labels):
if label not in label_indices:
label_indices[label] = []
label_indices[label].append(i)
for i, (label, indices) in enumerate(label_indices.items()):
indices = list(indices)
n = len(indices)
spl = int(np.floor(a * n))
ret_indices[i % 4] += indices[:spl]
del indices[:spl]
spl = int(np.floor(b * n))
ret_indices[(i + 1) % 4] += indices[:spl]
del indices[:spl]
spl = int(np.floor(b * n))
ret_indices[(i + 2) % 4] += indices[:spl]
del indices[:spl]
ret_indices[(i + 3) % 4] += indices
return [Subset(ds, random.sample(indices, len(indices))) for indices in ret_indices]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。