diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py index 62a91312893d70303ba0d9cf694742e6fe7a34fd..395aa37197e6ecaf5ed4fbc72ef19a3b2cc31d53 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/hooks.py @@ -159,8 +159,7 @@ class DumpUtil(object): class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape, md5_nume, l2norm): - self.data = data + def __init__(self, save_data, summary_data, dtype, shape, md5_nume, l2norm): self.save_data = save_data self.summary_data = summary_data self.dtype = dtype @@ -181,9 +180,27 @@ def is_unsupport_type(data): return is_unsupport_type +my_is_complex = wrap_functional.OpsFunc[wrap_functional.ops_label + 'is_complex'] +my_square = wrap_functional.OpsFunc[wrap_functional.ops_label + 'square'] +my_sqrt = wrap_functional.OpsFunc[wrap_functional.ops_label + 'sqrt'] +def complex_squre(A): + if my_is_complex(A): + return ops.conj(A) * A + return my_square(A) + + +def myl2norm(A): + ndim = A.ndim + dim = tuple(range(ndim)) + ret = my_sqrt(ops.reduce_sum(complex_squre(A), dim)) + return ret + + def cal_l2norm(data): Key_ops = "wrap_ops." if(universal_interface.API_DUMP_FRAMEWORK_TYPE == "mindtorch"): + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) saved_tensor = data.asnumpy() l2norm = np.linalg.norm(saved_tensor).item() else: @@ -191,11 +208,11 @@ def cal_l2norm(data): l2norm = None return l2norm - setattr(ms.ops, 'norm', wrap_functional.OpsFunc[wrap_functional.ops_label + 'norm']) - setattr(ms.ops, 'square', wrap_functional.OpsFunc[wrap_functional.ops_label + 'square']) - setattr(ms.ops, 'sqrt', wrap_functional.OpsFunc[wrap_functional.ops_label + 'sqrt']) - setattr(ms.ops, 'is_complex', wrap_functional.OpsFunc[wrap_functional.ops_label + 'is_complex']) - l2norm = ms.ops.norm(data).tolist() + l2norm = myl2norm(data) + if l2norm.dtype == ms.bfloat16: + l2norm = ops.Cast()(l2norm, dtype = ms.float32) + l2norm = l2norm.tolist() + if DumpUtil.dump_overflow: if _ascend_target(): check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE') @@ -203,14 +220,12 @@ def cal_l2norm(data): (_ascend_910bc_target() and check_overflow_mode == "SATURATION_MODE"): status = Tensor([0] * 8, mstype.int32) _get_cache_prim(NPUClearFloatStatusV2)()(status) - setattr(ms.ops, 'norm', getattr(wrap_functional.HOOKFunctionalOP, Key_ops + 'norm')) - setattr(ms.ops, 'square', getattr(wrap_functional.HOOKFunctionalOP, Key_ops + 'square')) - setattr(ms.ops, 'sqrt', getattr(wrap_functional.HOOKFunctionalOP, Key_ops + 'sqrt')) - setattr(ms.ops, 'is_complex', getattr(wrap_functional.HOOKFunctionalOP, Key_ops + 'is_complex')) return l2norm def cal_max(data): if(universal_interface.API_DUMP_FRAMEWORK_TYPE == "mindtorch"): + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) saved_tensor = data.asnumpy() tensor_max = saved_tensor.max().astype(np.float32).tolist() else: @@ -228,6 +243,8 @@ def cal_max(data): def cal_min(data): if(universal_interface.API_DUMP_FRAMEWORK_TYPE == "mindtorch"): + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) saved_tensor = data.asnumpy() tensor_min = saved_tensor.min().astype(np.float32).tolist() else: @@ -245,6 +262,8 @@ def cal_min(data): def cal_mean(data): if(universal_interface.API_DUMP_FRAMEWORK_TYPE == "mindtorch"): + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) saved_tensor = data.asnumpy() tensor_mean = saved_tensor.mean().astype(np.float32).tolist() else: @@ -284,15 +303,15 @@ def get_not_float_tensor_info(data, compute_summary, statistic_category): if 'md5' in statistic_category and 'l2norm' in statistic_category: md5_nume = hashlib.md5(saved_tensor).hexdigest() l2norm = cal_l2norm(data) - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) + return DataInfo(saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) elif 'md5' in statistic_category: md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume,[]) + return DataInfo(saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume,[]) elif 'l2norm' in statistic_category: l2norm = cal_l2norm(data) - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), [], l2norm) + return DataInfo(saved_tensor, summary_data, str(data.dtype), tuple(data.shape), [], l2norm) summary_data = [tensor_max, tensor_min, tensor_mean] - return DataInfo(data, saved_tensor, summary_data, str(data.dtype), tuple(data.shape), [], []) + return DataInfo(saved_tensor, summary_data, str(data.dtype), tuple(data.shape), [], []) def get_scalar_data_info(data, compute_summary, statistic_category): if compute_summary: @@ -300,22 +319,19 @@ def get_scalar_data_info(data, compute_summary, statistic_category): if 'md5' in statistic_category and 'l2norm' in statistic_category: md5_nume = hashlib.md5(str(data).encode()).hexdigest() l2norm = np.linalg.norm(data).item() - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, l2norm) + return DataInfo(data, summary_data, str(type(data)), [], md5_nume, l2norm) elif 'md5' in statistic_category: md5_nume = hashlib.md5(str(data).encode()).hexdigest() - return DataInfo(data, data, summary_data, str(type(data)), [], md5_nume, []) + return DataInfo(data, summary_data, str(type(data)), [], md5_nume, []) elif 'l2norm' in statistic_category: l2norm = np.linalg.norm(data).item() - return DataInfo(data, data, summary_data, str(type(data)), [], [], l2norm) + return DataInfo(data, summary_data, str(type(data)), [], [], l2norm) else: summary_data = [math.nan] * 3 - return DataInfo(data, data, summary_data, str(type(data)), [], [], []) + return DataInfo(data, summary_data, str(type(data)), [], [], []) def get_float_tensor_info(data, compute_summary,statistic_category): dtype = str(data.dtype) - if data.dtype == mstype.bfloat16: - data = ops.Cast()(data, dtype=mstype.float32) - saved_tensor = data.asnumpy() tensor_max, tensor_min, tensor_mean = math.nan, math.nan, math.nan if compute_summary: if 'max' in statistic_category: @@ -326,17 +342,30 @@ def get_float_tensor_info(data, compute_summary,statistic_category): tensor_mean = cal_mean(data) summary_data = [tensor_max, tensor_min, tensor_mean] if 'md5' in statistic_category and 'l2norm' in statistic_category: - md5_nume = hashlib.md5(saved_tensor).hexdigest() l2norm = cal_l2norm(data) - return DataInfo(data, saved_tensor, summary_data, dtype, tuple(data.shape), md5_nume, l2norm) + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) + saved_tensor = data.asnumpy() + md5_nume = hashlib.md5(saved_tensor).hexdigest() + return DataInfo(saved_tensor, summary_data, dtype, tuple(data.shape), md5_nume, l2norm) elif 'md5' in statistic_category: + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) + saved_tensor = data.asnumpy() md5_nume = hashlib.md5(saved_tensor).hexdigest() - return DataInfo(data, saved_tensor, summary_data, dtype, tuple(data.shape), md5_nume, []) + return DataInfo(saved_tensor, summary_data, dtype, tuple(data.shape), md5_nume, []) elif 'l2norm' in statistic_category: l2norm = cal_l2norm(data) - return DataInfo(data, saved_tensor, summary_data, dtype, tuple(data.shape), [], l2norm) + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) + saved_tensor = data.asnumpy() + return DataInfo(saved_tensor, summary_data, dtype, tuple(data.shape), [], l2norm) summary_data = [tensor_max, tensor_min, tensor_mean] - return DataInfo(data, saved_tensor, summary_data, dtype, tuple(data.shape), [], []) + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) + saved_tensor = data.asnumpy() + return DataInfo(saved_tensor, summary_data, dtype, tuple(data.shape), [], []) + def set_dump_path(fpath=None): if fpath is None: