From 01e1446f1e14c6500b050e7a64abbaa7eb6f18f8 Mon Sep 17 00:00:00 2001 From: shihlCST <1665105642@qq.com> Date: Fri, 14 Jun 2024 17:06:22 +0800 Subject: [PATCH] add_test_npy --- .../migrator/test_bfloat16_ms.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/st/troubleshooter/migrator/test_bfloat16_ms.py b/tests/st/troubleshooter/migrator/test_bfloat16_ms.py index 61eba3d..ce75d26 100644 --- a/tests/st/troubleshooter/migrator/test_bfloat16_ms.py +++ b/tests/st/troubleshooter/migrator/test_bfloat16_ms.py @@ -48,7 +48,8 @@ def conv2d_backward_func(x, weight): @pytest.mark.level0 @pytest.mark.env_onecard @pytest.mark.platform_arm_ascend910b_training -def test_conv2d_bfloat16(): +@pytest.mark.platform_x86_ascend_training +def test_conv2d_bfloat16_all(): """ Feature: api dump support bfloat16. Description: api dump collects tensor data for conv2d. @@ -74,3 +75,35 @@ def test_conv2d_bfloat16(): assert 'Functional_conv2d_0_forward_stack_info' in stack_list finally: shutil.rmtree(dump_path) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.platform_x86_ascend_training +def test_conv2d_bfloat16_npy(): + """ + Feature: api dump support bfloat16. + Description: api dump collects tensor data for conv2d. + Expectation: collect tensor data, stack, and statistics for conv2d. + """ + context.set_context(mode=context.PYNATIVE_MODE) + dump_path = Path(tempfile.mkdtemp(prefix="ms_api_dump_bfloat16")) + try: + net = Net() + api_dump_init(net, dump_path, retain_backward=True) + api_dump_start(dump_type = 'npy',statistic_category = ['max', 'min', 'avg', 'md5', 'l2norm']) + x = Tensor(np.ones([10, 32, 32, 32]), ms.bfloat16) + weight = Tensor(np.ones([32, 32, 3, 3]), ms.bfloat16) + grads = conv2d_backward_func(x, weight) + dx, dw = grads + api_dump_stop() + csv_list, npy_list, stack_list = get_csv_npy_stack_list( + dump_path, 'mindspore') + assert 'Functional_conv2d_0_forward_input.0' in npy_list + assert 'Functional_conv2d_0_forward_input.1' in npy_list + assert 'Functional_conv2d_0_forward_output' in npy_list + assert 'Functional_conv2d_0_backward_input' in npy_list + assert 'Functional_conv2d_0_forward_stack_info' in stack_list + finally: + shutil.rmtree(dump_path) \ No newline at end of file -- Gitee