diff --git a/troubleshooter/troubleshooter/__init__.py b/troubleshooter/troubleshooter/__init__.py index 8e78ac60aa4d8cb373943689ebff84ba58bc90ac..4acfb02b8efa8bb5b11342ccf7a22b918a428308 100644 --- a/troubleshooter/troubleshooter/__init__.py +++ b/troubleshooter/troubleshooter/__init__.py @@ -26,6 +26,7 @@ and use @snooping(...) to print the running result information of echo line code from .migrator.diff_handler import TensorRecorder as tensor_recorder from .migrator.diff_handler import DifferenceFinder as diff_finder from .migrator.diff_handler import WeightMigrator as weight_migrator +from .migrator.diff_handler import NetDifferenceFinder from .proposer import ProposalAction as proposal from .tracker import Tracker as tracking from .common.util import save diff --git a/troubleshooter/troubleshooter/migrator/diff_handler.py b/troubleshooter/troubleshooter/migrator/diff_handler.py index df970ffafc07ee8a420fc0076c0d5d1abb6a7ff9..3b9a177b2068dc149fc7d326d9242a429e4addfb 100644 --- a/troubleshooter/troubleshooter/migrator/diff_handler.py +++ b/troubleshooter/troubleshooter/migrator/diff_handler.py @@ -373,10 +373,10 @@ class NetDifferenceFinder: end_ms = time.time() print(f"In test case {idx}, the MindSpore net inference completed cost %.5f seconds." % ( end_ms - start_ms)) - if isinstance(result_ms, tuple): + if isinstance(result_ms, (tuple, list, ms.Tensor)): result_ms = {f"result_{idx}": result for idx, result in enumerate(result_ms)} - if isinstance(result_pt, tuple): + if isinstance(result_pt, (tuple, list)) or torch.is_tensor(result_pt): result_pt = {f"result_{idx}": result for idx, result in enumerate(result_pt)} return result_ms, result_pt