加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
coor_mask.py 7.67 KB
一键复制 编辑 原始数据 按行查看 历史
electronick_pro 提交于 2023-09-18 17:38 . to test
import Comm, config
import os, tarfile, shutil, argparse
import numpy as np
import ModelDef
import torch
import random
random.seed(42)
# 假测
def aggregate_mask(arrs):
ret = np.full(arrs[0].shape, True, np.bool_)
for arr in arrs:
ret = np.logical_and(ret, arr)
return ret, 1.0 * np.sum(ret) / ret.size
def group_aggregate_masks(clients_ip, name_list):
adhesions = []
masks = {}
for name in name_list:
arrs = []
for ip in clients_ip:
arrs.append(np.load(f"{config.coor_mask_dir(ip)}/{name}.npy"))
mask, ad = aggregate_mask(arrs)
masks[name] = mask
adhesions.append(ad)
return masks, np.average(np.array(adhesions))
# 两个分组算法假测
# TODO: algorithm design - now it takes random strategy
def grouping_algorithm(clients_ip, name_list, servers_ip):
"""
Generate a dict like:
{
'server0.ip.address': {
'clients': ['client0.ip.address','client1.ip.address','...'],
'masks': {
'param.name.0': <ndarray>,
'param.name.1': <ndarray>,
'...': <ndarray>
},
},
'server1.ip.address': {
'clients': ['client2.ip.address','client3.ip.address','...'],
'masks': {
'param.name.0': <ndarray>,
'param.name.1': <ndarray>,
'...': <ndarray>
},
},
'...': { ... }
}
"""
ret = {}
clients_ip = random.sample(clients_ip, k=len(clients_ip))
client_num = len(clients_ip)
server_num = len(servers_ip)
n = client_num // server_num
rem = client_num % server_num
num_arr = [n+1] * rem + [n] * (server_num - rem)
for i, server_ip in enumerate(servers_ip):
ret[server_ip] = {
"clients": [],
"masks": {}
}
ret[server_ip]["clients"] += clients_ip[:num_arr[i]]
del clients_ip[:num_arr[i]]
masks, p = group_aggregate_masks(ret[server_ip]["clients"], name_list)
ret[server_ip]["masks"] = masks
print(f"Aggregated proportion in server {server_ip}: {p:>.6f}\n")
return ret
class DJS:
def __init__(self, indices) -> None:
self.nodes = indices
def find(self, idx):
while idx != self.nodes[idx]:
idx = self.nodes[idx] = self.nodes[self.nodes[idx]]
return idx
# merge idx2 to idx1
def union(self, idx1, idx2):
idx1, idx2 = self.find(idx1), self.find(idx2)
self.nodes[idx2] = idx1
def grouping_algorithm2(clients_ip, name_list, servers_ip):
"""
Citation: the grouping method is inspired from a paper on arxiv.org named
PFA: Privacy-preserving method Federated Adaptation for Effective Model Personalization
by Bingyan Liu, Yao Guo and Xiangqun Chen
Disjoint set is used to form groups where masks are similar.
"""
n, m = len(clients_ip), len(servers_ip)
ls = []
for i in range(n):
for j in range(i, n):
_, p = group_aggregate_masks([clients_ip[i], clients_ip[j]], name_list)
ls.append(((i, j), p))
ls.sort(key=lambda e: e[1], reverse=True)
djs = DJS(list(range(n)))
for e in ls:
# reduce n clients to m servers
if n <= m:
break
u, v = e[0] # the client pair with max similarity
u = djs.find(u)
v = djs.find(v)
if u != v:
djs.union(u, v)
n -= 1
group_indices_sparse = [[] for _ in range(len(djs.nodes))]
for i in range(len(djs.nodes)):
idx = djs.find(i)
group_indices_sparse[idx].append(i)
group_indices = []
for e in group_indices_sparse:
if len(e) > 0:
group_indices.append(e)
ret = {}
for i, server_ip in enumerate(servers_ip):
ret[server_ip] = {
"clients": [],
"masks": {}
}
ret[server_ip]["clients"] += [clients_ip[idx] for idx in group_indices[i]]
masks, p = group_aggregate_masks(ret[server_ip]["clients"], name_list)
ret[server_ip]["masks"] = masks
print(f"Aggregated proportion in server {server_ip}: {p:>.6f}\n")
return ret
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--coordinator-initialized", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
conf = config.get_dyna_config()
args = get_args()
if args.coordinator_initialized:
print("Initializing model...")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = getattr(ModelDef, args.model)().to(device)
torch.save(model.state_dict(), config.init_model_path)
for ip, info in conf["clients_info"].items():
print(f"Sending init model to client {ip}...")
Comm.send_file(config.init_model_path, ip, info["model_recv_port"])
os.remove(config.init_model_path)
fed_rounds = (conf["epochs"] - 1) // conf["epochs_per_fed"]
for _ in range(fed_rounds):
if not os.path.exists(os.path.dirname(config.coor_mask_tar_path("0"))):
os.makedirs(os.path.dirname(config.coor_mask_tar_path("0")))
for ip, info in conf["clients_info"].items():
print(f"Receiving masks from client {ip}...")
Comm.recv_file(config.coor_mask_tar_path(ip), info["mask_send_port"])
print("Unzipping masks...")
for ip, info in conf["clients_info"].items():
with tarfile.open(config.coor_mask_tar_path(ip), "r") as tar:
tar.extractall(config.coor_mask_dir(ip))
os.remove(config.coor_mask_tar_path(ip))
print("Grouping masks...")
name_list = []
with open(f"{config.coor_mask_dir(next(iter(conf['clients_info'].keys())))}/NameList", "r") as fi:
for ln in fi:
if ln is None or ln.strip() == "": continue
name_list.append(ln.strip()[:-4])
groups = grouping_algorithm(list(conf["clients_info"].keys()), name_list, list(conf["servers_info"].keys()))
if not os.path.exists(os.path.dirname(config.coor_addr_path(next(iter(groups.keys()))))):
os.makedirs(os.path.dirname(config.coor_addr_path(next(iter(groups.keys())))))
for server_ip, group in groups.items():
with open(config.coor_addr_path(server_ip), "w") as fo:
fo.writelines([(ln+"\n") for ln in group["clients"]])
print(f"Sending client ips to their server {server_ip}...")
Comm.send_file(config.coor_addr_path(server_ip), server_ip, conf["servers_info"][server_ip]["coor_recv_port"])
os.remove(config.coor_addr_path(server_ip))
if not os.path.exists(config.coor_aggregated_mask_dir(server_ip)):
os.makedirs(config.coor_aggregated_mask_dir(server_ip))
with open(f"{config.coor_aggregated_mask_dir(server_ip)}/serverip", "w") as fo:
fo.writelines([server_ip])
for name, mask in group["masks"].items():
np.save(f"{config.coor_aggregated_mask_dir(server_ip)}/{name}.npy", mask)
with tarfile.open(config.coor_aggregated_mask_tar_path(server_ip), "w") as tar:
tar.add(config.coor_aggregated_mask_dir(server_ip), "")
shutil.rmtree(config.coor_aggregated_mask_dir(server_ip))
for client_ip in group["clients"]:
print(f"Sending masks and server ip {server_ip} to client {client_ip}...")
Comm.send_file(config.coor_aggregated_mask_tar_path(server_ip), client_ip, conf["clients_info"][client_ip]["mask_recv_port"])
os.remove(config.coor_aggregated_mask_tar_path(server_ip))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化