加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
FedUtils.py 10.14 KB
一键复制 编辑 原始数据 按行查看 历史
electronick_pro 提交于 2023-09-18 17:38 . to test
import numpy as np
import os, io, tarfile, shutil
from Crypto.Cipher import AES
from datetime import datetime, timezone
from tqdm import tqdm
from torch import tensor, Generator
from torch import load as tor_load, save as tor_save
from torch.utils.data import Subset, random_split
import json
import random
random.seed(42)
import config, Comm
alpha = 0.2
aggregation_weight = 1
sgx_split_size_threshold = 300
coordinator_ip = None
clients_info = None
self_ip = None
# 测存取,随便写个pytorch模型,调用写数据库后调用读取,测试是否一致
def db_save_model(cnx, model_name, jobid, state_dict, label_names):
cursor = cnx.cursor()
cur_dt = datetime.now(tz=timezone.utc)
cursor.execute(
"INSERT INTO nnmodel (name, label_names, id_job, create_time, update_time) VALUES (%s, %s, %s, %s, %s)",
(model_name, json.dumps(label_names), jobid, cur_dt, cur_dt)
)
model_id = cursor.lastrowid
exc_state = "INSERT INTO nnmodel_content (id_nnmodel, content, length, sequence, create_time, update_time) VALUES (%s, %s, %s, %s, %s, %s)"
buf = io.BytesIO()
tor_save(state_dict, buf)
arr = buf.getvalue()
buf.close()
for seq, beg in enumerate(tqdm(range(0, len(arr), 128))):
cur_dt = datetime.now(tz=timezone.utc)
block = arr[beg : (beg + 128)]
exc_val = (model_id, block, len(block), seq, cur_dt, cur_dt)
cursor.execute(exc_state, exc_val)
cnx.commit()
cursor.close()
return model_id
def db_load_state_dict(cnx, model_id):
cursor = cnx.cursor()
cursor.execute(
"SELECT content, length FROM nnmodel_content WHERE id_nnmodel=%s ORDER BY sequence",
(model_id,)
)
buf = io.BytesIO()
for (content, length) in tqdm(cursor):
buf.write(content[:length])
cursor.close()
arr = buf.getvalue()
buf.close()
buf = io.BytesIO(arr)
ret = tor_load(buf)
buf.close()
return ret
# 测试返回mask中1的比例是否小于topk_frac
def filter_param_mask(param, topk_frac):
topk_num = np.floor(param.size * (1.0 - topk_frac)).astype(np.int64)
return (np.argsort(np.abs(param), axis=None) > topk_num).reshape(param.shape)
# 测试返回数组中所有项之和是否等于total_size
def split_block(total_size, split_size):
block_num = total_size // split_size
block_rem = total_size % split_size
num_arr = [split_size] * block_num
if block_rem > 0: num_arr.append(block_rem)
ret = []
sum = 0
for num in num_arr:
sum += num
ret.append(sum)
return ret
def find_ref(model, param_name):
obj = model
for member_name in param_name.split("."):
try:
obj = getattr(obj, member_name)
except AttributeError:
obj = obj[int(member_name)]
return obj
def save_sgx_encrytped_param(filename, arr):
buf = io.BytesIO()
buf.write(arr.reshape(-1).astype(np.float64).tobytes())
cipher_bytes = AES.new(config.key, AES.MODE_CTR, initial_value=0, nonce=bytes([0])).encrypt(buf.getvalue())
buf.close()
with open(filename, "wb") as fo:
fo.write(cipher_bytes)
def load_sgx_encrypted_flat_param(filename):
buf = io.BytesIO()
with open(filename, "rb") as fi:
buf.write(fi.read())
origin_bytes = AES.new(config.key, AES.MODE_CTR, initial_value=0, nonce=bytes([0])).decrypt(buf.getvalue())
buf.close()
buf = io.BytesIO(origin_bytes)
ret = np.frombuffer(buf.getvalue(), dtype=np.float64)
buf.close()
return ret
def aggregate_masks(params):
mask_file_ls = []
if not os.path.exists(config.client_mask_dir):
os.makedirs(config.client_mask_dir)
print("Saving masks...")
with tqdm(total=len(params), bar_format="|{bar}| {n_fmt}/{total_fmt} [{postfix}]") as pb:
for name, param in params.items():
pb.set_postfix({"param": name})
mask = filter_param_mask(param, alpha)
np.save(f"{config.client_mask_dir}/{name}.npy", mask)
mask_file_ls.append(f"{name}.npy")
pb.update(1)
with open(f"{config.client_mask_dir}/NameList", "w") as fo:
fo.writelines([(ln+"\n") for ln in mask_file_ls])
with tarfile.open(config.client_mask_tar_path, "w") as tar:
tar.add(config.client_mask_dir, "")
shutil.rmtree(config.client_mask_dir)
print("Sending masks to coordinator...")
Comm.send_file(config.client_mask_tar_path, coordinator_ip, clients_info[self_ip]["mask_send_port"])
os.remove(config.client_mask_tar_path)
print("Receiving aggregated masks from coordinator...")
# handle no dir
Comm.recv_file(config.client_mask_tar_path, clients_info[self_ip]["mask_recv_port"])
with tarfile.open(config.client_mask_tar_path, "r") as tar:
tar.extractall(config.client_mask_dir)
os.remove(config.client_mask_tar_path)
def save_params(param_dict):
params = {}
for name, torchparam in param_dict.items():
params[name] = torchparam.cpu().numpy()
aggregate_masks(params)
with open(f"{config.client_mask_dir}/serverip", "r") as fi:
conf = config.get_dyna_config()
conf["server_ip"] = fi.readline().strip()
config.set_dyna_config(conf)
param_file_ls = []
masks = {}
if not os.path.exists(config.client_param_dir):
os.makedirs(config.client_param_dir)
print("Saving parameters...")
with tqdm(total=len(params), bar_format="|{bar}| {n_fmt}/{total_fmt} [{postfix}]") as pb:
for name, param in params.items():
mask = np.load(f"{config.client_mask_dir}/{name}.npy")
masks[name] = mask
cla_param = param[mask.nonzero()].reshape(-1)
uc_param = param * np.logical_not(mask)
if cla_param.size > sgx_split_size_threshold:
for i, arr in enumerate(np.split(cla_param, split_block(cla_param.size, sgx_split_size_threshold)[:-1])):
pb.set_postfix({"param": f"{name}-{i}"})
save_sgx_encrytped_param(f"{config.client_param_dir}/cla{name}-part{i}", arr)
param_file_ls.append(f"cla{name}-part{i}")
else:
pb.set_postfix({"param": name})
save_sgx_encrytped_param(f"{config.client_param_dir}/cla{name}", cla_param)
param_file_ls.append(f"cla{name}")
np.save(f"{config.client_param_dir}/uc{name}.npy", uc_param)
param_file_ls.append(f"uc{name}.npy")
pb.update(1)
with open(f"{config.client_param_dir}/NameList", "w") as fo:
fo.writelines([(ln+"\n") for ln in param_file_ls])
np.save(f"{config.client_param_dir}/AggregationWeight.npy", np.array([aggregation_weight], dtype=np.float64))
with tarfile.open(config.client_param_tar_path, "w") as tar:
tar.add(config.client_param_dir, "")
shutil.rmtree(config.client_param_dir)
shutil.rmtree(config.client_mask_dir)
return masks
def load_params(param_dict, masks, device):
"""Read parameters from aggregation, the function will overwrite on the input and return the result.
Args:
param_dict: A dict whose parameter value must be initialized as the Tensor with the correct shape
device: The device where the parameters tensors will be stored
"""
with tarfile.open(config.client_param_tar_path, "r") as tar:
tar.extractall(config.client_param_dir)
os.remove(config.client_param_tar_path)
print("Loading parameters...")
with tqdm(total=len(param_dict), bar_format="|{bar}| {n_fmt}/{total_fmt} [{postfix}]") as pb:
for name in param_dict.keys():
filename = f"{config.client_param_dir}/cla{name}"
if os.path.exists(filename):
pb.set_postfix({"param": name})
cla_param = load_sgx_encrypted_flat_param(filename)
else:
arr_ls = []
i = 0
while os.path.exists(f"{filename}-part{i}"):
pb.set_postfix({"param": f"{name}-{i}"})
arr_ls.append(load_sgx_encrypted_flat_param(f"{filename}-part{i}"))
i += 1
cla_param = np.concatenate(arr_ls)
mask = masks[name].astype(cla_param.dtype)
mask[mask.nonzero()] = cla_param
cla_param = mask.reshape(tuple(param_dict[name].shape))
uc_param = np.load(f"{config.client_param_dir}/uc{name}.npy")
param = cla_param + uc_param
param_dict[name] = tensor(param, dtype=param_dict[name].dtype, device=device)
pb.update(1)
shutil.rmtree(config.client_param_dir)
return param_dict
def biased_MNIST(ds):
percentage = 0.95
ret_indices = [[],[]]
label_indices = {}
labels = [int(e) for e in ds.targets]
for i, label in enumerate(labels):
if label not in label_indices:
label_indices[label] = []
label_indices[label].append(i)
for label, indices in label_indices.items():
spl = int(np.floor(percentage * len(indices)))
ret_indices[0] += indices[:spl]
ret_indices[1] += indices[spl:]
percentage = 1 - percentage
return [Subset(ds, random.sample(indices, len(indices))) for indices in ret_indices]
def random_CIFAR(ds):
return random_split(ds, [0.25,0.25,0.25,0.25], Generator().manual_seed(42))
def random_CIFAR2(ds):
return random_split(ds, [0.5, 0.5], Generator().manual_seed(42))
def biased_CIFAR(ds):
a,b=0.8,0.07
ret_indices = [[],[],[],[]]
label_indices = {}
labels = [int(e) for e in ds.targets]
for i, label in enumerate(labels):
if label not in label_indices:
label_indices[label] = []
label_indices[label].append(i)
for i, (label, indices) in enumerate(label_indices.items()):
indices = list(indices)
n = len(indices)
spl = int(np.floor(a * n))
ret_indices[i % 4] += indices[:spl]
del indices[:spl]
spl = int(np.floor(b * n))
ret_indices[(i + 1) % 4] += indices[:spl]
del indices[:spl]
spl = int(np.floor(b * n))
ret_indices[(i + 2) % 4] += indices[:spl]
del indices[:spl]
ret_indices[(i + 3) % 4] += indices
return [Subset(ds, random.sample(indices, len(indices))) for indices in ret_indices]
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化