Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/api_service/usage_json_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <gtest/gtest.h>
#include <json2pb/pb_to_json.h>

#include <nlohmann/json.hpp>
#include <string>

#include "api_service/utils.h"
#include "chat.pb.h"

namespace xllm {
namespace {

TEST(UsageJsonTest, ChatUsageSerializesOpenAICachedTokensField) {
Usage usage;
usage.num_prompt_tokens = 1024;
usage.num_generated_tokens = 50;
usage.num_total_tokens = 1074;
usage.num_cached_tokens = 896;

proto::ChatResponse response;
api_service::set_proto_usage(response.mutable_usage(), usage);

json2pb::Pb2JsonOptions options;
options.bytes_to_base64 = false;
options.jsonify_empty_array = true;
options.always_print_primitive_fields = true;

std::string json_text;
std::string error_message;
ASSERT_TRUE(json2pb::ProtoMessageToJson(
response, &json_text, options, &error_message))
<< error_message;

nlohmann::json json = nlohmann::json::parse(json_text);
ASSERT_TRUE(json.contains("usage"));
EXPECT_EQ(json["usage"]["prompt_tokens"], 1024);
EXPECT_EQ(json["usage"]["completion_tokens"], 50);
EXPECT_EQ(json["usage"]["total_tokens"], 1074);
ASSERT_TRUE(json["usage"].contains("prompt_tokens_details"));
EXPECT_EQ(json["usage"]["prompt_tokens_details"]["cached_tokens"], 896);
EXPECT_EQ(json["usage"]["prompt_tokens_details"]["audio_tokens"], 0);
ASSERT_TRUE(json["usage"].contains("completion_tokens_details"));
EXPECT_EQ(json["usage"]["completion_tokens_details"]["reasoning_tokens"], 0);
EXPECT_EQ(json["usage"]["completion_tokens_details"]["audio_tokens"], 0);
EXPECT_EQ(json["usage"]["completion_tokens_details"].size(), 2);
}

TEST(UsageJsonTest, ChatUsagePrintsZeroCachedTokens) {
Usage usage;
usage.num_prompt_tokens = 12;
usage.num_generated_tokens = 3;
usage.num_total_tokens = 15;
usage.num_cached_tokens = 0;

proto::ChatResponse response;
api_service::set_proto_usage(response.mutable_usage(), usage);

json2pb::Pb2JsonOptions options;
options.bytes_to_base64 = false;
options.jsonify_empty_array = true;
options.always_print_primitive_fields = true;

std::string json_text;
std::string error_message;
ASSERT_TRUE(json2pb::ProtoMessageToJson(
response, &json_text, options, &error_message))
<< error_message;

nlohmann::json json = nlohmann::json::parse(json_text);
ASSERT_TRUE(json["usage"].contains("prompt_tokens_details"));
EXPECT_EQ(json["usage"]["prompt_tokens_details"]["cached_tokens"], 0);
EXPECT_EQ(json["usage"]["prompt_tokens_details"]["audio_tokens"], 0);
ASSERT_TRUE(json["usage"].contains("completion_tokens_details"));
EXPECT_EQ(json["usage"]["completion_tokens_details"]["reasoning_tokens"], 0);
EXPECT_EQ(json["usage"]["completion_tokens_details"]["audio_tokens"], 0);
EXPECT_EQ(json["usage"]["completion_tokens_details"].size(), 2);
}

} // namespace
} // namespace xllm
48 changes: 48 additions & 0 deletions tests/core/framework/block/block_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,52 @@ TEST(BlockManagerPoolTest, SequenceCopyDoesNotReuseSingleBlockSlot) {
EXPECT_NE(GetSingleBlockIdOrFail(clone), GetSingleBlockIdOrFail(src));
}

