diff --git a/troubleshooter/README.md b/troubleshooter/README.md index a14a4bacdccc20eebb935312b31a0c35d6bbcf57..b3ff18ffa0ee04066b83b515d4ba97c98c50a388 100644 --- a/troubleshooter/README.md +++ b/troubleshooter/README.md @@ -25,6 +25,7 @@ pip install troubleshooter * 应用场景1:pth到ckpt权重转换 * 应用场景2:将转换后的ckpt与MindSpore网络生成的ckpt进行对比 * 应用场景3:比较两组tensor值(npy文件)是否相等 +* 应用场景4:比较pytorch和mindspore的网络输出是否相等 ## 其他调试功能 * 应用场景1:tracking在指定epoch/step停止跟踪(使用model.train训练) @@ -553,6 +554,52 @@ x = self.sqrt(y) 出现 nan, 给出“User Warning 'nan' is detected”报错。 # 将调整后的名称传入比较接口 dif.compare_npy_dir(name_map_list=name_list) +### 应用场景:比较mindspore和pytorch网络输出是否一致 + +在进行网络迁移时,由于大多数网络是使用pytorch搭建的,在迁移到mindspore过程中,我们需要比较mindspore和pytorch网络输出结果是否一致。此功能实现对比mindspore和pytorch的输出结果。 + +#### 接口参数 + +| 参数 | 类型 | 说明 | +| ------------ | ------------------------- | ------------------------------------------------------------ | +| ms_net | mindspore.nn.Cell | mindspore模型实例 | +| pt_net | torch.nn.Module | torch模型实例 | +| input_data | Union(list[Iterable[np.array]], list[Iterable[str]]) | 模型的输入。支持多输入,每个输入使用一个list。当list中为array时,将会依序使用其中的array作为模型的输入,每个array形状与模型输入形状相同,以图像分类任务为例,输入形状为[batch_size, num_channel, w, h];当为str时,将会加载相应位置的npy文件作为模型输入。 对于单输入情况,用户需要传入[[input1, input2, ...]] 或者 [['input1.npy', 'input2.npy', ...]];对于多输入,用户应当传入[[input1_1, input1_2, ...], [input2_1, input2_2, ...]]或者[['input1_1.npy', 'input1_2.npy', ...], ['input2_1.npy', 'input2_2.npy', ...]]| +| out_path | str | 结果保存的文件夹,例如,'troubleshooter/results' | +| print_result | bool | 是否打印输出结果,以及中间过程 | + +#### 如何使用 + +可以参考troubleshooter/tests/diff_handler/test_netdifffinder.py中的使用方法,或者下面的使用方法: + +```python +# 构造输入,这里假设测试用例有两个,分别为input1,input2; +input1 = [np.random.randn(1, 12).astype(np.float32), np.random.randn(1, 13).astype(np.float32)] +input2 = [np.random.randn(1, 12).astype(np.float32), np.random.randn(1, 13).astype(np.float32)] +# 实例化mindspore模型以及torch模型 +ms_net = MSNet() +pt_net = TorchNet() +diff_finder = NetDifferenceFinder( + ms_net=ms_net, + pt_net=pt_net, + inputs=[input1, input2], + out_path='troubleshooter/tests/diff_handler/results', + print_result=False, +) +diff_finder.start_compare() +``` + +#### 结果展示 + +命令: + +```python +python troubleshooter/tests/diff_handler/test_netdifffinder.py +``` + +输出结果: + +![网络输出对比结果展示](docs/images/outputcompare.png) ## 其他调试功能: ### 应用场景1:tracking在指定epoch/step停止跟踪(使用model.train训练) diff --git a/troubleshooter/docs/images/outputcompare.png b/troubleshooter/docs/images/outputcompare.png new file mode 100644 index 0000000000000000000000000000000000000000..813ea838fb0edb4ac6453af8bda3dd504246d3a3 Binary files /dev/null and b/troubleshooter/docs/images/outputcompare.png differ diff --git a/troubleshooter/tests/diff_handler/test_netdifffinder.py b/troubleshooter/tests/diff_handler/test_netdifffinder.py new file mode 100644 index 0000000000000000000000000000000000000000..40e767737b1b8534c5e221be3fa53f55a4dd05b8 --- /dev/null +++ b/troubleshooter/tests/diff_handler/test_netdifffinder.py @@ -0,0 +1,72 @@ +import mindspore.nn as m_nn +import torch.nn as t_nn +from troubleshooter.migrator.diff_handler import NetDifferenceFinder +import sys +import re +import numpy as np +sys.path.append('troubleshooter') + + +class TorchNet(t_nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = t_nn.Linear(12, 21) + self.net2 = t_nn.Linear(13, 22) + + def forward(self, x, y): + return {'a': self.net1(x), 'b': self.net2(y)} + + +class MSNet(m_nn.Cell): + def __init__(self, n=0) -> None: + super().__init__() + self.net1 = m_nn.Dense(12, 21) + self.net2 = m_nn.Dense(13, 22) + self.n = n + + def construct(self, x, y): + if self.n: + print(self.n) + return self.net1(x), self.net2(y) + + +def test_model(capsys): + input1 = (np.random.randn(1, 12).astype(np.float32), + np.random.randn(1, 13).astype(np.float32)) + input2 = (np.random.randn(1, 12).astype(np.float32), + np.random.randn(1, 13).astype(np.float32)) + ms_net = MSNet() + pt_net = TorchNet() + diff_finder = NetDifferenceFinder( + ms_net=ms_net, + pt_net=pt_net, + inputs=[input1, input2], + out_path='troubleshooter/tests/diff_handler/results', + print_result=False, + ) + diff_finder.start_compare() + out, err = capsys.readouterr() + info_pattern = r".*In test case \d+, the .*? net inference completed cost .*? seconds\..*" + assert re.match(info_pattern, out) is not None + assert err == '' + + +def test_dict(capsys): + input1 = {'a': np.random.randn(1, 12).astype( + np.float32), 'b': np.random.randn(1, 13).astype(np.float32)} + input2 = {'a': np.random.randn(1, 12).astype( + np.float32), 'b': np.random.randn(1, 13).astype(np.float32)} + ms_net = MSNet() + pt_net = TorchNet() + diff_finder = NetDifferenceFinder( + ms_net=ms_net, + pt_net=pt_net, + inputs=[input1, input2], + out_path='troubleshooter/tests/diff_handler/results', + print_result=False, + ) + diff_finder.start_compare() + out, err = capsys.readouterr() + info_pattern = r".*In test case \d+, the .*? net inference completed cost .*? seconds\..*" + assert re.match(info_pattern, out) is not None + assert err == '' diff --git a/troubleshooter/troubleshooter/common/format_msg.py b/troubleshooter/troubleshooter/common/format_msg.py index 261d7b030ecc34b6287ae05bbd6eb9700ef13300..ab118d25a901b9eeff339a612ba5e2fb5f08767c 100644 --- a/troubleshooter/troubleshooter/common/format_msg.py +++ b/troubleshooter/troubleshooter/common/format_msg.py @@ -48,7 +48,8 @@ def _add_row(x, item, message, width=TABLE_WIDTH, break_long_words=False, break_ if message is None: return item_cn = _item_to_cn.get(item) - format_message = _format_str_length(message) if os.linesep in message else message + format_message = _format_str_length( + message) if os.linesep in message else message x.add_row([item_cn, fill(format_message, width=width, break_long_words=break_long_words, break_on_hyphens=break_on_hyphens)]) @@ -56,10 +57,11 @@ def _add_row(x, item, message, width=TABLE_WIDTH, break_long_words=False, break_ def print_weight_compare_result(result_list, print_type=1): x = PrettyTable() x.title = 'The list of comparison results' - x.field_names = ["Parameter name of input ckpt", "Parameter name of converted ckpt", "Whether shape are equal", "Parameter shape of input ckpt", "Parameter shape of converted ckpt"] + x.field_names = ["Parameter name of input ckpt", "Parameter name of converted ckpt", "Whether shape are equal", + "Parameter shape of input ckpt", "Parameter shape of converted ckpt"] for result in result_list: if print_type == 1: - x.add_row([result[0],result[1],result[2],result[3],result[4]]) + x.add_row([result[0], result[1], result[2], result[3], result[4]]) elif result[2] is not True: x.add_row([result[0], result[1], result[2], result[3], result[4]]) print(x.get_string()) @@ -69,19 +71,33 @@ def print_convert_result(result_list): x = PrettyTable() x.title = 'The list of conversion result' x.field_names = ["Parameter name of pth", "Parameter name of converted ckpt", "Whether the name is converted", - "Whether the weight value is converted","Parameter shape of pth","Parameter shape of ckpt"] + "Whether the weight value is converted", "Parameter shape of pth", "Parameter shape of ckpt"] for result in result_list: - x.add_row([result[0],result[1],result[2],result[3],result[4],result[5]]) + x.add_row([result[0], result[1], result[2], + result[3], result[4], result[5]]) print(x.get_string()) + def print_diff_result(result_list): x = PrettyTable() x.title = 'The list of comparison results' - x.field_names = ["orig array name", "target array name", "Results of comparison", "(mean,max,min)"] + x.field_names = ["orig array name", "target array name", + "Results of comparison", "(mean,max,min)"] for result in result_list: - x.add_row([result[0],result[1],result[2],result[3]]) + x.add_row([result[0], result[1], result[2], result[3]]) print(x.get_string()) + +def print_net_infer_diff_result(result_list): + x = PrettyTable() + x.title = 'The list of comparison results' + x.field_names = ["Pytorch data", "MindSpore data", + "Results of comparison", "cosine similarity", "(mean, max, min)"] + for result in result_list: + x.add_row([result[0], result[1], result[2], result[3], result[4]]) + print(x.get_string()) + + def print_result(expert_experience, write_file_path): """ print MindSpore FAR @@ -94,19 +110,25 @@ def print_result(expert_experience, write_file_path): mindspore_version = expert_experience.get("mindspore_version") mindspore_mode = expert_experience.get("mindspore_mode") mindspore_device = expert_experience.get("Device Type") - x.add_row([item_desc.get("ms_version"), fill(mindspore_version, width=TABLE_WIDTH)]) - x.add_row([item_desc.get("ms_mode"), fill(mindspore_mode, width=TABLE_WIDTH)]) + x.add_row([item_desc.get("ms_version"), fill( + mindspore_version, width=TABLE_WIDTH)]) + x.add_row([item_desc.get("ms_mode"), fill( + mindspore_mode, width=TABLE_WIDTH)]) if mindspore_device: - x.add_row([item_desc.get("ms_device"), fill(mindspore_device, width=TABLE_WIDTH)]) + x.add_row([item_desc.get("ms_device"), fill( + mindspore_device, width=TABLE_WIDTH)]) ms_status = expert_experience.get("ms_status") code_line = expert_experience.get("code_line") sink_mode = expert_experience.get("Sink Mode") if ms_status: - x.add_row([item_desc.get("ms_status"), fill(ms_status, width=TABLE_WIDTH)]) + x.add_row([item_desc.get("ms_status"), + fill(ms_status, width=TABLE_WIDTH)]) if code_line: - x.add_row([item_desc.get("code_line"), fill(code_line, width=TABLE_WIDTH)]) + x.add_row([item_desc.get("code_line"), + fill(code_line, width=TABLE_WIDTH)]) if sink_mode: - x.add_row([item_desc.get("sink_mode"), fill(sink_mode, width=TABLE_WIDTH)]) + x.add_row([item_desc.get("sink_mode"), + fill(sink_mode, width=TABLE_WIDTH)]) # 可能原因 fault_cause = expert_experience.get('Fault Cause') @@ -131,7 +153,8 @@ def print_result(expert_experience, write_file_path): if write_file_path: case_id = expert_experience.get("ID") _add_row(x, "case_id", case_id) - file = os.path.join(write_file_path, "mindspore_failure_analysis_report.log") + file = os.path.join( + write_file_path, "mindspore_failure_analysis_report.log") with open(file, "w") as f: f.write(x.get_string() + os.linesep) print(x.get_string()) @@ -243,7 +266,8 @@ def _format_case_str(content, mindspore_version): # no mindspore link, return if match: link_version = line[match.start():match.end()] - line = _replace_link_version(line, link_version, mindspore_version) + line = _replace_link_version( + line, link_version, mindspore_version) result += line + os.linesep else: # link version same with mindspore version, no replace return content @@ -271,8 +295,10 @@ def _filter_stack(stack): def print_clear_exception(exc_type, exc_value, exc_traceback_obj): if exc_traceback_obj: - _print_msg("[TroubleShooter-Clear Stack] Python Traceback (most recent call last):", "NULL", False) - org_err_stack = traceback.format_exception(exc_type, exc_value, exc_traceback_obj) + _print_msg( + "[TroubleShooter-Clear Stack] Python Traceback (most recent call last):", "NULL", False) + org_err_stack = traceback.format_exception( + exc_type, exc_value, exc_traceback_obj) for stack in org_err_stack: if _filter_stack(stack): print(stack.rstrip(os.linesep)) @@ -293,9 +319,11 @@ def print_format_exception(exc_type, exc_value, exc_traceback_obj): _print_msg("Error Message:", msg_dict.get("err_msg")) if msg_dict.get("construct_stack_msg"): - _print_msg("The Traceback of Net Construct Code:", msg_dict.get("construct_stack_msg")) + _print_msg("The Traceback of Net Construct Code:", + msg_dict.get("construct_stack_msg")) else: - _print_msg("The Traceback of Net Construct Code:", msg_dict.get("construct_stack_in_file_msg")) + _print_msg("The Traceback of Net Construct Code:", + msg_dict.get("construct_stack_in_file_msg")) _print_msg("C++ Function:", msg_dict.get("cpp_fun_msg")) _print_msg("Inner Message:", msg_dict.get("abstract_inner_msg")) @@ -304,7 +332,8 @@ def format_error_message(error_message): """ format error message, from string to dict """ - msg_list = error_message.split('----------------------------------------------------') + msg_list = error_message.split( + '----------------------------------------------------') format_msg_dict = {} current_key = None for msg in msg_list: diff --git a/troubleshooter/troubleshooter/common/util.py b/troubleshooter/troubleshooter/common/util.py index 2a1a87b75f7a95c0f2804a7d0305e6705a1aded4..0a9154068df707167e00c746b08cf4bc2d00d7e7 100644 --- a/troubleshooter/troubleshooter/common/util.py +++ b/troubleshooter/troubleshooter/common/util.py @@ -191,3 +191,46 @@ class SaveNet(nn.Cell): return save = SaveNet() + + +def cal_cosine_sim(a, b): + a, b = a.flatten(), b.flatten() + sim = 0. + num = np.dot(a, b) + denom = np.linalg.norm(a) * np.linalg.norm(b) + if not denom == 0.: + sim = 0.5 + 0.5 * (num / denom) + return sim + + +def cal_similarity(ms_data, th_data, index, **kwargs): + result_list = [] + rtol = kwargs.get('rtol', 1e-05) + atol = kwargs.get('atol', 1e-08) + equal_nan = kwargs.get('equal_nan', False) + if ms_data.shape == th_data.shape: + result = np.allclose(ms_data, th_data, rtol=rtol, + atol=atol, equal_nan=equal_nan) + + if not result: + value_diff = np.abs(ms_data - th_data) + value_mean = value_diff.mean() + value_max = value_diff.max() + value_min = value_diff.min() + cosine_sim = cal_cosine_sim(ms_data, th_data) + diff_detail = value_mean, value_max, value_min + else: + diff_detail = () + else: + result = False + diff_detail = ("Shape is inconsistent", ms_data.shape, th_data.shape) + + result_list = ['mindspore output {}'.format( + index), 'torch output {}'.format(index), result, cosine_sim, diff_detail] + return result_list + + +def save_numpy_data(file_path, data): + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + np.save(file_path, data) diff --git a/troubleshooter/troubleshooter/migrator/diff_handler.py b/troubleshooter/troubleshooter/migrator/diff_handler.py index a0864f32bf65d5b30a03eaaefbecd976a4ae8808..df970ffafc07ee8a420fc0076c0d5d1abb6a7ff9 100644 --- a/troubleshooter/troubleshooter/migrator/diff_handler.py +++ b/troubleshooter/troubleshooter/migrator/diff_handler.py @@ -14,16 +14,23 @@ # ============================================================================ """compare tools""" import os +import csv +import time +from pprint import pprint import torch import numpy as np -from collections import OrderedDict -from pprint import pprint from troubleshooter import log as logger -from troubleshooter.common.util import validate_and_normalize_path, find_file, make_directory +from troubleshooter.common.util import validate_and_normalize_path, find_file, make_directory, \ + cal_similarity, cal_cosine_sim, save_numpy_data from troubleshooter.migrator.mapping_relation.weight_mapping_lib import weight_name_map, weight_value_map -from troubleshooter.common.format_msg import print_diff_result, print_weight_compare_result, print_convert_result +from troubleshooter.common.format_msg import print_diff_result, print_weight_compare_result, \ + print_convert_result, print_net_infer_diff_result FRAMEWORK_TYPE = "ms" +MS_OUTPUT_PATH = "data/output/MindSpore" +PT_OUTPUT_PATH = "data/output/PyTorch" +RESULT_COLUMNS = ["Pytorch data", "MindSpore data", + "Results of comparison", "cosine similarity", "(mean, max, min)"] try: import mindspore as ms @@ -35,6 +42,7 @@ except ModuleNotFoundError as e: else: raise e + class TensorRecorder: def __init__(self): self.summary_record = None @@ -45,7 +53,8 @@ class TensorRecorder: record_mode = ms.get_context("mode") if record_mode == 0: - self.summary_record = ms.SummaryRecord(record_path, export_options={'tensor_format': 'npy'}) + self.summary_record = ms.SummaryRecord( + record_path, export_options={'tensor_format': 'npy'}) def record(self): if not self.summary_record: @@ -63,7 +72,7 @@ class TensorRecorder: normal_dir = validate_and_normalize_path(record_dir) make_directory(normal_dir) - if framework=="ms": + if framework == "ms": if record_mode is None: record_mode = ms.get_context("mode") if record_mode == 0: @@ -86,11 +95,13 @@ class TensorRecorder: return save_op + class DifferenceFinder: def __init__(self, orig_dir, target_dir): self.orig_dir = orig_dir self.target_dir = target_dir + def get_filename_map_list(self): name_map_list = [] orig_name_list = find_file(self.orig_dir) @@ -99,8 +110,8 @@ class DifferenceFinder: none_flag = False if not (orig_name_list and target_name_list): - logger.user_error("The comparison file is not found in the directory. " - "Please check whether the directory is correct") + logger.user_error("The comparison file is not found in the directory. Please \ + check whether the directory is correct") exit(1) for name in orig_name_list: @@ -123,7 +134,6 @@ class DifferenceFinder: print("filename mapping list:" + str(name_map_list)) return name_map_list - def compare_npy_dir(self, name_map_list=None, **kwargs): """ """ @@ -144,7 +154,8 @@ class DifferenceFinder: if orig_name is None or target_name is None: result = False diff_detail = () - result_list.append((orig_name, target_name, result, diff_detail)) + result_list.append( + (orig_name, target_name, result, diff_detail)) continue orig_file = os.path.join(normal_orig_dir, orig_name) @@ -156,61 +167,40 @@ class DifferenceFinder: orig_value = np.load(orig_file) target_value = np.load(target_file) if orig_value.shape == target_value.shape: - result = np.allclose(orig_value, target_value, rtol=rtol, atol=atol, equal_nan=equal_nan) + result = np.allclose( + orig_value, target_value, rtol=rtol, atol=atol, equal_nan=equal_nan) if not result: value_diff = np.abs(orig_value - target_value) value_mean = value_diff.mean() value_max = value_diff.max() value_min = value_diff.min() + cosine_sim = cal_cosine_sim(orig_value, target_value) diff_detail = value_mean, value_max, value_min else: diff_detail = () + cosine_sim = cal_cosine_sim(orig_value, target_value) else: result = False - diff_detail = ("Shape is inconsistent", orig_value.shape, target_value.shape) + diff_detail = ("Shape is inconsistent", + orig_value.shape, target_value.shape) - result_list.append((orig_name, target_name, result, diff_detail)) + result_list.append( + (orig_name, target_name, result, cosine_sim, diff_detail)) logger.user_attention("The compare directory information:\n The orig dir: %s \n The target dir: %s", self.orig_dir, self.target_dir) print_diff_result(result_list) + class WeightMigrator: - def __init__(self, pt_model=None, pth_file_path=None, pth_para_dict=None, ckpt_save_path=None): + def __init__(self, pt_model=None, pth_file_path=None, ckpt_save_path=None): self.weight_map = weight_name_map self.ckpt_path = ckpt_save_path self.pt_model = pt_model - self.pt_para_dict = self._get_para_dict(pth_file_path, pth_para_dict) + self.pt_para_dict = torch.load(pth_file_path, map_location='cpu') self.print_params_list = [] - - def _get_para_dict(self, pth_file_path, pth_para_dict): - if pth_para_dict: - return pth_para_dict - - pt_para_dict = {} - pt_object = torch.load(pth_file_path, map_location='cpu') - if isinstance(pt_object, OrderedDict): - pt_para_dict = pt_object - elif isinstance(pt_object, torch.nn.Module): - pt_para_dict = pt_object.state_dict() - else: - raise ValueError("PTH file parsing failed, possible reasons: " - "1) If using a custom method to save parameter files, please load and set " - "the 'pth_para_dict' parameter yourself to use the conversion tool." - "2) If the input is an optimizer parameter, this tool does not support " - "the conversion of optimizer parameters.") - - values = list(pt_para_dict.values()) - if values and not isinstance(values[0], torch.Tensor): - raise ValueError("PTH file parsing failed, possible reasons: " - "1) If using a custom method to save parameter files, please load and set " - "the 'pth_para_dict' parameter yourself to use the conversion tool." - "2) If the input is an optimizer parameter, this tool does not support " - "the conversion of optimizer parameters.") - return pt_para_dict - def _get_object(self, name): object_res = None index = name.rfind(".") @@ -222,7 +212,6 @@ class WeightMigrator: object_res = getattr(imp_module, class_name) return object_res - def _get_trans_map(self, weight_name, module, weight_map, igone_name=False): res_weight_map = {} for api_name in weight_map: @@ -240,42 +229,22 @@ class WeightMigrator: return res_weight_map - - def _custorm_weight_name_prefix(self, weight_name_map, prefix=None): - if prefix: - custorm_name_map = {} - for key, value in weight_name_map.items(): - # print(key, ":", prefix + '.' + value) - custorm_name_map[key] = str(prefix) + '.' + str(value) - return custorm_name_map - else: - return weight_name_map - - - def get_weight_map(self, print_map=False, full_name_map=False): + def get_weight_map(self, print_map=False): res_weight_name_map = {} res_weight_value_map = {} - full_weight_name_map = {} - for name, module in self.pt_model.named_modules(): tmp_name_map = self._get_trans_map(name, module, weight_name_map) if tmp_name_map: res_weight_name_map.update(tmp_name_map) - tmp_value_map = self._get_trans_map(name, module, weight_value_map, igone_name=True) + tmp_value_map = self._get_trans_map( + name, module, weight_value_map, igone_name=True) if tmp_value_map: res_weight_value_map.update(tmp_value_map) - if full_name_map: - for key, value in self.pt_para_dict.items(): - full_weight_name_map[key]=key - full_weight_name_map.update(res_weight_name_map) - res_weight_name_map = full_weight_name_map - if print_map: pprint(res_weight_name_map) pprint(res_weight_value_map) return res_weight_name_map, res_weight_value_map - def _get_name_and_value(self, pth_param_name, name_map, value_map): new_name = pth_param_name parameter = self.pt_para_dict[pth_param_name] @@ -294,45 +263,38 @@ class WeightMigrator: def_get_value = self._get_object(fun) ms_tensor = def_get_value(ms_tensor) - self.print_params_list.append((pth_param_name, new_name, bool(ms_para_item), bool(fun) , parameter.size(), + self.print_params_list.append((pth_param_name, new_name, bool(ms_para_item), bool(fun), parameter.size(), ms_tensor.shape)) return new_name, ms_tensor - - def convert(self, weight_name_map=None, weight_value_map=None ,weight_name_prefix=None, print_conv_info=True): - - if weight_name_prefix: - name_map, value_map = self.get_weight_map(full_name_map=True) - name_map = self._custorm_weight_name_prefix(name_map, weight_name_prefix) - else: - name_map, value_map = self.get_weight_map() - + def convert(self, weight_name_map=None, weight_value_map=None, print_conv_info=True): + name_map, value_map = self.get_weight_map() if weight_name_map is not None: - name_map = weight_name_map + name_map = weight_name_map if weight_value_map is not None: - value_map = weight_value_map + value_map = weight_value_map new_params_list = [] for pth_param_name in self.pt_para_dict: # get ckpt name and value - new_name, ms_tensor = self._get_name_and_value(pth_param_name,name_map,value_map) + new_name, ms_tensor = self._get_name_and_value( + pth_param_name, name_map, value_map) # add name and value to list new_params_list.append({"name": new_name, "data": ms_tensor}) if new_params_list: - ms.save_checkpoint(new_params_list , self.ckpt_path) + ms.save_checkpoint(new_params_list, self.ckpt_path) else: logger.user_warning("There are no parameters to be converted. Parameter conversion failed. " "Please check whether the configuration is correct") if print_conv_info: - print_convert_result(self.print_params_list) + print_convert_result(self.print_params_list) logger.user_attention("The PTH has been converted to the checkpoint of MindSpore. " "Please check whether the conversion result is correct. " - "The saved path is: %s",self.ckpt_path) - + "The saved path is: %s", self.ckpt_path) def compare_ckpt(self, ckpt_path=None, converted_ckpt_path=None, print_result=1): name_map_list = [] @@ -345,12 +307,166 @@ class WeightMigrator: ms_para_after_conv = ckpt_after_conv_dict.get(ms_para_name) if ms_para_after_conv is not None: - name_map_list.append((ms_para_name, ms_para_name, (ms_para.shape == ms_para_after_conv.shape), - ms_para.shape, ms_para_after_conv.shape)) + name_map_list.append((ms_para_name, ms_para_name, (ms_para.shape == + ms_para_after_conv.shape), ms_para.shape, ms_para_after_conv.shape)) ckpt_after_conv_dict.pop(ms_para_name) else: - name_map_list.append((ms_para_name, None, None, ms_para.shape, None)) + name_map_list.append( + (ms_para_name, None, None, ms_para.shape, None)) for name, ms_para in ckpt_after_conv_dict.items(): name_map_list.append((None, name, None, None, ms_para.shape)) print_weight_compare_result(name_map_list, print_type=print_result) + + +class NetDifferenceFinder: + + def __init__(self, ms_net, pt_net, inputs, + out_path, print_result): + self.ms_net = ms_net + self.pt_net = pt_net + self.inputs = inputs + self.out_path = out_path + self.print_result = print_result + + def start_compare(self): + compare_results = [] + for idx, input in enumerate(self.inputs): + input_data = self.get_input_data(input) + result_ms, result_pt = self.infer_net( + input_data, self.ms_net, self.pt_net, idx) + self.check_output(result_ms, result_pt) + if idx != 0: + compare_results.append(['', '', '', '', '']) + compare_results.extend( + self.compare_results(result_ms, result_pt, idx)) + self.save_results(compare_results) + print_net_infer_diff_result(compare_results) + + def get_input_data(self, input): + input_data = [] + if isinstance(input, dict): + input_data = list(input.values()) + else: + for data in input: + if isinstance(data, str): + input_data.append(np.load(data)) + elif isinstance(data, np.ndarray): + input_data.append(data) + else: + logger.user_error( + 'Unknow input data type {}'.format(type(data))) + exit(1) + return input_data + + def infer_net(self, input_data, ms_net, pt_net, idx): + if self.print_result: + print( + "\n=================================Start inference net=================================") + start_pt = time.time() + result_pt = self.run_pt_net(pt_net, input_data) + end_pt = time.time() + print(f"In test case {idx}, the PyTorch net inference completed cost %.5f seconds." % ( + end_pt - start_pt)) + start_ms = time.time() + result_ms = self.run_ms_net(ms_net, input_data) + end_ms = time.time() + print(f"In test case {idx}, the MindSpore net inference completed cost %.5f seconds." % ( + end_ms - start_ms)) + if isinstance(result_ms, tuple): + result_ms = {f"result_{idx}": result for idx, + result in enumerate(result_ms)} + if isinstance(result_pt, tuple): + result_pt = {f"result_{idx}": result for idx, + result in enumerate(result_pt)} + return result_ms, result_pt + + def run_pt_net(self, pt_net, input_data_list): + data_list = [] + for data in input_data_list: + data_list.append(torch.tensor(data)) + pt_results = pt_net(*data_list) + return pt_results + + def run_ms_net(self, ms_net, input_data_list): + data_list = [] + for data in input_data_list: + data_list.append(ms.Tensor(data)) + ms_results = ms_net(*data_list) + return ms_results + + def check_output(self, result_ms, result_pt): + ms_result_num = len(result_ms) + if self.print_result: + print("The MindSpore net inference have %s result." % ms_result_num) + pt_result_num = len(result_pt) + if self.print_result: + print("The PyTorch net inference have %s result." % pt_result_num) + assert ms_result_num == pt_result_num, "output results are in different counts!" + + def compare_results(self, result_ms, result_pt, idx): + index = 0 + compare_result = [] + for (k_pt, k_ms) in zip(result_pt, result_ms): + result_pt_ = result_pt[k_pt].detach().numpy() + result_ms_ = result_ms[k_ms].asnumpy() + self.check_out_data_shape(index, result_ms_, result_pt_) + self.save_out_data(index, k_ms, k_pt, result_ms_, result_pt_) + result_pt_ = result_pt_.reshape(1, -1) + result_ms_ = result_ms_.reshape(1, -1) + result = self.compare_data(result_ms_, result_pt_, index) + result[0], result[1] = f"test{idx}-{k_ms}", f"test{idx}-{k_pt}" + result[3] = "%.5f" % float(result[3]) + min_max = ['%.5f' % r for r in result[4]] + result[4] = min_max + compare_result.append(result) + index += 1 + return compare_result + + def check_out_data_shape(self, index, result_ms, result_pt): + if self.print_result: + print( + "\n=========================Start Check Out Data %s ===========================" % index) + pt_out_shape = result_pt.shape + ms_out_shape = result_pt.shape + assert pt_out_shape == ms_out_shape, "output results are in different shapes!" + if self.print_result: + print("shape of result_pt: %s" % str(pt_out_shape)) + print("shape of result_ms: %s" % str(ms_out_shape)) + print("-result_pt-: \n", result_pt) + print("-result_ms-: \n", result_ms) + + def save_out_data(self, index, k_ms, k_pt, result_ms, result_pt): + if self.print_result: + print( + "\n================= ======Start Save Out Data %s =========================" % index) + result_file = "%s.npy" % k_ms + ms_out_path = os.path.join(self.out_path, MS_OUTPUT_PATH, result_file) + save_numpy_data(ms_out_path, result_ms) + if self.print_result: + print("Saved MindSpore output data at: %s" % ms_out_path) + + result_file = "%s.npy" % k_pt + pt_out_path = os.path.join(self.out_path, PT_OUTPUT_PATH, result_file) + save_numpy_data(pt_out_path, result_pt) + if self.print_result: + print("Saved PyTorch output data at: %s" % pt_out_path) + + def compare_data(self, result_ms, result_pt, index): + if self.print_result: + print( + "\n=========================Start Compare Out Data %s ===========================" % index) + sim_result = cal_similarity(result_ms, result_pt, index) + return sim_result + + def save_results(self, compare_results): + if self.print_result: + logger.info( + "=================================Start save result=================================") + result_path = os.path.join(self.out_path, "compare_result.csv") + with open(result_path, "w", encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(RESULT_COLUMNS) + writer.writerows(compare_results) + logger.info( + "The comparison result have been written to %s" % result_path) diff --git a/troubleshooter/troubleshooter/tests/diff_handler/results/compare_result.csv b/troubleshooter/troubleshooter/tests/diff_handler/results/compare_result.csv new file mode 100644 index 0000000000000000000000000000000000000000..963f09449b44c51c428dd13f59bfcc1d6e48e77f --- /dev/null +++ b/troubleshooter/troubleshooter/tests/diff_handler/results/compare_result.csv @@ -0,0 +1,6 @@ +Pytorch data,MindSpore data,Results of comparison,cosine similarity,"(mean, max, min)" +test0-result_0,test0-a,False,0.46132,"['0.44795', '1.04435', '0.06219']" +test0-result_1,test0-b,False,0.44855,"['0.36956', '1.00635', '0.00237']" +,,,, +test1-result_0,test1-a,False,0.40539,"['0.70020', '1.53871', '0.02968']" +test1-result_1,test1-b,False,0.51905,"['0.59969', '1.36461', '0.02182']"