diff --git a/troubleshooter/tests/common/test_save.py b/troubleshooter/tests/common/test_save.py index 8dc74c2429cd6e13c0064e781f1fdd970580e6f3..dc366efb84aef7566a907716c549095c4639ea42 100644 --- a/troubleshooter/tests/common/test_save.py +++ b/troubleshooter/tests/common/test_save.py @@ -33,10 +33,6 @@ def test_ms_save(mode): ts.save.cnt.set_data(Tensor(0, ms.int32)) x1 = Tensor(-0.5962, ms.float32) x2 = Tensor(0.4985, ms.float32) - single_input = x1 - list_input = [x1, x2] - tuple_input = (x2, x1) - dict_input = {"x1": x1, "x2": x2} net = NetWorkSave('/tmp/save/numpy', True, "ms") try: @@ -45,28 +41,13 @@ def test_ms_save(mode): pass os.makedirs("/tmp/save/") - single_output = net(single_input) - list_output = net(list_input) - tuple_output = net(tuple_input) - dict_output = net(dict_input) + x1 = net(x1) + x2 = net(x2) time.sleep(1) assert np.allclose(np.load("/tmp/save/0_numpy_ms.npy"), - single_output.asnumpy()) - - assert np.allclose(np.load("/tmp/save/1_numpy_0_ms.npy"), - list_output[0].asnumpy()) - assert np.allclose(np.load("/tmp/save/1_numpy_1_ms.npy"), - list_output[1].asnumpy()) - - assert np.allclose(np.load("/tmp/save/2_numpy_0_ms.npy"), - tuple_output[0].asnumpy()) - assert np.allclose(np.load("/tmp/save/2_numpy_1_ms.npy"), - tuple_output[1].asnumpy()) - - assert np.allclose(np.load("/tmp/save/3_numpy_x1_ms.npy"), - dict_output["x1"].asnumpy()) - assert np.allclose(np.load("/tmp/save/3_numpy_x2_ms.npy"), - dict_output["x2"].asnumpy()) + x1.asnumpy()) + assert np.allclose(np.load("/tmp/save/1_numpy_ms.npy"), + x2.asnumpy()) @pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) @@ -80,10 +61,6 @@ def test_torch_save(mode): ts.save.cnt.set_data(Tensor(0, ms.int32)) x1 = torch.tensor(-0.5962, dtype=torch.float32) x2 = torch.tensor(0.4985, dtype=torch.float32) - single_input = x1 - list_input = [x1, x2] - tuple_input = (x2, x1) - dict_input = {"x1": x1, "x2": x2} file = '/tmp/save/numpy' try: shutil.rmtree("/tmp/save/") @@ -91,26 +68,12 @@ def test_torch_save(mode): pass os.makedirs("/tmp/save/") - ts.save(file, single_input, True, "torch") - ts.save(file, list_input, True, "torch") - ts.save(file, tuple_input, True, "torch") - ts.save(file, dict_input, True, "torch") + ts.save(file, x1, True, "torch") + ts.save(file, x2, True, "torch") time.sleep(1) assert np.allclose(np.load("/tmp/save/0_numpy_torch.npy"), - single_input.cpu().detach().numpy()) - - assert np.allclose(np.load("/tmp/save/1_numpy_0_torch.npy"), - list_input[0].cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/1_numpy_1_torch.npy"), - list_input[1].cpu().detach().numpy()) - - assert np.allclose(np.load("/tmp/save/2_numpy_0_torch.npy"), - tuple_input[0].cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/2_numpy_1_torch.npy"), - tuple_input[1].cpu().detach().numpy()) + x1.cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/3_numpy_x1_torch.npy"), - dict_input["x1"].cpu().detach().numpy()) - assert np.allclose(np.load("/tmp/save/3_numpy_x2_torch.npy"), - dict_input["x2"].cpu().detach().numpy()) + assert np.allclose(np.load("/tmp/save/1_numpy_torch.npy"), + x2.cpu().detach().numpy()) diff --git a/troubleshooter/troubleshooter/common/util.py b/troubleshooter/troubleshooter/common/util.py index 4b7ccd9ecd8d70a429f334a11ae86babb584c1b1..2a1a87b75f7a95c0f2804a7d0305e6705a1aded4 100644 --- a/troubleshooter/troubleshooter/common/util.py +++ b/troubleshooter/troubleshooter/common/util.py @@ -140,16 +140,13 @@ class SaveNet(nn.Cell): Inputs: file (str): The name of the file to be stored. - data (Union(Tensor, list[Tensor], Tuple[Tensor], dict[str, Tensor])): Supports data types of Tensor, - list[Tensor], tuple(Tensor), and dict[str, Tensor] for both MindSpore and PyTorch. When the input is - a list or tuple of Tensor, the file name will be numbered according to the index of the Tensor. - When the input is a dictionary of Tensor, the corresponding key will be added to the file name. + data (Union(Tensor)): Supports data types of Tensor for both MindSpore and PyTorch. auto_id (bool): Whether to enable automatic numbering. If set to True, an incremental number will be added before the saved file name. If set to False, no numbering will be added to the file name. suffix (str): The suffix of the saved file name. Outputs: - The output storage name is '{id}_name_{idx/key}_{suffix}.npy'. + The output storage name is '{id}_name_{suffix}.npy'. """ def __init__(self): @@ -183,31 +180,14 @@ class SaveNet(nn.Cell): def construct(self, file, data, auto_id, suffix): path, name = self.handle_path(file) - if isinstance(data, (list, tuple)): - for idx, val in enumerate(data): - if auto_id: - np.save(f"{path}{int(self.cnt)}_{name}_{idx}_{suffix}" if suffix else - f"{path}{int(self.cnt)}_{name}_{idx}", self.numpy(val)) - else: - np.save(f"{file}_{idx}_{suffix}" if suffix else - f"{file}_{idx}", self.numpy(val)) - elif isinstance(data, (dict, OrderedDict)): - for key, val in data.items(): - if auto_id: - np.save(f"{path}{int(self.cnt)}_{name}_{key}_{suffix}" if suffix else - f"{path}{int(self.cnt)}_{name}_{key}", self.numpy(val)) - else: - np.save(f"{file}_{key}_{suffix}" if suffix else - f"{file}_{key}", self.numpy(val)) + if auto_id: + np.save(f"{path}{int(self.cnt)}_{name}_{suffix}" if suffix else + f"{path}{int(self.cnt)}_{name}", self.numpy(data)) else: - if auto_id: - np.save(f"{path}{int(self.cnt)}_{name}_{suffix}" if suffix else - f"{path}{int(self.cnt)}_{name}", self.numpy(data)) - else: - np.save(f"{file}_{suffix}" if suffix else file, - self.numpy(data)) + np.save(f"{file}_{suffix}" if suffix else file, + self.numpy(data)) if auto_id: self.cnt += 1 - + return save = SaveNet()