Skip to content
2 changes: 1 addition & 1 deletion third_party/xllm_ops
Submodule xllm_ops updated from 79eb46 to 409d1d
9 changes: 9 additions & 0 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,12 @@ struct ModelInputParams {
safe_to(graph_buffer.attn_mask, device, true);
params.graph_buffer.tiling_data =
safe_to(graph_buffer.tiling_data, device, true);
params.graph_buffer.xfa_q_cu_seq_lens =
safe_to(graph_buffer.xfa_q_cu_seq_lens, device, true);
params.graph_buffer.xfa_extra_tiling =
safe_to(graph_buffer.xfa_extra_tiling, device, true);
params.graph_buffer.xfa_attn_mask =
safe_to(graph_buffer.xfa_attn_mask, device, true);

// params for flashinfer
params.paged_kv_indptr = safe_to(paged_kv_indptr, device);
Expand Down Expand Up @@ -666,6 +672,9 @@ struct ModelInputParams {
struct GraphBuffer {
torch::Tensor attn_mask;
torch::Tensor tiling_data;
torch::Tensor xfa_q_cu_seq_lens;
torch::Tensor xfa_extra_tiling;
torch::Tensor xfa_attn_mask;
bool use_expanded_decode_for_spec_verify_attention = false;
torch::Tensor expanded_kv_seq_lens;
torch::Tensor expanded_block_tables;
Expand Down
1 change: 1 addition & 0 deletions xllm/core/kernels/npu/xllm_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cc_library(
beam_search.cpp
select_unshared_kv.cpp
beam_search_rec.cpp
x_flash_attention_infer.cpp
DEPS
atb
torch_npu
Expand Down
4 changes: 4 additions & 0 deletions xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ void beam_search_rec(const torch::Tensor& logprobs,
create_acltensor(&out_beam_count_prefix_sums_ids, out_beam_count_prefix_sums);
create_acltensor(&out_sequence_ids, out_sequence);

CHECK_GT(top_tokens.dim(), 0)
<< "beam_search_rec: top_tokens must have at least one dimension";
const int64_t top_k = top_tokens.size(-1);
uint64_t workspace_size = 0;
aclOpExecutor* executor = nullptr;
CHECK_ACL_SUCCESS(
Expand All @@ -91,6 +94,7 @@ void beam_search_rec(const torch::Tensor& logprobs,
top_logprobs_ids,
sequence_group_ids,
current_step,
top_k,
out_token_ids_ids,
out_token_index_ids,
out_log_probs_ids,
Expand Down
Loading