From 9a5d0bd466cf1669a3ec9b033495ed04ced5a18e Mon Sep 17 00:00:00 2001 From: wangnan39 Date: Tue, 30 Apr 2024 12:00:12 +0800 Subject: [PATCH] support save max min --- .../troubleshooter/migrator/save.py | 152 ++++++++++++------ 1 file changed, 104 insertions(+), 48 deletions(-) diff --git a/troubleshooter/troubleshooter/migrator/save.py b/troubleshooter/troubleshooter/migrator/save.py index d3e604e..ccf1735 100644 --- a/troubleshooter/troubleshooter/migrator/save.py +++ b/troubleshooter/troubleshooter/migrator/save.py @@ -98,21 +98,23 @@ def _add_id_prefix_to_filename(filename): return new_filename -class _SaveBase: - def __init__(self, file): - super(_SaveBase, self).__init__() +class TensorSaverBase: + def __init__(self, file, suffix, output_mode, statistics): + super(TensorSaverBase, self).__init__() + _check_save_mode(output_mode, "save") path, name = _handle_path(file) self.path = path self.name = name - self.save_func = {'npy': _npy_save, 'print': _print_save} + self.save_func_dict = {'npy': _npy_save, 'print': _print_save} + self.save_func = self.save_func_dict[output_mode] + self.suffix = suffix + self.statistics = statistics - def get_save_func(self, mode): - return self.save_func[mode] - -class _SaveGradBase: +class GradientSaverBase: def __init__(self, file, suffix, output_mode): - super(_SaveGradBase, self).__init__() + super(GradientSaverBase, self).__init__() + _check_save_mode(output_mode, "save_grad") path, name = _handle_path(file) if suffix: name = f"{name}_{suffix}" @@ -136,15 +138,27 @@ def torch_TensorDump(file, data): os.chmod(file, stat.S_IRUSR) -def _wrapper_torch_save_grad(file, output_mode): +def _wrapper_torch_save_grad(file, output_mode, statistics): def _save_grad_func(grad): if grad is None: return if output_mode == 'print': format_name = f"{SAVE_NAME_MARK}{file}" - print(format_name, grad) + if statistics: + print(format_name + "_norm", grad.norm()) + print(format_name + "_max", grad.max()) + print(format_name + "_min", grad.min()) + print(format_name + "_mean", grad.mean()) + else: + print(format_name, grad) else: - torch_TensorDump(file, grad) + if statistics: + torch_TensorDump(file + "_norm", grad.norm()) + torch_TensorDump(file + "_max", grad.max()) + torch_TensorDump(file + "_min", grad.min()) + torch_TensorDump(file + "_mean", grad.mean()) + else: + torch_TensorDump(file, grad) return grad return _save_grad_func @@ -164,26 +178,41 @@ if {"torch", "mindspore"}.issubset(FRAMEWORK_TYPE): raise TypeError(f"For 'ts.save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " f"but got {type(data)}") - def _wrapper_save_grad_func(file, output_mode): + def _wrapper_save_grad_func(file, output_mode, statistics): def _save_grad_func(grad): if output_mode == 'print': format_name = f"{SAVE_NAME_MARK}{file}" - print(format_name, grad) + data = grad + if data.dtype == ms.bfloat16: + data = data.float() + if statistics: + print(format_name + "_norm", data.norm()) + print(format_name + "_max", data.max()) + print(format_name + "_min", data.min()) + print(format_name + "_mean", data.mean()) + else: + print(format_name, data) else: data = grad if data.dtype == ms.bfloat16: data = data.float() - ms.ops.TensorDump()(file, data) + if statistics: + ms.ops.TensorDump()(file + "_norm", data.norm()) + ms.ops.TensorDump()(file + "_max", data.max()) + ms.ops.TensorDump()(file + "_min", data.min()) + ms.ops.TensorDump()(file + "_mean", data.mean()) + else: + ms.ops.TensorDump()(file, data) return grad return _save_grad_func @ms.jit_class - class _SaveGradCell(_SaveGradBase): - def __init__(self, file, suffix, output_mode): - super(_SaveGradCell, self).__init__(file, suffix, output_mode) + class GradientSaver(GradientSaverBase): + def __init__(self, file, suffix, output_mode='npy', statistics=False): + super(GradientSaver, self).__init__(file, suffix, output_mode) self.ms_save_grad = ms.ops.InsertGradientOf( - _wrapper_save_grad_func(self.file, output_mode)) - self.pt_save_func = _wrapper_torch_save_grad(self.file, output_mode) + _wrapper_save_grad_func(self.file, output_mode, statistics)) + self.pt_save_func = _wrapper_torch_save_grad(self.file, output_mode, statistics) def __call__(self, x): if isinstance(x, ms.Tensor): @@ -206,9 +235,9 @@ elif "torch" in FRAMEWORK_TYPE: raise TypeError(f"For 'ts.save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " f"but got {type(data)}") - class _SaveGradCell(_SaveGradBase): + class GradientSaver(GradientSaverBase): def __init__(self, file, suffix, output_mode): - super(_SaveGradCell, self).__init__(file, suffix, output_mode) + super(GradientSaver, self).__init__(file, suffix, output_mode) self.pt_save_func = _wrapper_torch_save_grad(self.file, output_mode) def __call__(self, x): @@ -231,25 +260,40 @@ elif "mindspore" in FRAMEWORK_TYPE: raise TypeError(f"For 'ts.save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " f"but got {type(data)}") - def _wrapper_save_grad_func(file, output_mode): + def _wrapper_save_grad_func(file, output_mode, statistics): def _save_grad_func(grad): if output_mode == 'print': format_name = f"{SAVE_NAME_MARK}{file}" - print(format_name, grad) + data = grad + if data.dtype == ms.bfloat16: + data = data.float() + if statistics: + print(format_name + "_norm", data.norm()) + print(format_name + "_max", data.max()) + print(format_name + "_min", data.min()) + print(format_name + "_mean", data.mean()) + else: + print(format_name, grad) else: data = grad if data.dtype == ms.bfloat16: data = data.float() - ms.ops.TensorDump()(file, data) + if statistics: + ms.ops.TensorDump()(file + "_norm", data.norm()) + ms.ops.TensorDump()(file + "_max", data.max()) + ms.ops.TensorDump()(file + "_min", data.min()) + ms.ops.TensorDump()(file + "_mean", data.mean()) + else: + ms.ops.TensorDump()(file, data) return grad return _save_grad_func @ms.jit_class - class _SaveGradCell(_SaveGradBase): - def __init__(self, file, suffix, output_mode): - super(_SaveGradCell, self).__init__(file, suffix, output_mode) + class GradientSaver(GradientSaverBase): + def __init__(self, file, suffix, output_mode, statistics): + super(GradientSaver, self).__init__(file, suffix, output_mode) self.ms_save_grad = ms.ops.InsertGradientOf( - _wrapper_save_grad_func(self.file, output_mode)) + _wrapper_save_grad_func(self.file, output_mode, statistics)) def __call__(self, x): if isinstance(x, ms.Tensor): @@ -265,52 +309,64 @@ if "mindspore" in FRAMEWORK_TYPE: import mindspore as ms @ms.jit_class - class _SaveCell(_SaveBase): - def __call__(self, data, suffix, output_mode): - self.get_save_func(output_mode)(self.name, data, suffix, self.path) + class TensorSaver(TensorSaverBase): + def __call__(self, data): + self.save_func(self.name, data, self.suffix, self.path, self.statistics) else: - class _SaveCell(_SaveBase): - def __call__(self, data, suffix, output_mode): - self.get_save_func(output_mode)(self.name, data, suffix, self.path) + class TensorSaver(TensorSaverBase): + def __call__(self, data): + self.save_func(self.name, data, self.suffix, self.path, self.statistics) -def _npy_save(item_name, data, suffix, path): +def _npy_save(item_name, data, suffix, path, statistics): if isinstance(data, (list, tuple, dict, OrderedDict)): for key, val in _iterate_items(data): - _npy_save(f"{item_name}.{key}", val, suffix, path) + _npy_save(f"{item_name}.{key}", val, suffix, path, statistics) else: if data is None: return if suffix: item_name = f"{item_name}_{suffix}" - _npy_save_ops(f"{path}{item_name}", data) + if statistics: + _npy_save_ops(f"{path}{item_name}_norm", data.norm()) + _npy_save_ops(f"{path}{item_name}_max", data.max()) + _npy_save_ops(f"{path}{item_name}_min", data.min()) + _npy_save_ops(f"{path}{item_name}_mean", data.mean()) + else: + _npy_save_ops(f"{path}{item_name}", data) -def _print_save(item_name, data, suffix, path=None): +def _print_save(item_name, data, suffix, path, statistics): if isinstance(data, (list, tuple, dict, OrderedDict)): for key, val in _iterate_items(data): - _print_save(f"{item_name}.{key}", val, suffix, path) + _print_save(f"{item_name}.{key}", val, suffix, path, statistics) else: if data is None: return if suffix: item_name = f"{item_name}_{suffix}" - format_name = f"{SAVE_NAME_MARK}{item_name}" - print(format_name, data) + if statistics: + format_name = f"{SAVE_NAME_MARK}{item_name}" + print(f"{SAVE_NAME_MARK}{item_name}_norm", data.norm()) + print(f"{SAVE_NAME_MARK}{item_name}_max", data.max()) + print(f"{SAVE_NAME_MARK}{item_name}_min", data.min()) + print(f"{SAVE_NAME_MARK}{item_name}_mean", data.mean()) + else: + format_name = f"{SAVE_NAME_MARK}{item_name}" + print(format_name, data) -def save(file, data, suffix=None, output_mode="npy"): +def save(file, data, suffix=None, output_mode="npy", statistics=False): """ save tensor. """ - _check_save_mode(output_mode, "save") - _SaveCell(file)(data, suffix, output_mode) + TensorSaver(file, suffix, output_mode, statistics)(data) -def save_grad(file, data, suffix="backward", output_mode="npy"): +def save_grad(file, data, suffix="backward", output_mode="npy", statistics=False): """ save grad. """ - _check_save_mode(output_mode, "save_grad") - return _SaveGradCell(file, suffix, output_mode)(data) + return GradientSaver(file, suffix, output_mode, statistics)(data) + -- Gitee