TEST(BlockManagerPoolTest, PrefixCacheHitsAreReportedAsCachedTokens) {
ScopedValue<int32_t> max_seqs_guard(
&SchedulerConfig::get_instance().max_seqs_per_batch(), 2);

BlockManagerPool::Options options;
options.num_blocks(16).host_num_blocks(0).block_size(4).enable_prefix_cache(
true);
BlockManagerPool pool(options, /*dp_size=*/1);

{
Sequence seq1 =
MakeSequence(0, /*prompt_tokens=*/{1, 2, 3, 4, 5, 6, 7, 8, 9});
ASSERT_TRUE(pool.allocate(&seq1));
seq1.kv_state().set_kv_cache_tokens_num(seq1.num_prompt_tokens());
EXPECT_EQ(seq1.num_cached_tokens(), 0);
pool.cache(&seq1);

Sequence seq2 =
MakeSequence(0, /*prompt_tokens=*/{1, 2, 3, 4, 5, 6, 7, 8, 10});
ASSERT_TRUE(pool.allocate(&seq2));

EXPECT_EQ(seq2.num_cached_tokens(), 8);
EXPECT_LE(seq2.num_cached_tokens(), seq2.num_prompt_tokens());
EXPECT_EQ(seq2.num_cached_tokens() % options.block_size(), 0);
}

int32_t dp_rank = 0;
auto eviction_blocks =
pool.allocate((options.num_blocks() - 1) * options.block_size(), dp_rank);
EXPECT_EQ(eviction_blocks.size(), options.num_blocks() - 1);
EXPECT_EQ(pool.num_blocks_in_prefix_cache()[0], 0);
}

