diff --git a/tests/st/troubleshooter/migrator/test_api_dump_communication.py b/tests/st/troubleshooter/migrator/test_api_dump_communication.py new file mode 100644 index 0000000000000000000000000000000000000000..4a43b2af23205f61283e6055e6f9b2e4216c3b91 --- /dev/null +++ b/tests/st/troubleshooter/migrator/test_api_dump_communication.py @@ -0,0 +1,60 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from mindspore.communication import init, get_rank +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +import tempfile +import shutil +from mindspore import dtype as mstype +from pathlib import Path +from troubleshooter.migrator import api_dump_init, api_dump_start, api_dump_stop +from tests.st.troubleshooter.migrator.dump.utils import get_pkl_npy_stack_list + + +class NetNull(nn.Cell): + def __init__(self): + super(NetNull, self).__init__() + pass + + def construct(self, x): + output = comm_func.all_reduce(ms.Tensor(x, mstype.float32)) + return output + +def all_reduce_dump(x): + return comm_func.all_reduce(x) + +def test_api_dump_communicate(): + init() + if not "comm_func" in dir(ms.communication): + return + net = NetNull() + dump_path = Path(tempfile.mkdtemp(prefix="ms_api_dump_communication")) + try: + api_dump_init(net, dump_path, retain_backward=True) + api_dump_start() + input = ms.Tensor(np.ones([3, 4]).astype(np.float32)) + expect_output = [[2, 2, 2, 2],[2, 2, 2, 2],[2, 2, 2, 2]] + output = ops.grad(all_reduce_dump)(input) + api_dump_stop() + pkl_list, npy_list, stack_list = get_pkl_npy_stack_list(dump_path, 'mindspore') + assert np.allclose(output.asnumpy(), expect_output) + assert 'Functional_all_reduce_0_backward_input' in npy_list + assert 'Functional_all_reduce_0_forward_input.0' in npy_list + assert 'Functional_all_reduce_0_forward_output' in npy_list + finally: + shutil.rmtree(dump_path) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py index a40fa58cc8cacefa034c9f21e267f74798647042..1ba3fea73992b220f22faef89c8ffcc37c3bef08 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py @@ -10,7 +10,6 @@ from . import wrap_functional, wrap_nn, wrap_sub_tensor, wrap_tensor from .utils import (CompareException, Const, check_file_or_directory_path, print_error_log) from troubleshooter import log as logger - def initialize_hook(hook): wrap_tensor.wrap_tensor_ops_and_bind(hook) wrap_sub_tensor.wrap_sub_tensor_ops_and_bind(hook) @@ -23,12 +22,16 @@ def initialize_hook(hook): for attr_name in dir(wrap_functional.HOOKFunctionalOP): if attr_name.startswith("wrap_"): if attr_name.startswith("wrap_ops."): - setattr(ms.ops, attr_name[len("wrap_ops."):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + setattr(ms.ops, attr_name[len("wrap_ops."):], + getattr(wrap_functional.HOOKFunctionalOP, attr_name)) if attr_name.startswith("wrap_mint.ops."): - setattr(ms.mint, attr_name[len("wrap_mint.ops."):], getattr(wrap_functional.HOOKFunctionalOP, - attr_name)) + setattr(ms.mint, attr_name[len("wrap_mint.ops."):], + getattr(wrap_functional.HOOKFunctionalOP,attr_name)) if attr_name.startswith("wrap_mint.nn.functional."): setattr(ms.mint.nn.functional, attr_name[len("wrap_mint.nn.functional."):], + getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + if attr_name.startswith("wrap_communication.comm_func."): + setattr(ms.communication.comm_func, attr_name[len("wrap_communication.comm_func."):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) wrap_nn.wrap_nn_cell_and_bind() diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml index 37408440c2f8d6882b283fb29abcca3d4ae09969..b4777616196dedbb84713aad04391a52e5a0f728 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml @@ -14,6 +14,11 @@ # ============================================================================ # List of ops that register hooks +communication.comm_func: + - all_reduce + - all_gather_into_tensor + - reduce + - reduce_scatter_tensor ops: - adaptive_avg_pool1d diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py index b7073b181ea77027c66106bc503371561147cd14..bf8911259ed320ae72a8f4021b1d61b7acf247d9 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py @@ -16,13 +16,13 @@ import mindspore as ms import os import yaml - from .hook_cell import HOOKCell - ops_label = "ops." mint_ops_label = "mint.ops." mint_nn_func_label = "mint.nn.functional." +communication_comm_func_label = "communication.comm_func." + cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") WrapFunctionalOps = [] @@ -32,6 +32,8 @@ with open(yaml_path, 'r') as f: WrapFunctionalOps.extend([mint_ops_label + f for f in yaml.safe_load(f).get('mint.ops')]) with open(yaml_path, 'r') as f: WrapFunctionalOps.extend([mint_nn_func_label + f for f in yaml.safe_load(f).get('mint.nn.functional')]) +with open(yaml_path, 'r') as f: + WrapFunctionalOps.extend([communication_comm_func_label + f for f in yaml.safe_load(f).get('communication.comm_func')]) OpsFunc = {} for f in dir(ms.ops): @@ -41,6 +43,10 @@ if "mint" in dir(ms): OpsFunc[mint_ops_label + f] = getattr(ms.mint, f) for f in dir(ms.mint.nn.functional): OpsFunc[mint_nn_func_label + f] = getattr(ms.mint.nn.functional, f) +if "comm_func" in dir(ms.communication): + for f in dir(ms.communication.comm_func): + OpsFunc[communication_comm_func_label + f] = getattr(ms.communication.comm_func, f) + def get_functional_ops(): @@ -50,6 +56,8 @@ def get_functional_ops(): if "mint" in dir(ms): _all_functional_ops.extend([mint_ops_label + f for f in dir(ms.mint)]) _all_functional_ops.extend([mint_nn_func_label + f for f in dir(ms.mint.nn.functional)]) + if "comm_func" in dir(ms.communication): + _all_functional_ops.extend([communication_comm_func_label + f for f in dir(ms.communication.comm_func)]) return set(WrapFunctionalOps) & set(_all_functional_ops) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..0e246d7327fb62d0c2c5e17ba8cf0dafed09ecfa --- /dev/null +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py @@ -0,0 +1,67 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import os +import yaml +from pathlib import Path + +from .hook_cell import HOOKCell + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with open(yaml_path, 'r') as f: + WrapFunctionalOps = yaml.safe_load(f).get('communication.comm_func') + +OpsFunc = {} +for f in dir(ms.communication.comm_func): + OpsFunc[f] = getattr(ms.communication.comm_func, f) + + +def get_functional_ops(): + global WrapFunctionalOps + _all_functional_ops = dir(ms.communication.comm_func) + return set(WrapFunctionalOps) & set(_all_functional_ops) + + +class HOOKFunctionalOP(object): + pass + + +class FunctionalOPTemplate(HOOKCell): + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Functional_" + str(op_name) + "_" + super().__init__(hook) + + def construct(self, *args, **kwargs): + if self.op_name_.startswith('dropout'): + return args[0] if args else kwargs.get('input') + return OpsFunc[self.op_name_](*args, **kwargs) + + +def wrap_functional_op(op_name, hook): + def functional_op_template(*args, **kwargs): + return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) + + return functional_op_template + + +def wrap_hccl_functional_ops_and_bind(hook): + _functional_ops = get_functional_ops() + for op_name in _functional_ops: + if callable(OpsFunc[op_name]): + setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook))