From 5208142012061e61876d697e8de15450b9d6f651 Mon Sep 17 00:00:00 2001 From: fandawei Date: Fri, 5 May 2023 17:20:24 +0800 Subject: [PATCH] Add a unified save interface for pytorch and mindspore --- troubleshooter/tests/common/test_save.py | 116 +++++++++++++++++++ troubleshooter/troubleshooter/__init__.py | 1 + troubleshooter/troubleshooter/common/util.py | 89 +++++++++++++- 3 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 troubleshooter/tests/common/test_save.py diff --git a/troubleshooter/tests/common/test_save.py b/troubleshooter/tests/common/test_save.py new file mode 100644 index 0000000..8dc74c2 --- /dev/null +++ b/troubleshooter/tests/common/test_save.py @@ -0,0 +1,116 @@ +import os +import time +import shutil +import pytest + +import troubleshooter as ts +import numpy as np +import torch +import mindspore as ms +from mindspore import nn, Tensor + + +class NetWorkSave(nn.Cell): + def __init__(self, file, auto_id, suffix): + super(NetWorkSave, self).__init__() + self.auto_id = auto_id + self.suffix = suffix + self.file = file + + def construct(self, x): + ts.save(self.file, x, self.auto_id, self.suffix) + return x + + +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE, ms.GRAPH_MODE]) +def test_ms_save(mode): + """ + Feature: ts.save + Description: Verify the result of save + Expectation: success + """ + ms.set_context(mode=mode, device_target="CPU") + ts.save.cnt.set_data(Tensor(0, ms.int32)) + x1 = Tensor(-0.5962, ms.float32) + x2 = Tensor(0.4985, ms.float32) + single_input = x1 + list_input = [x1, x2] + tuple_input = (x2, x1) + dict_input = {"x1": x1, "x2": x2} + net = NetWorkSave('/tmp/save/numpy', True, "ms") + + try: + shutil.rmtree("/tmp/save/") + except FileNotFoundError: + pass + os.makedirs("/tmp/save/") + + single_output = net(single_input) + list_output = net(list_input) + tuple_output = net(tuple_input) + dict_output = net(dict_input) + time.sleep(1) + assert np.allclose(np.load("/tmp/save/0_numpy_ms.npy"), + single_output.asnumpy()) + + assert np.allclose(np.load("/tmp/save/1_numpy_0_ms.npy"), + list_output[0].asnumpy()) + assert np.allclose(np.load("/tmp/save/1_numpy_1_ms.npy"), + list_output[1].asnumpy()) + + assert np.allclose(np.load("/tmp/save/2_numpy_0_ms.npy"), + tuple_output[0].asnumpy()) + assert np.allclose(np.load("/tmp/save/2_numpy_1_ms.npy"), + tuple_output[1].asnumpy()) + + assert np.allclose(np.load("/tmp/save/3_numpy_x1_ms.npy"), + dict_output["x1"].asnumpy()) + assert np.allclose(np.load("/tmp/save/3_numpy_x2_ms.npy"), + dict_output["x2"].asnumpy()) + + +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +def test_torch_save(mode): + """ + Feature: ts.save + Description: Verify the result of save + Expectation: success + """ + ms.set_context(mode=mode, device_target="CPU") + ts.save.cnt.set_data(Tensor(0, ms.int32)) + x1 = torch.tensor(-0.5962, dtype=torch.float32) + x2 = torch.tensor(0.4985, dtype=torch.float32) + single_input = x1 + list_input = [x1, x2] + tuple_input = (x2, x1) + dict_input = {"x1": x1, "x2": x2} + file = '/tmp/save/numpy' + try: + shutil.rmtree("/tmp/save/") + except FileNotFoundError: + pass + os.makedirs("/tmp/save/") + + ts.save(file, single_input, True, "torch") + ts.save(file, list_input, True, "torch") + ts.save(file, tuple_input, True, "torch") + ts.save(file, dict_input, True, "torch") + time.sleep(1) + + assert np.allclose(np.load("/tmp/save/0_numpy_torch.npy"), + single_input.cpu().detach().numpy()) + + assert np.allclose(np.load("/tmp/save/1_numpy_0_torch.npy"), + list_input[0].cpu().detach().numpy()) + assert np.allclose(np.load("/tmp/save/1_numpy_1_torch.npy"), + list_input[1].cpu().detach().numpy()) + + assert np.allclose(np.load("/tmp/save/2_numpy_0_torch.npy"), + tuple_input[0].cpu().detach().numpy()) + assert np.allclose(np.load("/tmp/save/2_numpy_1_torch.npy"), + tuple_input[1].cpu().detach().numpy()) + + assert np.allclose(np.load("/tmp/save/3_numpy_x1_torch.npy"), + dict_input["x1"].cpu().detach().numpy()) + assert np.allclose(np.load("/tmp/save/3_numpy_x2_torch.npy"), + dict_input["x2"].cpu().detach().numpy()) diff --git a/troubleshooter/troubleshooter/__init__.py b/troubleshooter/troubleshooter/__init__.py index c681d2a..8e78ac6 100644 --- a/troubleshooter/troubleshooter/__init__.py +++ b/troubleshooter/troubleshooter/__init__.py @@ -28,3 +28,4 @@ from .migrator.diff_handler import DifferenceFinder as diff_finder from .migrator.diff_handler import WeightMigrator as weight_migrator from .proposer import ProposalAction as proposal from .tracker import Tracker as tracking +from .common.util import save diff --git a/troubleshooter/troubleshooter/common/util.py b/troubleshooter/troubleshooter/common/util.py index 19e5311..4b7ccd9 100644 --- a/troubleshooter/troubleshooter/common/util.py +++ b/troubleshooter/troubleshooter/common/util.py @@ -16,6 +16,12 @@ import re import os import stat +from collections import OrderedDict + +from mindspore import Tensor, Parameter, nn +import mindspore as ms +import torch +import numpy as np def print_line(char, times): @@ -105,6 +111,7 @@ def find_file(dir, suffix=".npy"): file_list.append(file) return file_list + def make_directory(path: str): """Make directory.""" if path is None or not isinstance(path, str) or path.strip() == "": @@ -123,4 +130,84 @@ def make_directory(path: str): real_path = path except PermissionError: raise TypeError("No write permission on the directory `{path}`.") - return real_path \ No newline at end of file + return real_path + + +class SaveNet(nn.Cell): + """ + The SaveNet class is used to build a unified data storage interface that supports PyTorch and MindSpore + PYNATIVE_MODE as well as GRAPH_MODE, but currently does not support MindSpore GRAPH_MODE. + + Inputs: + file (str): The name of the file to be stored. + data (Union(Tensor, list[Tensor], Tuple[Tensor], dict[str, Tensor])): Supports data types of Tensor, + list[Tensor], tuple(Tensor), and dict[str, Tensor] for both MindSpore and PyTorch. When the input is + a list or tuple of Tensor, the file name will be numbered according to the index of the Tensor. + When the input is a dictionary of Tensor, the corresponding key will be added to the file name. + auto_id (bool): Whether to enable automatic numbering. If set to True, an incremental number will be + added before the saved file name. If set to False, no numbering will be added to the file name. + suffix (str): The suffix of the saved file name. + + Outputs: + The output storage name is '{id}_name_{idx/key}_{suffix}.npy'. + """ + + def __init__(self): + super(SaveNet, self).__init__() + self.cnt = Parameter(Tensor(0, ms.int32), + name="cnt", requires_grad=False) + self.sep = os.sep + + def numpy(self, data): + if isinstance(data, ms.Tensor): + return data.asnumpy() + elif torch.is_tensor(data): + return data.cpu().detach().numpy() + else: + raise TypeError(f"For ts.save, the type of argument 'data' must be mindspore.Tensor or torch.tensor, " \ + f"but got {type(data)}") + + def handle_path(self, file): + if file[-1] == self.sep: + raise ValueError(f"For ts.save, the type of argument 'file' must be a valid filename, but got {file}") + name = '' + for c in file: + if c == self.sep: + name = '' + else: + name += c + path = '' + for i in range(len(file) - len(name)): + path += file[i] + return path, name + + def construct(self, file, data, auto_id, suffix): + path, name = self.handle_path(file) + if isinstance(data, (list, tuple)): + for idx, val in enumerate(data): + if auto_id: + np.save(f"{path}{int(self.cnt)}_{name}_{idx}_{suffix}" if suffix else + f"{path}{int(self.cnt)}_{name}_{idx}", self.numpy(val)) + else: + np.save(f"{file}_{idx}_{suffix}" if suffix else + f"{file}_{idx}", self.numpy(val)) + elif isinstance(data, (dict, OrderedDict)): + for key, val in data.items(): + if auto_id: + np.save(f"{path}{int(self.cnt)}_{name}_{key}_{suffix}" if suffix else + f"{path}{int(self.cnt)}_{name}_{key}", self.numpy(val)) + else: + np.save(f"{file}_{key}_{suffix}" if suffix else + f"{file}_{key}", self.numpy(val)) + else: + if auto_id: + np.save(f"{path}{int(self.cnt)}_{name}_{suffix}" if suffix else + f"{path}{int(self.cnt)}_{name}", self.numpy(data)) + else: + np.save(f"{file}_{suffix}" if suffix else file, + self.numpy(data)) + if auto_id: + self.cnt += 1 + + +save = SaveNet() -- Gitee