From 8a2202afeef8f02c83152ede5f617ea2b5b854bd Mon Sep 17 00:00:00 2001 From: jijiarong Date: Mon, 22 Apr 2024 15:54:06 +0800 Subject: [PATCH] cell dump --- troubleshooter/docs/api/dump_init.md | 183 ++++++++++++++++++ troubleshooter/docs/api_summary.md | 1 + .../troubleshooter/migrator/__init__.py | 3 +- .../troubleshooter/migrator/cell_dump.py | 51 +++++ 4 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 troubleshooter/docs/api/dump_init.md create mode 100644 troubleshooter/troubleshooter/migrator/cell_dump.py diff --git a/troubleshooter/docs/api/dump_init.md b/troubleshooter/docs/api/dump_init.md new file mode 100644 index 0000000..22e6333 --- /dev/null +++ b/troubleshooter/docs/api/dump_init.md @@ -0,0 +1,183 @@ +### troubleshooter.dump_init + +> troubleshooter.dump_init(net:cell, path:str, feature_list=[‘input’, ‘output’, ‘backward input’, ‘backward output’], block_list=None) + +Mindspore 和 PyTorch 的统一模块级数据保存接口,以cell(pytorch侧module)为粒度进行dump。 + +### 参数 + +- net(cell): 待进行数据dump的层。类型为:mindspore.nn.Cell +- path(str):dump数据的保存路径。默认值:./dump/ +- feature_list(list,可选): 默认值:[‘input’, ‘output’, ‘backward_input’, ‘backward_output’] + - input: 保存模型的正向输入 + - output: 保存模块的正向输出 + - backward_input: 保存模块的反向输入 + - backward_output: 保存模块反向输出 +- block_list(list,可选): 默认值:None,dump net 中的全部cell。开启后,仅dump block_list中的cell。 + +> Warning: +> +> - 如果在短时间内保存大量数据,可能会导致内存溢出。 建议手动调小[max_device_memory](https://www.mindspore.cn/docs/zh-CN/r2.2/api_python/mindspore/mindspore.set_context.html?highlight=max_device_memory),保证预留出足够的device侧内存。 + + + +### 样例: + +```python +import numpy as np +import troubleshooter as ts +from mindspore import nn, Tensor, context +from mindspore.dataset import GeneratorDataset +from mindspore.train import Model +from cell_dump import dump_init +import mindspore as ms + +class RandomAccessDataset: + def __init__(self): + self._data = np.ones((4 , 3, 512)).astype(np.float32) + self._label = np.zeros((4, 3)).astype(np.int32) + + def __getitem__(self, index): + return self._data[index], self._label[index] + + def __len__(self): + return len(self._data) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Dense(512, 128) + self.bn1 = nn.BatchNorm1d(128) + self.fc2 = nn.Dense(128, 1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def construct(self, x): + x = self.relu(self.bn1(self.fc1(x))) + x = self.sigmoid(self.fc2(x)) + return x + + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +net = Net() + +# 使用dump_init 进行标记 +dump_init(net, path="./cell_dump_path/", feature_list=['input', 'output', 'backward input', 'backward output']) + +loader = RandomAccessDataset() +dataset = GeneratorDataset(source=loader, column_names=["data", "label"]) + +loss_fn = nn.CrossEntropyLoss() +optimizer = nn.SGD(net.trainable_params(), learning_rate=0.01) + +model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'}) +model.train(1, dataset) +``` + + + +### 结果展示 + +```sh +cell_dump_path +├── bn1 +│ ├── backward_input +│ │ ├── 17_data.0_backward.npy +│ │ ├── 36_data.0_backward.npy +│ │ ├── 55_data.0_backward.npy +│ │ └── 74_data.0_backward.npy +│ ├── backward_output +│ │ ├── 16_data.npy +│ │ ├── 35_data.npy +│ │ ├── 54_data.npy +│ │ └── 73_data.npy +│ ├── input +│ │ ├── 21_data.0.npy +│ │ ├── 2_data.0.npy +│ │ ├── 40_data.0.npy +│ │ └── 59_data.0.npy +│ └── output +│ ├── 22_data.npy +│ ├── 3_data.npy +│ ├── 41_data.npy +│ └── 60_data.npy +├── fc1 +│ ├── backward_output +│ │ ├── 18_data.npy +│ │ ├── 37_data.npy +│ │ ├── 56_data.npy +│ │ └── 75_data.npy +│ ├── input +│ │ ├── 0_data.0.npy +│ │ ├── 19_data.0.npy +│ │ ├── 38_data.0.npy +│ │ └── 57_data.0.npy +│ └── output +│ ├── 1_data.npy +│ ├── 20_data.npy +│ ├── 39_data.npy +│ └── 58_data.npy +├── fc2 +│ ├── backward_input +│ │ ├── 13_data.0_backward.npy +│ │ ├── 32_data.0_backward.npy +│ │ ├── 51_data.0_backward.npy +│ │ └── 70_data.0_backward.npy +│ ├── backward_output +│ │ ├── 12_data.npy +│ │ ├── 31_data.npy +│ │ ├── 50_data.npy +│ │ └── 69_data.npy +│ ├── input +│ │ ├── 25_data.0.npy +│ │ ├── 44_data.0.npy +│ │ ├── 63_data.0.npy +│ │ └── 6_data.0.npy +│ └── output +│ ├── 26_data.npy +│ ├── 45_data.npy +│ ├── 64_data.npy +│ └── 7_data.npy +├── relu +│ ├── backward_input +│ │ ├── 15_data.0_backward.npy +│ │ ├── 34_data.0_backward.npy +│ │ ├── 53_data.0_backward.npy +│ │ └── 72_data.0_backward.npy +│ ├── backward_output +│ │ ├── 14_data.npy +│ │ ├── 33_data.npy +│ │ ├── 52_data.npy +│ │ └── 71_data.npy +│ ├── input +│ │ ├── 23_data.0.npy +│ │ ├── 42_data.0.npy +│ │ ├── 4_data.0.npy +│ │ └── 61_data.0.npy +│ └── output +│ ├── 24_data.npy +│ ├── 43_data.npy +│ ├── 5_data.npy +│ └── 62_data.npy +└── sigmoid + ├── backward_input + │ ├── 11_data.0_backward.npy + │ ├── 30_data.0_backward.npy + │ ├── 49_data.0_backward.npy + │ └── 68_data.0_backward.npy + ├── backward_output + │ ├── 10_data.npy + │ ├── 29_data.npy + │ ├── 48_data.npy + │ └── 67_data.npy + ├── input + │ ├── 27_data.0.npy + │ ├── 46_data.0.npy + │ ├── 65_data.0.npy + │ └── 8_data.0.npy + └── output + ├── 28_data.npy + ├── 47_data.npy + ├── 66_data.npy + └── 9_data.npy +``` diff --git a/troubleshooter/docs/api_summary.md b/troubleshooter/docs/api_summary.md index e897372..66ee078 100644 --- a/troubleshooter/docs/api_summary.md +++ b/troubleshooter/docs/api_summary.md @@ -27,6 +27,7 @@ | [troubleshooter.save_grad](api/save_grad.md) | 用于保存MindSpore和PyTorch的Tensor对应的反向梯度数据 | | [troubleshooter.widget.save_convert](api/widget/save_convert.md) | 将save使用print方式保存的文件转化为npy文件 | | [troubleshooter.migrator.save_net_and_weight_params](api/migrator/save_net_and_weight_params.md)|将网络对象保存成文件| +| [troubleshooter.migrator.dump_init](api/dump_init.md) | 以module为粒度保存MindSpore和PyTorch的Tensor数据和对应的反向梯度数据 | ## 数据比较 diff --git a/troubleshooter/troubleshooter/migrator/__init__.py b/troubleshooter/troubleshooter/migrator/__init__.py index a8b2490..da9245c 100644 --- a/troubleshooter/troubleshooter/migrator/__init__.py +++ b/troubleshooter/troubleshooter/migrator/__init__.py @@ -20,8 +20,9 @@ from troubleshooter.migrator import api_dump, diff_handler from troubleshooter.migrator.api_dump import * from troubleshooter.migrator.diff_handler import * from troubleshooter.migrator.save import save, save_grad +from troubleshooter.migrator.cell_dump import dump_init -__all__ = ["save", "save_grad"] +__all__ = ["save", "save_grad", "dump_init"] __all__.extend(diff_handler.__all__) __all__.extend(api_dump.__all__) diff --git a/troubleshooter/troubleshooter/migrator/cell_dump.py b/troubleshooter/troubleshooter/migrator/cell_dump.py new file mode 100644 index 0000000..2155f48 --- /dev/null +++ b/troubleshooter/troubleshooter/migrator/cell_dump.py @@ -0,0 +1,51 @@ +import os +import troubleshooter as ts +from mindspore import ops, Tensor, Parameter + + +def save_wrapper_func_args(func, self, path, feature_list): + def new_construct(self, *args, **kwargs): + new_args = [] + for index, arg in enumerate(args): + if isinstance(arg, Tensor): + if "input" in feature_list: + ts.save(f"{path}/{self.param_prefix}/input/data.{index}", arg) + if "backward input" in feature_list: + arg = ts.save_grad(f"{path}/{self.param_prefix}/backward_input/data.{index}", arg) + new_args.append(arg) + + outputs = func(*new_args, **kwargs) + + new_outputs = [] + if isinstance(outputs, tuple): + for index, out in enumerate(outputs): + if isinstance(out, Tensor): + if "output" in feature_list: + ts.save(f"{path}/{self.param_prefix}/output/data.{index}", out) + if "backward output" in feature_list: + out = ts.save_grad(f"{path}/{self.param_prefix}/backward_output/data.{index}", out, suffix=None) + new_outputs.append(out) + else: + if isinstance(outputs, Tensor): + if "output" in feature_list: + ts.save(f"{path}/{self.param_prefix}/output/data", outputs) + if "backward output" in feature_list: + new_outputs = ts.save_grad(f"{path}/{self.param_prefix}/backward_output/data", outputs, suffix=None) + return new_outputs + + return new_construct.__get__(self, type(self)) + + +def dump_init(net, path="./dump/", feature_list=None, block_list=None): + support_feature_list = ['input', 'output', 'backward input', 'backward output'] + if feature_list is None: + feature_list = support_feature_list + not_support_feature = [item for item in feature_list if item not in support_feature_list] + if not_support_feature: + raise ValueError(f"Not support feature: {not_support_feature}.\nOnly support {support_feature_list}.") + + net.update_cell_prefix() + for name, cell in net.cells_and_names(): + if name != "": + if not block_list or name in block_list: + cell.construct = save_wrapper_func_args(cell.construct, cell, path, feature_list) -- Gitee