diff --git a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py index c7ec352be0b17006cc1527490fbf849cc1e794f3..9dfdc606be8dac1de91a25eea49469a02ccd0943 100644 --- a/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py +++ b/troubleshooter/troubleshooter/migrator/api_dump/ms_dump/initialize.py @@ -16,49 +16,54 @@ try: except ImportError: comm_func_label = False +key_wrap = "wrap_" +Key_ops = "wrap_ops." +key_mint_ops = "wrap_mint.ops." +key_mint_nn_functional = "wrap_mint.nn.functional." +key_communication_comm_func = "wrap_communication.comm_func." def hook_apis(): for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): + if attr_name.startswith(key_wrap): setattr(ms.Tensor, attr_name[5:], getattr(wrap_tensor.HOOKTensor, attr_name)) setattr(ms.common._stub_tensor.StubTensor, attr_name[5:], getattr(wrap_sub_tensor.HOOKSubTensor, attr_name)) for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): - if attr_name.startswith("wrap_ops."): - setattr(ms.ops, attr_name[len("wrap_ops."):], + if attr_name.startswith(key_wrap): + if attr_name.startswith(Key_ops): + setattr(ms.ops, attr_name[len(Key_ops):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) - if attr_name.startswith("wrap_mint.ops."): - setattr(ms.mint, attr_name[len("wrap_mint.ops."):], + if attr_name.startswith(key_mint_ops): + setattr(ms.mint, attr_name[len(key_mint_ops):], getattr(wrap_functional.HOOKFunctionalOP,attr_name)) - if attr_name.startswith("wrap_mint.nn.functional."): - setattr(ms.mint.nn.functional, attr_name[len("wrap_mint.nn.functional."):], + if attr_name.startswith(key_mint_nn_functional): + setattr(ms.mint.nn.functional, attr_name[len(key_mint_nn_functional):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) - if comm_func_label: - setattr(ms.communication.comm_func, attr_name[len("wrap_communication.comm_func."):], + if comm_func_label and attr_name.startswith(key_communication_comm_func): + setattr(ms.communication.comm_func, attr_name[len(key_communication_comm_func):], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) def restore_apis(): for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): + if attr_name.startswith(key_wrap): setattr(ms.Tensor, attr_name[5:], wrap_tensor.TensorFunc[attr_name[5:]]) setattr(ms.common._stub_tensor.StubTensor, attr_name[5:], wrap_sub_tensor.SubTensorFunc[attr_name[5:]]) for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): - if attr_name.startswith("wrap_ops."): - setattr(ms.ops, attr_name[len("wrap_ops."):], - wrap_functional.OpsFunc[wrap_functional.ops_label + attr_name[len("wrap_ops."):]]) - if attr_name.startswith("wrap_mint.ops."): - setattr(ms.mint, attr_name[len("wrap_mint.ops."):], - wrap_functional.OpsFunc[wrap_functional.mint_ops_label + attr_name[len("wrap_mint.ops."):]]) - if attr_name.startswith("wrap_mint.nn.functional."): - setattr(ms.mint.nn.functional, attr_name[len("wrap_mint.nn.functional."):], - wrap_functional.OpsFunc[wrap_functional.mint_nn_func_label + attr_name[len("wrap_mint.nn.functional."):]]) - if comm_func_label: - setattr(ms.communication.comm_func, attr_name[len("wrap_communication.comm_func."):], - wrap_functional.OpsFunc[wrap_functional.communication_comm_func_label + attr_name[len("wrap_communication.comm_func."):]]) + if attr_name.startswith(key_wrap): + if attr_name.startswith(Key_ops): + setattr(ms.ops, attr_name[len(Key_ops):], + wrap_functional.OpsFunc[wrap_functional.ops_label + attr_name[len(Key_ops):]]) + if attr_name.startswith(key_mint_ops): + setattr(ms.mint, attr_name[len(key_mint_ops):], + wrap_functional.OpsFunc[wrap_functional.mint_ops_label + attr_name[len(key_mint_ops):]]) + if attr_name.startswith(key_mint_nn_functional): + setattr(ms.mint.nn.functional, attr_name[len(key_mint_nn_functional):], + wrap_functional.OpsFunc[wrap_functional.mint_nn_func_label + attr_name[len(key_mint_nn_functional):]]) + if comm_func_label and attr_name.startswith(key_communication_comm_func): + setattr(ms.communication.comm_func, attr_name[len(key_communication_comm_func):], + wrap_functional.OpsFunc[wrap_functional.communication_comm_func_label + attr_name[len(key_communication_comm_func):]]) class MyMindsporeFunctionExecutor(_MindsporeFunctionExecutor):