From 673d6b424d2b3aa141c19103f9dbe5a90a93e8de Mon Sep 17 00:00:00 2001 From: shihlCST Date: Thu, 23 May 2024 09:01:37 +0000 Subject: [PATCH 1/6] =?UTF-8?q?api=5Fdump=E9=80=9A=E4=BF=A1=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: shihlCST --- .../migrator/api_dump/ms_dump/support_wrap_ops.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml index 3740844..b477761 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/support_wrap_ops.yaml @@ -14,6 +14,11 @@ # ============================================================================ # List of ops that register hooks +communication.comm_func: + - all_reduce + - all_gather_into_tensor + - reduce + - reduce_scatter_tensor ops: - adaptive_avg_pool1d -- Gitee From 2cdb024ec5f5901054983e0f94c317435bcac7d3 Mon Sep 17 00:00:00 2001 From: shihlCST Date: Thu, 23 May 2024 09:02:32 +0000 Subject: [PATCH 2/6] =?UTF-8?q?api=5Fdump=E9=80=9A=E4=BF=A1=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: shihlCST --- .../migrator/api_dump/ms_dump/initialize.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py index a40fa58..051799f 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py @@ -6,10 +6,11 @@ import mindspore as ms from . import hook_cell from . import hooks -from . import wrap_functional, wrap_nn, wrap_sub_tensor, wrap_tensor +from . import wrap_functional, wrap_nn, wrap_sub_tensor, wrap_tensor, wrap_hccl_functional from .utils import (CompareException, Const, check_file_or_directory_path, print_error_log) from troubleshooter import log as logger - +from mindspore.communication import comm_func +import pdb def initialize_hook(hook): wrap_tensor.wrap_tensor_ops_and_bind(hook) @@ -23,12 +24,16 @@ def initialize_hook(hook): for attr_name in dir(wrap_functional.HOOKFunctionalOP): if attr_name.startswith("wrap_"): if attr_name.startswith("wrap_ops."): - setattr(ms.ops, attr_name[len("wrap_ops."):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + setattr(ms.ops, attr_name[len("wrap_ops."):], + getattr(wrap_functional.HOOKFunctionalOP, attr_name)) if attr_name.startswith("wrap_mint.ops."): - setattr(ms.mint, attr_name[len("wrap_mint.ops."):], getattr(wrap_functional.HOOKFunctionalOP, - attr_name)) + setattr(ms.mint, attr_name[len("wrap_mint.ops."):], + getattr(wrap_functional.HOOKFunctionalOP,attr_name)) if attr_name.startswith("wrap_mint.nn.functional."): setattr(ms.mint.nn.functional, attr_name[len("wrap_mint.nn.functional."):], + getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + if attr_name.startswith("wrap_communication.comm_func."): + setattr(ms.communication.comm_func, attr_name[len("wrap_communication.comm_func."):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) wrap_nn.wrap_nn_cell_and_bind() -- Gitee From 46ea3e3934b71fe655397458f6740befee26bfb2 Mon Sep 17 00:00:00 2001 From: shihlCST Date: Thu, 23 May 2024 09:03:12 +0000 Subject: [PATCH 3/6] =?UTF-8?q?api=5Fdump=E9=80=9A=E4=BF=A1=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: shihlCST --- .../migrator/api_dump/ms_dump/wrap_functional.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py index b7073b1..465cdfc 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py @@ -16,13 +16,15 @@ import mindspore as ms import os import yaml - +from mindspore.communication import comm_func +from troubleshooter import log as logger from .hook_cell import HOOKCell - ops_label = "ops." mint_ops_label = "mint.ops." mint_nn_func_label = "mint.nn.functional." +communication_comm_func_label = "communication.comm_func." + cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") WrapFunctionalOps = [] @@ -32,6 +34,8 @@ with open(yaml_path, 'r') as f: WrapFunctionalOps.extend([mint_ops_label + f for f in yaml.safe_load(f).get('mint.ops')]) with open(yaml_path, 'r') as f: WrapFunctionalOps.extend([mint_nn_func_label + f for f in yaml.safe_load(f).get('mint.nn.functional')]) +with open(yaml_path, 'r') as f: + WrapFunctionalOps.extend([communication_comm_func_label + f for f in yaml.safe_load(f).get('communication.comm_func')]) OpsFunc = {} for f in dir(ms.ops): @@ -41,6 +45,10 @@ if "mint" in dir(ms): OpsFunc[mint_ops_label + f] = getattr(ms.mint, f) for f in dir(ms.mint.nn.functional): OpsFunc[mint_nn_func_label + f] = getattr(ms.mint.nn.functional, f) +if "communication" in dir(ms): + for f in dir(ms.communication.comm_func): + OpsFunc[communication_comm_func_label + f] = getattr(ms.communication.comm_func, f) + def get_functional_ops(): @@ -50,6 +58,8 @@ def get_functional_ops(): if "mint" in dir(ms): _all_functional_ops.extend([mint_ops_label + f for f in dir(ms.mint)]) _all_functional_ops.extend([mint_nn_func_label + f for f in dir(ms.mint.nn.functional)]) + if "communication" in dir(ms): + _all_functional_ops.extend([communication_comm_func_label + f for f in dir(ms.communication.comm_func)]) return set(WrapFunctionalOps) & set(_all_functional_ops) @@ -78,6 +88,7 @@ def wrap_functional_op(op_name, hook): def wrap_functional_ops_and_bind(hook): _functional_ops = get_functional_ops() + logger.user_attention(f"------------ops allreduce---------- _fuctional_ops---- {_functional_ops}.") for op_name in _functional_ops: if callable(OpsFunc[op_name]): setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) -- Gitee From dc29a74c75ea7f9efcf1332811996814ebf24065 Mon Sep 17 00:00:00 2001 From: shihlCST Date: Thu, 23 May 2024 09:04:44 +0000 Subject: [PATCH 4/6] =?UTF-8?q?api=5Fdump=E9=80=9A=E4=BF=A1=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: shihlCST --- .../api_dump/ms_dump/wrap_hccl_functional.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py new file mode 100644 index 0000000..ad3facc --- /dev/null +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py @@ -0,0 +1,69 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +import os +import yaml +from pathlib import Path + +from .hook_cell import HOOKCell +from troubleshooter import log as logger + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with open(yaml_path, 'r') as f: + WrapFunctionalOps = yaml.safe_load(f).get('communication.comm_func') + +OpsFunc = {} +for f in dir(ms.communication.comm_func): + OpsFunc[f] = getattr(ms.communication.comm_func, f) + + +def get_functional_ops(): + global WrapFunctionalOps + _all_functional_ops = dir(ms.communication.comm_func) + return set(WrapFunctionalOps) & set(_all_functional_ops) + + +class HOOKFunctionalOP(object): + pass + + +class FunctionalOPTemplate(HOOKCell): + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Functional_" + str(op_name) + "_" + super().__init__(hook) + + def construct(self, *args, **kwargs): + if self.op_name_.startswith('dropout'): + return args[0] if args else kwargs.get('input') + return OpsFunc[self.op_name_](*args, **kwargs) + + +def wrap_functional_op(op_name, hook): + def functional_op_template(*args, **kwargs): + return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) + + return functional_op_template + + +def wrap_hccl_functional_ops_and_bind(hook): + _functional_ops = get_functional_ops() + logger.warning("--------------- %s", _functional_ops) + for op_name in _functional_ops: + if callable(OpsFunc[op_name]): + setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) -- Gitee From 6dfc5687b36dd92fae04c594ce6d7e26558e59d6 Mon Sep 17 00:00:00 2001 From: shihlCST Date: Thu, 23 May 2024 09:14:29 +0000 Subject: [PATCH 5/6] =?UTF-8?q?api=5Fdump=E9=80=9A=E4=BF=A1=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: shihlCST --- .../migrator/test_api_dump_communication.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/st/troubleshooter/migrator/test_api_dump_communication.py diff --git a/tests/st/troubleshooter/migrator/test_api_dump_communication.py b/tests/st/troubleshooter/migrator/test_api_dump_communication.py new file mode 100644 index 0000000..0dd1d7e --- /dev/null +++ b/tests/st/troubleshooter/migrator/test_api_dump_communication.py @@ -0,0 +1,61 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from mindspore.communication import init, get_rank +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +import tempfile +import shutil +import mindspore.communication.comm_func as comm_func +from mindspore import dtype as mstype +from pathlib import Path +from troubleshooter.migrator import api_dump_init, api_dump_start, api_dump_stop +from tests.st.troubleshooter.migrator.dump.utils import get_pkl_npy_stack_list + + +class NetNull(nn.Cell): + def __init__(self): + super(NetNull, self).__init__() + pass + + def construct(self, x): + output = comm_func.all_reduce(ms.Tensor(x, mstype.float32)) + return output + +def all_reduce_dump(x): + return comm_func.all_reduce(x) + +def test_api_dump_communicate(): + init() + if not "communication" in dir(ms): + return + net = NetNull() + dump_path = Path(tempfile.mkdtemp(prefix="ms_api_dump_communication")) + try: + api_dump_init(net, dump_path, retain_backward=True) + api_dump_start() + input = ms.Tensor(np.ones([3, 4]).astype(np.float32)) + expect_output = [[2, 2, 2, 2],[2, 2, 2, 2],[2, 2, 2, 2]] + output = ops.grad(all_reduce_dump)(input) + api_dump_stop() + pkl_list, npy_list, stack_list = get_pkl_npy_stack_list(dump_path, 'mindspore') + assert np.allclose(output.asnumpy(), expect_output) + assert 'Functional_all_reduce_0_backward_input' in npy_list + assert 'Functional_all_reduce_0_forward_input.0' in npy_list + assert 'Functional_all_reduce_0_forward_output' in npy_list + finally: + shutil.rmtree(dump_path) -- Gitee From 42409582ae4446501139f96a6c6b795106ad716a Mon Sep 17 00:00:00 2001 From: shihlCST Date: Fri, 24 May 2024 00:53:34 +0800 Subject: [PATCH 6/6] api_dump_communication --- .../troubleshooter/migrator/test_api_dump_communication.py | 3 +-- .../troubleshooter/migrator/api_dump/ms_dump/initialize.py | 4 +--- .../migrator/api_dump/ms_dump/wrap_functional.py | 7 ++----- .../migrator/api_dump/ms_dump/wrap_hccl_functional.py | 2 -- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/st/troubleshooter/migrator/test_api_dump_communication.py b/tests/st/troubleshooter/migrator/test_api_dump_communication.py index 0dd1d7e..4a43b2a 100644 --- a/tests/st/troubleshooter/migrator/test_api_dump_communication.py +++ b/tests/st/troubleshooter/migrator/test_api_dump_communication.py @@ -20,7 +20,6 @@ import mindspore.nn as nn import mindspore.ops as ops import tempfile import shutil -import mindspore.communication.comm_func as comm_func from mindspore import dtype as mstype from pathlib import Path from troubleshooter.migrator import api_dump_init, api_dump_start, api_dump_stop @@ -41,7 +40,7 @@ def all_reduce_dump(x): def test_api_dump_communicate(): init() - if not "communication" in dir(ms): + if not "comm_func" in dir(ms.communication): return net = NetNull() dump_path = Path(tempfile.mkdtemp(prefix="ms_api_dump_communication")) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py index 051799f..1ba3fea 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py @@ -6,11 +6,9 @@ import mindspore as ms from . import hook_cell from . import hooks -from . import wrap_functional, wrap_nn, wrap_sub_tensor, wrap_tensor, wrap_hccl_functional +from . import wrap_functional, wrap_nn, wrap_sub_tensor, wrap_tensor from .utils import (CompareException, Const, check_file_or_directory_path, print_error_log) from troubleshooter import log as logger -from mindspore.communication import comm_func -import pdb def initialize_hook(hook): wrap_tensor.wrap_tensor_ops_and_bind(hook) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py index 465cdfc..bf89112 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_functional.py @@ -16,8 +16,6 @@ import mindspore as ms import os import yaml -from mindspore.communication import comm_func -from troubleshooter import log as logger from .hook_cell import HOOKCell ops_label = "ops." @@ -45,7 +43,7 @@ if "mint" in dir(ms): OpsFunc[mint_ops_label + f] = getattr(ms.mint, f) for f in dir(ms.mint.nn.functional): OpsFunc[mint_nn_func_label + f] = getattr(ms.mint.nn.functional, f) -if "communication" in dir(ms): +if "comm_func" in dir(ms.communication): for f in dir(ms.communication.comm_func): OpsFunc[communication_comm_func_label + f] = getattr(ms.communication.comm_func, f) @@ -58,7 +56,7 @@ def get_functional_ops(): if "mint" in dir(ms): _all_functional_ops.extend([mint_ops_label + f for f in dir(ms.mint)]) _all_functional_ops.extend([mint_nn_func_label + f for f in dir(ms.mint.nn.functional)]) - if "communication" in dir(ms): + if "comm_func" in dir(ms.communication): _all_functional_ops.extend([communication_comm_func_label + f for f in dir(ms.communication.comm_func)]) return set(WrapFunctionalOps) & set(_all_functional_ops) @@ -88,7 +86,6 @@ def wrap_functional_op(op_name, hook): def wrap_functional_ops_and_bind(hook): _functional_ops = get_functional_ops() - logger.user_attention(f"------------ops allreduce---------- _fuctional_ops---- {_functional_ops}.") for op_name in _functional_ops: if callable(OpsFunc[op_name]): setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py index ad3facc..0e246d7 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/wrap_hccl_functional.py @@ -19,7 +19,6 @@ import yaml from pathlib import Path from .hook_cell import HOOKCell -from troubleshooter import log as logger cur_path = os.path.dirname(os.path.realpath(__file__)) @@ -63,7 +62,6 @@ def wrap_functional_op(op_name, hook): def wrap_hccl_functional_ops_and_bind(hook): _functional_ops = get_functional_ops() - logger.warning("--------------- %s", _functional_ops) for op_name in _functional_ops: if callable(OpsFunc[op_name]): setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) -- Gitee