From b9df057c1b5cad66fb9b24debb229668cd1170a1 Mon Sep 17 00:00:00 2001 From: fuchao Date: Mon, 3 Jun 2024 16:15:37 +0800 Subject: [PATCH] =?UTF-8?q?fixed=20d3a29c1=20from=20https://gitee.com/hwfu?= =?UTF-8?q?chao/toolkits/pulls/294=20api=20dump=E7=BB=9F=E8=AE=A1=E4=BF=A1?= =?UTF-8?q?=E6=81=AF=E6=94=AF=E6=8C=81l2norm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- troubleshooter/docs/api/migrator/api_dump.md | 2 +- .../migrator/api_dump/api_dump_compare.py | 6 +++--- .../migrator/api_dump/apis_match/apis_match.py | 2 +- .../migrator/api_dump/ms_dump/hooks.py | 14 +++++++++----- .../migrator/api_dump/pt_dump/dump/dump.py | 14 +++++++++----- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/troubleshooter/docs/api/migrator/api_dump.md b/troubleshooter/docs/api/migrator/api_dump.md index 93f6ad4..c4580cb 100644 --- a/troubleshooter/docs/api/migrator/api_dump.md +++ b/troubleshooter/docs/api/migrator/api_dump.md @@ -75,7 +75,7 @@ output_path # 输出目录 - `api_dump_info.pkl`文件为网络在dump时按照API的执行顺序保存的信息,文件项格式如下: ``` - [数据名称,保留字段,保留字段,数据类型,数据shape,[最大值,最小值,均值], md5值] + [数据名称,保留字段,保留字段,数据类型,数据shape,[最大值,最小值,均值], md5值, l2norm值] ``` 当数据为bool类型或关闭统计信息保存时,最大值/最小值/均值会显示为`NAN`。 diff --git a/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py b/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py index f004270..f85b6a7 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/api_dump_compare.py @@ -70,7 +70,7 @@ def _get_npy_list(apis, io, file_dict): def _get_npy_shape_map(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line return {prefix: data_shape} ret = {} @@ -481,7 +481,7 @@ def print_mindtorch_summary_result( def compare_mindtorch_summary(origin_pkl_path, target_pkl_path, name_map_list, frame_names, **print_kwargs): def get_api_info(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line return {prefix: (data_type, data_shape, data_summary)} ret = {} @@ -530,7 +530,7 @@ def compare_mindtorch_summary(origin_pkl_path, target_pkl_path, name_map_list, f def compare_summary(origin_pkl_path, target_pkl_path, name_map_list, **print_kwargs): def get_api_info(pkl_path): def _read_line(line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line return {prefix: (data_shape, data_summary)} ret = {} diff --git a/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py b/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py index acc2570..a8ce621 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/apis_match/apis_match.py @@ -176,7 +176,7 @@ class APIList: _get_uni_io(self.api_list, self.framework) def _read_line(self, line): - prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume = line + prefix, dump_step, _, data_type, data_shape, data_summary, md5_nume, l2norm = line api_data = APIDataNode(data_shape, data_type, data_summary) def _read_prefix(prefix): diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py index 6dfbb1b..7da46d6 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py @@ -154,13 +154,14 @@ class DumpUtil(object): class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume): + def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): self.data = data self.save_data = save_data self.summary_data = summary_data self.dtype = dtype self.shape = shape self.md5_nume = md5_nume + self.l2norm = l2norm def get_not_float_tensor_info(data, compute_summary): @@ -184,7 +185,8 @@ def get_not_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def get_scalar_data_info(data, compute_summary): @@ -193,7 +195,8 @@ def get_scalar_data_info(data, compute_summary): else: summary_data = [math.nan] * 3 md5_nume = hashlib.md5(str(data).encode()).hexdigest() - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume) + l2norm = np.linalg.norm(data) + return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) def get_float_tensor_info(data, compute_summary): @@ -212,7 +215,8 @@ def get_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def set_dump_path(fpath=None): @@ -265,7 +269,7 @@ def dump_data(dump_file_name, dump_step, prefix, data_info, dump_type): else: np.save(output_path, data_info.save_data) os.chmod(output_path, 0o400) - json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume], f) + json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) f.write('\n') diff --git a/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py b/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py index 4c19a62..24f11e0 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/pt_dump/dump/dump.py @@ -54,13 +54,14 @@ NNCount = defaultdict(int) class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume): + def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): self.data = data self.save_data = save_data self.summary_data = summary_data self.dtype = dtype self.shape = shape self.md5_nume = md5_nume + self.l2norm = l2norm def get_not_float_tensor_info(data, compute_summary): @@ -84,7 +85,8 @@ def get_not_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def get_scalar_data_info(data, compute_summary): @@ -93,7 +95,8 @@ def get_scalar_data_info(data, compute_summary): else: summary_data = [math.nan] * 3 md5_nume = hashlib.md5(str(data).encode()).hexdigest() - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume) + l2norm = np.linalg.norm(data) + return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) def get_float_tensor_info(data, compute_summary): @@ -108,7 +111,8 @@ def get_float_tensor_info(data, compute_summary): tensor_mean = math.nan summary_data = [tensor_max, tensor_min, tensor_mean] md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume) + l2norm = np.linalg.norm(saved_tensor) + return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) def json_dump_condition(prefix): @@ -167,7 +171,7 @@ def dump_data(dump_file_name, dump_step, prefix, data_info, dump_npy): else: np.save(output_path, data_info.save_data) os.chmod(output_path, 0o400) - json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume], f) + json.dump([prefix, dump_step, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) f.write('\n') -- Gitee