Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions xllm_service/chat_template/jinja_chat_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,29 @@ JinjaChatTemplate::JinjaChatTemplate(const TokenizerArgs& args) : args_(args) {
std::optional<std::string> JinjaChatTemplate::apply(
const ChatMessages& messages) const {
const std::vector<xllm_service::JsonTool> empty_tools;
return apply(messages, empty_tools);
const nlohmann::ordered_json chat_template_kwargs = nlohmann::json::object();
Comment thread
JimHsiung marked this conversation as resolved.
return apply(messages, empty_tools, chat_template_kwargs);
}

std::optional<std::string> JinjaChatTemplate::apply(
nlohmann::ordered_json& messages) const {
// Call the overloaded method with empty tools
nlohmann::ordered_json empty_tools = nlohmann::json::array();
return apply(messages, empty_tools);
const nlohmann::ordered_json chat_template_kwargs = nlohmann::json::object();
return apply(messages, empty_tools, chat_template_kwargs);
Comment thread
JimHsiung marked this conversation as resolved.
}

std::optional<std::string> JinjaChatTemplate::apply(
const ChatMessages& messages,
const std::vector<xllm_service::JsonTool>& json_tools) const {
const nlohmann::ordered_json chat_template_kwargs = nlohmann::json::object();
Comment thread
JimHsiung marked this conversation as resolved.
return apply(messages, json_tools, chat_template_kwargs);
}

std::optional<std::string> JinjaChatTemplate::apply(
const ChatMessages& messages,
const std::vector<xllm_service::JsonTool>& json_tools,
const nlohmann::ordered_json& chat_template_kwargs) const {
// convert the messages to json object
nlohmann::ordered_json messages_json = nlohmann::json::array();
for (const auto& message : messages) {
Expand Down Expand Up @@ -82,16 +92,25 @@ std::optional<std::string> JinjaChatTemplate::apply(
tools_json.push_back(tool_json);
}
// apply the template
return apply(messages_json, tools_json);
return apply(messages_json, tools_json, chat_template_kwargs);
}

std::optional<std::string> JinjaChatTemplate::apply(
nlohmann::ordered_json& messages,
const nlohmann::ordered_json& tools) const {
const nlohmann::ordered_json chat_template_kwargs = nlohmann::json::object();
Comment thread
JimHsiung marked this conversation as resolved.
return apply(messages, tools, chat_template_kwargs);
}

std::optional<std::string> JinjaChatTemplate::apply(
nlohmann::ordered_json& messages,
const nlohmann::ordered_json& tools,
const nlohmann::ordered_json& chat_template_kwargs) const {
minja::chat_template_inputs input;
input.messages = messages;
input.tools = tools;
input.add_generation_prompt = true;
input.extra_context = chat_template_kwargs;
minja::chat_template_options options;

return template_->apply(input, options);
Expand Down
10 changes: 10 additions & 0 deletions xllm_service/chat_template/jinja_chat_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,23 @@ class JinjaChatTemplate {
const ChatMessages& messages,
const std::vector<xllm_service::JsonTool>& json_tools) const;

std::optional<std::string> apply(
const ChatMessages& messages,
const std::vector<xllm_service::JsonTool>& json_tools,
const nlohmann::ordered_json& chat_template_kwargs) const;

// expose this function for testing
// apply the template to the values in the json object
std::optional<std::string> apply(nlohmann::ordered_json& messages) const;

std::optional<std::string> apply(nlohmann::ordered_json& messages,
const nlohmann::ordered_json& tools) const;

std::optional<std::string> apply(
nlohmann::ordered_json& messages,
const nlohmann::ordered_json& tools,
const nlohmann::ordered_json& chat_template_kwargs) const;

private:
Comment thread
JimHsiung marked this conversation as resolved.
nlohmann::ordered_json get_mm_content(const Message::MMContentVec& vec) const;

Expand Down
20 changes: 20 additions & 0 deletions xllm_service/chat_template/jinja_chat_template_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,24 @@ TEST(JinjaChatTemplate, OpenChatModel) {
EXPECT_EQ(result.value(), expected);
}

TEST(JinjaChatTemplate, ApplyChatTemplateKwargs) {
const std::string template_str =
"{% if enable_thinking %}<think>{% endif %}"
"{% for message in messages %}{{ message['content'] }}{% endfor %}";

nlohmann::ordered_json messages = {{{"role", "user"}, {"content", "hello"}}};
nlohmann::ordered_json chat_template_kwargs = {{"enable_thinking", false}};

TokenizerArgs args;
args.chat_template(template_str);
args.bos_token("");
args.eos_token("");
JinjaChatTemplate template_(args);

auto result = template_.apply(
messages, nlohmann::ordered_json::array(), chat_template_kwargs);
ASSERT_TRUE(result.has_value());
EXPECT_EQ(result.value(), "hello");
}

} // namespace xllm_service
4 changes: 4 additions & 0 deletions xllm_service/http_service/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,10 @@ void XllmHttpServiceImpl::ChatCompletions(
for (const auto& message : req_pb->messages()) {
service_request->messages.emplace_back(message.role(), message.content());
}
if (req_pb->has_chat_template_kwargs()) {
service_request->chat_template_kwargs =
proto_struct_to_json(req_pb->chat_template_kwargs());
}
service_request->tools = parse_tools_from_proto(req_pb->tools());
if (req_pb->has_tool_choice()) {
service_request->tool_choice = req_pb->tool_choice();
Expand Down
3 changes: 3 additions & 0 deletions xllm_service/request/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct Request {
// controls tool usage behavior, e.g. auto/none/required
std::string tool_choice = "auto";

// extra template context such as {"enable_thinking": false}
nlohmann::json chat_template_kwargs = nlohmann::json::object();

// token ids of prompt
std::vector<int32_t> token_ids;

Expand Down
3 changes: 2 additions & 1 deletion xllm_service/scheduler/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ bool Scheduler::schedule(std::shared_ptr<Request> request) {
const std::vector<JsonTool> empty_tools;
const std::vector<JsonTool>& tools_for_template =
request->tool_choice == "none" ? empty_tools : request->tools;
auto prompt = chat_template_->apply(request->messages, tools_for_template);
auto prompt = chat_template_->apply(
request->messages, tools_for_template, request->chat_template_kwargs);
if (!prompt.has_value()) {
LOG(ERROR) << "Failed to construct prompt from messages";
return false;
Expand Down
Loading