加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
server_send_param.py 1.65 KB
一键复制 编辑 原始数据 按行查看 历史
electronick_pro 提交于 2023-06-26 13:41 . bugfix
import Comm, config
import tarfile, os, shutil
import numpy as np
import FedUtils
if __name__ == "__main__":
conf = config.get_dyna_config()
name_list = []
with open(f"{config.server_param_dir(conf['clients_ip'][0])}/NameList", "r") as fi:
for ln in fi:
if ln is None or ln.strip() == "": continue
if ln.startswith("clauc"):
name_list.append(ln.strip().split("-")[0][3:])
shutil.rmtree(os.path.dirname(config.server_param_dir("0")))
# Verify
correct, total = 0, 0
for name in name_list:
filename = f"{config.server_aggregated_param_dir}/cla{name}"
if os.path.exists(filename):
cla_param = FedUtils.load_sgx_encrypted_flat_param(filename)
else:
arr_ls = []
i = 0
while os.path.exists(f"{filename}-part{i}"):
arr_ls.append(FedUtils.load_sgx_encrypted_flat_param(f"{filename}-part{i}"))
i += 1
cla_param = np.concatenate(arr_ls)
param = np.load(f"{config.server_aggregated_param_dir}/{name}").reshape(-1)
total += param.size
correct += np.sum(np.abs(param - cla_param) < 1e-6)
print(f"The verfication accuracy of the epoch is {(correct / total):>7f}")
with tarfile.open(config.server_aggregated_param_tar_path, "w") as tar:
tar.add(config.server_aggregated_param_dir, "")
shutil.rmtree(config.server_aggregated_param_dir)
for ip in conf["clients_ip"]:
Comm.send_file(config.server_aggregated_param_tar_path, ip, conf["clients_info"][ip]["param_recv_port"])
os.remove(config.server_aggregated_param_tar_path)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化