From 53e8f7577b892f18da5670bb562809e12082af67 Mon Sep 17 00:00:00 2001 From: fandawei <fandawei2@huawei.com> Date: Sun, 17 Dec 2023 12:02:13 +0800 Subject: [PATCH] fix torch save grad bug --- troubleshooter/troubleshooter/migrator/save.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/troubleshooter/troubleshooter/migrator/save.py b/troubleshooter/troubleshooter/migrator/save.py index fd0226d..894608a 100644 --- a/troubleshooter/troubleshooter/migrator/save.py +++ b/troubleshooter/troubleshooter/migrator/save.py @@ -202,7 +202,7 @@ elif "torch" in FRAMEWORK_TYPE: class _SaveGradCell(_SaveGradBase): def __init__(self, file, suffix, output_mode): - super(_SaveGradCell, _SaveGradBase).__init__(file, suffix, output_mode) + super(_SaveGradCell, self).__init__(file, suffix, output_mode) self.pt_save_func = _wrapper_torch_save_grad(self.file, output_mode) def __call__(self, x): -- Gitee