diff --git a/tests/api_service/sample_service_impl_test.cpp b/tests/api_service/sample_service_impl_test.cpp index 62c42912d9..44062693ba 100644 --- a/tests/api_service/sample_service_impl_test.cpp +++ b/tests/api_service/sample_service_impl_test.cpp @@ -32,7 +32,8 @@ class CharTokenizer final : public Tokenizer { public: bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override { + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override { if (ids == nullptr) { return false; } @@ -83,7 +84,8 @@ class UnstableLiteralTokenizer final : public Tokenizer { public: bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override { + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override { if (ids == nullptr) { return false; } diff --git a/tests/core/framework/request/sample_slot_test.cpp b/tests/core/framework/request/sample_slot_test.cpp index 6c255c22f1..49acd6763c 100644 --- a/tests/core/framework/request/sample_slot_test.cpp +++ b/tests/core/framework/request/sample_slot_test.cpp @@ -37,7 +37,8 @@ class CharTokenizer final : public Tokenizer { public: bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override { + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override { if (ids == nullptr) { return false; } @@ -138,7 +139,8 @@ class UnstableLiteralTokenizer final : public Tokenizer { public: bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override { + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override { if (ids == nullptr) { return false; } diff --git a/tests/core/scheduler/fixed_steps_scheduler_test.cpp b/tests/core/scheduler/fixed_steps_scheduler_test.cpp index 836fe08f98..930af82b2c 100644 --- a/tests/core/scheduler/fixed_steps_scheduler_test.cpp +++ b/tests/core/scheduler/fixed_steps_scheduler_test.cpp @@ -34,10 +34,12 @@ class FakeTokenizer : public Tokenizer { public: bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override { + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override { (void)text; (void)ids; (void)add_special_tokens; + (void)max_sequence_length; return false; } std::string decode(const Slice& ids, diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 0a2e23e47a..f3de39579e 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -205,6 +205,8 @@ DECLARE_uint32(prefetch_timeout); DECLARE_uint32(prefetch_batch_size); +DECLARE_uint32(prefetch_bacth_size); + DECLARE_uint32(layers_wise_copy_batchs); DECLARE_double(host_blocks_factor); @@ -335,6 +337,9 @@ DECLARE_bool(dit_debug_print); DECLARE_bool(use_audio_in_video); +// --- mistral prompt to message config --- +DECLARE_bool(enable_mistral_prompt_to_message); + // --- kernel config --- #if defined(USE_NPU) DECLARE_bool(enable_customize_mla_kernel); @@ -347,4 +352,4 @@ DECLARE_bool(enable_intralayer_addnorm); // --- chat template config --- DECLARE_bool(use_cpp_chat_template); -DECLARE_int32(health_check_interval_ms); +DECLARE_int32(health_check_interval_ms); \ No newline at end of file diff --git a/xllm/core/distributed_runtime/llm_master.cpp b/xllm/core/distributed_runtime/llm_master.cpp index 449268912b..5f3bd7ee93 100644 --- a/xllm/core/distributed_runtime/llm_master.cpp +++ b/xllm/core/distributed_runtime/llm_master.cpp @@ -292,7 +292,76 @@ std::shared_ptr LLMMaster::generate_request( sp.source_xservice_addr); return nullptr; } + if (FLAGS_enable_mistral_prompt_to_message) { + // Check if prompt is already a formatted chat template string (not JSON) + // If it starts with '<' (e.g., "[SYSTEM_PROMPT]..."), skip JSON parsing + bool is_json_input = false; + std::string trimmed = prompt; + // Trim leading whitespace + size_t start = trimmed.find_first_not_of(" \t\n\r"); + if (start != std::string::npos && start > 0) { + trimmed = trimmed.substr(start); + } + is_json_input = + !trimmed.empty() && (trimmed[0] == '{' || trimmed[0] == '['); + + if (is_json_input) { + LOG(INFO) << "llm_master prompt (JSON input):" << prompt; + // 1. Preprocess using LlmChatJsonParser + const ChatJsonParser& parser = + xllm::ChatJsonParser::get(ServingMode::LLM); + auto [status, processed_json] = parser.preprocess(prompt); + if (!status.ok()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Preprocess failed: " + status.message(), + sp.service_request_id, + sp.source_xservice_addr); + LOG(ERROR) << "Preprocess failed: " << status.message(); + return nullptr; + } + + // 2. Parse into Message objects + std::vector messages; + try { + auto json = nlohmann::json::parse(processed_json); + if (json.contains("messages")) { + for (const auto& msg : json["messages"]) { + messages.emplace_back(msg["role"].get(), + msg["content"].get()); + } + } + } catch (const nlohmann::json::exception& e) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "JSON parse failed: " + std::string(e.what()), + sp.service_request_id, + sp.source_xservice_addr); + LOG(ERROR) << "JSON parse failed: " << e.what(); + return nullptr; + } + // 3. Call the message version of generate_request + Timer timer; + std::optional formatted_prompt; + formatted_prompt = + chat_template_->apply(messages, sp.tools, sp.chat_template_kwargs); + if (!formatted_prompt.has_value()) { + CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, + "Failed to construct prompt from messages", + sp.service_request_id, + sp.source_xservice_addr); + LOG(ERROR) << "Failed to construct prompt from messages"; + return nullptr; + } + + COUNTER_ADD(chat_template_latency_seconds, timer.elapsed_seconds()); + + prompt = std::move(formatted_prompt.value()); + prompt_tokens = std::nullopt; + } else { + // Prompt is already a formatted chat template string, use directly + LOG(INFO) << "llm_master prompt (already formatted):" << prompt; + } + } // encode the prompt Timer timer; std::vector local_prompt_tokens; @@ -300,8 +369,12 @@ std::shared_ptr LLMMaster::generate_request( if (prompt_tokens.has_value()) { local_prompt_tokens = std::move(prompt_tokens.value()); } else { - if (!tokenizer_->encode( - prompt, &local_prompt_tokens, sp.add_special_tokens)) { + if (!tokenizer_->encode(prompt, + &local_prompt_tokens, + FLAGS_enable_mistral_prompt_to_message + ? false + : sp.add_special_tokens, + FLAGS_enable_mistral_prompt_to_message ? 512 : 0)) { LOG(ERROR) << "Failed to encode prompt: " << prompt; CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "Failed to encode prompt", diff --git a/xllm/core/distributed_runtime/llm_master.h b/xllm/core/distributed_runtime/llm_master.h index dcce1c76e4..27b803b80c 100644 --- a/xllm/core/distributed_runtime/llm_master.h +++ b/xllm/core/distributed_runtime/llm_master.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "api_service/chat_json_parser.h" +#include "api_service/serving_mode.h" #include "common/options.h" #include "common/rate_limiter.h" #include "framework/chat_template/chat_template.h" @@ -31,7 +33,6 @@ limitations under the License. #include "llm_engine.h" #include "master.h" #include "scheduler/continuous_scheduler.h" - namespace xllm { class Call; diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index ac5ec7f86f..a71eafe670 100644 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -486,7 +486,7 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output, const auto sequences = get_sequences(); for (auto* seq : sequences) { int64_t n_images = seq->get_mm_data().size(); - if (n_images <= 0) { + if (!FLAGS_enable_mistral_prompt_to_message && n_images <= 0) { continue; } std::vector seq_mm_embeddings; diff --git a/xllm/core/framework/config/model_config.cpp b/xllm/core/framework/config/model_config.cpp index 8133e7740e..4db93d25a2 100644 --- a/xllm/core/framework/config/model_config.cpp +++ b/xllm/core/framework/config/model_config.cpp @@ -71,6 +71,10 @@ DEFINE_bool( false, "Whether to decode both audio and video when the input is a video."); +DEFINE_bool(enable_mistral_prompt_to_message, + false, + "Whether to enable mistral prompt to message conversion."); + // NOTE: This is an experimental flag, // it needs to be removed after the function is stable. DEFINE_bool(use_cpp_chat_template, diff --git a/xllm/core/framework/tokenizer/fast_tokenizer.cpp b/xllm/core/framework/tokenizer/fast_tokenizer.cpp index 7dafb0358d..36c5beb244 100644 --- a/xllm/core/framework/tokenizer/fast_tokenizer.cpp +++ b/xllm/core/framework/tokenizer/fast_tokenizer.cpp @@ -71,7 +71,8 @@ bool add_special_token_id(const std::string& token, bool FastTokenizer::encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens) const { + bool add_special_tokens, + int32_t max_sequence_length) const { TokenizerEncodeResult result; tokenizers_encode( handle_, text.data(), text.size(), add_special_tokens, &result); @@ -103,6 +104,16 @@ bool FastTokenizer::encode(const std::string_view& text, ids, /*prepend=*/false); } + // Add pad to max_sequence_length if configured + if (max_sequence_length > 0 && !tokenizer_args_.pad_token().empty()) { + const auto pad_id = token_to_id(tokenizer_args_.pad_token()); + if (pad_id.has_value() && + static_cast(ids->size()) < max_sequence_length) { + int32_t pad_count = + max_sequence_length - static_cast(ids->size()); + ids->insert(ids->begin(), pad_count, pad_id.value()); + } + } return true; } diff --git a/xllm/core/framework/tokenizer/fast_tokenizer.h b/xllm/core/framework/tokenizer/fast_tokenizer.h index c62b4f3ee1..f1464f4893 100644 --- a/xllm/core/framework/tokenizer/fast_tokenizer.h +++ b/xllm/core/framework/tokenizer/fast_tokenizer.h @@ -30,7 +30,8 @@ class FastTokenizer : public Tokenizer { bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override; + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override; std::string decode(const Slice& ids, bool skip_special_tokens) const override; diff --git a/xllm/core/framework/tokenizer/sentencepiece_tokenizer.cpp b/xllm/core/framework/tokenizer/sentencepiece_tokenizer.cpp index 31820f1dd5..3aa7dc8728 100644 --- a/xllm/core/framework/tokenizer/sentencepiece_tokenizer.cpp +++ b/xllm/core/framework/tokenizer/sentencepiece_tokenizer.cpp @@ -129,7 +129,8 @@ bool SentencePieceTokenizer::encode_internal(const std::string_view& text, bool SentencePieceTokenizer::encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens) const { + bool add_special_tokens, + int32_t max_sequence_length) const { // prepend prefix tokens if exists if (!prefix_token_ids_.empty()) { ids->insert( diff --git a/xllm/core/framework/tokenizer/sentencepiece_tokenizer.h b/xllm/core/framework/tokenizer/sentencepiece_tokenizer.h index 6f19f33a13..c616fb65fa 100644 --- a/xllm/core/framework/tokenizer/sentencepiece_tokenizer.h +++ b/xllm/core/framework/tokenizer/sentencepiece_tokenizer.h @@ -34,7 +34,8 @@ class SentencePieceTokenizer : public Tokenizer { bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override; + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override; std::string decode(const Slice& ids, bool skip_special_tokens) const override; diff --git a/xllm/core/framework/tokenizer/tiktoken_tokenizer.cpp b/xllm/core/framework/tokenizer/tiktoken_tokenizer.cpp index fdd87d6914..3590b3498b 100644 --- a/xllm/core/framework/tokenizer/tiktoken_tokenizer.cpp +++ b/xllm/core/framework/tokenizer/tiktoken_tokenizer.cpp @@ -255,7 +255,8 @@ void TiktokenTokenizer::encode_internal(const std::string_view& text, bool TiktokenTokenizer::encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens) const { + bool add_special_tokens, + int32_t max_sequence_length) const { // prepend prefix tokens if exists if (!prefix_token_ids_.empty()) { ids->insert( diff --git a/xllm/core/framework/tokenizer/tiktoken_tokenizer.h b/xllm/core/framework/tokenizer/tiktoken_tokenizer.h index 00d79dbcb1..aa7eee2ad8 100644 --- a/xllm/core/framework/tokenizer/tiktoken_tokenizer.h +++ b/xllm/core/framework/tokenizer/tiktoken_tokenizer.h @@ -35,7 +35,8 @@ class TiktokenTokenizer : public Tokenizer { bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override; + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override; std::string decode(const Slice& ids, bool skip_special_tokens) const override; diff --git a/xllm/core/framework/tokenizer/tokenizer.h b/xllm/core/framework/tokenizer/tokenizer.h index 3492330401..a8004c0483 100644 --- a/xllm/core/framework/tokenizer/tokenizer.h +++ b/xllm/core/framework/tokenizer/tokenizer.h @@ -31,7 +31,8 @@ class Tokenizer { virtual bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const { + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const { return false; } diff --git a/xllm/core/framework/tokenizer/tokenizer_proxy.cpp b/xllm/core/framework/tokenizer/tokenizer_proxy.cpp index 7f3119be2d..4f13058e15 100644 --- a/xllm/core/framework/tokenizer/tokenizer_proxy.cpp +++ b/xllm/core/framework/tokenizer/tokenizer_proxy.cpp @@ -32,8 +32,10 @@ std::unique_ptr TokenizerProxy::clone() const { bool TokenizerProxy::encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens) const { - return get_tls_tokenizer()->encode(text, ids, add_special_tokens); + bool add_special_tokens, + int32_t max_sequence_length) const { + return get_tls_tokenizer()->encode( + text, ids, add_special_tokens, max_sequence_length); } bool TokenizerProxy::encode(int64_t item_id, diff --git a/xllm/core/framework/tokenizer/tokenizer_proxy.h b/xllm/core/framework/tokenizer/tokenizer_proxy.h index e22b02f75e..1bf46dac5e 100644 --- a/xllm/core/framework/tokenizer/tokenizer_proxy.h +++ b/xllm/core/framework/tokenizer/tokenizer_proxy.h @@ -27,7 +27,8 @@ class TokenizerProxy : public Tokenizer { bool encode(const std::string_view& text, std::vector* ids, - bool add_special_tokens = true) const override; + bool add_special_tokens = true, + int32_t max_sequence_length = 0) const override; bool encode(int64_t item_id, std::vector* token_ids) const override; diff --git a/xllm/core/layers/common/add_matmul.cpp b/xllm/core/layers/common/add_matmul.cpp index 4771b63383..7c0f1f874a 100644 --- a/xllm/core/layers/common/add_matmul.cpp +++ b/xllm/core/layers/common/add_matmul.cpp @@ -129,10 +129,9 @@ void AddMatmulWeightTransposedImpl::load_state_dict( if (state_dict.has("weight")) { xllm::weight::load_weight(state_dict, "weight", weight_, weight_is_loaded_); // weight need to be transposed when using addmm - if (with_bias_) { - torch::Tensor transposed = weight_.data().transpose(0, 1).contiguous(); - weight_.set_data(transposed); - } + + torch::Tensor transposed = weight_.data().transpose(0, 1).contiguous(); + weight_.set_data(transposed); } if (with_bias_) { weight::load_weight(state_dict, "bias", bias_, bias_is_loaded_); diff --git a/xllm/core/runtime/embed_worker_impl.cpp b/xllm/core/runtime/embed_worker_impl.cpp index bc137ba842..e3c7ed6719 100644 --- a/xllm/core/runtime/embed_worker_impl.cpp +++ b/xllm/core/runtime/embed_worker_impl.cpp @@ -86,6 +86,17 @@ std::optional EmbedWorkerImpl::step(const ForwardInput& input) { auto embeddings = model_->pooler(hidden_states, sampling_params.selected_token_idxes); sample_output.embeddings = embeddings; + if (FLAGS_enable_return_mm_full_embeddings) { + auto q_seq_len_vec = input.input_params.attention.host.q_seq_lens; + sample_output.mm_embeddings.reserve(q_seq_len_vec.size()); + int32_t token_start_idx = 0; + for (auto seq_len : q_seq_len_vec) { + auto image_embed = + embeddings.slice(0, token_start_idx, token_start_idx + seq_len); + sample_output.mm_embeddings.emplace_back(image_embed); + token_start_idx += seq_len; + } + } COUNTER_ADD(execution_latency_seconds_sampling, timer.elapsed_seconds()); // set sample output to output diff --git a/xllm/core/scheduler/request_priority_queue.h b/xllm/core/scheduler/request_priority_queue.h index 911fdc6a0b..8b02b3d9b2 100644 --- a/xllm/core/scheduler/request_priority_queue.h +++ b/xllm/core/scheduler/request_priority_queue.h @@ -156,8 +156,8 @@ class HeapQueue final : public RequestPriorityQueue { class SetQueue final : public RequestPriorityQueue { private: using QueueType = std::set, Comparator>; - QueueType queue_; Comparator lower_priority_comparator_; + QueueType queue_; public: explicit SetQueue(Comparator lower_priority_comparator)