From fd44191dec69c7c1660f8667cdbc1151e853c428 Mon Sep 17 00:00:00 2001 From: fandawei Date: Mon, 5 Feb 2024 17:57:32 +0800 Subject: [PATCH 1/2] llama adapts sdc check --- mindformers/models/llama/llama_layer.py | 18 ++++++++++++++++-- mindformers/wrapper/wrapper.py | 21 ++++++++++++++------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/mindformers/models/llama/llama_layer.py b/mindformers/models/llama/llama_layer.py index 92f3edd9..5d533f18 100644 --- a/mindformers/models/llama/llama_layer.py +++ b/mindformers/models/llama/llama_layer.py @@ -16,7 +16,9 @@ from enum import Enum import numpy as np +import os +from mindspore.ops.operations._inner_ops import _MirrorSilentCheck from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore import nn @@ -358,6 +360,14 @@ class LlamaRMSNorm(nn.Cell): self.rsqrt = P.Rsqrt() self.rms_norm = self._self_norm self.self_define = True + self._enable_npu_silent_detect = False + if os.environ.get('NPU_DETECT') == "1": + self._enable_npu_silent_detect = True + self._pre_val = Parameter(Tensor(0, mstype.float32), name="pre_val", requires_grad=False) + self._min_val = Parameter(Tensor(0, mstype.float32), name="min_val", requires_grad=False) + self._max_val = Parameter(Tensor(0, mstype.float32), name="max_val", requires_grad=False) + self._sdc_result = Parameter(Tensor(False, mstype.bool_), name="sdc_result", requires_grad=False) + self.check_weight = _MirrorSilentCheck() def _self_norm(self, x): original_type = x.dtype @@ -371,7 +381,11 @@ class LlamaRMSNorm(nn.Cell): def _rms_norm(self, x): original_type = x.dtype - return self.norm(x, self.cast(self.weight, original_type))[0] + if self._enable_npu_silent_detect: + weight = self.check_weight(self.weight, self._pre_val, self._min_val, self._max_val, None, self._sdc_result, None) + return self.norm(x, self.cast(weight, original_type))[0] + else: + return self.norm(x, self.cast(self.weight, original_type))[0] def construct(self, x): """Forward of RMSNorm.""" @@ -387,7 +401,7 @@ class LlamaRMSNorm(nn.Cell): self.mul.shard((strategy_in, strategy_in)) self.mul2.shard((strategy_in, (1,))) else: - self.norm.shard((strategy_in,)) + self.norm.shard((strategy_in, (1,))) class LlamaFeedForward(Cell): diff --git a/mindformers/wrapper/wrapper.py b/mindformers/wrapper/wrapper.py index 4b6209d3..664d74eb 100644 --- a/mindformers/wrapper/wrapper.py +++ b/mindformers/wrapper/wrapper.py @@ -116,6 +116,8 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) + is_silent_fault = self._get_silent_check_status() + learning_rate = self.learning_rate if self.optimizer.dynamic_lr: if self.optimizer.is_group_lr: @@ -125,14 +127,16 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): # if there is no overflow, do optimize if not overflow: - if self.use_clip_grad: - grads, grad_norm = self.clip_grad_norm(grads) - if self.use_grad_norm: - loss = F.depend(loss, self.optimizer(grads, grad_norm)) + # if there is not sdc, do optimize + if not self._enable_npu_silent_recovery or not is_silent_fault: + if self.use_clip_grad: + grads, grad_norm = self.clip_grad_norm(grads) + if self.use_grad_norm: + loss = F.depend(loss, self.optimizer(grads, grad_norm)) + else: + loss = F.depend(loss, self.optimizer(grads)) else: loss = F.depend(loss, self.optimizer(grads)) - else: - loss = F.depend(loss, self.optimizer(grads)) return loss, overflow, scaling_sens, learning_rate @@ -264,7 +268,10 @@ class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): cond = F.depend(cond, grads) overflow = self.process_loss_scale(cond) + is_silent_fault = self._get_silent_check_status() + if not overflow: - loss = F.depend(loss, self.optimizer(grads)) + if not self._enable_npu_silent_recovery or not is_silent_fault: + loss = F.depend(loss, self.optimizer(grads)) return loss, overflow, scaling_sens.value(), learning_rate -- Gitee From 88a109c88e3e724922895818fd91da0137450c38 Mon Sep 17 00:00:00 2001 From: fandawei Date: Mon, 5 Feb 2024 21:53:04 +0800 Subject: [PATCH 2/2] fix bug --- mindformers/wrapper/wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindformers/wrapper/wrapper.py b/mindformers/wrapper/wrapper.py index 664d74eb..eed36c28 100644 --- a/mindformers/wrapper/wrapper.py +++ b/mindformers/wrapper/wrapper.py @@ -116,7 +116,7 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) - is_silent_fault = self._get_silent_check_status() + is_silent_fault = self._get_silent_check_status(grads) learning_rate = self.learning_rate if self.optimizer.dynamic_lr: @@ -268,7 +268,7 @@ class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): cond = F.depend(cond, grads) overflow = self.process_loss_scale(cond) - is_silent_fault = self._get_silent_check_status() + is_silent_fault = self._get_silent_check_status(grads) if not overflow: if not self._enable_npu_silent_recovery or not is_silent_fault: -- Gitee