From db2e84c42708592306f16e68ed5108519a1aecce Mon Sep 17 00:00:00 2001 From: Super User Date: Fri, 22 May 2026 16:40:17 +0800 Subject: [PATCH] support text_encoder for flux2 --- xllm/core/layers/npu/CMakeLists.txt | 4 + .../npu/loader/mistral_decoder_loader.cpp | 153 ++++++ .../npu/loader/mistral_decoder_loader.h | 39 ++ .../npu/npu_mistral_decoder_layer_impl.cpp | 270 +++++++++ .../npu/npu_mistral_decoder_layer_impl.h | 102 ++++ xllm/core/scheduler/request_priority_queue.h | 2 +- xllm/models/llm/npu/mistral.h | 518 ++++++++++++++++++ xllm/models/llm/npu/mistral3.h | 250 +++++++++ xllm/models/models.h | 1 + 9 files changed, 1338 insertions(+), 1 deletion(-) create mode 100644 xllm/core/layers/npu/loader/mistral_decoder_loader.cpp create mode 100644 xllm/core/layers/npu/loader/mistral_decoder_loader.h create mode 100644 xllm/core/layers/npu/npu_mistral_decoder_layer_impl.cpp create mode 100644 xllm/core/layers/npu/npu_mistral_decoder_layer_impl.h create mode 100644 xllm/models/llm/npu/mistral.h create mode 100644 xllm/models/llm/npu/mistral3.h diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index 1f56d48605..2a69e66b50 100644 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -28,6 +28,7 @@ cc_library( npu_deepseek_v2_decoder_layer_impl.h npu_deepseek_v32_decoder_layer_impl.h npu_llama_decoder_layer_impl.h + npu_mistral_decoder_layer_impl.h npu_qwen2_decoder_layer_impl.h npu_qwen3_decoder_layer_impl.h npu_onerec_block_layer_impl.h @@ -50,6 +51,7 @@ cc_library( loader/glm4_moe_decoder_loader.h loader/glm4_moe_lite_decoder_loader.h loader/llama_decoder_loader.h + loader/mistral_decoder_loader.h loader/qwen2_vision_encoder_loader.h loader/qwen2dot5_vision_encoder_loader.h loader/qwen3_vision_encoder_loader.h @@ -79,6 +81,7 @@ cc_library( npu_deepseek_v2_decoder_layer_impl.cpp npu_deepseek_v32_decoder_layer_impl.cpp npu_llama_decoder_layer_impl.cpp + npu_mistral_decoder_layer_impl.cpp npu_qwen2_decoder_layer_impl.cpp npu_qwen3_decoder_layer_impl.cpp npu_onerec_block_layer_impl.cpp @@ -101,6 +104,7 @@ cc_library( loader/glm4_moe_decoder_loader.cpp loader/glm4_moe_lite_decoder_loader.cpp loader/llama_decoder_loader.cpp + loader/mistral_decoder_loader.cpp loader/qwen2_vision_encoder_loader.cpp loader/qwen2dot5_vision_encoder_loader.cpp loader/qwen3_vision_encoder_loader.cpp diff --git a/xllm/core/layers/npu/loader/mistral_decoder_loader.cpp b/xllm/core/layers/npu/loader/mistral_decoder_loader.cpp new file mode 100644 index 0000000000..41d89421d4 --- /dev/null +++ b/xllm/core/layers/npu/loader/mistral_decoder_loader.cpp @@ -0,0 +1,153 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "core/layers/npu/loader/mistral_decoder_loader.h" + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + + IN_NORM_WEIGHT = 0, // weight + IN_NORM_BIAS, // bias + IN_NORM_NEW_WEIGHT, // new weight + IN_NORM_NEW_BIAS, // new bias + + IN_Q_WEIGHT, // weight + IN_Q_BIAS, // bias + IN_Q_DEQSCALE, // deq_scale + IN_Q_OFFSET, // offset + IN_Q_SCALE, // scale + IN_Q_COMPRESS_IDX, + + IN_K_WEIGHT, // weight + IN_K_BIAS, // bias + IN_K_DEQSCALE, // deq_scale + IN_K_OFFSET, // offset + IN_K_SCALE, // scale + IN_K_COMPRESS_IDX, + + IN_V_WEIGHT, // weight + IN_V_BIAS, // bias + IN_V_DEQSCALE, // deq_scale + IN_V_OFFSET, // offset + IN_V_SCALE, // scale + IN_V_COMPRESS_IDX, + + IN_ATTENTION_OUT_WEIGHT, // weight + IN_ATTENTION_OUT_BIAS, // bias + IN_ATTENTION_OUT_DEQSCALE, // deq_scale + IN_ATTENTION_OUT_OFFSET, // offset + IN_ATTENTION_OUT_SCALE, // scale + IN_ATTENTION_OUT_COMPRESS_IDX, + + IN_SELFOUT_NORM_WEIGHT, // weight + IN_SELFOUT_NORM_BIAS, // bias + IN_SELFOUT_NORM_NEW_WEIGHT, // new weight + IN_SELFOUT_NORM_NEW_BIAS, // new bias + + IN_MLP_W2_WEIGHT, // weight + IN_MLP_W2_BIAS, // bias + IN_MLP_W2_DEQSCALE, // deq_scale + IN_MLP_W2_OFFSET, // offset + IN_MLP_W2_SCALE, // scale + IN_MLP_W2_COMPRESS_IDX, + + IN_MLP_W1_WEIGHT, // weight + IN_MLP_W1_BIAS, // bias + IN_MLP_W1_DEQSCALE, // deq_scale + IN_MLP_W1_OFFSET, // offset + IN_MLP_W1_SCALE, // scale + IN_MLP_W1_COMPRESS_IDX, + + IN_MLP_CPROJ_WEIGHT, // weight + IN_MLP_CPROJ_BIAS, // bias + IN_MLP_CPROJ_DEQSCALE, // deq_scale + IN_MLP_CPROJ_OFFSET, // offset + IN_MLP_CPROJ_SCALE, // scale + IN_MLP_CPROJ_COMPRESS_IDX, +}; + +static std::vector> WEIGHT_MAPPING = { + {IN_NORM_WEIGHT, "input_layernorm.weight"}, + {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, + {IN_K_WEIGHT, "self_attn.k_proj.weight"}, + {IN_V_WEIGHT, "self_attn.v_proj.weight"}, + {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, + {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, + {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, + {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, + {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}, +}; +static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, + {IN_K_WEIGHT, 0}, + {IN_V_WEIGHT, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_W2_WEIGHT, 0}, + {IN_MLP_W1_WEIGHT, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +MistralDecoderLoader::MistralDecoderLoader(uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) { + at_weight_tensors_.resize(weight_count); + + auto options = context.get_tensor_options(); + dtype_ = torch::typeMetaToScalarType(options.dtype()); + + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void MistralDecoderLoader::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void MistralDecoderLoader::merge_loaded_weights() { + auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], + at_weight_tensors_[IN_K_WEIGHT], + at_weight_tensors_[IN_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; + + at_weight_tensors_[IN_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_WEIGHT] = torch::zeros({1}).to(device_); + + auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], + at_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0); + at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight; + + at_weight_tensors_[IN_MLP_W1_WEIGHT] = torch::zeros({1}).to(device_); +} + +void MistralDecoderLoader::load_state_dict(const StateDict& state_dict) { + for (const auto& [index, name] : WEIGHT_MAPPING) { + auto original_tensor = state_dict.get_tensor(name); + + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/npu/loader/mistral_decoder_loader.h b/xllm/core/layers/npu/loader/mistral_decoder_loader.h new file mode 100644 index 0000000000..d117e7aea3 --- /dev/null +++ b/xllm/core/layers/npu/loader/mistral_decoder_loader.h @@ -0,0 +1,39 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class MistralDecoderLoader final : public BaseLoader { + public: + MistralDecoderLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + bool enable_add_norm_; + int32_t rank_id_; +}; + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_mistral_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_mistral_decoder_layer_impl.cpp new file mode 100644 index 0000000000..70c2883e10 --- /dev/null +++ b/xllm/core/layers/npu/npu_mistral_decoder_layer_impl.cpp @@ -0,0 +1,270 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "core/layers/npu/npu_mistral_decoder_layer_impl.h" + +#include +#include + +#include + +#include "common/global_flags.h" +#include "core/layers/common/attention_mask.h" +#include "loader/mistral_decoder_loader.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" + +namespace xllm { +namespace layer { + +const uint64_t kWeightCountPerLayer = 50; + +NpuMistralDecoderLayerImpl::NpuMistralDecoderLayerImpl( + const ModelContext& context) + : BaseLayer(context) { + param_from_args(prefill_param_, + context.get_model_args(), + context.get_parallel_args(), + true); + param_from_args(decode_param_, + context.get_model_args(), + context.get_parallel_args(), + false); + + atb_weight_tensors_.resize(kWeightCountPerLayer); + placeholder_vec_ = {1}; + + auto options = context.get_tensor_options(); + dtype_ = c10::typeMetaToScalarType(options.dtype()); + device_id_ = options.device().index(); + placeholder_ = atb_speed::Utils::AtTensor2Tensor( + torch::zeros({1}).to(device_).to(dtype_)); + + loader_ = + std::make_unique(kWeightCountPerLayer, context); + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); +} + +// fix param +void NpuMistralDecoderLayerImpl::param_from_args( + atb_speed::mistral::MistralLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool isPrefill) { + // Basic settings + param.isFA = false; + param.isPrefill = isPrefill; + param.isBF16 = args.dtype() == "bfloat16"; + param.enableSwiGLU = true; + param.enableLcoc = isPrefill; + param.enableSpeculate = false; + param.enableSplitFuse = FLAGS_enable_chunked_prefill && isPrefill; + param.enableLora = false; + param.loraEnableGMM = false; + + // Quantization settings (adjust as needed) + param.packQuantType = {1, 1}; + param.linearQuantType = {0, -1, -1, 0, 0, -1, 0}; + param.linearTransposeType = {1, -1, -1, 1, 1, -1, 1}; + param.enableKvQuant = false; + param.quantGroupSize = 0; + + // Normalization parameters + param.normEps = args.rms_norm_eps(); // Mistral 7B is typically 1e-5 + param.enableFA3 = false; + + // Parallel settings + param.worldSize = parallel_args.world_size(); + param.numAttentionHeadsPerRank = args.n_heads() / param.worldSize; + + int64_t config_head_dim = args.head_dim(); // Get value directly + if (config_head_dim > 0) { + // Use head_dim from config + param.hiddenSizePerAttentionHead = config_head_dim; + LOG(INFO) << "Using head_dim from config: " << config_head_dim; + } else { + // Computed from + param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads(); + LOG(INFO) << "head_dim not in config or invalid, computed: " + << param.hiddenSizePerAttentionHead; + } + + // GQA settings - Mistral-specific + int n_kv_heads = args.n_kv_heads().value_or(8); // Mistral 7B default is 8 + param.numKeyValueHeadsPerRank = n_kv_heads / param.worldSize; + + param.rank = parallel_args.rank(); + param.backend = "lccl"; + param.tensorParallelInfo = { + parallel_args.rank(), parallel_args.world_size(), "lccl"}; +} + +void NpuMistralDecoderLayerImpl::merge_loaded_weights() { + loader_->merge_loaded_weights(); + + auto& at_weight_tensors = loader_->get_at_weight_tensors(); + Device::empty_cache(device_.index()); + for (int i = 0; i < kWeightCountPerLayer; ++i) { + atb_weight_tensors_[i] = + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); + } + + init_layer(); +} + +int64_t NpuMistralDecoderLayerImpl::init_layer() { + init_attn_mask(); + name_ = "mistral_decoder_layer"; + model_name_ = "mistral"; + CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); + CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + + return atb::NO_ERROR; +} + +int64_t NpuMistralDecoderLayerImpl::init_attn_mask() { + torch::Dtype dtype = + prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16; + decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype); + + return atb::NO_ERROR; +} + +int64_t NpuMistralDecoderLayerImpl::init_node( + atb_speed::Model::Node& node, + atb_speed::mistral::MistralLayerParam& param) { + atb::Operation* operation = nullptr; + atb_speed::mistral::MistralDecoderLayer decoder_layer(param); + decoder_layer.BuildGraph(&operation); + node.operation.reset(operation); + if (node.operation == nullptr) { + LOG(ERROR) << "node.operation is null"; + return -1; + } + if (node.operation->GetInputNum() < 1) { + LOG(ERROR) << "Can not resize number which is smaller than 1"; + return -1; + } + node.inTensors.resize(node.operation->GetInputNum()); + node.outTensors.resize(1); + size_t inTensorId = 1; + + for (size_t weightTensorId = 0; weightTensorId < kWeightCountPerLayer; + ++weightTensorId) { + node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId]; + } + + node.variantPack.inTensors.reserve(node.inTensors.size()); + node.variantPack.inTensors.resize(node.inTensors.size()); + node.variantPack.outTensors.reserve(1); + node.variantPack.outTensors.resize(1); + + return atb::NO_ERROR; +} + +torch::Tensor NpuMistralDecoderLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + int node_id) { + atb::Status st; + + if (!input_params.meta.batch_forward_type.is_decode()) { + build_node_variant_pack(prefill_node_, + x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + true); + // mstxRangeEnd(id); + st = execute_node(prefill_node_, node_id); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute prefill layer fail, error code: " << st; + } else { + build_node_variant_pack(decode_node_, + x, + cos_pos, + sin_pos, + decode_attn_mask_, + kv_cache, + input_params, + false); + st = execute_node(decode_node_, node_id + 1000); + LOG_IF(FATAL, st != 0) << model_name_ + << "excute decode layer fail, error code: " << st; + } + + return at_placeholder_; +} + +void NpuMistralDecoderLayerImpl::build_node_variant_pack( + atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill) { + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); + node.variantPack.inTensors.at(kWeightCountPerLayer) = internal_tensors_; + node.variantPack.inTensors.at(kWeightCountPerLayer + 1) = + atb_speed::Utils::AtTensor2Tensor(cos_pos); + node.variantPack.inTensors.at(kWeightCountPerLayer + 2) = + atb_speed::Utils::AtTensor2Tensor(sin_pos); + node.variantPack.inTensors.at(kWeightCountPerLayer + 3) = + atb_speed::Utils::AtTensor2Tensor(attn_mask); + node.variantPack.inTensors.at(kWeightCountPerLayer + 4) = + atb_speed::Utils::AtTensor2Tensor(kv_cache.get_k_cache()); + node.variantPack.inTensors.at(kWeightCountPerLayer + 5) = + atb_speed::Utils::AtTensor2Tensor(kv_cache.get_v_cache()); + node.variantPack.inTensors.at(kWeightCountPerLayer + 6) = + atb_speed::Utils::AtTensor2Tensor( + input_params.attention.device.kv_seq_lens); + node.variantPack.inTensors.at(kWeightCountPerLayer + 6).hostData = + input_params.attention.host.kv_seq_lens.data(); + node.variantPack.inTensors.at(kWeightCountPerLayer + 7) = placeholder_; + node.variantPack.inTensors.at(kWeightCountPerLayer + 7).hostData = + placeholder_vec_.data(); + node.variantPack.inTensors.at(kWeightCountPerLayer + 8) = placeholder_; + node.variantPack.inTensors.at(kWeightCountPerLayer + 9) = + atb_speed::Utils::AtTensor2Tensor( + input_params.attention.device.block_tables); + node.variantPack.inTensors.at(kWeightCountPerLayer + 10) = + atb_speed::Utils::AtTensor2Tensor( + input_params.attention.device.new_cache_slots); + if (is_prefill && FLAGS_enable_chunked_prefill) { + node.variantPack.inTensors.at(kWeightCountPerLayer + 11) = + atb_speed::Utils::AtTensor2Tensor( + input_params.attention.device.q_seq_lens); + node.variantPack.inTensors.at(kWeightCountPerLayer + 11).hostData = + input_params.attention.host.q_seq_lens.data(); + } + for (size_t i = 0; i < kWeightCountPerLayer; ++i) { + CHECK_THROW(node.inTensors.at(i) == nullptr, + model_name_ << "inTensor " << i << "is NULL"); + node.variantPack.inTensors.at(i) = *node.inTensors.at(i); + } + + node.variantPack.outTensors.at(0) = internal_tensors_; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/npu_mistral_decoder_layer_impl.h b/xllm/core/layers/npu/npu_mistral_decoder_layer_impl.h new file mode 100644 index 0000000000..f9534b73fe --- /dev/null +++ b/xllm/core/layers/npu/npu_mistral_decoder_layer_impl.h @@ -0,0 +1,102 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#include +#else +#include +#include +#endif + +#include + +#include + +#include "atb/atb_infer.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "framework/state_dict/state_dict.h" +#include "nlohmann/json.hpp" +#include "npu_base_layer.h" +#include "pytorch/adapter/utils/utils.h" +#include "xllm_atb_layers/core/include/atb_speed/base/hosttensor_binder.h" +#include "xllm_atb_layers/core/include/atb_speed/base/model.h" +#include "xllm_atb_layers/core/include/atb_speed/log.h" +#include "xllm_atb_layers/core/include/atb_speed/utils/model_factory.h" +#include "xllm_atb_layers/models/mistral/layer/decoder_layer.h" + +namespace xllm { +namespace layer { + +class NpuMistralDecoderLayerImpl final : public BaseLayer { + public: + explicit NpuMistralDecoderLayerImpl(const ModelContext& context); + + ~NpuMistralDecoderLayerImpl() override = default; + + void merge_loaded_weights() override; + + virtual int64_t init_layer() override; + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + int node_id = 0); + + private: + void build_node_variant_pack(atb_speed::Model::Node& node, + torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + bool is_prefill); + + int64_t init_node(atb_speed::Model::Node& node, + atb_speed::mistral::MistralLayerParam& param); + + int64_t init_attn_mask(); + + void param_from_args(atb_speed::mistral::MistralLayerParam& param, + const ModelArgs& args, + const ParallelArgs& parallel_args, + bool isPrefill); + + atb_speed::Model::Node prefill_node_; + atb_speed::Model::Node decode_node_; + std::string model_name_; + atb_speed::mistral::MistralLayerParam prefill_param_; + atb_speed::mistral::MistralLayerParam decode_param_; + atb::Tensor internal_tensors_; + atb::Tensor placeholder_; + + // at::Tensor encode_attn_mask_; + torch::Tensor decode_attn_mask_; + + at::Tensor at_placeholder_; + + int device_id_; +}; +TORCH_MODULE(NpuMistralDecoderLayer); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/scheduler/request_priority_queue.h b/xllm/core/scheduler/request_priority_queue.h index 911fdc6a0b..8b02b3d9b2 100644 --- a/xllm/core/scheduler/request_priority_queue.h +++ b/xllm/core/scheduler/request_priority_queue.h @@ -156,8 +156,8 @@ class HeapQueue final : public RequestPriorityQueue { class SetQueue final : public RequestPriorityQueue { private: using QueueType = std::set, Comparator>; - QueueType queue_; Comparator lower_priority_comparator_; + QueueType queue_; public: explicit SetQueue(Comparator lower_priority_comparator) diff --git a/xllm/models/llm/npu/mistral.h b/xllm/models/llm/npu/mistral.h new file mode 100644 index 0000000000..fc420b21b9 --- /dev/null +++ b/xllm/models/llm/npu/mistral.h @@ -0,0 +1,518 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_args.h" +#include "core/framework/model/model_output.h" +#include "core/layers/common/activation.h" +#include "core/layers/common/attention.h" +#include "core/layers/common/linear.h" +#include "core/layers/common/rotary_embedding.h" // Use existing rotary_embedding +#include "core/layers/npu/npu_mistral_decoder_layer_impl.h" +#include "framework/parallel_state/parallel_args.h" +#include "framework/state_dict/state_dict.h" +#include "models/model_registry.h" + +// Mistral model compatible with huggingface weights +namespace xllm { +torch::Tensor silu(torch::Tensor x) { return x * torch::sigmoid(x); } + +// ==================== Mistral Decoder Layer ==================== + +class MistralDecoderLayerImpl : public torch::nn::Module { + public: + MistralDecoderLayerImpl(const ModelContext& context) { + decoder_layer_ = register_module("decoder_layer", + layer::NpuMistralDecoderLayer(context)); + } + + ModelOutput forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + int node_id) { + auto hidden_states = decoder_layer_( + x, cos_pos, sin_pos, attn_mask, kv_cache, input_params, node_id); + return ModelOutput(hidden_states); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + decoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const {} + + // Add missing lifecycle functions + void merge_loaded_weights() { + if (decoder_layer_) { + decoder_layer_->merge_loaded_weights(); + } + } + + void merge_and_move_pinned_host() { + if (decoder_layer_) { + decoder_layer_->merge_and_move_pinned_host(); + } + } + + void free_weights() { + if (decoder_layer_) { + decoder_layer_->free_weights(); + } + } + + void reload_weights() { + if (decoder_layer_) { + decoder_layer_->reload_weights(); + } + } + + void reload_weights_from_device() { + if (decoder_layer_) { + decoder_layer_->reload_weights_from_device(); + } + } + + torch::Tensor _create_4d_causal_attention_mask(torch::IntArrayRef input_shape, + torch::Dtype dtype, + torch::Device device) { + const int64_t bsz = input_shape[0]; + const int64_t tgt_len = input_shape[1]; + + auto options = torch::TensorOptions().dtype(dtype).device(device); + auto causal_mask = torch::full( + {tgt_len, tgt_len}, -std::numeric_limits::infinity(), options); + causal_mask.triu_(1); + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0); + causal_mask = causal_mask.expand({bsz, 1, tgt_len, tgt_len}); + return causal_mask; + } + + private: + layer::NpuMistralDecoderLayer decoder_layer_{nullptr}; +}; +TORCH_MODULE(MistralDecoderLayer); + +std::tuple get_mistral_rotary_embedding( + int64_t dim, + int64_t seq_len, + double rope_theta, + const torch::TensorOptions& options) { + auto options_new = + torch::device(options.device()).dtype(at::ScalarType::Double); + auto inv_freq = + 1.0 / torch::pow(rope_theta, torch::arange(0, dim, 2, options_new) / dim) + .to(at::ScalarType::Float); + auto seq_idx = torch::arange(seq_len, options_new); + + auto freqs = torch::ger(seq_idx, inv_freq).to(torch::kFloat32); + auto emb = torch::cat({freqs, freqs}, -1); + auto rope_cos = torch::cos(emb); + auto rope_sin = torch::sin(emb); + + auto dtype = options.dtype(); + if (dtype == torch::kFloat16 || dtype == torch::kBFloat16 || + dtype == torch::kInt8) { + if (dtype == torch::kBFloat16) { + rope_cos = rope_cos.to(torch::kBFloat16); + rope_sin = rope_sin.to(torch::kBFloat16); + } else { + rope_cos = rope_cos.to(torch::kFloat16); + rope_sin = rope_sin.to(torch::kFloat16); + } + } + return std::make_tuple(rope_cos, rope_sin); +} + +// ==================== Mistral Model ==================== + +class MistralModelImpl : public torch::nn::Module { + public: + MistralModelImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(context.get_model_args().n_layers()); + npu_embed_tokens_ = + register_module("npu_embed_tokens", layer::NpuWordEmbedding(context)); + norm_ = register_module("norm", layer::NpuRMSNorm(context)); + std::tie(cos_pos_, sin_pos_) = + get_mistral_rotary_embedding(128, + model_args.max_position_embeddings(), + model_args.rope_theta(), + options); + + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + max_seq_len_ = 0; + + for (int32_t i = 0; i < model_args.n_layers(); i++) { + auto block = MistralDecoderLayer(context); + layers_.push_back(block); + blocks_->push_back(block); + } + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + ModelOutput forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::Tensor h = npu_embed_tokens_(tokens, 0); + auto cos_pos = cos_pos_.index_select(0, positions); + auto sin_pos = sin_pos_.index_select(0, positions); + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor max_of_seq = + torch::max(input_params.attention.device.kv_seq_lens); + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(max_of_seq.item(), max_seq_len_) + : 128; + + torch::Tensor attn_mask; + + if (FLAGS_enable_chunked_prefill) { + LOG(FATAL) + << "Flux2 text encoder (Mistral) does not support chunked_prefill. " + << "Please set --enable_chunked_prefill=false and restart."; + // Use the original logic + int max_kv_seq = input_params.meta.kv_max_seq_len; + int num_sequences = input_params.meta.num_sequences; + if (num_sequences > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(num_sequences); + for (int j = 0; j < num_sequences; j++) { + auto mask = attn_mask_.gen_append_mask( + input_params.attention.host.q_seq_lens[j], + input_params.attention.host.kv_seq_lens[j], + max_kv_seq, + cos_pos.dtype().toScalarType(), + cos_pos.device()); + req_mask_vec.emplace_back(mask); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } else if (input_params.meta.batch_forward_type.is_prefill()) { + int64_t seq_len = h.size(0); + bool is_bf16 = (cos_pos.scalar_type() == torch::kBFloat16); + // float min_dtype = is_bf16 ? -3.389538e+38f : -65504.0f; + float min_dtype = 1.0f; + auto opts = torch::TensorOptions() + .dtype(cos_pos.dtype().toScalarType()) + .device(cos_pos.device()); + + // Detect left-padding from token ID (pad_token_id=770) + auto is_pad = (tokens == 11); // [seqLen] + int64_t pad_count = 0; + // left-padding: consecutive pad tokens starting from position 0 + auto is_pad_cpu = is_pad.cpu(); + for (int64_t i = 0; i < seq_len; ++i) { + if (is_pad_cpu[i].item()) { + pad_count++; + } else { + break; + } + } + // Create causal mask [seq_len, seq_len] + auto causal = torch::zeros({seq_len, seq_len}, opts); + auto upper = torch::ones({seq_len, seq_len}, opts); + upper.triu_(1); + causal = causal + upper * min_dtype; + + if (pad_count > 0) { + // left-padding: mask all rows and columns of the first pad_count + causal.slice(0, 0, pad_count) = min_dtype; + causal.slice(1, 0, pad_count) = min_dtype; + } + attn_mask = causal; + } + + std::vector all_layer_hidden_states; + all_layer_hidden_states.reserve(layers_.size()); + + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + layer(h, cos_pos, sin_pos, attn_mask, kv_caches[i], input_params_new, i); + + all_layer_hidden_states.emplace_back( + h.clone()); // Collect the output of each layer + } + + auto hidden_states = norm_(h, 0); + auto stacked = torch::stack(all_layer_hidden_states, /*dim=*/0); + return ModelOutput(hidden_states, torch::Tensor(), stacked); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + npu_embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // rotary_emb has no weights to load (all are buffers) + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + npu_embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + + "."); + } + norm_->verify_loaded_weights(prefix + "norm."); + } + + // Add missing lifecycle functions + void merge_loaded_weights() { + LOG(INFO) << "Merging loaded weights for MistralModel"; + + if (npu_embed_tokens_) { + npu_embed_tokens_->merge_loaded_weights(); + } + + for (auto& layer : layers_) { + if (layer) { + layer->merge_loaded_weights(); + } + } + + if (norm_) { + norm_->merge_loaded_weights(); + } + + LOG(INFO) << "MistralModel merge_loaded_weights completed"; + } + + void merge_and_move_pinned_host() { + if (npu_embed_tokens_) { + npu_embed_tokens_->merge_and_move_pinned_host(); + } + + for (auto& layer : layers_) { + if (layer) { + layer->merge_and_move_pinned_host(); + } + } + + if (norm_) { + norm_->merge_and_move_pinned_host(); + } + } + + void free_weights() { + if (npu_embed_tokens_) { + npu_embed_tokens_->free_weights(); + } + + for (auto& layer : layers_) { + if (layer) { + layer->free_weights(); + } + } + + if (norm_) { + norm_->free_weights(); + } + } + + void reload_weights() { + if (npu_embed_tokens_) { + npu_embed_tokens_->reload_weights(); + } + + for (auto& layer : layers_) { + if (layer) { + layer->reload_weights(); + } + } + + if (norm_) { + norm_->reload_weights(); + } + } + + void reload_weights_from_device() { + if (npu_embed_tokens_) { + npu_embed_tokens_->reload_weights_from_device(); + } + + for (auto& layer : layers_) { + if (layer) { + layer->reload_weights_from_device(); + } + } + + if (norm_) { + norm_->reload_weights_from_device(); + } + } + + private: + torch::Tensor cos_pos_; + torch::Tensor sin_pos_; + int max_seq_len_ = 0; + int device_id_ = 0; + layer::AttentionMask attn_mask_; + layer::NpuWordEmbedding npu_embed_tokens_{nullptr}; + layer::NpuRMSNorm norm_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + // hold same data but different type as blocks_ to avoid type cast + std::vector layers_; +}; +TORCH_MODULE(MistralModel); + +// ==================== Mistral For Causal LM ==================== + +class MistralForCausalLMImpl : public torch::nn::Module { + public: + MistralForCausalLMImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + + // register submodules + model_ = register_module("model", MistralModel(context)); + + lm_head_ = register_module("npu_lm_head", layer::NpuLmHead(context)); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + ModelOutput forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + auto hidden_states = model_(tokens, positions, kv_caches, input_params); + return ModelOutput(hidden_states); + } + + virtual torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + return h; + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return lm_head_(hidden_states, seleted_idxes, 0); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + model_->load_state_dict(state_dict.get_dict_with_prefix("model.")); + lm_head_->load_state_dict(state_dict.get_dict_with_prefix("lm_head.")); + } + + void verify_loaded_weights() const { + model_->verify_loaded_weights("model."); + lm_head_->verify_loaded_weights("lm_head."); + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + void load_model(std::unique_ptr loader) { + LOG(INFO) << "Loading MistralForCausalLM from ModelLoader..."; + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict( + state_dict->get_dict_with_prefix("language_model.")); + lm_head_->load_state_dict( + state_dict->get_dict_with_prefix("language_model.lm_head.")); + } + // Critical: add merge_loaded_weights call! + if (model_) { + model_->merge_loaded_weights(); + } + if (lm_head_) { + lm_head_->merge_loaded_weights(); + } + model_->verify_loaded_weights("language_model."); + lm_head_->verify_loaded_weights("language_model.lm_head."); + LOG(INFO) << "MistralForCausalLM loaded successfully."; + } + + private: + // parameter members, must be registered + MistralModel model_{nullptr}; + layer::NpuLmHead lm_head_{nullptr}; +}; +TORCH_MODULE(MistralForCausalLM); + +// ==================== Registration ==================== + +REGISTER_CAUSAL_MODEL(mistral, MistralForCausalLM); +REGISTER_MODEL_ARGS(mistral, [&] { + LOAD_ARG_OR(model_type, "model_type", "mistral"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 32000); + LOAD_ARG_OR(hidden_size, "hidden_size", 4096); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 32); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 8); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 14336); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 4096 * 32); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-5); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 1); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 2); + LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); + + LOAD_ARG_OR(rope_scaling_rope_type, "rope_scaling_rope_type", "default"); + LOAD_ARG_OR(rope_scaling_factor, "rope_scaling_factor", 1.0f); + LOAD_ARG_OR(rope_scaling_original_max_position_embeddings, + "rope_scaling_original_max_position_embeddings", + 4096); + LOAD_ARG_OR(rope_extrapolation_factor, "rope_extrapolation_factor", 1.0f); + LOAD_ARG_OR(rope_scaling_attn_factor, "rope_scaling_attn_factor", 1.0f); + LOAD_ARG_OR(rope_scaling_beta_fast, "rope_scaling_beta_fast", 32.0f); + LOAD_ARG_OR(rope_scaling_beta_slow, "rope_scaling_beta_slow", 1.0f); + LOAD_ARG_OR(rope_scaling_mscale, "rope_scaling_mscale", 1.0f); + LOAD_ARG_OR(rope_scaling_mscale_all_dim, "rope_scaling_mscale_all_dim", 1.0f); + + // head_dim needs to be calculated from hidden_size and n_heads, must use + // LOAD_ARG_OR_FUNC + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); +}); + +} // namespace xllm \ No newline at end of file diff --git a/xllm/models/llm/npu/mistral3.h b/xllm/models/llm/npu/mistral3.h new file mode 100644 index 0000000000..40a4e094ca --- /dev/null +++ b/xllm/models/llm/npu/mistral3.h @@ -0,0 +1,250 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model_context.h" +#include "core/util/timer.h" +#include "llm_model_base.h" +#include "mistral.h" +#include "models/model_registry.h" +namespace xllm { +// Mistral3 model (without LM head) +class Mistral3ModelImpl : public torch::nn::Module { + public: + explicit Mistral3ModelImpl(const ModelContext& context) { + language_model_ = register_module("language_model", MistralModel(context)); + } + + ModelOutput forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + auto output = + language_model_->forward(tokens, positions, kv_caches, input_params); + return output; // Return directly, including aux_hidden_states + } + + void load_state_dict(const StateDict& state_dict) { + language_model_->load_state_dict(state_dict.get_dict_with_prefix("model.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + language_model_->verify_loaded_weights(prefix + "model."); + } + + // Add lifecycle functions + void merge_loaded_weights() { + LOG(INFO) << "Merging loaded weights for Mistral3Model"; + if (language_model_) { + language_model_->merge_loaded_weights(); + } + } + + void merge_and_move_pinned_host() { + if (language_model_) { + language_model_->merge_and_move_pinned_host(); + } + } + + void free_weights() { + if (language_model_) { + language_model_->free_weights(); + } + } + + void reload_weights() { + if (language_model_) { + language_model_->reload_weights(); + } + } + + void reload_weights_from_device() { + if (language_model_) { + language_model_->reload_weights_from_device(); + } + } + + private: + MistralModel language_model_{nullptr}; +}; +TORCH_MODULE(Mistral3Model); + +// Mistral3 model for conditional generation (text-only) +class Mistral3ForConditionalGenerationImpl : public torch::nn::Module { + public: + Mistral3ForConditionalGenerationImpl(const ModelContext& context) { + // register submodules + model_ = register_module("model", Mistral3Model(context)); + + lm_head_ = register_module("npu_lm_head", layer::NpuLmHead(context)); + } + + ModelOutput forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + auto model_output = model_(tokens, positions, kv_caches, input_params); + auto indices = torch::tensor({9, 19, 29}, torch::kLong) + .to(model_output.aux_hidden_states.device()); + auto selected = + model_output.aux_hidden_states.index_select(/*dim=*/0, indices); + + return ModelOutput(selected); + } + + virtual torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + // Add FLAGS_enable_return_mm_full_embeddings + auto h = hidden_states; + // return full embeddings if set flag + if (FLAGS_enable_return_mm_full_embeddings) { + return h; + } + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + auto pooler_output = torch::nn::functional::normalize( + h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); + return pooler_output; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return lm_head_(hidden_states, seleted_idxes, 0); + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + void load_state_dict(const StateDict& state_dict) { + model_->load_state_dict(state_dict.get_dict_with_prefix("language_model.")); + lm_head_->load_state_dict( + state_dict.get_dict_with_prefix("language_model.lm_head.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + model_->verify_loaded_weights(prefix); + lm_head_->verify_loaded_weights(prefix + "lm_head."); + } + + void load_model(std::unique_ptr loader) { + LOG(INFO) << "Loading Mistral3ForConditionalGeneration from ModelLoader..."; + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict( + state_dict->get_dict_with_prefix("language_model.")); + lm_head_->load_state_dict( + state_dict->get_dict_with_prefix("language_model.lm_head.")); + } + + // Critical: add merge_loaded_weights call! + if (model_) { + model_->merge_loaded_weights(); + } + if (lm_head_) { + lm_head_->merge_loaded_weights(); + } + + model_->verify_loaded_weights("language_model."); + lm_head_->verify_loaded_weights("language_model.lm_head."); + LOG(INFO) << "Mistral3ForConditionalGeneration loaded successfully."; + } + + // Add lifecycle functions + void merge_loaded_weights() { + if (model_) { + model_->merge_loaded_weights(); + } + if (lm_head_) { + lm_head_->merge_loaded_weights(); + } + } + + void merge_and_move_pinned_host() { + if (model_) { + model_->merge_and_move_pinned_host(); + } + if (lm_head_) { + lm_head_->merge_and_move_pinned_host(); + } + } + + void free_weights() { + if (model_) { + model_->free_weights(); + } + if (lm_head_) { + lm_head_->free_weights(); + } + } + + void reload_weights() { + if (model_) { + model_->reload_weights(); + } + if (lm_head_) { + lm_head_->reload_weights(); + } + } + + void reload_weights_from_device() { + if (model_) { + model_->reload_weights_from_device(); + } + if (lm_head_) { + lm_head_->reload_weights_from_device(); + } + } + + private: + // parameter members, must be registered + Mistral3Model model_{nullptr}; + layer::NpuLmHead lm_head_{nullptr}; +}; +TORCH_MODULE(Mistral3ForConditionalGeneration); + +// Model registration +REGISTER_CAUSAL_MODEL(mistral3, Mistral3ForConditionalGeneration); + +REGISTER_MODEL_ARGS(mistral3, [&] { + LOAD_ARG_OR(model_type, "model_type", "mistral3"); + LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16"); + LOAD_ARG_OR(vocab_size, "vocab_size", 131072); + LOAD_ARG_OR(hidden_size, "hidden_size", 5120); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 32768); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 40); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 8); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 131072); + LOAD_ARG_OR(head_dim, "head_dim", 128); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-5); + LOAD_ARG_OR(rope_theta, "rope_theta", 1e9); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 1); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 2); +}); + +} // namespace xllm diff --git a/xllm/models/models.h b/xllm/models/models.h index 0253387c58..1efd77bb59 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -35,6 +35,7 @@ limitations under the License. #include "llm/npu/kimi_k2.h" // IWYU pragma: keep #include "llm/npu/llama.h" // IWYU pragma: keep #include "llm/npu/llama3.h" // IWYU pragma: keep +#include "llm/npu/mistral3.h" // IWYU pragma: keep #include "llm/npu/oxygen.h" // IWYU pragma: keep #include "llm/npu/qwen2.h" // IWYU pragma: keep #include "llm/npu/qwen3.h" // IWYU pragma: keep