From 55fe8e19bd2b647d07e2e02034e0a850e367288d Mon Sep 17 00:00:00 2001 From: fandawei Date: Tue, 9 May 2023 16:44:07 +0800 Subject: [PATCH] add ts.save on r2.0 --- troubleshooter/tests/common/test_save.py | 57 ++++---------------- troubleshooter/troubleshooter/common/util.py | 36 +++---------- 2 files changed, 18 insertions(+), 75 deletions(-) diff --git a/troubleshooter/tests/common/test_save.py b/troubleshooter/tests/common/test_save.py index 8dc74c2..dc366ef 100644 --- a/troubleshooter/tests/common/test_save.py +++ b/troubleshooter/tests/common/test_save.py @@ -33,10 +33,6 @@ def test_ms_save(mode): ts.save.cnt.set_data(Tensor(0, ms.int32)) x1 = Tensor(-0.5962, ms.float32) x2 = Tensor(0.4985, ms.float32) - single_input = x1 - list_input = [x1, x2] - tuple_input = (x2, x1) - dict_input = {"x1": x1, "x2": x2} net = NetWorkSave('/tmp/save/numpy', True, "ms") try: @@ -45,28 +41,13 @@ def test_ms_save(mode): pass os.makedirs("/tmp/save/") - single_output = net(single_input) - list_output = net(list_input) - tuple_output = net(tuple_input) - dict_output = net(dict_input) + x1 = net(x1) + x2 = net(x2) time.sleep(1) assert np.allclose(np.load("/tmp/save/0_numpy_ms.npy"), - single_output.asnumpy()) - - assert np.allclose(np.load("/tmp/save/1_numpy_0_ms.npy"), - list_output[0].asnumpy()) - assert np.allclose(np.load("/tmp/save/1_numpy_1_ms.npy"), - list_output[1].asnumpy()) - - assert np.allclose(np.load("/tmp/save/2_numpy_0_ms.npy"), - tuple_output[0].asnumpy()) - assert np.allclose(np.load("/tmp/save/2_numpy_1_ms.npy"), - tuple_output[1].asnumpy()) - - assert np.allclose(np.load("/tmp/save/3_numpy_x1_ms.npy"), - dict_output["x1"].asnumpy()) - assert np.allclose(np.load("/tmp/save/3_numpy_x2_ms.npy"), - dict_output["x2"].asnumpy()) + x1.asnumpy()) + assert np.allclose(np.load("/tmp/save/1_numpy_ms.npy"), + x2.asnumpy()) @pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) @@ -80,10 +61,6 @@ def test_torch_save(mode): ts.save.cnt.set_data(Tensor(0, ms.int32)) x1 = torch.tensor(-0.5962, dtype=torch.float32) x2 = torch.tensor(0.4985, dtype=torch.float32) - single_input = x1 - list_input = [x1, x2] - tuple_input = (x2, x1) - dict_input = {"x1": x1, "x2": x2} file = '/tmp/save/numpy' try: shutil.rmtree("/tmp/save/") @@ -91,26 +68,12 @@ def test_torch_save(mode): pass os.makedirs("/tmp/save/") - ts.save(file, single_input, True, "torch") - ts.save(file, list_input, True, "torch") - ts.save(file, tuple_input, True, "torch") - ts.save(file, dict_input, True, "torch") + ts.save(file, x1, True, "torch") + ts.save(file, x2, True, "torch") time.sleep(1) assert np.allclose(np.load("/tmp/save/0_numpy_torch.npy"), - single_input.cpu().detach().numpy()) - - assert np.allclose(np.load("/tmp/save/1_numpy_0_torch.npy"), - list_input[0].cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/1_numpy_1_torch.npy"), - list_input[1].cpu().detach().numpy()) - - assert np.allclose(np.load("/tmp/save/2_numpy_0_torch.npy"), - tuple_input[0].cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/2_numpy_1_torch.npy"), - tuple_input[1].cpu().detach().numpy()) + x1.cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/3_numpy_x1_torch.npy"), - dict_input["x1"].cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/3_numpy_x2_torch.npy"), - dict_input["x2"].cpu().detach().numpy()) + assert np.allclose(np.load("/tmp/save/1_numpy_torch.npy"), + x2.cpu().detach().numpy()) diff --git a/troubleshooter/troubleshooter/common/util.py b/troubleshooter/troubleshooter/common/util.py index 4b7ccd9..2a1a87b 100644 --- a/troubleshooter/troubleshooter/common/util.py +++ b/troubleshooter/troubleshooter/common/util.py @@ -140,16 +140,13 @@ class SaveNet(nn.Cell): Inputs: file (str): The name of the file to be stored. - data (Union(Tensor, list[Tensor], Tuple[Tensor], dict[str, Tensor])): Supports data types of Tensor, - list[Tensor], tuple(Tensor), and dict[str, Tensor] for both MindSpore and PyTorch. When the input is - a list or tuple of Tensor, the file name will be numbered according to the index of the Tensor. - When the input is a dictionary of Tensor, the corresponding key will be added to the file name. + data (Union(Tensor)): Supports data types of Tensor for both MindSpore and PyTorch. auto_id (bool): Whether to enable automatic numbering. If set to True, an incremental number will be added before the saved file name. If set to False, no numbering will be added to the file name. suffix (str): The suffix of the saved file name. Outputs: - The output storage name is '{id}_name_{idx/key}_{suffix}.npy'. + The output storage name is '{id}_name_{suffix}.npy'. """ def __init__(self): @@ -183,31 +180,14 @@ class SaveNet(nn.Cell): def construct(self, file, data, auto_id, suffix): path, name = self.handle_path(file) - if isinstance(data, (list, tuple)): - for idx, val in enumerate(data): - if auto_id: - np.save(f"{path}{int(self.cnt)}_{name}_{idx}_{suffix}" if suffix else - f"{path}{int(self.cnt)}_{name}_{idx}", self.numpy(val)) - else: - np.save(f"{file}_{idx}_{suffix}" if suffix else - f"{file}_{idx}", self.numpy(val)) - elif isinstance(data, (dict, OrderedDict)): - for key, val in data.items(): - if auto_id: - np.save(f"{path}{int(self.cnt)}_{name}_{key}_{suffix}" if suffix else - f"{path}{int(self.cnt)}_{name}_{key}", self.numpy(val)) - else: - np.save(f"{file}_{key}_{suffix}" if suffix else - f"{file}_{key}", self.numpy(val)) + if auto_id: + np.save(f"{path}{int(self.cnt)}_{name}_{suffix}" if suffix else + f"{path}{int(self.cnt)}_{name}", self.numpy(data)) else: - if auto_id: - np.save(f"{path}{int(self.cnt)}_{name}_{suffix}" if suffix else - f"{path}{int(self.cnt)}_{name}", self.numpy(data)) - else: - np.save(f"{file}_{suffix}" if suffix else file, - self.numpy(data)) + np.save(f"{file}_{suffix}" if suffix else file, + self.numpy(data)) if auto_id: self.cnt += 1 - + return save = SaveNet() -- Gitee