Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/api_service/sample_service_impl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class CharTokenizer final : public Tokenizer {
public:
bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override {
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override {
if (ids == nullptr) {
return false;
}
Expand Down Expand Up @@ -83,7 +84,8 @@ class UnstableLiteralTokenizer final : public Tokenizer {
public:
bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override {
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override {
if (ids == nullptr) {
return false;
}
Expand Down
6 changes: 4 additions & 2 deletions tests/core/framework/request/sample_slot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class CharTokenizer final : public Tokenizer {
public:
bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override {
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override {
if (ids == nullptr) {
return false;
}
Expand Down Expand Up @@ -138,7 +139,8 @@ class UnstableLiteralTokenizer final : public Tokenizer {
public:
bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override {
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override {
if (ids == nullptr) {
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion tests/core/scheduler/fixed_steps_scheduler_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class FakeTokenizer : public Tokenizer {
public:
bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override {
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override {
(void)text;
(void)ids;
(void)add_special_tokens;
(void)max_sequence_length;
return false;
}
std::string decode(const Slice<int32_t>& ids,
Expand Down
7 changes: 6 additions & 1 deletion xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ DECLARE_uint32(prefetch_timeout);

DECLARE_uint32(prefetch_batch_size);

DECLARE_uint32(prefetch_bacth_size);

DECLARE_uint32(layers_wise_copy_batchs);

DECLARE_double(host_blocks_factor);
Expand Down Expand Up @@ -335,6 +337,9 @@ DECLARE_bool(dit_debug_print);

DECLARE_bool(use_audio_in_video);

// --- mistral prompt to message config ---
DECLARE_bool(enable_mistral_prompt_to_message);

// --- kernel config ---
#if defined(USE_NPU)
DECLARE_bool(enable_customize_mla_kernel);
Expand All @@ -347,4 +352,4 @@ DECLARE_bool(enable_intralayer_addnorm);
// --- chat template config ---
DECLARE_bool(use_cpp_chat_template);

DECLARE_int32(health_check_interval_ms);
DECLARE_int32(health_check_interval_ms);
77 changes: 75 additions & 2 deletions xllm/core/distributed_runtime/llm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,89 @@ std::shared_ptr<Request> LLMMaster::generate_request(
sp.source_xservice_addr);
return nullptr;
}
if (FLAGS_enable_mistral_prompt_to_message) {
// Check if prompt is already a formatted chat template string (not JSON)
// If it starts with '<' (e.g., "<s>[SYSTEM_PROMPT]..."), skip JSON parsing
bool is_json_input = false;
std::string trimmed = prompt;
// Trim leading whitespace
size_t start = trimmed.find_first_not_of(" \t\n\r");
if (start != std::string::npos && start > 0) {
trimmed = trimmed.substr(start);
}
is_json_input =
!trimmed.empty() && (trimmed[0] == '{' || trimmed[0] == '[');

if (is_json_input) {
LOG(INFO) << "llm_master prompt (JSON input):" << prompt;
// 1. Preprocess using LlmChatJsonParser
const ChatJsonParser& parser =
xllm::ChatJsonParser::get(ServingMode::LLM);
auto [status, processed_json] = parser.preprocess(prompt);
if (!status.ok()) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"Preprocess failed: " + status.message(),
sp.service_request_id,
sp.source_xservice_addr);
LOG(ERROR) << "Preprocess failed: " << status.message();
return nullptr;
Comment thread
wang-shuibin marked this conversation as resolved.
}

// 2. Parse into Message objects
std::vector<Message> messages;
try {
auto json = nlohmann::json::parse(processed_json);
if (json.contains("messages")) {
for (const auto& msg : json["messages"]) {
messages.emplace_back(msg["role"].get<std::string>(),
msg["content"].get<std::string>());
}
}
} catch (const nlohmann::json::exception& e) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"JSON parse failed: " + std::string(e.what()),
sp.service_request_id,
sp.source_xservice_addr);
LOG(ERROR) << "JSON parse failed: " << e.what();
return nullptr;
}
// 3. Call the message version of generate_request
Timer timer;
std::optional<std::string> formatted_prompt;
formatted_prompt =
chat_template_->apply(messages, sp.tools, sp.chat_template_kwargs);

if (!formatted_prompt.has_value()) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"Failed to construct prompt from messages",
sp.service_request_id,
sp.source_xservice_addr);
LOG(ERROR) << "Failed to construct prompt from messages";
return nullptr;
}

COUNTER_ADD(chat_template_latency_seconds, timer.elapsed_seconds());

prompt = std::move(formatted_prompt.value());
prompt_tokens = std::nullopt;
} else {
// Prompt is already a formatted chat template string, use directly
LOG(INFO) << "llm_master prompt (already formatted):" << prompt;
}
}
// encode the prompt
Timer timer;
std::vector<int> local_prompt_tokens;

if (prompt_tokens.has_value()) {
local_prompt_tokens = std::move(prompt_tokens.value());
} else {
if (!tokenizer_->encode(
prompt, &local_prompt_tokens, sp.add_special_tokens)) {
if (!tokenizer_->encode(prompt,
&local_prompt_tokens,
FLAGS_enable_mistral_prompt_to_message
? false
: sp.add_special_tokens,
FLAGS_enable_mistral_prompt_to_message ? 512 : 0)) {
LOG(ERROR) << "Failed to encode prompt: " << prompt;
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,
"Failed to encode prompt",
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/distributed_runtime/llm_master.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.
#include <string>
#include <vector>

#include "api_service/chat_json_parser.h"
#include "api_service/serving_mode.h"
#include "common/options.h"
#include "common/rate_limiter.h"
#include "framework/chat_template/chat_template.h"
Expand All @@ -31,7 +33,6 @@ limitations under the License.
#include "llm_engine.h"
#include "master.h"
#include "scheduler/continuous_scheduler.h"

namespace xllm {

class Call;
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ void Batch::process_sample_output(const RawForwardOutput& raw_output,
const auto sequences = get_sequences();
for (auto* seq : sequences) {
int64_t n_images = seq->get_mm_data().size();
if (n_images <= 0) {
if (!FLAGS_enable_mistral_prompt_to_message && n_images <= 0) {
continue;
}
std::vector<torch::Tensor> seq_mm_embeddings;
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/framework/config/model_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ DEFINE_bool(
false,
"Whether to decode both audio and video when the input is a video.");

DEFINE_bool(enable_mistral_prompt_to_message,
false,
"Whether to enable mistral prompt to message conversion.");

// NOTE: This is an experimental flag,
// it needs to be removed after the function is stable.
DEFINE_bool(use_cpp_chat_template,
Expand Down
13 changes: 12 additions & 1 deletion xllm/core/framework/tokenizer/fast_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ bool add_special_token_id(const std::string& token,

bool FastTokenizer::encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens) const {
bool add_special_tokens,
int32_t max_sequence_length) const {
TokenizerEncodeResult result;
tokenizers_encode(
handle_, text.data(), text.size(), add_special_tokens, &result);
Expand Down Expand Up @@ -103,6 +104,16 @@ bool FastTokenizer::encode(const std::string_view& text,
ids,
/*prepend=*/false);
}
// Add pad to max_sequence_length if configured
if (max_sequence_length > 0 && !tokenizer_args_.pad_token().empty()) {
const auto pad_id = token_to_id(tokenizer_args_.pad_token());
if (pad_id.has_value() &&
static_cast<int32_t>(ids->size()) < max_sequence_length) {
int32_t pad_count =
max_sequence_length - static_cast<int32_t>(ids->size());
ids->insert(ids->begin(), pad_count, pad_id.value());
}
}

return true;
}
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/fast_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class FastTokenizer : public Tokenizer {

bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override;
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override;

std::string decode(const Slice<int32_t>& ids,
bool skip_special_tokens) const override;
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/sentencepiece_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ bool SentencePieceTokenizer::encode_internal(const std::string_view& text,

bool SentencePieceTokenizer::encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens) const {
bool add_special_tokens,
int32_t max_sequence_length) const {
// prepend prefix tokens if exists
if (!prefix_token_ids_.empty()) {
ids->insert(
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/sentencepiece_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class SentencePieceTokenizer : public Tokenizer {

bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override;
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override;

std::string decode(const Slice<int32_t>& ids,
bool skip_special_tokens) const override;
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/tiktoken_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ void TiktokenTokenizer::encode_internal(const std::string_view& text,

bool TiktokenTokenizer::encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens) const {
bool add_special_tokens,
int32_t max_sequence_length) const {
// prepend prefix tokens if exists
if (!prefix_token_ids_.empty()) {
ids->insert(
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/tiktoken_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class TiktokenTokenizer : public Tokenizer {

bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override;
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override;

std::string decode(const Slice<int32_t>& ids,
bool skip_special_tokens) const override;
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Tokenizer {

virtual bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const {
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const {
return false;
}

Expand Down
6 changes: 4 additions & 2 deletions xllm/core/framework/tokenizer/tokenizer_proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ std::unique_ptr<Tokenizer> TokenizerProxy::clone() const {

bool TokenizerProxy::encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens) const {
return get_tls_tokenizer()->encode(text, ids, add_special_tokens);
bool add_special_tokens,
int32_t max_sequence_length) const {
return get_tls_tokenizer()->encode(
text, ids, add_special_tokens, max_sequence_length);
}

bool TokenizerProxy::encode(int64_t item_id,
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/tokenizer/tokenizer_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class TokenizerProxy : public Tokenizer {

bool encode(const std::string_view& text,
std::vector<int32_t>* ids,
bool add_special_tokens = true) const override;
bool add_special_tokens = true,
int32_t max_sequence_length = 0) const override;

bool encode(int64_t item_id, std::vector<int32_t>* token_ids) const override;

Expand Down
7 changes: 3 additions & 4 deletions xllm/core/layers/common/add_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,9 @@ void AddMatmulWeightTransposedImpl::load_state_dict(
if (state_dict.has("weight")) {
xllm::weight::load_weight(state_dict, "weight", weight_, weight_is_loaded_);
// weight need to be transposed when using addmm
if (with_bias_) {
torch::Tensor transposed = weight_.data().transpose(0, 1).contiguous();
weight_.set_data(transposed);
}

torch::Tensor transposed = weight_.data().transpose(0, 1).contiguous();
weight_.set_data(transposed);
}
if (with_bias_) {
weight::load_weight(state_dict, "bias", bias_, bias_is_loaded_);
Expand Down
11 changes: 11 additions & 0 deletions xllm/core/runtime/embed_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ std::optional<ForwardOutput> EmbedWorkerImpl::step(const ForwardInput& input) {
auto embeddings =
model_->pooler(hidden_states, sampling_params.selected_token_idxes);
sample_output.embeddings = embeddings;
if (FLAGS_enable_return_mm_full_embeddings) {
auto q_seq_len_vec = input.input_params.attention.host.q_seq_lens;
sample_output.mm_embeddings.reserve(q_seq_len_vec.size());
int32_t token_start_idx = 0;
for (auto seq_len : q_seq_len_vec) {
auto image_embed =
embeddings.slice(0, token_start_idx, token_start_idx + seq_len);
sample_output.mm_embeddings.emplace_back(image_embed);
token_start_idx += seq_len;
}
}
COUNTER_ADD(execution_latency_seconds_sampling, timer.elapsed_seconds());

// set sample output to output
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/scheduler/request_priority_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ class HeapQueue final : public RequestPriorityQueue {
class SetQueue final : public RequestPriorityQueue {
private:
using QueueType = std::set<std::shared_ptr<Request>, Comparator>;
QueueType queue_;
Comparator lower_priority_comparator_;
QueueType queue_;

public:
explicit SetQueue(Comparator lower_priority_comparator)
Expand Down