diff --git a/examples/community/cryptosfl/README.md b/examples/community/cryptosfl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..16676b231a94c825b97b0c2e2a1da2c0f1940120 --- /dev/null +++ b/examples/community/cryptosfl/README.md @@ -0,0 +1,88 @@ +## 拆分联邦学习的函数加密防御 + +## 描述 + +本项目是基于MindSpore框架针对拆分联邦学习模型的防御手段,利用函数加密技术来保护拆分层输出,并保持模型的精度。 + +## 模型 + +客户端模型采用基于MindSpore框架的MLP模型,包含1层隐藏层,其输出维度与服务器的输入维度保持一致。 + +## 数据集 + +使用手写数字数据集MNIST作为训练和测试数据集。使用以下命令下载后保存在`datasets`文件夹中。 + +```bash +cd datasets/ +python generate_data.py +``` + +## 加密配置 + +采用预生成的离散对数表加速函数加密的加密及解密过程。使用以下命令生成该表后保存在`config`文件夹中。 + +```bash +cd config/ +python generate_config.py +``` + +## 原型论文 + +Xu R, Joshi J, Li C. Nn-emd: Efficiently training neural networks using encrypted multi-sourced datasets[J]. IEEE +Transactions on Dependable and Secure Computing, 2021, 19(4): 2807-2820.[[pdf]](https://arxiv.org/pdf/2012.10547) + +Thapa C, Arachchige P C M, Camtepe S, et al. Splitfed: When federated learning meets split learning[C]//Proceedings of +the AAAI Conference on Artificial Intelligence. 2022, 36(8): 8485-8493.[[pdf]](https://arxiv.org/pdf/2004.12088) + +## 环境要求 + +Mindspore >= 1.9,硬件平台为GPU/CPU/Ascend。 + +## 脚本说明 + +```markdown +├── README.md +├── config +│ ├── generate_config.py //生成函数加密所需的离散对数表 +├── crypto +│ ├── sife_dynamic.py //基于内积的函数加密的实现 +│ ├── utils.py //函数加密实现所需的常用函数 +├── datasets +│ ├── generate_data.py //加载数据的相关代码 +├── nn +│ ├── smc.py //参与方加密及解密 +│ ├── worker.py //参与方模型训练及预测 +│ ├── utils.py //初始化参与方模型的函数 +└── example_cryptosfl.py //在拆分联邦学习上的应用例 +``` + +## 训练依赖 + +```markdown +numpy == 1.26.4 +gmpy2 == 2.1.5 +``` + +## 训练过程 + +```bash +python example_cryptosfl.py +``` + +## 默认训练参数 + +```markdown +epochs = 20 num_users = 5 batch_size = 50 model = 'mlp64' lr = 1e-3 +``` + +## 实验结果 + +```markdown +epoch 1, test accuracy = 93.50% epoch 2, test accuracy = 94.74% epoch 3, test accuracy = 95.43% epoch 4, test accuracy = +96.07% epoch 5, test accuracy = 96.22% epoch 6, test accuracy = 96.49% epoch 7, test accuracy = 96.77% epoch 8, test +accuracy = 96.79% epoch 9, test accuracy = 97.02% epoch 10, test accuracy = 97.05% epoch 11, test accuracy = 97.19% +epoch 12, test accuracy = 97.42% epoch 13, test accuracy = 97.59% epoch 14, test accuracy = 97.64% epoch 15, test +accuracy = 97.65% epoch 16, test accuracy = 97.76% epoch 17, test accuracy = 97.86% epoch 18, test accuracy = 97.87% +epoch 19, test accuracy = 97.95% epoch 20, test accuracy = 97.92% +``` + diff --git a/examples/community/cryptosfl/config/generate_config.py b/examples/community/cryptosfl/config/generate_config.py new file mode 100644 index 0000000000000000000000000000000000000000..415dc1a7b3d19994062d67ef14570a9b953027b0 --- /dev/null +++ b/examples/community/cryptosfl/config/generate_config.py @@ -0,0 +1,27 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate the config of a discrete logarithm table""" +from crypto.utils import generate_config_files + + +def test_generate_config_files(): + sec_param_config_file = './sec_param.json' + dlog_table_config_file = './dlog_b8.json' + func_value_bound = 100000000 + sec_param = 256 + generate_config_files(sec_param, sec_param_config_file, dlog_table_config_file, func_value_bound) + + +if __name__ == "__main__": + test_generate_config_files() diff --git a/examples/community/cryptosfl/crypto/__init__.py b/examples/community/cryptosfl/crypto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed17eb616f704841eb8c7bedccf4442b5c2f9b00 --- /dev/null +++ b/examples/community/cryptosfl/crypto/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/examples/community/cryptosfl/crypto/sife_dynamic.py b/examples/community/cryptosfl/crypto/sife_dynamic.py new file mode 100644 index 0000000000000000000000000000000000000000..1914a989bda9102ebfb5f2736df6a13c4c66367f --- /dev/null +++ b/examples/community/cryptosfl/crypto/sife_dynamic.py @@ -0,0 +1,165 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""underlying implementation of single-input functional encryption""" +import logging +import math + +import gmpy2 as gp +from crypto.utils import _param_generator +from crypto.utils import _random +from crypto.utils import _random_generator +from crypto.utils import load_sec_param_config + +logger = logging.getLogger(__name__) + + +class SIFEDynamicTPA: + """Class for Single-Input Functional Encryption with dynamic third-party authority.""" + + def __init__(self, eta, sec_param=256, sec_param_config=None): + self.eta = eta + if sec_param_config is not None: + self.p, self.q, self.r, self.g, self.sec_param = load_sec_param_config(sec_param_config) + else: + self.p, self.q, self.r = _param_generator(sec_param) + self.g = _random_generator(sec_param, self.p, self.r) + self.sec_param = sec_param + + def setup(self): + self.msk = [_random(self.p, self.sec_param) for i in range(self.eta)] + pk = [gp.powmod(self.g, self.msk[i], self.p) for i in range(self.eta)] + self.mpk = {'p': self.p, 'g': self.g, 'pk': pk} + + def generate_common_public_key(self): + pk = dict() + pk['g'] = gp.digits(self.mpk['g']) + pk['p'] = gp.digits(self.mpk['p']) + return pk + + def generate_public_key(self, vec_size): + assert vec_size <= self.eta + pk = dict() + pk['bound'] = vec_size + pk['g'] = gp.digits(self.mpk['g']) + pk['p'] = gp.digits(self.mpk['p']) + pk['pk'] = list() + for i in range(vec_size): + pk['pk'].append(gp.digits(self.mpk['pk'][i])) + return pk + + def generate_private_key(self, vec): + assert len(vec) <= self.eta + + sk = gp.mpz(0) + for i in range(len(vec)): + sk = gp.add(sk, gp.mul(self.msk[i], vec[i])) + return {'bound': len(vec), 'sk': gp.digits(sk)} + + +class SIFEDynamicClient: + """Class for Single-Input Functional Encryption with dynamic client functionality.""" + + def __init__(self, sec_param=256, role='dec', dlog=None): + if role not in ('dec', 'both'): + if dlog is not None: + self.dlog_table = dlog['dlog_table'] + self.func_bound = dlog['func_bound'] + else: + self.sec_param = sec_param + self.dlog_table = None + self.func_bound = None + elif role == 'enc': + self.sec_param = sec_param + + def encrypt(self, pk, vec): + """Encrypt the plaintext using the public key.""" + assert len(vec) == pk['bound'] + + p = gp.mpz(pk['p']) + g = gp.mpz(pk['g']) + + r = _random(p, self.sec_param) + ct0 = gp.digits(gp.powmod(g, r, p)) + ct_list = [] + for i in range(len(vec)): + ct_list.append(gp.digits( + gp.mul( + gp.powmod(gp.mpz(pk['pk'][i]), r, p), + gp.powmod(g, gp.mpz(int(vec[i])), p) + ) + )) + return {'ct0': ct0, 'ct_list': ct_list} + + def decrypt(self, pk, sk, vec, ct, max_innerprod): + """Decrypt the ciphertext using the secret key.""" + p = gp.mpz(pk['p']) + g = gp.mpz(pk['g']) + + res = gp.mpz(1) + for i in range(len(vec)): + res = gp.mul( + res, + gp.powmod(gp.mpz(ct['ct_list'][i]), gp.mpz(vec[i]), p) + ) + res = gp.t_mod(res, p) + g_f = gp.divm(res, gp.powmod(gp.mpz(ct['ct0']), gp.mpz(sk['sk']), p), p) + + f = self._solve_dlog(p, g, g_f, max_innerprod) + + return f + + def _solve_dlog(self, p, g, h, dlog_max): + """ + Attempts to solve for the discrete log x, where g^x = h mod p via + hash table. + """ + if self.dlog_table is not None and str(h) in self.dlog_table: + return self.dlog_table[str(h)] + logger.warning("did not find f in dlog table, may cost more time to compute") + return self._solve_dlog_naive(p, g, h, dlog_max) + + def _solve_dlog_naive(self, p, g, h, dlog_max): + """ + Attempts to solve for the discrete log x, where g^x = h, via + trial and error. Assumes that x is at most dlog_max. + """ + res = None + for j in range(dlog_max): + if gp.powmod(g, j, p) == gp.mpz(h): + res = j + break + if res is None: + h = gp.invert(h, p) + for i in range(dlog_max): + if gp.powmod(g, i, p) == gp.mpz(h): + res = -i + return res + + def _solve_dlog_bsgs(self, g, h, p): + """ + Attempts to solve for the discrete log x, where g^x = h mod p, + via the Baby-Step Giant-Step algorithm. + """ + m = math.ceil(math.sqrt(p - 1)) # phi(p) is p-1, if p is prime + # store hashmap of g^{1,2,...,m}(mod p) + hash_table = {pow(g, i, p): i for i in range(m)} + # precompute via Fermat's Little Theorem + c = pow(g, m * (p - 2), p) + # search for an equivalence in the table. Giant Step. + for j in range(m): + y = (h * pow(c, j, p)) % p + if y in hash_table: + return j * m + hash_table[y] + + return None \ No newline at end of file diff --git a/examples/community/cryptosfl/crypto/utils.py b/examples/community/cryptosfl/crypto/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..444200bbdd09e34eddfef263a6d5fffa1439d330 --- /dev/null +++ b/examples/community/cryptosfl/crypto/utils.py @@ -0,0 +1,139 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""functions for the implementation in single-input functional encryption""" +import json +import logging +import random + +import gmpy2 as gp + +logger = logging.getLogger(__name__) + + +def _random(maximum, bits): + rand_function = random.SystemRandom() + r = gp.mpz(rand_function.getrandbits(bits)) + while r >= maximum: + r = gp.mpz(rand_function.getrandbits(bits)) + return r + + +def _random_generator(bits, p, r): + while True: + h = _random(p, bits) + g = gp.powmod(h, r, p) + if not g == 1: + break + return g + + +def _random_prime(bits): + rand_function = random.SystemRandom() + r = gp.mpz(rand_function.getrandbits(bits)) + r = gp.bit_set(r, bits - 1) + return gp.next_prime(r) + + +def _param_generator(bits, r=2): + while True: + p = _random_prime(bits) + q = (p - 1) // 2 + if gp.is_prime(p) and gp.is_prime(q): + break + return p, q, r + + +def generate_config_files(sec_param, sec_param_config, dlog_table_config, func_bound): + """ + Generate configuration files for secure parameters and discrete log tables. + + Args: + sec_param (int): Security parameter. + sec_param_config (str): Path to save the security parameter configuration. + dlog_table_config (str): Path to save the discrete log table configuration. + func_bound (int): Function bound for discrete log table. + """ + p, q, r = _param_generator(sec_param) + g = _random_generator(sec_param, p, r) + group_info = { + 'p': gp.digits(p), + 'q': gp.digits(q), + 'r': gp.digits(r) + } + sec_param_dict = {'g': gp.digits(g), 'sec_param': sec_param, 'group': group_info} + + with open(sec_param_config, 'w') as outfile: + json.dump(sec_param_dict, outfile) + + dlog_table = dict() + bound = func_bound + 1 + for i in range(bound): + dlog_table[gp.digits(gp.powmod(g, i, p))] = i + for i in range(-1, -bound, -1): + dlog_table[gp.digits(gp.powmod(g, i, p))] = i + + dlog_table_dict = { + 'g': gp.digits(g), + 'func_bound': func_bound, + 'dlog_table': dlog_table + } + + with open(dlog_table_config, 'w') as outfile: + json.dump(dlog_table_dict, outfile) + + +def load_sec_param_config(sec_param_config_file): + """ + Load security parameter configuration from a file. + + Args: + sec_param_config_file (str): Path to the security parameter configuration file. + + Returns: + tuple: Contains p, q, r, g, and sec_param loaded from the configuration file. + """ + with open(sec_param_config_file, 'r') as infile: + sec_param_dict = json.load(infile) + + p = gp.mpz(sec_param_dict['group']['p']) + q = gp.mpz(sec_param_dict['group']['q']) + r = gp.mpz(sec_param_dict['group']['r']) + g = gp.mpz(sec_param_dict['g']) + sec_param = sec_param_dict['sec_param'] + + return p, q, r, g, sec_param + + +def load_dlog_table_config(dlog_table_config_file): + """ + Load discrete log table configuration from a file. + + Args: + dlog_table_config_file (str): Path to the discrete log table configuration file. + + Returns: + dict: Contains dlog_table, func_bound, and g loaded from the configuration file. + """ + with open(dlog_table_config_file, 'r') as infile: + store_dict = json.load(infile) + + dlog_table = store_dict['dlog_table'] + func_bound = store_dict['func_bound'] + g = gp.mpz(store_dict['g']) + + return { + 'dlog_table': dlog_table, + 'func_bound': func_bound, + 'g': g + } diff --git a/examples/community/cryptosfl/datasets/generate_data.py b/examples/community/cryptosfl/datasets/generate_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1a1d9aaa8fbee93fc445f1fd243ceb172dfc6d --- /dev/null +++ b/examples/community/cryptosfl/datasets/generate_data.py @@ -0,0 +1,19 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate data""" +from download import download + +url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \ + "notebook/datasets/MNIST_Data.zip" +path = download(url, "./", kind="zip", replace=True) diff --git a/examples/community/cryptosfl/example_cryptosfl.py b/examples/community/cryptosfl/example_cryptosfl.py new file mode 100644 index 0000000000000000000000000000000000000000..07ff43018b58ce5ea7cbba0332e2ff7eddfbf9c2 --- /dev/null +++ b/examples/community/cryptosfl/example_cryptosfl.py @@ -0,0 +1,198 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""an example of CryptoSFL using MLP as the model""" +import copy + +import mindspore +import mindspore.dataset as ds +import mindspore.nn as mspnn +import numpy as np +from mindspore import ParameterTuple, dtype as mstype, ops +from mindspore.dataset import vision, transforms + +from crypto.sife_dynamic import SIFEDynamicClient +from crypto.sife_dynamic import SIFEDynamicTPA +from crypto.utils import load_dlog_table_config +from nn.smc import Secure2PCClient +from nn.smc import Secure2PCServer +from nn.utils import initiate_server_model +from nn.worker import CryptoClient +from nn.worker import CryptoServer + + +def dataset_distribute(num_users): + """Distribute dataset among users.""" + num_items = int(60000 / num_users) + dict_users, all_indices = {}, [i for i in range(60000)] + for i in range(num_users): + dict_users[i] = set(np.random.choice(all_indices, num_items, replace=False)) + all_indices = list(set(all_indices) - dict_users[i]) + return dict_users + + +def datapipe(path, batch_size, sampler): + """Prepare data pipeline for training.""" + image_transforms = [ + vision.Rescale(1.0 / 255.0, 0), + vision.Normalize(mean=(0.1307,), std=(0.3081,)), + vision.HWC2CHW() + ] + label_transform = transforms.TypeCast(mindspore.int32) + + dataset = ds.MnistDataset(path, sampler=sampler) + dataset = dataset.map(image_transforms, 'image') + dataset = dataset.map(label_transform, 'label') + dataset = dataset.batch(batch_size) + return dataset + + +def get_dataloaders(num_users, batch_size): + """Get dataloaders for training and testing.""" + dict_users = dataset_distribute(num_users) + train_loader_dict = {} + for i in range(num_users): + train_loader_dict[i] = datapipe(path='./datasets/MNIST_Data/train', batch_size=batch_size, + sampler=list(dict_users[i])) + test_loader = datapipe(path='./datasets/MNIST_Data/test', batch_size=batch_size, sampler=list(range(10000))) + return train_loader_dict, test_loader + + +def aggregate_weights(w_locals_client, w_locals_server): + """Aggregate weights from multiple clients.""" + num_users = len(w_locals_client) + # server weights aggregation + w_avg_server = copy.deepcopy(w_locals_server[0]) + for k in range(len(w_avg_server)): + for i in range(1, len(w_locals_server)): + w_avg_server[k] += w_locals_server[i][k] + w_avg_server[k] = np.divide(w_avg_server[k], len(w_locals_server)) + + # client weights aggregation + w_avg_client = copy.deepcopy(w_locals_client[0]) + for k in w_avg_client.keys(): + for i in range(1, len(w_locals_client)): + w_avg_client[k] += w_locals_client[i][k] + w_avg_client[k] = ops.div(w_avg_client[k], len(w_locals_client)) + + for key, value in w_avg_client.items(): + w_avg_client[key] = mindspore.Parameter(value.asnumpy()) + + return [copy.deepcopy(w_avg_client) for _ in range(num_users)], \ + [copy.deepcopy(w_avg_server) for _ in range(num_users)] + + +def initiate_model_list(lr, model, precision_data=3, precision_weight=3): + """Initialize model list for clients and server.""" + # initiate the crypto system + sec_param_config_file = './config/sec_param.json' + dlog_table_config_file = './config/dlog_b8.json' + eta = 1250 + sec_param = 256 + dlog = load_dlog_table_config(dlog_table_config_file) + sife_tpa = SIFEDynamicTPA(eta, sec_param=sec_param, sec_param_config=sec_param_config_file) + sife_tpa.setup() + sife_enc_client = SIFEDynamicClient(role='enc') + sife_dec_client = SIFEDynamicClient(role='dec', dlog=dlog) + + secure2pc_client = Secure2PCClient(sife=(sife_tpa, sife_enc_client), precision=precision_data) + secure2pc_server = Secure2PCServer(sife=(sife_tpa, sife_dec_client), precision=(precision_data, precision_weight)) + + n_features, hidden_layers = initiate_server_model(model) + + client = CryptoClient(smc=secure2pc_client, model=model) + server = CryptoServer(n_features=n_features, hidden_layers=hidden_layers, eta=lr, smc=secure2pc_server) + return client, server + + +class GradNet(mspnn.Cell): + """Gradient network for training.""" + + def __init__(self, net, grad_wrt_output: np.ndarray): + super(GradNet, self).__init__() + self.net = net + self.params = ParameterTuple(net.trainable_params()) + self.grad_op = ops.GradOperation(get_by_list=True, sens_param=True) + + # grads of the smashed data + self.grad_wrt_output = mindspore.Tensor(grad_wrt_output, dtype=mstype.float32) + + def construct(self, x): + gradient_function = self.grad_op(self.net, self.params) + return gradient_function(x, self.grad_wrt_output) + + +def train_and_test(epochs, num_users, batch_size, model, lr): + """Train and test the model.""" + # set client & server model + client, server = initiate_model_list(lr, model) + + w_locals_client = [copy.deepcopy(client.parameters_dict()) for _ in range(num_users)] + w_locals_server = [copy.deepcopy(server.w) for _ in range(num_users)] + + # load data + train_loader_dict, test_loader = get_dataloaders(num_users, batch_size) + + for epoch in range(epochs): + # training + for idx in range(num_users): + if epoch != 0: + _, _ = mindspore.load_param_into_net(client, w_locals_client[idx]) + + server.w = copy.deepcopy(w_locals_server[idx]) + optimizer = mspnn.SGD(client.trainable_params(), learning_rate=lr) + for _, (inputs, targets) in enumerate(train_loader_dict[idx].create_tuple_iterator()): + # client forward + plain_intermediate = client(inputs) + ct_feedforward, ct_backpropagation, y_onehot = client.encrypt( + intermediate=plain_intermediate.asnumpy(), y=targets.asnumpy()) + # server forward + z, a = server.feedforward_secure(ct_batch=ct_feedforward, w=copy.deepcopy(server.w)) + # server backward + grad, d_intermediate = server.get_gradient_secure(ct_batch=ct_backpropagation, y_encode=y_onehot, + a=copy.deepcopy(a), z=copy.deepcopy(z), + w=copy.deepcopy(server.w)) + delta_w = [server.eta * grad[i] for i in range(len(server.w))] + for i in range(len(server.w)): + server.w[i] -= delta_w[i] + # client backward + client_grads = GradNet(net=client, grad_wrt_output=d_intermediate)(inputs) + optimizer(client_grads) + + w_locals_client[idx] = copy.deepcopy(client.parameters_dict()) + w_locals_server[idx] = copy.deepcopy(server.w) + + # aggregate weights after each global epoch + w_locals_client, w_locals_server = aggregate_weights(w_locals_client, w_locals_server) + + # testing + _, _ = mindspore.load_param_into_net(client, w_locals_client[0]) + server.w = copy.deepcopy(w_locals_server[0]) + correct, total = 0, 0 + + for inputs, targets in test_loader.create_tuple_iterator(): + plain_intermediate = client(inputs) + pred = server.predict(plain_intermediate) + correct += np.sum(targets.asnumpy() == pred, axis=0) + total += len(targets) + test_accuracy = correct / total + print('epoch {}, test accuracy = {:.2f}%'.format(epoch + 1, 100 * test_accuracy)) + + +if __name__ == "__main__": + rounds = 20 + user_counts = 5 + bs = 50 + client_net = 'mlp64' + learning_rate = 1e-3 + train_and_test(rounds, user_counts, bs, client_net, learning_rate) diff --git a/examples/community/cryptosfl/nn/__init__.py b/examples/community/cryptosfl/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed17eb616f704841eb8c7bedccf4442b5c2f9b00 --- /dev/null +++ b/examples/community/cryptosfl/nn/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/examples/community/cryptosfl/nn/smc.py b/examples/community/cryptosfl/nn/smc.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8ceca2f6d350bd0a76ac537fcb5f2378c0f179 --- /dev/null +++ b/examples/community/cryptosfl/nn/smc.py @@ -0,0 +1,81 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""implementation of the crypto system within client and server""" +import logging + +import numpy as np + +logger = logging.getLogger(__name__) + + +class Secure2PCClient: + """Client for secure two-party computation using functional encryption.""" + + def __init__(self, sife, precision): + self.sife_tpa, self.sife_client = sife + self.precision = precision + + def execute(self, data_array): + data_list = (data_array * pow(10, self.precision)).astype(int).flatten().tolist() + pk = self.sife_tpa.generate_public_key(len(data_list)) + ct_data = self.sife_client.encrypt(pk, data_list) + return ct_data + + def execute_ndarray(self, data_ndarray): + assert isinstance(data_ndarray, np.ndarray), 'input data should be in numpy array format' + assert len(data_ndarray.shape) == 2, 'at present, only address 2d array' + + ct_list = [self.execute(data_ndarray[i, :]) for i in range(data_ndarray.shape[0])] + return ct_list + + +class Secure2PCServer: + """Server for secure two-party computation using functional encryption.""" + + def __init__(self, sife, precision): + self.sife_tpa, self.sife_client = sife + self.precision_client, self.precision_server = precision + self.common_pk = self.sife_tpa.generate_common_public_key() + + def request_key(self, data_array): + data_list = (data_array * pow(10, self.precision_server)).astype(int).flatten().tolist() + sk = self.sife_tpa.generate_private_key(data_list) + return sk + + def execute(self, sk, ct, data_array): + data_list = (data_array * pow(10, self.precision_server)).astype(int).flatten().tolist() + max_inner_prod = 2100000000 # max_value * max_value * self.vec_len + dec_prod = self.sife_client.decrypt(self.common_pk, sk, data_list, ct, max_inner_prod) + if dec_prod is None: + logger.debug('find a bad case - decryption: ') + assert False + return float(dec_prod) / pow(10, self.precision_server) / pow(10, self.precision_client) + + def request_key_ndarray(self, data_ndarray): + assert isinstance(data_ndarray, np.ndarray), 'input weight should be a numpy array' + assert len(data_ndarray.shape) == 2, 'only address 2d array' + + sk_list = [self.request_key(data_ndarray[i, :]) for i in range(data_ndarray.shape[0])] + return sk_list + + def execute_ndarray(self, sk_list, ct_list, data_ndarray): + assert isinstance(data_ndarray, np.ndarray), 'input weight should be a numpy array' + assert len(data_ndarray.shape) == 2, 'only address 2d array' + assert len(sk_list) == data_ndarray.shape[0] + + res = np.zeros((data_ndarray.shape[0], len(ct_list))) + for i in range(data_ndarray.shape[0]): + for j in range(len(ct_list)): + res[i][j] = self.execute(sk_list[i], ct_list[j], data_ndarray[i, :]) + return res diff --git a/examples/community/cryptosfl/nn/utils.py b/examples/community/cryptosfl/nn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c498ab7d285cd63830bb40a4535cd1851da3bf --- /dev/null +++ b/examples/community/cryptosfl/nn/utils.py @@ -0,0 +1,60 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""function of initiation in client model and server model""" +import mindspore.nn as mspnn + + +def initiate_client_model(model): + """ + Initialize the client model based on the given model name. + + Args: + model (str): The name of the model to initialize. + + Returns: + mindspore.nn.Cell: The initialized client model. + + Raises: + NameError: If the model name is not recognized. + """ + if model == 'mlp64': + net = mspnn.SequentialCell( + mspnn.Flatten(), + mspnn.Dense(784, 128), + mspnn.ReLU(), + mspnn.Dense(128, 64), + mspnn.ReLU(), + ) + return net + raise NameError('choose model from mlp64') + + +def initiate_server_model(model): + """ + Initialize the server model based on the given model name. + + Args: + model (str): The name of the model to initialize. + + Returns: + tuple: The number of features and a list of hidden layers. + + Raises: + NameError: If the model name is not recognized. + """ + if model == 'mlp64': + n_features = 64 + hidden_layers = [32, 16] + return n_features, hidden_layers + raise NameError('choose model!') diff --git a/examples/community/cryptosfl/nn/worker.py b/examples/community/cryptosfl/nn/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..9147b3f7145a8cef73045a92abc33700141180c0 --- /dev/null +++ b/examples/community/cryptosfl/nn/worker.py @@ -0,0 +1,306 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""implementation of CryptoClient and CryptoServer""" +import copy + +import mindspore.nn as mspnn +import numpy as np +from nn.smc import Secure2PCClient +from nn.smc import Secure2PCServer +from nn.utils import initiate_client_model + + +class CryptoClient(mspnn.Cell): + """Client for cryptographic operations in a secure federated learning system.""" + def __init__(self, model='mlp64', n_output=10, smc=None): + super(CryptoClient, self).__init__() + self.model = model + self.n_output = n_output + self.smc = smc + self.encoder = initiate_client_model(self.model) + + @staticmethod + def _add_bias_unit(x, how='column'): + """ + Add bias unit (column or row of 1s) to the data array. + + Args: + x (ndarray): Data array. + how (str): 'column' or 'row', determines where to add the bias unit. + + Returns: + ndarray: Data array with bias unit added. + """ + if how == 'column': + x_new = np.ones((x.shape[0], x.shape[1] + 1)) + x_new[:, 1:] = x + elif how == 'row': + x_new = np.ones((x.shape[0] + 1, x.shape[1])) + x_new[1:, :] = x + else: + raise AttributeError('`how` must be `column` or `row`') + return x_new + + @staticmethod + def _encode_labels(y, k): + """ + Encode labels into a one-hot representation. + + Args: + y (ndarray): Labels array. + k (int): Number of classes. + + Returns: + ndarray: One-hot encoded labels. + """ + onehot = np.zeros((k, y.shape[0])) + for idx, val in enumerate(y): + onehot[val, idx] = 1.0 + return onehot + + def construct(self, x): + return self.encoder(x) + + def encrypt(self, intermediate, y): + """ + Encrypt the data using secure multiparty computation. + + Args: + x (ndarray): Data array to encrypt. + + Returns: + Encrypted data. + """ + intermediate_data = copy.deepcopy(intermediate) + y_data = copy.deepcopy(y) + + y_onehot = self._encode_labels(y_data, self.n_output) + intermediate_data = self._add_bias_unit(intermediate_data, how='column') + + if self.smc and isinstance(self.smc, Secure2PCClient): + ct_feedforward = np.array(self.smc.execute_ndarray(intermediate_data)) + ct_backpropagation = np.array(self.smc.execute_ndarray(intermediate_data.T)) + return ct_feedforward, ct_backpropagation, y_onehot + return intermediate_data, y_onehot + + +class CryptoServer: + """Server for cryptographic operations in a secure federated learning system.""" + + def __init__(self, n_features, hidden_layers, n_output=10, + l1=0.0, l2=0.0, eta=0.001, + smc=None, precision=None): + self.n_output = n_output + self.n_features = n_features + self.hidden_layers = hidden_layers + self.w = self._initialize_weights() + self.l1 = l1 + self.l2 = l2 + self.eta = eta + self.smc = smc + self.precision = precision + + def _initialize_weights(self): + self.layers = [self.n_features] + self.hidden_layers + [self.n_output] + w = [self._xavier_uniform((self.layers[i + 1], self.layers[i] + 1)) + for i in range(len(self.layers) - 1)] + return w + + @staticmethod + def _xavier_uniform(shape): + fan_in, fan_out = shape[0], shape[1] + limit = np.sqrt(6.0 / (fan_in + fan_out)) + weights = np.random.uniform(low=-limit, high=limit, size=shape) + return weights + + @staticmethod + def _softmax(x): + max_vals = np.max(x, axis=0, keepdims=True) + e_x = np.exp(x - max_vals) + return e_x / e_x.sum(axis=0, keepdims=True) + + @staticmethod + def _relu(z): + return np.maximum(0, z) + + @staticmethod + def _relu_gradient(z): + return z > 0 + + @staticmethod + def _add_bias_unit(x, how='column'): + """ + Add bias unit (column or row of 1s) to the data array. + + Args: + x (ndarray): Data array. + how (str): 'column' or 'row', determines where to add the bias unit. + + Returns: + ndarray: Data array with bias unit added. + """ + if how == 'column': + x_new = np.ones((x.shape[0], x.shape[1] + 1)) + x_new[:, 1:] = x + elif how == 'row': + x_new = np.ones((x.shape[0] + 1, x.shape[1])) + x_new[1:, :] = x + else: + raise AttributeError('`how` must be `column` or `row`') + return x_new + + def feedforward(self, x, w): + """ + Perform feedforward computation. + + Args: + x (ndarray): Input data array. + + Returns: + tuple: Activations and linear transformations. + """ + z = [None for _ in range(len(w))] + a = [None for _ in range(len(w))] + if self.precision: + z[0] = (w[0] * pow(10, self.precision)).astype(int).dot(x.T) / pow(10, self.precision) + else: + z[0] = w[0].dot(x.T) + a[0] = self._add_bias_unit(self._relu(z[0]), how='row') + for i in range(1, len(w)): + z[i] = w[i].dot(a[i - 1]) + if i != len(w) - 1: + a[i] = self._add_bias_unit(self._relu(z[i]), how='row') + else: + a[i] = self._softmax(z[i]) + return z, a + + def feedforward_secure(self, ct_batch, w): + """ + Perform secure feedforward computation. + + Args: + ct_batch: Encrypted batch of data. + w: Weights for computation. + + Returns: + tuple: Secure activations and linear transformations. + """ + z = [None for _ in range(len(w))] + a = [None for _ in range(len(w))] + + if isinstance(self.smc, Secure2PCServer): + sk_w0 = self.smc.request_key_ndarray(w[0]) + z[0] = self.smc.execute_ndarray(sk_w0, ct_batch.tolist(), w[0]) + + a[0] = self._add_bias_unit(self._relu(z[0]), how='row') + + for i in range(1, len(w)): + z[i] = w[i].dot(a[i - 1]) + if i != len(w) - 1: + a[i] = self._add_bias_unit(self._relu(z[i]), how='row') + else: + a[i] = self._softmax(z[i]) + return z, a + + def get_gradient(self, x, y_encode, a, z, w): + """ + Compute gradient for backpropagation. + + Args: + y_encode: Encoded labels. + a: Activations. + z: Linear transformations. + w: Weights. + + Returns: + tuple: Gradients and delta inputs. + """ + sigma = [None for i in range(len(w))] + grad = [None for i in range(len(w))] + sigma[-1] = a[-1] - y_encode + for i in range(len(w) - 2, -1, -1): + sigma[i] = w[i + 1].T.dot(sigma[i + 1]) * self._relu_gradient(self._add_bias_unit(z[i], how='row')) + sigma[i] = sigma[i][1:, :] + if self.precision: + grad[0] = (sigma[0] * pow(10, self.precision)).astype(int).dot(x) / pow(10, self.precision) + else: + grad[0] = sigma[0].dot(x) + + for i in range(1, len(w)): + grad[i] = sigma[i].dot(a[i - 1].T) + + for i in range(len(w)): + grad[i][:, 1:] += self.l2 * w[i][:, 1:] + grad[i][:, 1:] += self.l1 * np.sign(w[i][:, 1:]) + + d_inputs = w[0].T.dot(sigma[0]) + d_inputs = d_inputs[1:, :].T + + return grad, d_inputs + + def get_gradient_secure(self, ct_batch, y_encode, a, z, w): + """ + Compute secure gradient for backpropagation. + + Args: + ct_batch: Encrypted batch of data. + y_encode: Encoded labels. + a: Activations. + z: Linear transformations. + w: Weights. + + Returns: + tuple: Secure gradients and delta inputs. + """ + sigma = [None for _ in range(len(w))] + grad = [None for _ in range(len(w))] + sigma[-1] = a[-1] - y_encode + for i in range(len(w) - 2, -1, -1): + sigma[i] = w[i + 1].T.dot(sigma[i + 1]) * self._relu_gradient(self._add_bias_unit(z[i], how='row')) + sigma[i] = sigma[i][1:, :] + + if isinstance(self.smc, Secure2PCServer): + sk_sigma0 = self.smc.request_key_ndarray(sigma[0]) + grad[0] = self.smc.execute_ndarray(sk_sigma0, ct_batch.tolist(), sigma[0]) + + for i in range(1, len(w)): + grad[i] = sigma[i].dot(a[i - 1].T) + + for i in range(len(w)): + grad[i][:, 1:] += self.l2 * w[i][:, 1:] + grad[i][:, 1:] += self.l1 * np.sign(w[i][:, 1:]) + + d_inputs = w[0].T.dot(sigma[0]) + d_inputs = d_inputs[1:, :].T + + return grad, d_inputs + + def predict(self, x): + """ + Predict class labels for samples in x. + + Args: + x (ndarray): Data array. + + Returns: + ndarray: Predicted class labels. + """ + if len(x.shape) != 2: + raise AttributeError('X must be a [n_samples, n_features] array.\n' + 'Use X[:,None] for 1-feature classification,' + '\nor X[[i]] for 1-sample classification') + x = self._add_bias_unit(x, how='column') + z, _ = self.feedforward(x, self.w) + y_pred = np.argmax(z[len(z) - 1], axis=0) + return y_pred