TEST(BlockManagerPoolTest, PrefixCacheDisabledReportsZeroCachedTokens) {
ScopedValue<int32_t> max_seqs_guard(
&SchedulerConfig::get_instance().max_seqs_per_batch(), 2);

BlockManagerPool::Options options;
options.num_blocks(16).host_num_blocks(0).block_size(4).enable_prefix_cache(
false);
BlockManagerPool pool(options, /*dp_size=*/1);

Sequence seq = MakeSequence(0, /*prompt_tokens=*/{1, 2, 3, 4, 5, 6, 7, 8, 9});
ASSERT_TRUE(pool.allocate(&seq));

EXPECT_EQ(seq.num_cached_tokens(), 0);
}

} // namespace xllm
8 changes: 2 additions & 6 deletions xllm/api_service/chat_service_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,7 @@ bool send_delta_to_client_brpc(
response.set_created(created_time);
response.set_model(model);
auto* proto_usage = response.mutable_usage();
proto_usage->set_prompt_tokens(usage.num_prompt_tokens);
proto_usage->set_completion_tokens(usage.num_generated_tokens);
proto_usage->set_total_tokens(usage.num_total_tokens);
api_service::set_proto_usage(proto_usage, usage);
if (!call->write(response)) {
return false;
}
Expand Down Expand Up @@ -460,9 +458,7 @@ bool send_result_to_client_brpc(std::shared_ptr<ChatCall> call,
if (req_output.usage.has_value()) {
const auto& usage = req_output.usage.value();
auto* proto_usage = response.mutable_usage();
proto_usage->set_prompt_tokens(usage.num_prompt_tokens);
proto_usage->set_completion_tokens(usage.num_generated_tokens);
proto_usage->set_total_tokens(usage.num_total_tokens);
api_service::set_proto_usage(proto_usage, usage);
}

return call->write_and_finish(response);
Expand Down
21 changes: 20 additions & 1 deletion xllm/api_service/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ limitations under the License.

#include "api_service/stream_output_parser.h"
#include "chat.pb.h"
#include "common.pb.h"
#include "core/common/types.h"
#include "core/framework/request/request_output.h"
#include "function_call/function_call.h"

namespace xllm {
Expand All @@ -36,6 +38,23 @@ namespace api_service {
// Check for unstreamed tool arguments and send them using the provided sender
// This is shared between Chat API and Anthropic API implementations
using SendFunc = std::function<bool(const std::string&, int)>;

inline void set_proto_usage(proto::Usage* proto_usage,
const xllm::Usage& usage) {
CHECK(proto_usage != nullptr);
proto_usage->set_prompt_tokens(usage.num_prompt_tokens);
proto_usage->set_completion_tokens(usage.num_generated_tokens);
proto_usage->set_total_tokens(usage.num_total_tokens);
auto* prompt_tokens_details = proto_usage->mutable_prompt_tokens_details();
prompt_tokens_details->set_cached_tokens(usage.num_cached_tokens);
prompt_tokens_details->set_audio_tokens(0);

auto* completion_tokens_details =
proto_usage->mutable_completion_tokens_details();
completion_tokens_details->set_reasoning_tokens(0);
completion_tokens_details->set_audio_tokens(0);
}

inline bool check_for_unstreamed_tool_args(
std::shared_ptr<StreamOutputParser> stream_parser,
size_t index,
Expand Down Expand Up @@ -154,4 +173,4 @@ inline nlohmann::json struct_to_json(
}

} // namespace api_service
} // namespace xllm
} // namespace xllm
1 change: 1 addition & 0 deletions xllm/c_api/internal/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ XLLM_Response* build_success_response(const InferenceType& inference_type,
response->usage.prompt_tokens = usage.num_prompt_tokens;
response->usage.completion_tokens = usage.num_generated_tokens;
response->usage.total_tokens = usage.num_total_tokens;
response->usage.cached_tokens = usage.num_cached_tokens;
}

return response;
Expand Down
2 changes: 2 additions & 0 deletions xllm/c_api/test/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,14 @@ void PbToXllmUsage(const c_api_test::XLLM_Usage& pb, XLLM_Usage* out) {
out->prompt_tokens = pb.prompt_tokens();
out->completion_tokens = pb.completion_tokens();
out->total_tokens = pb.total_tokens();
out->cached_tokens = pb.cached_tokens();
}

void XllmUsageToPb(const XLLM_Usage& in, c_api_test::XLLM_Usage* pb) {
pb->set_prompt_tokens(in.prompt_tokens);
pb->set_completion_tokens(in.completion_tokens);
pb->set_total_tokens(in.total_tokens);
pb->set_cached_tokens(in.cached_tokens);
}

void PbToXllmLogProbs(const c_api_test::XLLM_LogProbs& pb,
Expand Down
3 changes: 2 additions & 1 deletion xllm/c_api/test/xllm_test.proto
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ message XLLM_Usage {
int32 prompt_tokens = 1;
int32 completion_tokens = 2;
int32 total_tokens = 3;
int32 cached_tokens = 4;
}

// --- XLLM_LogProb / XLLM_LogProbs ---
Expand Down Expand Up @@ -221,4 +222,4 @@ message XLLM_DumpRecord {
// --backend: rec -> xllm_rec_*; llm -> xllm_llm_*).
service XllmRecCapiService {
rpc Inference(XLLM_Request) returns (XLLM_Response);
}
}
3 changes: 3 additions & 0 deletions xllm/c_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ typedef struct XLLM_CAPI_EXPORT XLLM_Usage {

/** Total tokens used (prompt + completion) */
int32_t total_tokens;

/** Number of prompt tokens served from prefix cache */
int32_t cached_tokens;
} XLLM_Usage;

/**
Expand Down
1 change: 1 addition & 0 deletions xllm/cc_api/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ XLLM_Response build_success_response(const RequestOutput& output,
response.usage.prompt_tokens = usage.num_prompt_tokens;
response.usage.completion_tokens = usage.num_generated_tokens;
response.usage.total_tokens = usage.num_total_tokens;
response.usage.cached_tokens = usage.num_cached_tokens;
}

return response;
Expand Down
3 changes: 3 additions & 0 deletions xllm/cc_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ struct XLLM_CAPI_EXPORT XLLM_Usage {

// The total number of tokens used in the request (prompt + completion).
int32_t total_tokens;

// The number of prompt tokens served from prefix cache.
int32_t cached_tokens;
};

struct XLLM_CAPI_EXPORT XLLM_LogProbData {
Expand Down
8 changes: 8 additions & 0 deletions xllm/core/framework/request/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ limitations under the License.
#include <absl/time/time.h>
#include <glog/logging.h>

#include <algorithm>
#include <cstdint>
#include <limits>
#include <string>
#include <vector>

Expand Down Expand Up @@ -99,6 +101,7 @@ void Request::log_statistic(double total_latency) {
<< "finish_reason: "
<< seq->finish_reason().to_string().value_or("") << ", "
<< "prompt_tokens: " << seq->num_prompt_tokens() << ", "
<< "cached_tokens: " << seq->num_cached_tokens() << ", "
<< "generated_tokens: " << gen_tokens << ", " << std::fixed
<< std::setprecision(1) << "ttft: " << ttft * 1000 << "ms, "
<< "total_latency: " << total_latency * 1000 << "ms, "
Expand Down Expand Up @@ -155,13 +158,18 @@ RequestOutput Request::generate_output(const Tokenizer& tokenizer,
// summarize statistics for all sequences
Usage usage;
usage.num_prompt_tokens = state_.prompt_tokens.size();
size_t num_cached_tokens = 0;
for (const auto& seq : sequences()) {
usage.num_generated_tokens += seq->num_generated_tokens();
num_cached_tokens = std::max(num_cached_tokens, seq->num_cached_tokens());
// NOTE: Avoid counting the extra execution step in overlap scenario.
if (state_.enable_schedule_overlap) {
usage.num_generated_tokens--;
}
}
CHECK_LE(num_cached_tokens,
static_cast<size_t>(std::numeric_limits<int32_t>::max()));
usage.num_cached_tokens = static_cast<int32_t>(num_cached_tokens);
usage.num_total_tokens = usage.num_prompt_tokens + usage.num_generated_tokens;

RequestOutput output;
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/request/request_output.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct Usage {

// the total number of tokens used in the request (prompt + completion).
int32_t num_total_tokens = 0;

// the number of prompt tokens served from prefix cache.
int32_t num_cached_tokens = 0;
};

struct LogProbData {
Expand Down
34 changes: 34 additions & 0 deletions xllm/core/framework/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ Sequence::Sequence(const Sequence& other)
num_tokens_(other.num_tokens_),
token_to_count_map_(other.token_to_count_map_),
num_prompt_tokens_(other.num_prompt_tokens_),
num_cached_tokens_(other.num_cached_tokens_),
onerec_state_(other.onerec_state_),
volatile_num_prompt_tokens_(other.volatile_num_prompt_tokens_),
request_id_(other.request_id_),
Expand Down Expand Up @@ -690,8 +691,38 @@ void Sequence::add_host_kv_blocks(const std::vector<Block>& blocks) {
host_kv_state_.add_kv_blocks(blocks);
}

size_t Sequence::current_num_cached_tokens() const {
size_t cached_tokens = std::max(kv_state_.shared_kv_tokens_num(),
host_kv_state_.shared_kv_tokens_num());
if (cached_tokens <= num_prompt_tokens_) {
return cached_tokens;
}

size_t block_size = 0;
if (kv_state_.shared_kv_blocks_num() > 0 && kv_state_.num_kv_blocks() > 0) {
block_size = kv_state_.kv_blocks()[0].size();
} else if (host_kv_state_.shared_kv_blocks_num() > 0 &&
host_kv_state_.num_kv_blocks() > 0) {
block_size = host_kv_state_.kv_blocks()[0].size();
}
if (block_size == 0) {
return 0;
}
return (num_prompt_tokens_ / block_size) * block_size;
}
Comment on lines +694 to +712
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for calculating cached tokens under-reports the count when the prompt is not block-aligned and the cache match covers the entire prompt.

For example, if block_size is 16 and num_prompt_tokens_ is 10, and the prefix cache matches the first block (16 tokens), cached_tokens will be 16. The current logic (16 <= 10) is false, and it returns (10 / 16) * 16 = 0, even though all 10 prompt tokens were served from the cache.

Similarly, if num_prompt_tokens_ is 20 and 2 blocks are shared (32 tokens), it returns (20 / 16) * 16 = 16, missing the 4 tokens in the second block.

The correct value should be the minimum of the cached tokens and the prompt tokens. This also simplifies the implementation by removing the need to calculate block_size.

size_t Sequence::current_num_cached_tokens() const {
  size_t cached_tokens = std::max(kv_state_.shared_kv_tokens_num(),
                                  host_kv_state_.shared_kv_tokens_num());
  return std::min(cached_tokens, num_prompt_tokens_);
}


void Sequence::record_cached_tokens() {
num_cached_tokens_ =
std::max(num_cached_tokens_, current_num_cached_tokens());
}

size_t Sequence::num_cached_tokens() const {
return std::max(num_cached_tokens_, current_num_cached_tokens());
}

// release all cache blocks
void Sequence::reset() {
record_cached_tokens();
kv_state_.reset();
host_kv_state_.reset();
timer_.reset();
Expand All @@ -702,10 +733,12 @@ void Sequence::reset() {

void Sequence::add_shared_kv_blocks(std::vector<Block>&& blocks) {
kv_state_.add_shared_kv_blocks(std::move(blocks), num_tokens_);
record_cached_tokens();
}

void Sequence::add_shared_host_kv_blocks(std::vector<Block>&& blocks) {
host_kv_state_.add_shared_kv_blocks(std::move(blocks), num_tokens_);
record_cached_tokens();
}

bool Sequence::finished() const {
Expand Down Expand Up @@ -805,6 +838,7 @@ bool Sequence::update_prefetch_result(uint32_t timeout, uint32_t& success_cnt) {
host_kv_state_.incr_kv_cache_tokens_num(
success_cnt * host_kv_state_.kv_blocks()[0].size());
host_kv_state_.incr_shared_kv_blocks_num(success_cnt);
record_cached_tokens();
}
prefetch_results_.clear();
return true;
Expand Down
Loading
Loading