diff --git a/npu_tuned_model/llm/deepseek_v3/README.md b/npu_tuned_model/llm/deepseek_v3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c042bc2155298a50f9f7a39b762ea00a04f6eee9 --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/README.md @@ -0,0 +1,207 @@ +# DeepseekV3 + +本sample主要是DeepseekV3模型在npu上的推理适配点介绍,使用transformers==4.40.0版本,基于DeepseekV3开源方法[modeling_deepseek.py](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py)进行迁移。 + +--- + +# 1. Quick Start:执行样例 + +本sample的目录下提供了手动Tensor并行及DeepseekV3推理的执行样例参考 + +## 1.1. 环境准备 +**基于搭建的conda环境,安装对应的transformers版本** + +```shell +pip3 install transformers==4.40.0 +``` + +**依赖MindSpeed提供的GMM算子,安装[MindSpeed](https://gitee.com/ascend/MindSpeed)** + +```shell +git clone https://gitee.com/ascend/MindSpeed.git +pip install -e MindSpeed +``` + +**设置环境变量** + +```shell +cann_path=/usr/local/Ascend +source ${cann_path}/latest/bin/setenv.bash # 昇腾cann包安装目录 + +export ASCEND_HOME_PATH=${cann_path}/latest +export HCCL_OP_EXPANSION_MODE=AIV # HCCL AIVector Core加速 +``` + +## 1.2. 权重准备 +**手动切分权重** + +循环调用scripts/split_weight.py处理每个device对应的权重。其中WORLD_SIZE表示推理的卡数,path_to_deepseek_model_origin为原始完整权重路径,path_to_deepseek_model_tp为TP切分后的新权重落盘路径 + +```shell +export WORLD_SIZE=8 + +for((i=0; i<${WORLD_SIZE}; i++)) +do + export LOCAL_RANK=$i + python scripts/split_weight.py --model-path "path_to_deepseek_model_origin" --output-path "path_to_deepseek_model_after_tp" +done +``` +- 已提供权重切分脚本`split_weight.py` + +## 1.3. DeepseekV3多卡推理 + +本sample的目录下提供了通过同时拉起多个进程的方式,实现了多卡推理。同时提供了推理脚本`infer.py`作为参考。 + +```shell +export WORLD_SIZE=8 +export ENABLE_PROFILE=1 +export PROFILING_PATH="profiling_path" +export HCCL_DETERMINISTIC=true + +for((i=0; i<${WORLD_SIZE}; i++)) +do + export LOCAL_RANK=$i + export RANK_ID=$i + python3 infer.py \ + --model_name=${MODEL_NAME} --model_path=${MODEL_DIR} \ + --input_max_len=${INPUT_MAX_LEN} --max_new_tokens=${MAX_NEW_TOKENS} --batch_size=${BATCH_SIZE} \ + --tokenizer_mode=${TOKENIZER_MODE} --execute_mode=${EXE_MODE} \ + --profiling_path=${PROFILING_PATH} & +done +``` + +--- + +# 2. 目录结构 + +本sample目录结构与文件介绍如下: +- `engine`目录:涉及通用模型执行引擎`model_run.py` + - `model_run.py`:模型执行引擎,包含模型初始化、模型加载、tokenizer初始化、模型推理等通用基类方法。 +- `scripts`目录:涉及当前DeepseekV3模型执行涉及的相关脚本 + - `models`目录:涉及模型脚本 + - `configuration_deepseek.py`:DeepseekV3模型配置config + - `modeling_deepseek.py`:DeepseekV3模型脚本 + - `runner_deepseek.py`:基于通用模型执行引擎进行继承,适配当前模型所需各项内容 + - `split_weight.py`:权重切分工具 + - `infer.py`:DeepSeekV3模型推理执行脚本 + +--- + +# 3. 模型迁移、适配与优化 + +[模型迁移指导](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/ptmoddevg/trainingmigrguide/PT_LMTMOG_0002.html) + +## 3.1. 权重切分与手动Tensor并行 + +以DeepseekV3的开源尝试为样本,进行Tensor并行尝试。推理时我们需要对模型权重进行切分,使得内存占用小于device可用内存。当前sample以Tensor并行为例,对DeepseekV3的权重进行了手动切分。 + +手动Tensor并行主要涉及以下几步: + +- 定义切分后的模型权重,涉及DeepseekV3Attention、DeepseekV3MLP两个类 +- 切分模型权重,可参考scripts/split_weight.py中的split_w函数,将Attention层的q/k/v Weight在N轴上切分成TP份,将MOE层每个专家中的w1/w3 weight在N轴上切分TP份,w2在K轴进行切分 +- Attention层和MOE层结尾处插入allreduce算子 + +## 3.2. 性能优化 + +**注**:在modeling_deepseek.py中,被修改的原函数都加了‘__’前缀,可用于对比修改后的函数变化。deepseek结构中的非MOE部分与Llama类似,通用优化点可参考[Llama](https://gitee.com/ascend/torchair/tree/master/npu_tuned_model/llm/llama)的改动,如固定kv cache大小、cos/sin优化、Add+RMSNorm融合、全量优化LM Head计算量。本sample重点展示其余改动点。 + +### 3.2.1. 算法优化 + +**DeepseekV2低秩压缩优化** +参考[DeepSeek-V2论文](https://arxiv.org/pdf/2405.04434)中提及的低秩压缩方法,本sample对`DeepseekV3Attention`类进行修改 +- 将原始实现中的`kv_b_proj`拆分成`kv_b_proj_w_k`与`kv_b_proj_w_v`,权重切分方式参考`split_weight.py` +- 相关计算过程在`forward`方法中体现 + +### 3.2.2. 算子融合 +**GMM使能&&Routing优化** + +Hugging face原始的MOE实现比较朴素,for循环处理每个专家,单独计算expert_num个FFN,计算效率较低。 + +CANN提供了[GroupedMatmul](https://gitee.com/ascend/MindSpeed/blob/master/docs/ops/gmm.md)算子,可以同时计算多个专家,提高计算和搬运效率。为了使能GroupedMatmul算子,我们需修改Routing逻辑,构造对应输入。 + +- 整体上进行`DeepseekV3MoE`重构,原始实现体现在`__DeepseekV3MoE`类中 + - 路由专家计算过程,主要涉及`DeepSeekV3MLP`类,适配GroupedMatmul算子并修改为`DeepSeekV3MLPGMM`类 + - 路由专家的权重在`DeepSeekV3MLPGMM`中进行了专家合并,合并为了一个weight,体现在`split_weight.py`中。进行tensor并行切分时,同时对`DeepSeekV3MLPGMM`进行切分 + - 共享专家依旧沿用`DeepSeekV3MLP`类 + +- 专家路由基础流程参考[GMM使能&&Routing优化](https://gitee.com/ascend/torchair/blob/master/npu_tuned_model/llm/mixtral/README.md)实现,体现在`DeepseekV3MoE`类中的`moe_infer_normal`函数 +- 同时,可通过使能CANN提供的torch_npu moe_routing相关算子进行优化,体现在`DeepseekV3MoE`类中的`moe_infer_fusion`函数。 + - 可通过设置`self.npu_routing_kernel=True`类使能,默认为True + - 可通过使能[torch_npu.npu_moe_init_routing](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/apiref/apilist/ptaoplist_000780.html)替换基础流程中专家排布获取环节 + - 可通过使能[torch_npu.npu_moe_compute_expert_tokens](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/apiref/apilist/ptaoplist_000782.html)替换基础流程中专家获得token数计算环节 + - 可通过使能[torch_npu.npu_moe_finalize_routing](https://www.hiascend.com/document/detail/zh/Pytorch/60RC3/apiref/apilist/ptaoplist_000781.html)替换基础流程中专家计算完成后的重新排布环节,用于获得最终输出 + +**MoeGate亲和优化** +原始实现中,通过`torch.zeros_like`与`scatter`算子来进行`group_mask`获取: +```python +group_mask = torch.zeros_like(group_scores) # [n, n_group] +group_mask.scatter_(1, group_idx, 1) # [n, n_group] +``` + +在本sample中,通过`one_hot`与`sum`进行等价替换: +```python +def one_hot(tensor, num_classes): + index = torch.arange(0, num_classes, dtype=tensor.dtype, device=tensor.device) + return ( + tensor.view([*tensor.shape, 1]) == index.view([1] * tensor.ndim + [num_classes]) + ).to(torch.float32) + +group_mask = one_hot(group_idx, self.n_group) # [n, n_group] +group_mask = torch.sum(group_mask, dim=1) # [n, n_group] +``` + +**MLP合并优化** +原始`DeepseekV3MLP`实现中,存在`gate_proj`、`up_proj`与`down_proj`三个matmul运算,可通过将`gate_proj`与`up_proj`进行合并整合计算,提升整体计算效率。 +- 整体上进行`DeepseekV3MLP`重构,原始实现体现在`__DeepseekV3MLP`类中 +- 权重切分过程中,需要额外对`gate_proj`与`up_proj`的权重进行合并,体现在`split_weight.py`中 + +### 3.2.3. 图模式适配 + +在图模式适配过程中,需要**注意**: + + - 需先保证模型在npu上的eager模式功能正常和精度正确,然后再进行图模式的迁移和适配。 + +考虑到LLM prefill阶段,query的seq length经常是变化的;decode阶段,seq length通常是固定的。本sample通过提前引入输入padding,将输入padding到预设长度,同时以静态图的方式执行Prefill & Decode。 + +CompilerConfig配置参考[torchair资料](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/modthirdparty/torchairuseguide/torchair_0021.html) + +- torchair提供了NPU的图构造/图编译/图执行能力。相关能力全部集成到NPU图的后端,在使用torch.compile接口时,指定NPU图后端来使能。同时提供了开关和config控制图的编译和执行流程。 +- 在使用NPU图后端的时候,torchair提供了静态图和动态图两种图执行的能力。根据dynamic参数决定是否走动态图。 + +### 3.2.4. HCCL使能AIV + +利用Device的AI Vector Core计算单元来加速AllReduce,可参考[HCCL_OP_EXPANSION_MODE环境变量](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha003/apiref/envref/envref_07_0088.html) + +```shell +export HCCL_OP_EXPANSION_MODE=AIV +``` +--- + +# 4. 附录:环境变量说明 + + + + + + + + + + + + + + + + + + + + + + + + + + +
类别归属环境变量说明
环境配置项
WORLD_SIZE多卡执行时,用于声明可使用的卡数
LOCAL_RANK每个进程在整体通信域中感知到的rank_id
RANK_ID每个进程在整体通信域中感知到的rank_id
模型基础配置项
MODEL_NAME模型名
MODEL_DIR权重路径,必须配置到模型权重所在文件夹
INPUT_MAX_LEN本sample默认将输入padding到固定长度进行执行
MAX_NEW_TOKENS用于配置最多decode生成字符个数
BATCH_SIZE默认执行prefill-1batch, decode-nBatch模式。可通过设置该环境变量,使能decode多batch推理,默认为1
TOKENIZER_MODE可使用不同的tokenizer,用于生成不同的prompt进行推理。支持default与chat两种,默认为default
执行模式配置
EXE_MODE用于区分图模式与单算子模式。eager表示单算子模式,dynamo表示图模式。默认为单算子模式
调测配置项
ENABLE_PROFILE是否执行Profiling用于性能分析,默认不开启
PROFILING_PATH用于指定Profiling数据生成路径
HCCL_DETERMINISTIC可设置该环境变量为true,用于使能多卡间的确定性计算。默认为false
HCCL_OP_EXPANSION_MODE利用Device的AI Vector Core计算单元来加速AllReduce。与确定性计算HCCL_DETERMINISTIC互斥
\ No newline at end of file diff --git a/npu_tuned_model/llm/deepseek_v3/engine/model_runner.py b/npu_tuned_model/llm/deepseek_v3/engine/model_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..92f6d38e3b4a47f9d34b4c4bcf1f0826ca0e6917 --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/engine/model_runner.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +import logging +import copy +import numpy as np +import torch +import torch_npu + +from transformers import AutoTokenizer + +root_logger = logging.getLogger() +root_logger.handlers.clear() +logging.basicConfig(format='%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s', + level=logging.INFO) +logging.getLogger("paramiko").setLevel(logging.ERROR) + +torch.manual_seed(42) +torch.npu.manual_seed_all(42) + + +class InferenceContextManager: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + +class ModelRunner: + def __init__(self, model_path, execute_mode, **kwargs): + self.model_name = kwargs.get("model_name", "default_model_name") + self.dtype = kwargs.get("dtype", torch.bfloat16) + self.max_position_embeddings = kwargs.get("max_position_embeddings", 131072) + self.input_max_len = kwargs.get("input_max_len", 1024) + self.max_new_tokens = kwargs.get("max_new_tokens", 32) + self.batch_size = kwargs.get("batch_size", 72) + self.tokenizer = None + self.model = None + self.device = None + self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.rank_offset = int(os.getenv("RANK_OFFSET", "0")) + self.global_rank = self.local_rank + self.rank_offset + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + if self.world_size == 1: + self.model_path = model_path + else: + self.model_path = os.path.join(model_path, f"rank_{self.local_rank}") + self.use_pretrained_model = True + self.execute_mode = execute_mode + self.tokenizer_mode = kwargs.get("tokenizer_mode", "default") + self.profiling_path = kwargs.get("profiling_path", "") + self.enable_profile = False + self.init_device() + + def init_device(self): + logging.info("Set execution using npu index: %s, global: %s", self.local_rank, self.global_rank) + self.device = torch.device("%s:%s" % ("npu", self.local_rank)) + torch.npu.set_device(self.device) + + master_addr = os.environ["MASTER_ADDR"] + master_port = int(os.environ["MASTER_PORT"]) + + if torch.npu.is_available() and self.world_size > 1: + torch.distributed.init_process_group( + backend="hccl", world_size=self.world_size, rank=self.global_rank) + + def init_model(self, model, config=None): + if self.use_pretrained_model: + self.load_model(model) + else: + self.init_model_from_config(model, config=config) + self.to_device() + self.cast_format() + self.compile_model() + self.init_tokenizer() + + def init_model_from_config(self, model, config): + assert config is not None + config_file = "*.json" + model_config = config.from_pretrained(config_file, torch_dtype=self.dtype, + max_position_embeddings=self.max_position_embeddings) + self.model = model(model_config).to(self.dtype) + + def load_model(self, model): + logging.info("Try to load pretrained model in path: %s", self.model_path) + self.model = model.from_pretrained(self.model_path, + low_cpu_mem_usage=True, + ignore_mismatched_sizes=True, + torch_dtype=self.dtype, + max_position_embeddings=self.max_position_embeddings) + + def save_model(self): + pass + + def to_device(self): + self.model.to(self.device) + + def cast_format(self): + pass + + def compile_model(self): + logging.info("The final model structure is: \n %s", self.model) + if self.execute_mode == "dynamo": + logging.info("Try to compile model") + self.graph_compile() + + def init_tokenizer(self): + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right", truncation_side='right') + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + def graph_compile(self): + import torchair as tng + import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce + from torchair.configs.compiler_config import CompilerConfig + + compiler_config = CompilerConfig() + compiler_config.experimental_config.frozen_parameter = True + compiler_config.experimental_config.tiling_schedule_optimize = True + npu_backend = tng.get_npu_backend(compiler_config=compiler_config) + self.model.model = torch.compile(self.model.model, dynamic=True, fullgraph=False, backend=npu_backend) + + def mark_inputs(self, model_inputs): + if self.execute_mode == "dynamo": + pass + + def model_input_prepare(self, input_dict): + return None + + def repeat_batch(self, tensor, N): + if N == 1: + return tensor + return tensor.repeat(N, *[1]*(tensor.dim() - 1)) + + def model_output_process(self, model_inputs, outputs, input_dict): + pass + + def _define_profiling(self, profile_switch=False, profile_save_path="prof"): + if profile_switch: + os.makedirs(profile_save_path, exist_ok=True) + experimental_config = torch_npu.profiler._ExperimentalConfig( + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization + ) + profiler = torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.NPU, + torch_npu.profiler.ProfilerActivity.CPU], + with_stack=False, + record_shapes=False, + profile_memory=False, + experimental_config=experimental_config, + schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=1, repeat=1, skip_first=0), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path) + ) + else: + profiler = InferenceContextManager() + return profiler + + def model_inference(self, model_inputs, warm_up=False, profile_switch=False, profile_save_path=""): + torch.npu.synchronize() + if warm_up: + self.mark_inputs(model_inputs) + profiler = self._define_profiling(profile_switch, profile_save_path) + start_time = time.time() + with profiler as prof: + with torch.no_grad(): + logits = self.model(**model_inputs) + + torch.npu.synchronize() + end_time = time.time() + logging.info(f"{self.model_name} inference time cost {(end_time - start_time)*1000:.2f} ms") + return logits + + def model_generate(self, prompts, warm_up=False, **kwargs): + pass diff --git a/npu_tuned_model/llm/deepseek_v3/scripts/infer.py b/npu_tuned_model/llm/deepseek_v3/scripts/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba6bf472b5c9cbd4379096873f723be69e157a4 --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/scripts/infer.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import argparse +import logging +import json +import torch + +CUR_DIR = os.path.dirname(__file__) +ROOT_DIR = os.path.realpath(os.path.join(CUR_DIR, "..")) +sys.path.append(ROOT_DIR) +from runner_deepseek import DeepSeekRunner + +root_logger = logging.getLogger() +root_logger.handlers.clear() +logging.basicConfig(format='%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s', + level=logging.INFO) +logging.getLogger("paramiko").setLevel(logging.ERROR) +torch.manual_seed(42) +torch.npu.manual_seed_all(42) + + +# basic token generater +def generate_default_prompt(): + # prompts的size大小决定了模型执行时的batch size大小 + _PROMPTS = [ + "用一句话描述地球为什么是独一无二的。", + "给出一段对话,使用合适的语气和回答方式继续对话。\n对话:\nA:你今天看起来很高兴,发生了什么好事?\nB:是的,我刚刚得到一份来自" + # "梅西银行的工作通知书。\nA:哇,恭喜你!你打算什么时候开始工作?\nB:下个月开始,所以我现在正为这份工作做准备。", + # "Let x = 1. What is x << 3 in Python 3? the answer is", + # "In Python 3, what is ['a', 'Chemistry', 0, 1][-3]?", + # "The study of older adults and aging is reffered to as", + # "Why is the sky blue?", + # "What's your name?", + # "Hello my name is", + ] + return _PROMPTS + + +def generate_chat_prompt(bs): + _PROMPTS = [ + {"role": "user", "content": "Write a piece of quicksort code in C++"}, + ] + _PROMPTS = [_PROMPTS] * (bs // len(_PROMPTS) + 1) + _PROMPTS = _PROMPTS[:bs] + logging.info(f"chat prompt batch size: {bs}") + return _PROMPTS + + +def generate_prompt(bs, tokenizer_mode): + if tokenizer_mode == "default": + return generate_default_prompt() + else: + return generate_chat_prompt(bs) + + +def parse_args(): + parser = argparse.ArgumentParser(description="llm run parameters") + parser.add_argument('--model_path', type=str, help="Path of model weights") + parser.add_argument('--model_name', type=str, help="Model name") + parser.add_argument('--execute_mode', type=str, default="eager", choices=["dynamo", "eager"], + help="eager or dynamo") + parser.add_argument('--tokenizer_mode', type=str, default="default", choices=["default", "chat"], + help="tokenizer_mode should be default or chat") + parser.add_argument('--profiling_path', type=str, help="Path of profiling, not set means no dump") + parser.add_argument('--local_rank', type=int, default=0, help="local rank id for torch distributed launch") + parser.add_argument('--input_max_len', type=int, default=1024, help="Max number of input") + parser.add_argument('--max_new_tokens', type=int, default=32, help="Max number of new tokens") + parser.add_argument('--batch_size', type=int, default=2, help="Batch size for testing") + parser.add_argument('--json_path', type=str, help="Path of settings") + parser_args = parser.parse_args() + return parser_args + + +def run_deepseek(model_path, execute_mode, **kwargs): + _PROMPTS = generate_prompt(1, args.tokenizer_mode) + model_runner = DeepSeekRunner(model_path, execute_mode, **kwargs) + # 表示在图模式下开启算子二进制复用,提高图模式下编译阶段性能 + torch.npu.set_compile_mode(jit_compile=False) + model_runner.init_model() + # warmup + model_runner.model_generate(_PROMPTS, warm_up=True, **kwargs) + # generate perf data + model_runner.model_generate(_PROMPTS, **kwargs) + if model_runner.profiling_path: + model_runner.set_enable_profile(True) + model_runner.model_generate(_PROMPTS, **kwargs) + + +if __name__ == "__main__": + args = parse_args() + input_max_len = args.input_max_len # 输入padding的长孺 + max_new_tokens = args.max_new_tokens # 最大输出token的个数 + max_position_embeddings = input_max_len + max_new_tokens # 用于申请kv_cache时指定seq_len长度 + model_config = { + "dtype": torch.bfloat16, + "input_max_len": input_max_len, + "max_new_tokens": max_new_tokens, + "max_position_embeddings": max_position_embeddings + } + run_config = { + "tokenizer_mode": args.tokenizer_mode, + "profiling_path": args.profiling_path, + "batch_size": args.batch_size, + "model_name": args.model_name + } + config = {**model_config, **run_config} + os.environ["EXE_MODE"] = args.execute_mode + run_deepseek(args.model_path, args.execute_mode, **config) + logging.info("model run success") diff --git a/npu_tuned_model/llm/deepseek_v3/scripts/models/configuration_deepseek.py b/npu_tuned_model/llm/deepseek_v3/scripts/models/configuration_deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd2e9615f6e77fd61f8c45370cf803c936423ec --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/scripts/models/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV3Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/npu_tuned_model/llm/deepseek_v3/scripts/models/modeling_deepseek.py b/npu_tuned_model/llm/deepseek_v3/scripts/models/modeling_deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4d7f33e1ad501a472986b43e93691bb7fda1cd --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/scripts/models/modeling_deepseek.py @@ -0,0 +1,2414 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import os +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch_npu +from mindspeed.ops import gmm + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV3Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV3Config" + + +def _use_return_dict(self): + # return self.config.use_return_dict(self) + return False + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def __forward(self, hidden_states, residual: Optional[torch.Tensor] = None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def ln_npu(self, hidden_states): + result = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + return result + + def forward(self, hidden_states, *args): + if len(args) == 0: # only hidden _states exists + result = self.ln_npu(hidden_states) + return result + elif len(args) == 1 and args[0] is None: # residual is None + result = self.ln_npu(hidden_states) + residual = hidden_states + return (result, residual) + elif len(args) == 1: # residual is not None + residual = args[0] + y, _, x = torch_npu.npu_add_rms_norm(residual, hidden_states, self.weight, self.variance_epsilon) + return (y, x) + else: + raise NotImplementedError( + f"insupportable DeepseekV3RMSNorm for input_args len as (include hid): {len(args) + 1}" + ) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) + + +class DeepseekV3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + # self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def __forward(self, x, kv_len=None, max_seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or kv_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=kv_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:kv_len].to(dtype=x.dtype), + self.sin_cached[:kv_len].to(dtype=x.dtype), + ) + + def forward(self, x, kv_len, max_seq_len=None): + # x shape is [bs, num_attention_heads, seq_len, head_size] + if max_seq_len is None: + self._set_cos_sin_cache(seq_len=kv_len, device=x.device, dtype=x.dtype) + elif max_seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype) + + batch_size, seq_len, _ = x.size() + if seq_len == 1: + # BD -> BNSD + cos = torch.index_select(self.cos_cached, dim=0, index=kv_len).unsqueeze(1).unsqueeze(1) + sin = torch.index_select(self.sin_cached, dim=0, index=kv_len).unsqueeze(1).unsqueeze(1) + else: + # SD -> BSND + cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, 1, 1) + sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, 1, 1) + + cos = cos[0,:,0,:] + sin = sin[0,:,0,:] + return ( + cos.to(dtype=x.dtype), + sin.to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 +class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): + """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # BSND->BNSD + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # BSND->BNSD + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class __DeepseekV3MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +class DeepseekV3MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.intermediate_size_per_rank = self.intermediate_size // self.world_size + self.merge_up_gate_proj = nn.Linear(self.hidden_size, self.intermediate_size_per_rank * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size_per_rank, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + merged_x = self.merge_up_gate_proj(x) + gate_state, up_state = merged_x.chunk(2, dim=-1) + intermediate_hidden_states = self.act_fn(gate_state) * up_state + down_proj = self.down_proj(intermediate_hidden_states) + if self.world_size > 1: + dist.all_reduce(down_proj) + return down_proj + + +class DeepseekV3MLPGMM(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.num_experts = config.n_routed_experts + + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + self.intermediate_size_per_rank = self.intermediate_size // self.world_size + + self.act_fn = ACT2FN[config.hidden_act] + + self.group_w1_w3 = nn.Parameter(torch.ones(self.num_experts, self.intermediate_size_per_rank * 2, self.hidden_size), + requires_grad=False) + self.group_w2 = nn.Parameter(torch.ones(self.num_experts, self.hidden_size, self.intermediate_size_per_rank), + requires_grad=False) + + def forward(self, hidden_states, expert_tokens, seq_len=None): + mm1_mm3 = gmm.npu_gmm(hidden_states, torch.transpose(self.group_w1_w3, 1, 2), + bias=None, group_list=expert_tokens, group_type=0) + mm1, mm3 = mm1_mm3.chunk(2, dim=-1) + intermediate_hidden_states = self.act_fn(mm1) * mm3 + hidden_states = gmm.npu_gmm(intermediate_hidden_states, torch.transpose(self.group_w2, 1, 2), + bias=None, group_list=expert_tokens, group_type=0) + return hidden_states + + +def one_hot(tensor, num_classes): + index = torch.arange(0, num_classes, dtype=tensor.dtype, device=tensor.device) + return ( + tensor.view([*tensor.shape, 1]) == index.view([1] * tensor.ndim + [num_classes]) + ).to(torch.float32) + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.n_routed_experts)) + ) + + def reset_parameters(self) -> None: + pass + # import torch.nn.init as init + + # init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.to(torch.float32), self.weight.to(torch.float32), None + ) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "noaux_tc": + assert not self.training + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + # group_mask = torch.zeros_like(group_scores) # [n, n_group] + # group_mask.scatter_(1, group_idx, 1) # [n, n_group] + group_mask = one_hot(group_idx, self.n_group) # [n, n_group] + group_mask = torch.sum(group_mask, dim=1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + topk_weight = scores.gather(1, topk_idx) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor + + return topk_idx, topk_weight, None + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class __DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV3MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.to(hidden_states.dtype).view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_dim = config.hidden_size + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.batch_size_decode = int(os.getenv("BATCH_SIZE", "1")) + self.batch_size_prefill = 1 + self.npu_routing_kernel = True + self.num_experts_per_tok = config.num_experts_per_tok + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = DeepseekV3MLPGMM(config, intermediate_size=config.moe_intermediate_size) + + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP(config, intermediate_size=intermediate_size) + if self.npu_routing_kernel: + self.row_idx_decode_len = self.batch_size_decode * self.top_k + self.row_idx_decode = torch.arange( + 0, self.row_idx_decode_len, + dtype=torch.int32).view(self.top_k, -1).permute(1,0).int().contiguous().npu() + + def forward(self, hidden_states): + identity = hidden_states + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + y = self.moe_infer(hidden_states, topk_idx, topk_weight) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + def __get_idx_info(self, selected_experts): + # input_shape: selected_experts --> [bs*seq, topk] + selected_experts = selected_experts.view(-1) + selected_experts_fp32 = selected_experts.to(torch.int32).to(torch.float) # [bs*seq*topk] + + # get expert_cumsum mask + # expert_mask shape is [bs*seq*topk, expert_num] + expert_mask = one_hot(selected_experts_fp32, num_classes=self.experts_per_rank) + # expert_tokens shape is [expert_num, ], represent token_num performed by expert_i + expert_tokens = torch.sum(expert_mask, dim=0) + expert_tokens = torch.cumsum(expert_tokens, dim=0).to(torch.int64) + + # get sorted / unsort indices + _, sorted_indices = torch.sort(selected_experts_fp32, dim=-1) + sorted_indices_fp32 = sorted_indices.to(torch.int32).to(torch.float) + _, unsort_indices = torch.sort(sorted_indices_fp32, dim=-1) + return expert_tokens, sorted_indices, unsort_indices + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + if self.npu_routing_kernel: + return self.moe_infer_fusion(x, topk_ids, topk_weight) + else: + return self.moe_infer_normal(x, topk_ids, topk_weight) + + def moe_infer_normal(self, x, topk_ids, topk_weight): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + topk_weight = topk_weight.to(x.dtype) + expert_tokens, sorted_indices, unsort_indices = self.__get_idx_info(topk_ids) + + # get hid states + hidden_states = x[:, None, ...].repeat((1, self.top_k, 1)).view((-1, x.shape[-1])) # [bs*seq*topk, hidden_size] + hidden_states_sorted_by_experts = torch.index_select(hidden_states, 0, sorted_indices) + + # hidden_states_sorted_by_experts shape is [bs*seq*topk, hidden_size] + hidden_states_sorted_by_experts = self.experts(hidden_states_sorted_by_experts, expert_tokens, seq_len=orig_shape[1]) + + # hidden_states shape is [bs*seq*topk, hidden_size] + hidden_states = torch.index_select(hidden_states_sorted_by_experts, 0, unsort_indices) + # hidden_states shape is [bs*seq, topk, hidden_size] + hidden_states = hidden_states.view(-1, self.top_k, x.shape[-1]) + # hidden_states shape is [bs*seq, topk, hidden_size] + hidden_states = hidden_states * topk_weight.unsqueeze(-1) + # hidden_states shape is [bs*seq, hidden_size] + hidden_states = torch.sum(hidden_states, dim=1) + + if self.world_size > 1: + dist.all_reduce(hidden_states) + hidden_states = hidden_states.view(*orig_shape) + return hidden_states + + def moe_infer_fusion(self, x, topk_ids, topk_weight): + batch_size, sequence_length, h = x.shape + hidden_states = x.view(-1, x.shape[-1]) + + routing_weights = topk_weight.to(x.dtype) + expert_idx = topk_ids.int() + if sequence_length == 1: + row_idx = self.row_idx_decode + else: + row_idx_prefill_len = self.batch_size_prefill * sequence_length * self.top_k + row_idx_prefill = torch.arange( + 0, row_idx_prefill_len, dtype=torch.int32, + device=topk_weight.device).view(self.top_k, -1).permute(1,0).int().contiguous() + row_idx = row_idx_prefill + + expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=expert_idx, + active_num=batch_size*sequence_length + ) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, self.num_experts) + expert_tokens = expert_tokens.to(torch.int64) + + hidden_states_ordered_by_experts = self.experts(expanded_x, expert_tokens, seq_len=sequence_length) + + hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states_ordered_by_experts, + skip1=None, skip2=None, + bias=None, + scales=routing_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=expert_idx + ) + + if self.world_size > 1: + dist.all_reduce(hidden_states) + hidden_states = hidden_states.view(batch_size, -1, self.hidden_dim) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV3RotaryEmbedding( + self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding( + self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding( + self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.config.qk_rope_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.config.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3 +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + super().__init__() + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_heads_per_rank = self.num_heads // self.world_size + self.num_key_value_heads_per_rank = self.num_heads_per_rank + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads_per_rank * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads_per_rank * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + + self.kv_b_proj_w_k = nn.Parameter( + torch.zeros(self.num_heads_per_rank, self.qk_nope_head_dim, self.kv_lora_rank) + ) + self.kv_b_proj_w_v = nn.Parameter( + torch.zeros(self.num_heads_per_rank, self.kv_lora_rank, self.v_head_dim) + ) + + self.o_proj = nn.Linear( + self.num_heads_per_rank * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def _bmm(self, x, y): + b, s, n, _, d = x.shape + x = x.view(b*s, n, d).transpose(0,1) # n, bs, d + output = torch.matmul(x, y) # n, bs, rank + output = output.transpose(1, 0).view(b, s, n, -1) + return output + + def __prepare_qkv( + self, + hidden_states: torch.Tensor, + cos_sin: torch.Tensor = None, + kv_len: torch.IntTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + q = q.view(bsz, q_len, self.num_heads_per_rank, self.q_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = q_pe.transpose(1, 2) + q_nope = self._bmm( + q_nope.view(bsz, q_len, self.num_heads_per_rank, 1, self.qk_nope_head_dim), + self.kv_b_proj_w_k + ) + q_nope = q_nope.view(bsz, q_len, self.num_heads_per_rank, self.kv_lora_rank) + q_nope = q_nope.transpose(1, 2) + + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + k_nope = ( + self.kv_a_layernorm(compressed_kv) + .view(bsz, -1, 1, self.kv_lora_rank) + .transpose(1, 2) + ) # (bs, 1, q_len, kv_lora_rank) + + # rope + cos, sin = cos_sin + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = torch.cat([q_nope, q_pe], dim=-1) + key_states = torch.cat([k_nope, k_pe], dim=-1) + + kv_seq_len = k_nope.shape[-2] + if past_key_value is not None: + past_key_states = past_key_value[self.layer_idx][0] + torch_npu.scatter_update_(past_key_states, kv_len, key_states, -2) + if q_len == 1: + key_states = past_key_states + kv_seq_len = past_key_value[0][0].size()[-2] + value_states = key_states + return query_states, key_states, value_states, kv_seq_len + + def __apply_attention_npu( + self, + query_states, key_states, value_states, kv_seq_len, + attention_mask: Optional[torch.Tensor] = None, + actual_seq_lengths_kv: list = None, + output_attentions: bool = False, + past_key_value: Optional[Cache] = None, + ): + # repeat k/v heads if n_kv_heads < n_heads + bsz, _, q_len, _ = query_states.size() + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + assert attention_mask is not None + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + value_states = value_states[..., :self.kv_lora_rank] + attn_output = torch.matmul(attn_weights, value_states) + + # kv rank opt + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = self._bmm( + attn_output.unsqueeze(3), + self.kv_b_proj_w_v + ) # (bs, q_len, num_heads, kv_lora_rank) + attn_output = self.o_proj(attn_output.reshape(bsz, q_len, -1)) + if self.world_size > 1: + dist.all_reduce(attn_output) + return attn_output + + def forward( + self, + hidden_states: torch.Tensor, + kv_len: torch.IntTensor = None, + actual_seq_lengths_kv: list = None, + cos_sin: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + query_states, key_states, value_states, kv_seq_len = self.__prepare_qkv( + hidden_states=hidden_states, + cos_sin=cos_sin, + kv_len=kv_len, + position_ids=position_ids, + past_key_value=past_key_value + ) + output = self.__apply_attention_npu( + query_states=query_states, key_states=key_states, value_states=value_states, + kv_seq_len=kv_seq_len, + actual_seq_lengths_kv=actual_seq_lengths_kv, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value + ) + return output + + def __forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 +class DeepseekV3FlashAttention2(DeepseekV3Attention): + """ + DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV3FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, kv_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV3RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV3Attention, + "flash_attention_2": DeepseekV3FlashAttention2, +} + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV3RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + kv_len: torch.IntTensor, + actual_seq_lengths_kv: list, + cos_sin: torch.Tensor, + past_residual: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor]: + hidden_states, residual = self.input_layernorm(hidden_states, past_residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_len=kv_len, + actual_seq_lengths_kv=actual_seq_lengths_kv, + cos_sin=cos_sin, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (residual, hidden_states) + return outputs + + +DeepseekV3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3PreTrainedModel(PreTrainedModel): + config_class = DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + pass + # std = self.config.initializer_range + # if isinstance(module, nn.Linear): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.bias is not None: + # module.bias.data.zero_() + # elif isinstance(module, nn.Embedding): + # module.weight.data.normal_(mean=0.0, std=std) + # if module.padding_idx is not None: + # module.weight.data[module.padding_idx].zero_() + + +DeepseekV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3Model(DeepseekV3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`] + + Args: + config: DeepseekV3Config + """ + + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.config = config + self.rank_id = int(os.getenv("LOCAL_RANK", "0")) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.vocab_size_per_rank = self.vocab_size // self.world_size + + self.embed_tokens = nn.Embedding( + self.vocab_size_per_rank, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV3DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + _init_rope(self) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def __forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor, + kv_len: torch.IntTensor = None, + actual_seq_lengths_kv: list = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + ): + + batch_size, seq_length = input_ids.shape + past_key_values_length = past_key_values[0][0].size()[-2] + + if position_ids is None: + device = input_ids.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if self.world_size > 1: + new_input_ids = input_ids - self.rank_id * self.vocab_size_per_rank + mask = (new_input_ids >= 0) & (new_input_ids < self.vocab_size_per_rank) # (bs, qlen) + new_input_ids_per_rank = new_input_ids * mask + inputs_embeds = self.embed_tokens(new_input_ids_per_rank) * mask.unsqueeze(-1) + dist.all_reduce(inputs_embeds) + else: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + cos_sin = self.rotary_emb(hidden_states, kv_len, self.config.max_position_embeddings) + residual = None + + for decoder_layer in self.layers: + residual, hidden_states = decoder_layer( + hidden_states, + kv_len, + actual_seq_lengths_kv, + cos_sin=cos_sin, + past_residual=residual, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.config = config + self.input_max_len = int(os.getenv("INPUT_MAX_LEN", 1024)) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size // self.world_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def __forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + kv_len: torch.IntTensor = None, + actual_seq_lengths_kv: list = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + ): + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + kv_len=kv_len, + actual_seq_lengths_kv=actual_seq_lengths_kv, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + + hidden_states = outputs + + if hidden_states.size()[1] > 1: + gather_index, _ = torch.max(position_ids, dim=-1) + gather_index = gather_index.unsqueeze(1).unsqueeze(2).repeat(1, 1, hidden_states.shape[-1]) + hidden_states = torch.gather(hidden_states, 1, gather_index) + + logits = self.lm_head(hidden_states) + if self.world_size > 1: + new_logits = [logits.clone().detach() for _ in range(self.world_size)] + dist.all_gather(new_logits, logits) + logits = torch.cat(new_logits, dim=-1) + logits = logits.float() + + return logits + + def init_cache( + self, + input_ids, + world_size=1, + ): + batch_size, seq_len = input_ids.size() + + cache_seq_len = self.config.max_position_embeddings + + past_key_values = () + cache_key_shape = ( + batch_size, + 1, + cache_seq_len, + self.config.kv_lora_rank + self.config.qk_rope_head_dim + ) + dtype = self.config.torch_dtype + + for i in range(self.config.num_hidden_layers): + key_cache = torch.zeros(cache_key_shape, dtype=dtype, device=input_ids.device) + past_key_values += ((key_cache, ),) + + return past_key_values + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + is_prefill=None, + kv_len=None, + share_mask_tril=None, + world_size=1, + **kwargs + ): + batch_size, seq_len = input_ids.size() + if past_key_values is None: + past_key_values = self.init_cache(input_ids, world_size) + if is_prefill: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + attention_mask = share_mask_tril + kv_len = torch.zeros((position_ids.size()[0]), dtype=torch.int32, device=input_ids.device) + actual_seq_lengths_kv = None + else: + attention_mask = None + position_ids = kv_len.unsqueeze(1) + actual_seq_lengths_kv = (kv_len + 1).cpu().detach().numpy().tolist() + + # attention_mask set + if is_prefill: + past_key_values_length = 0 + sliding_window = self.input_max_len + input_mask = None + else: + past_key_values_length = self.config.max_position_embeddings - seq_len + sliding_window = min(self.config.max_position_embeddings, kwargs.get("input_lens")) + input_mask = share_mask_tril + + attention_mask = _prepare_4d_causal_attention_mask( + input_mask, + (batch_size, seq_len), + input_ids.float(), + past_key_values_length, + sliding_window + ) + + model_inputs = {} + model_inputs.update( + { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + "kv_len": kv_len, + "actual_seq_lengths_kv": actual_seq_lengths_kv + } + ) + return model_inputs + + def __prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV3_START_DOCSTRING, +) +class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/npu_tuned_model/llm/deepseek_v3/scripts/runner_deepseek.py b/npu_tuned_model/llm/deepseek_v3/scripts/runner_deepseek.py new file mode 100644 index 0000000000000000000000000000000000000000..404546a5228f0b50add6eb7f104f268110d7373e --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/scripts/runner_deepseek.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +import logging +import copy +import numpy as np +import torch +import torch_npu + +from functools import wraps +from engine.model_runner import ModelRunner +from models.modeling_deepseek import DeepseekV3ForCausalLM + +root_logger = logging.getLogger() +root_logger.handlers.clear() +logging.basicConfig(format='%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s', + level=logging.INFO) +logging.getLogger("paramiko").setLevel(logging.ERROR) + +torch.manual_seed(42) +torch.npu.manual_seed_all(42) + + +def override(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + + +def get_init_attn_mask(mask_length, device, valid_len=None): + share_mask_tril = ~torch.tril( + torch.ones((mask_length, mask_length), + dtype=torch.bool, device=device)) + if valid_len is not None: + share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length) + return share_mask_tril + + +def get_decode_mask(mask_length, device, position): + decode_mask = torch.zeros((1, mask_length), device=device) + decode_mask[0, :position] = 1 + return decode_mask + + +class DeepSeekRunner(ModelRunner): + def __init__(self, model_path, execute_mode, **kwargs): + super().__init__(model_path, execute_mode, **kwargs) + self.enable_mla = kwargs.get("enable_mla", 0) + self.no_ckpt = int(os.getenv("NO_CKPT", "0")) + self.enable_mix = int(os.getenv("ENABLE_MIX", "0")) + if self.enable_mix: + self.attn_dp_size = int(os.getenv("ATTN_DP_SIZE", "0")) + else: + self.attn_dp_size = 1 + + def init_model(self): + if not self.no_ckpt: + self.use_pretrained_model = True + config = None + else: + self.use_pretrained_model = False + try: + from models.configuration_deepseek import DeepseekV3Config as config + except: + config = None + logging.info(f"using default DeepseekV3ForCausalLM: for model name is %s", self.model_name) + super().init_model(DeepseekV3ForCausalLM, config) + + @override + def mark_inputs(self, model_inputs): + if self.execute_mode == "dynamo": + input_ids = model_inputs.get("input_ids") + kv_len = model_inputs.get("kv_len") + attention_mask = model_inputs.get("attention_mask") + position_ids = model_inputs.get("position_ids") + past_key_values = model_inputs.get("past_key_values") + + # prefill with dynamic sequence length, decode with static sequence length + torch._dynamo.mark_static(kv_len) + for item in past_key_values: + for sub_item in item: + torch._dynamo.mark_static(sub_item) + + torch._dynamo.mark_static(input_ids) + if attention_mask is not None: + torch._dynamo.mark_static(attention_mask) + torch._dynamo.mark_static(position_ids) + + @override + def model_input_prepare(self, input_dict): + input_ids = input_dict.get("input_ids") + attention_mask = input_dict.get("attention_mask") + past_key_values = input_dict.get("past_key_values") + is_prefill = input_dict.get("is_prefill") + kv_len = input_dict.get("kv_len") + share_mask_tril = input_dict.get("share_mask_tril") + model_inputs = self.model.prepare_inputs_for_generation( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + is_prefill=is_prefill, + kv_len=kv_len, + input_lens=input_dict.get("input_lens"), + share_mask_tril=share_mask_tril, + world_size=self.world_size) + return model_inputs + + @override + def model_output_process(self, model_inputs, outputs, input_dict): + next_batch = self.batch_size if input_dict["is_prefill"] else 1 + next_batch_dp = next_batch // self.attn_dp_size if input_dict["is_prefill"] else 1 + input_dict['is_prefill'] = False + input_dict['input_lens'] = input_dict['input_lens'] + 1 + + kv_len = torch.max(model_inputs.get("position_ids"), axis=1)[0] + 1 + input_dict['kv_len'] = self.repeat_batch(kv_len, next_batch_dp) + + logits = outputs + past_key_values = model_inputs.get("past_key_values") + past_key_values_batch = () + for i in range(len(past_key_values)): + past_key_values_layer_i = past_key_values[i] + cache_new_i = () + for cache_j in past_key_values_layer_i: + cache_j_new = self.repeat_batch(cache_j, next_batch_dp) + cache_new_i += (cache_j_new, ) + past_key_values_batch += (cache_new_i, ) + input_dict["past_key_values"] = past_key_values_batch + + attention_mask = None + + share_mask_tril = get_decode_mask(mask_length=self.max_position_embeddings, + device=self.device, + position=input_dict["input_lens"]) + share_mask_tril = share_mask_tril[None, None, ...] + + input_dict['attention_mask'] = attention_mask + input_dict['share_mask_tril'] = self.repeat_batch(share_mask_tril, self.batch_size) + + next_tokens = torch.argmax(logits, dim=-1)[:, -1:] + input_dict['input_ids'] = self.repeat_batch(next_tokens, next_batch) + input_dict['generate_ids'] = self.repeat_batch( + torch.cat([input_dict['generate_ids'], next_tokens], dim=-1), + next_batch + ) + + @override + def model_generate(self, prompts, warm_up=False, **kwargs): + assert self.input_max_len > 0 + calling_func = { + "default": self.tokenizer, + "chat": self.tokenizer.apply_chat_template + } + kwargs = { + "return_tensors": "pt", + "truncation": True, + "padding": "max_length", + "max_length": self.input_max_len + } + if self.tokenizer_mode == "chat": + chat_kwargs = { + "add_generation_prompt": True, "return_dict": True + } + kwargs.update(chat_kwargs) + + tokenizer = calling_func[self.tokenizer_mode] + inputs = tokenizer(prompts, **kwargs).to(self.device) + if int(os.getenv("ENABLE_PROFILE", "0")): + inputs.attention_mask = inputs.attention_mask * 0 + 1 + + # get init input_dict + share_mask_tril = get_init_attn_mask( + self.max_position_embeddings, self.device, + valid_len=self.input_max_len) + share_mask_tril = share_mask_tril[None, None, ...] + + input_lens = copy.deepcopy(inputs.input_ids.size()[1]) + logging.info("Prompt lens is : %d", input_lens) + input_dict = { + "input_ids": inputs.input_ids, + "input_lens": input_lens, + "attention_mask": inputs.attention_mask, + "past_key_values": None, + "is_prefill": True, + "kv_len": None, + "share_mask_tril": share_mask_tril, + "generate_ids": inputs.input_ids, + } + + prefill_time = 0 + decode_time = 0 + generate_tokens = 0 + cnt = 0 + while True: + jump_flag, profile_switch, profile_save_path = self._get_running_config(cnt, warm_up, generate_tokens) + if jump_flag: + break + + model_inputs = self.model_input_prepare(input_dict) + outputs = self.model_inference(model_inputs, warm_up=warm_up, profile_switch=profile_switch, + profile_save_path=profile_save_path) + self.model_output_process(model_inputs, outputs, input_dict) + generate_tokens += 1 + cnt += 1 + + generate_ids = input_dict["generate_ids"][0:1, input_lens:].clip(0, self.model.config.vocab_size - 1) + res = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True) + + if isinstance(res, list): + for answer in res: + logging.info("Inference decode result: \n%s", answer) + else: + logging.info("Inference decode result: \n%s", res) + return res + + def set_enable_profile(self, flag): + self.enable_profile = flag + logging.info(">>> Runner set enable_profile as: %d", flag) + + def _get_running_config(self, cnt, warm_up, generate_tokens): + default_decode_dump = 2 + # warm up only perform for 5 times(decode) + jump_flag_warm = warm_up and cnt >= default_decode_dump + # do not generate after max_token + jump_flag_oversize = generate_tokens >= self.max_new_tokens + jump_flag = jump_flag_oversize or jump_flag_warm + + # profile settings + profile_switch = self.enable_profile and (cnt < default_decode_dump) and (not warm_up) + + path_prefill = f"{self.profiling_path}/prefill" + path_decode = f"{self.profiling_path}/decode" + profile_save_path_dict = {0: path_prefill, 1: path_decode} + profile_save_path = profile_save_path_dict.get(cnt, path_decode) + return jump_flag, profile_switch, profile_save_path diff --git a/npu_tuned_model/llm/deepseek_v3/scripts/split_weight.py b/npu_tuned_model/llm/deepseek_v3/scripts/split_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..59b3cf79d0fae266463704784988634aebe45431 --- /dev/null +++ b/npu_tuned_model/llm/deepseek_v3/scripts/split_weight.py @@ -0,0 +1,188 @@ +import os +import argparse +import logging +import shutil +import numpy as np +import torch +from torch import nn +from transformers import AutoModelForCausalLM +from models.modeling_deepseek import DeepseekV3ForCausalLM + +root_logger = logging.getLogger() +root_logger.handlers.clear() +logging.basicConfig(format='%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s', + level=logging.INFO) +logging.getLogger("paramiko").setLevel(logging.ERROR) + + +def split_w(src_model, dst_model, world_size, local_rank, use_gmm_kernel=True): + def _to_parameter(data): + return nn.Parameter(data, requires_grad=False) + + vocab_size = src_model.model.vocab_size // world_size + + dst_model.lm_head.weight.data = src_model.lm_head.weight.data[local_rank * vocab_size: (local_rank + 1) * vocab_size, :] + dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[local_rank * vocab_size: (local_rank + 1) * vocab_size, :] + + dst_model.model.norm.weight.data = src_model.model.norm.weight.data + q_dim = dst_model.layers[0].self_attn.num_heads_per_rank * dst_model.layers[0].self_attn.q_head_dim + k_dim = dst_model.layers[0].self_attn.num_heads_per_rank * \ + (dst_model.layers[0].self_attn.qk_nope_head_dim + dst_model.layers[0].self_attn.v_head_dim) + o_dim = dst_model.layers[0].self_attn.num_heads_per_rank * dst_model.layers[0].self_attn.v_head_dim + + for i, block in enumerate(src_model.model.layers): + if dst_model.model.layers[i].self_attn.q_lora_rank is None: + dst_model.model.layers[i].self_attn.q_proj.weight.data = \ + block.self_attn.q_proj.weight.data[local_rank * q_dim: (local_rank + 1) * q_dim, :].contiguous() + else: + dst_model.model.layers[i].self_attn.q_a_proj.weight.data = \ + block.self_attn.q_a_proj.weight.data + dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = \ + block.self_attn.q_a_layernorm.weight.data + dst_model.model.layers[i].self_attn.q_ab_proj.weight.data = \ + block.self_attn.q_b_proj.weight.data[local_rank * q_dim: (local_rank + 1) * q_dim, :].contiguous() + + dst_model.model.layers[i].self_attn.kv_a_proj_woth_mqa.weight.data = \ + block.self_attn.kv_a_proj_woth_mqa.weight.data + + dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = \ + block.self_attn.kv_a_layernorm.weight.data + dst_model.model.layers[i].self_attn.o_proj.weight.data = \ + block.self_attn.o_proj.weight.data[:, local_rank * o_dim: (local_rank + 1) * o_dim].contiguous() + dst_model.model.layers[i].self_attn.input_layernorm.weight.data = \ + block.self_attn.input_layernorm.weight.data + dst_model.model.layers[i].self_attn.post_attention_layernorm.weight.data = \ + block.self_attn.post_attention_layernorm.weight.data + + kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[local_rank * k_dim: (local_rank + 1) * k_dim, :].contiguous() + qk_nope_head_dim = dst_model.layers[i].self_attn.qk_nope_head_dim + num_heads_per_rank = dst_model.layers[i].self_attn.num_heads_per_rank + kv_lora_rank = dst_model.layers[i].self_attn.kv_lora_rank + v_head_dim = dst_model.layers[i].self_attn.v_head_dim + + index_tensor = torch.arange(qk_nope_head_dim).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave(qk_nope_head_dim) * (qk_nope_head_dim + v_head_dim) + kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) + dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view(num_heads_per_rank, qk_nope_head_dim, kv_lora_rank).contiguous() + index_tensor = torch.arange(qk_nope_head_dim, qk_nope_head_dim + v_head_dim).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave(v_head_dim) * (qk_nope_head_dim + v_head_dim) + kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) + dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = kv_b_proj_w_v.view(num_heads_per_rank, v_head_dim, kv_lora_rank).transpose(1, 2).contiguous() + + # moe experts + # TP + if not (i >= dst_model.config.first_k_dense_replace and i % dst_model.config.moe_layer_freq == 0): + up_weight_list = [] + ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank + gate_weight = block.mlp.gate_proj.weight[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous() + up_weight = block.mlp.up_proj.weight[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous() + up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0))) + + if len(up_weight_list) == 1: + dst_model.model.layers[i].mlp.merged_up_gate_proj.weight = up_weight_list[0] + else: + dst_model.model.layers[i].mlp.merged_up_gate_proj.weight = _to_parameter(torch.cat(up_weight_list, axis=0)) + dst_model.model.layers[i].mlp.down_proj.weight = \ + block.mlp.down_proj.weight.data[:, local_rank * ffn_dim: (local_rank + 1) * ffn_dim].contiguous() + + else: + shared_up_weight_list = [] + ffn_dim = dst_model.model.layers[i].mlp.shared_expert.intermediate_size_per_rank + gate_weight = block.mlp.shared_expert.gate_proj.weight[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous() + up_weight = block.mlp.shared_expert.up_proj.weight[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous() + shared_up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0))) + if len(shared_up_weight_list) == 1: + dst_model.model.layers[i].mlp.shared_expert.merged_up_gate_proj.weight = shared_up_weight_list[0] + else: + dst_model.model.layers[i].mlp.shared_expert.merged_up_gate_proj.weight = \ + _to_parameter(torch.cat(shared_up_weight_list, axis=0)) + dst_model.model.layers[i].mlp.shared_expert.down_proj.weight = \ + block.mlp.shared_expert.down_proj.weight.data[:, local_rank * ffn_dim: (local_rank + 1) * ffn_dim].contiguous() + dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data + if dst_model.model.layers[i].mlp.gate.topk_method == "noaux_tc": + dst_model.model.layers[i].mlp.gate.a_score_correction_bias.data = block.mlp.gate.a_score_correction_bias.data + + expert_num = block.mlp.config.n_routed_experts + gate_proj_list, down_proj_list, up_proj_list = [], [], [] + for j, src_expert in enumerate(block.mlp.experts): + if use_gmm_kernel: + ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank + gate_proj_list.append(src_expert.gate_proj.weight.data[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous()) + up_proj_list.append(src_expert.up_proj.weight.data[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous()) + down_proj_list.append(src_expert.down_proj.weight.data[:, local_rank * ffn_dim: (local_rank + 1) * ffn_dim].contiguous()) + else: + ffn_dim = dst_model.model.layers[i].mlp.experts[j].intermediate_size_per_rank + dst_model.model.layers[i].mlp.experts[j].gate_proj.weight.data = \ + src_expert.gate_proj.weight.data[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous() + dst_model.model.layers[i].mlp.experts[j].up_proj.weight.data = \ + src_expert.up_proj.weight.data[local_rank * ffn_dim: (local_rank + 1) * ffn_dim, :].contiguous() + dst_model.model.layers[i].mlp.experts[j].down_proj.weight.data = \ + src_expert.down_proj.weight.data[:, local_rank * ffn_dim: (local_rank + 1) * ffn_dim].contiguous() + + if use_gmm_kernel: + dst_model.model.layers[i].mlp.experts.group_w2.data = \ + torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous() + group_gate_proj = torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() + group_up_proj = torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() + dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat([group_gate_proj, group_up_proj], dim=1) + + +def copy_files_with_prefix(src_dir, dst_dir, prefix): + for file in os.listdir(src_dir): + if file.startswith(prefix): + src_file = os.path.join(src_dir, file) + dst_file = os.path.join(dst_dir, file) + shutil.copy2(src_file, dst_file) + + +def parse_args(): + parser = argparse.ArgumentParser(description="split weight parameters with tensor parallel") + parser.add_argument('--model-path', type=str, help="Path of model weights") + parser.add_argument('--output-path', type=str, help="The output directory where the results are saved") + parser_args = parser.parse_args() + return parser_args + + +if __name__ == "__main__": + args = parse_args() + output_path = args.output_path + if not os.path.exists(output_path): + os.makedirs(output_path) + rank_size = int(os.getenv("WORLD_SIZE", "1")) + origin_model = AutoModelForCausalLM.from_pretrained(args.model_path, + trust_remote_code=True, + ignore_mismatched_sizes=True, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + attn_implementation="eager") + src_param_size = 0 + for name, params in origin_model.named_parameters(): + size_per_param = np.prod(params.size()) + src_param_size += size_per_param + logging.info("Param before tensor parallel: %s, %s, %s", + name, params.size(), params.dtype) + logging.info("Total param size before tensor parallel: %s", src_param_size) + + for rank_id in range(rank_size): + logging.info("rank_id={} / rank_size={}".format(rank_id, rank_size)) + os.environ["LOCAL_RANK"] = rank_id + + save_path = os.path.join(output_path, f"rank_{rank_id}") + logging.info("Split weight for rank %s start, save path is: %s", rank_id, save_path) + + config = origin_model.config + part_model = DeepseekV3ForCausalLM(config) + + split_w(origin_model, part_model, rank_size, rank_id) + + dst_param_size = 0 + for name, params in part_model.named_parameters(): + size_per_param = np.prod(params.size()) + dst_param_size += size_per_param + logging.info("Param after tensor parallel: %s, %s, %s, %s", + name, params.size(), params.dtype, params.device) + logging.info("Total param size after tensor parallel: %s", dst_param_size) + + part_model.save_pretrained(save_path) + copy_files_with_prefix(args.model_path, save_path, "tokenizer") + logging.info("Split weight for rank %s finished, save path is: %s", rank_id, save_path) + + del part_model \ No newline at end of file