From 080aea8576267bc8795d71fae26fb2b123fc4fee Mon Sep 17 00:00:00 2001 From: fuchao Date: Mon, 3 Jun 2024 15:32:18 +0800 Subject: [PATCH 1/3] =?UTF-8?q?api=E5=89=8D=E7=BC=80=E5=8D=95=E7=8B=AC?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=8F=98=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../migrator/api_dump/ms_dump/initialize.py | 53 ++++++++++--------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py index c7ec352..9dfdc60 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py @@ -16,49 +16,54 @@ try: except ImportError: comm_func_label = False +key_wrap = "wrap_" +Key_ops = "wrap_ops." +key_mint_ops = "wrap_mint.ops." +key_mint_nn_functional = "wrap_mint.nn.functional." +key_communication_comm_func = "wrap_communication.comm_func." def hook_apis(): for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): + if attr_name.startswith(key_wrap): setattr(ms.Tensor, attr_name[5:], getattr(wrap_tensor.HOOKTensor, attr_name)) setattr(ms.common._stub_tensor.StubTensor, attr_name[5:], getattr(wrap_sub_tensor.HOOKSubTensor, attr_name)) for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): - if attr_name.startswith("wrap_ops."): - setattr(ms.ops, attr_name[len("wrap_ops."):], + if attr_name.startswith(key_wrap): + if attr_name.startswith(Key_ops): + setattr(ms.ops, attr_name[len(Key_ops):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) - if attr_name.startswith("wrap_mint.ops."): - setattr(ms.mint, attr_name[len("wrap_mint.ops."):], + if attr_name.startswith(key_mint_ops): + setattr(ms.mint, attr_name[len(key_mint_ops):], getattr(wrap_functional.HOOKFunctionalOP,attr_name)) - if attr_name.startswith("wrap_mint.nn.functional."): - setattr(ms.mint.nn.functional, attr_name[len("wrap_mint.nn.functional."):], + if attr_name.startswith(key_mint_nn_functional): + setattr(ms.mint.nn.functional, attr_name[len(key_mint_nn_functional):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) - if comm_func_label: - setattr(ms.communication.comm_func, attr_name[len("wrap_communication.comm_func."):], + if comm_func_label and attr_name.startswith(key_communication_comm_func): + setattr(ms.communication.comm_func, attr_name[len(key_communication_comm_func):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) def restore_apis(): for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): + if attr_name.startswith(key_wrap): setattr(ms.Tensor, attr_name[5:], wrap_tensor.TensorFunc[attr_name[5:]]) setattr(ms.common._stub_tensor.StubTensor, attr_name[5:], wrap_sub_tensor.SubTensorFunc[attr_name[5:]]) for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): - if attr_name.startswith("wrap_ops."): - setattr(ms.ops, attr_name[len("wrap_ops."):], - wrap_functional.OpsFunc[wrap_functional.ops_label + attr_name[len("wrap_ops."):]]) - if attr_name.startswith("wrap_mint.ops."): - setattr(ms.mint, attr_name[len("wrap_mint.ops."):], - wrap_functional.OpsFunc[wrap_functional.mint_ops_label + attr_name[len("wrap_mint.ops."):]]) - if attr_name.startswith("wrap_mint.nn.functional."): - setattr(ms.mint.nn.functional, attr_name[len("wrap_mint.nn.functional."):], - wrap_functional.OpsFunc[wrap_functional.mint_nn_func_label + attr_name[len("wrap_mint.nn.functional."):]]) - if comm_func_label: - setattr(ms.communication.comm_func, attr_name[len("wrap_communication.comm_func."):], - wrap_functional.OpsFunc[wrap_functional.communication_comm_func_label + attr_name[len("wrap_communication.comm_func."):]]) + if attr_name.startswith(key_wrap): + if attr_name.startswith(Key_ops): + setattr(ms.ops, attr_name[len(Key_ops):], + wrap_functional.OpsFunc[wrap_functional.ops_label + attr_name[len(Key_ops):]]) + if attr_name.startswith(key_mint_ops): + setattr(ms.mint, attr_name[len(key_mint_ops):], + wrap_functional.OpsFunc[wrap_functional.mint_ops_label + attr_name[len(key_mint_ops):]]) + if attr_name.startswith(key_mint_nn_functional): + setattr(ms.mint.nn.functional, attr_name[len(key_mint_nn_functional):], + wrap_functional.OpsFunc[wrap_functional.mint_nn_func_label + attr_name[len(key_mint_nn_functional):]]) + if comm_func_label and attr_name.startswith(key_communication_comm_func): + setattr(ms.communication.comm_func, attr_name[len(key_communication_comm_func):], + wrap_functional.OpsFunc[wrap_functional.communication_comm_func_label + attr_name[len(key_communication_comm_func):]]) class MyMindsporeFunctionExecutor(_MindsporeFunctionExecutor): -- Gitee From d3a29c134e39a96be112efccd7ad5cdda2f1e41c Mon Sep 17 00:00:00 2001 From: fuchao Date: Mon, 3 Jun 2024 16:15:37 +0800 Subject: [PATCH 2/3] =?UTF-8?q?api=20dump=E7=BB=9F=E8=AE=A1=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E6=94=AF=E6=8C=81l2norm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- troubleshooter/docs/api/migrator/api_dump.md | 2 +- .../migrator/api_dump/api_dump_compare.py | 6 +++--- .../migrator/api_dump/apis_match/apis_match.py | 2 +- .../migrator/api_dump/ms_dump/hooks.py | 14 +++++++++----- .../migrator/api_dump/pt_dump/dump/dump.py | 14 +++++++++----- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/troubleshooter/docs/api/migrator/api_dump.md b/troubleshooter/docs/api/migrator/api_dump.md index 93f6ad4..c4580cb 100644 --- a/troubleshooter/docs/api/migrator/api_dump.md +++ b/troubleshooter/docs/api/migrator/api_dump.md @@ -75,7 +75,7 @@ output_path # 输出目录 - `api_dump_info.pkl`文件为网络在dump时按照API的执行顺序保存的信息,文件项格式如下: ``` - [数据名称,保留字段,保留字段,数据类型,数据shape,[最大值,最小值,均值], md5值] + [数据名称,保留字段,保留字段,数据类型,数据shape,[最大值,最小值,均值], md5值, l2norm值] ``` 当数据为bool类型或关闭统计信息保存时,最大值/最小值/均值会显示为`NAN`。 diff --git a/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py b/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py index f004270..f85b6a7 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py @@ -70,7 +70,7 @@ def _get_npy_list(apis, io, file_dict): def _get_npy_shape_map(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line return {prefix: data_shape} ret = {} @@ -481,7 +481,7 @@ def print_mindtorch_summary_result( def compare_mindtorch_summary(origin_pkl_path, target_pkl_path, name_map_list, frame_names, **print_kwargs): def get_api_info(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line return {prefix: (data_type, data_shape, data_summary)} ret = {} @@ -530,7 +530,7 @@ def compare_mindtorch_summary(origin_pkl_path, target_pkl_path, name_map_list, f def compare_summary(origin_pkl_path, target_pkl_path, name_map_list, **print_kwargs): def get_api_info(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line return {prefix: (data_shape, data_summary)} ret = {} diff --git a/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py b/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py index acc2570..a8ce621 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py @@ -176,7 +176,7 @@ class APIList: _get_uni_io(self.api_list, self.framework) def _read_line(self, line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line api_data = APIDataNode(data_shape, data_type, data_summary) def _read_prefix(prefix): diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py index 6dfbb1b..7da46d6 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py @@ -154,13 +154,14 @@ class DumpUtil(object): class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume): + def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): self.data = data self.save_data = save_data self.summary_data = summary_data self.dtype = dtype self.shape = shape self.md5_nume = md5_nume + self.l2norm = l2norm def get_not_float_tensor_info(data, compute_summary): @@ -184,7 +185,8 @@ def get_not_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def get_scalar_data_info(data, compute_summary): @@ -193,7 +195,8 @@ def get_scalar_data_info(data, compute_summary): else: summary_data = [math.nan] * 3 md5_nume = hashlib.md5(str(data).encode()).hexdigest() - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume) + l2norm = np.linalg.norm(data) + return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) def get_float_tensor_info(data, compute_summary): @@ -212,7 +215,8 @@ def get_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def set_dump_path(fpath=None): @@ -265,7 +269,7 @@ def dump_data(dump_file_name, dump_step, prefix, data_info, dump_type): else: np.save(output_path, data_info.save_data) os.chmod(output_path, 0o400) - json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume], f) + json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) f.write('\n') diff --git a/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py b/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py index 4c19a62..24f11e0 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py @@ -54,13 +54,14 @@ NNCount = defaultdict(int) class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume): + def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): self.data = data self.save_data = save_data self.summary_data = summary_data self.dtype = dtype self.shape = shape self.md5_nume = md5_nume + self.l2norm = l2norm def get_not_float_tensor_info(data, compute_summary): @@ -84,7 +85,8 @@ def get_not_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def get_scalar_data_info(data, compute_summary): @@ -93,7 +95,8 @@ def get_scalar_data_info(data, compute_summary): else: summary_data = [math.nan] * 3 md5_nume = hashlib.md5(str(data).encode()).hexdigest() - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume) + l2norm = np.linalg.norm(data) + return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) def get_float_tensor_info(data, compute_summary): @@ -108,7 +111,8 @@ def get_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def json_dump_condition(prefix): @@ -167,7 +171,7 @@ def dump_data(dump_file_name, dump_step, prefix, data_info, dump_npy): else: np.save(output_path, data_info.save_data) os.chmod(output_path, 0o400) - json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume], f) + json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) f.write('\n') -- Gitee From 3f63bb87ea79f52a9435013842da9373a44e2db6 Mon Sep 17 00:00:00 2001 From: fuchao Date: Mon, 3 Jun 2024 16:22:53 +0800 Subject: [PATCH 3/3] =?UTF-8?q?Revert=20"api=20dump=E7=BB=9F=E8=AE=A1?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E6=94=AF=E6=8C=81l2norm"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d3a29c134e39a96be112efccd7ad5cdda2f1e41c. --- troubleshooter/docs/api/migrator/api_dump.md | 2 +- .../migrator/api_dump/api_dump_compare.py | 6 +++--- .../migrator/api_dump/apis_match/apis_match.py | 2 +- .../migrator/api_dump/ms_dump/hooks.py | 14 +++++--------- .../migrator/api_dump/pt_dump/dump/dump.py | 14 +++++--------- 5 files changed, 15 insertions(+), 23 deletions(-) diff --git a/troubleshooter/docs/api/migrator/api_dump.md b/troubleshooter/docs/api/migrator/api_dump.md index c4580cb..93f6ad4 100644 --- a/troubleshooter/docs/api/migrator/api_dump.md +++ b/troubleshooter/docs/api/migrator/api_dump.md @@ -75,7 +75,7 @@ output_path # 输出目录 - `api_dump_info.pkl`文件为网络在dump时按照API的执行顺序保存的信息,文件项格式如下: ``` - [数据名称,保留字段,保留字段,数据类型,数据shape,[最大值,最小值,均值], md5值, l2norm值] + [数据名称,保留字段,保留字段,数据类型,数据shape,[最大值,最小值,均值], md5值] ``` 当数据为bool类型或关闭统计信息保存时,最大值/最小值/均值会显示为`NAN`。 diff --git a/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py b/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py index f85b6a7..f004270 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py @@ -70,7 +70,7 @@ def _get_npy_list(apis, io, file_dict): def _get_npy_shape_map(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line return {prefix: data_shape} ret = {} @@ -481,7 +481,7 @@ def print_mindtorch_summary_result( def compare_mindtorch_summary(origin_pkl_path, target_pkl_path, name_map_list, frame_names, **print_kwargs): def get_api_info(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line return {prefix: (data_type, data_shape, data_summary)} ret = {} @@ -530,7 +530,7 @@ def compare_mindtorch_summary(origin_pkl_path, target_pkl_path, name_map_list, f def compare_summary(origin_pkl_path, target_pkl_path, name_map_list, **print_kwargs): def get_api_info(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line return {prefix: (data_shape, data_summary)} ret = {} diff --git a/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py b/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py index a8ce621..acc2570 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py @@ -176,7 +176,7 @@ class APIList: _get_uni_io(self.api_list, self.framework) def _read_line(self, line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line api_data = APIDataNode(data_shape, data_type, data_summary) def _read_prefix(prefix): diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py index 7da46d6..6dfbb1b 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py @@ -154,14 +154,13 @@ class DumpUtil(object): class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): + def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume): self.data = data self.save_data = save_data self.summary_data = summary_data self.dtype = dtype self.shape = shape self.md5_nume = md5_nume - self.l2norm = l2norm def get_not_float_tensor_info(data, compute_summary): @@ -185,8 +184,7 @@ def get_not_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - l2norm = np.linalg.norm(saved_tensor) - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) def get_scalar_data_info(data, compute_summary): @@ -195,8 +193,7 @@ def get_scalar_data_info(data, compute_summary): else: summary_data = [math.nan] * 3 md5_nume = hashlib.md5(str(data).encode()).hexdigest() - l2norm = np.linalg.norm(data) - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) + return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume) def get_float_tensor_info(data, compute_summary): @@ -215,8 +212,7 @@ def get_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - l2norm = np.linalg.norm(saved_tensor) - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) def set_dump_path(fpath=None): @@ -269,7 +265,7 @@ def dump_data(dump_file_name, dump_step, prefix, data_info, dump_type): else: np.save(output_path, data_info.save_data) os.chmod(output_path, 0o400) - json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) + json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume], f) f.write('\n') diff --git a/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py b/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py index 24f11e0..4c19a62 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py @@ -54,14 +54,13 @@ NNCount = defaultdict(int) class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): + def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume): self.data = data self.save_data = save_data self.summary_data = summary_data self.dtype = dtype self.shape = shape self.md5_nume = md5_nume - self.l2norm = l2norm def get_not_float_tensor_info(data, compute_summary): @@ -85,8 +84,7 @@ def get_not_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - l2norm = np.linalg.norm(saved_tensor) - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) def get_scalar_data_info(data, compute_summary): @@ -95,8 +93,7 @@ def get_scalar_data_info(data, compute_summary): else: summary_data = [math.nan] * 3 md5_nume = hashlib.md5(str(data).encode()).hexdigest() - l2norm = np.linalg.norm(data) - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) + return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume) def get_float_tensor_info(data, compute_summary): @@ -111,8 +108,7 @@ def get_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - l2norm = np.linalg.norm(saved_tensor) - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) def json_dump_condition(prefix): @@ -171,7 +167,7 @@ def dump_data(dump_file_name, dump_step, prefix, data_info, dump_npy): else: np.save(output_path, data_info.save_data) os.chmod(output_path, 0o400) - json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) + json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume], f) f.write('\n') -- Gitee