From ab729cddc4f313d0784397f1ce296a1b452dc64a Mon Sep 17 00:00:00 2001 From: fandawei Date: Wed, 10 May 2023 14:23:18 +0800 Subject: [PATCH] fix NetDifferenceFinder bug --- troubleshooter/troubleshooter/__init__.py | 1 + troubleshooter/troubleshooter/migrator/diff_handler.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/troubleshooter/troubleshooter/__init__.py b/troubleshooter/troubleshooter/__init__.py index 8e78ac6..4acfb02 100644 --- a/troubleshooter/troubleshooter/__init__.py +++ b/troubleshooter/troubleshooter/__init__.py @@ -26,6 +26,7 @@ and use @snooping(...) to print the running result information of echo line code from .migrator.diff_handler import TensorRecorder as tensor_recorder from .migrator.diff_handler import DifferenceFinder as diff_finder from .migrator.diff_handler import WeightMigrator as weight_migrator +from .migrator.diff_handler import NetDifferenceFinder from .proposer import ProposalAction as proposal from .tracker import Tracker as tracking from .common.util import save diff --git a/troubleshooter/troubleshooter/migrator/diff_handler.py b/troubleshooter/troubleshooter/migrator/diff_handler.py index df970ff..3b9a177 100644 --- a/troubleshooter/troubleshooter/migrator/diff_handler.py +++ b/troubleshooter/troubleshooter/migrator/diff_handler.py @@ -373,10 +373,10 @@ class NetDifferenceFinder: end_ms = time.time() print(f"In test case {idx}, the MindSpore net inference completed cost %.5f seconds." % ( end_ms - start_ms)) - if isinstance(result_ms, tuple): + if isinstance(result_ms, (tuple, list, ms.Tensor)): result_ms = {f"result_{idx}": result for idx, result in enumerate(result_ms)} - if isinstance(result_pt, tuple): + if isinstance(result_pt, (tuple, list)) or torch.is_tensor(result_pt): result_pt = {f"result_{idx}": result for idx, result in enumerate(result_pt)} return result_ms, result_pt -- Gitee