Skip to content

refactor: split forward inputs from model input params [3 / 3].#1469

Open
RobbieLeung wants to merge 1 commit into
jd-opensource:mainfrom
RobbieLeung:refactor/remove_input_params
Open

refactor: split forward inputs from model input params [3 / 3].#1469
RobbieLeung wants to merge 1 commit into
jd-opensource:mainfrom
RobbieLeung:refactor/remove_input_params

Conversation

@RobbieLeung
Copy link
Copy Markdown
Collaborator

Summary

This PR refactors the forward input path by separating the old ModelInputParams payload into structured input groups and moving execution-facing state into ForwardInput.

The goal is to make the runtime input flow clearer and more maintainable across batch builders, workers, executors, model adapters, and shared-memory transport, without changing the intended inference behavior.

What Changed

1. Split forward input data into structured groups

Refactored the previous flat input layout into explicit components, including:

  • batch metadata
  • attention inputs
  • embedding inputs
  • parallel inputs
  • block copy inputs
  • multimodal inputs
  • expert inputs
  • graph inputs
  • rec / dit-specific inputs

ForwardInput is now the main execution payload and carries both the tensor inputs (token_ids, positions) and the structured runtime/model input state.

2. Move execution code to the new ForwardInput interface

Updated the runtime execution path to consume ForwardInput directly, including:

  • worker implementations
  • speculative / MTP input preparation
  • CUDA graph executor
  • MLU graph executor
  • ACL graph executor
  • model forward interfaces and adapters

This removes the previous dependency on passing a separate ModelInputParams object through the execution stack.

3. Align attention metadata construction with the refactor

Refactored attention metadata building to consume explicit input sections instead of the old monolithic params object.

This keeps the metadata path consistent for:

  • regular eager execution
  • graph execution
  • decode / prefill branches
  • FlashInfer / xattention related metadata preparation

4. Keep serialization and decode builders in sync

Updated the shared-memory forward input serialization/deserialization flow and speculative decode input builders so they match the new structured input layout.

This includes:

  • forward input packing/unpacking
  • host/device tensor reconstruction
  • decode row context construction
  • speculative validation / draft input preparation

5. Update tests and call sites

Adjusted affected tests and call sites to use the new input layout, especially around:

  • CUDA graph executor behavior
  • MLU graph executor behavior
  • speculative decode input building
  • batch and framework-level input construction

Motivation

The previous input flow mixed execution tensors and model/runtime parameters in a single broad structure, which made the data contract hard to follow and harder to evolve safely.

This refactor improves:

  • readability of the forward input contract
  • separation of responsibilities between input domains
  • consistency across runtime backends
  • maintainability of serialization and graph execution paths

Scope

This change is primarily a structural refactor.
It is intended to preserve the existing runtime behavior while making the forward input pipeline easier to extend and reason about.

Validation

  • Updated unit tests and affected call sites to match the new interfaces
  • clang-format checks passed during commit

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the model input handling by introducing a new ForwardInput struct and replacing the previous ModelInputParams and RawForwardInput structures. This change standardizes input data transport across the codebase, including distributed and shared memory communication. The review comments correctly identify that the newly introduced structs (ForwardInputBufferPlan, ForwardInput, AttentionInput) contain member functions, which violates the repository's style guide requiring structs to be used only for plain data aggregation. Additionally, one comment correctly points out a violation regarding the use of plain int instead of fixed-width integers, and another highlights a potential regression in memory efficiency in input_write due to unnecessary copying.

Comment thread xllm/core/runtime/forward_params.h
Comment thread xllm/core/runtime/forward_params.h
Comment thread xllm/core/runtime/forward_params.h
Comment thread xllm/core/runtime/forward_params.h
Comment thread xllm/core/runtime/forward_shared_memory_manager.cpp
@RobbieLeung RobbieLeung changed the title refactor: optimize forward input device views. refactor: split forward inputs from model input params [3 / 3]. May 18, 2026
@RobbieLeung RobbieLeung force-pushed the refactor/remove_input_params branch from 4494340 to 42fa2d0 Compare May 19, 2026 07:24
- separate batch, attention, embedding, parallel and graph fields into structured forward input data.

- update runtime workers, graph executors, model adapters and tests to consume the new input layout.

- keep shared-memory serialization and speculative decode builders aligned with the refactored parameter flow.
@RobbieLeung RobbieLeung force-pushed the refactor/remove_input_params branch from 42fa2d0 to d5252b8 Compare May 19, 2026 08:22
@RobbieLeung
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the codebase by replacing the legacy ModelInputParams structure with a new ForwardInput structure across various models, layers, and test files. This significant change involves updating function signatures, member access patterns, and helper functions to integrate the new input structure throughout the codebase. I have no review feedback to provide.

CHECK(seq_idx < attention.host.q_seq_lens.size()) << "seq_idx out of range";
return attention.host.q_seq_lens[seq_idx];
#else
CHECK(seq_idx < attention.host.q_seq_lens.size() - 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Potential unsigned integer underflow if attention.host.q_seq_lens.size() is 0. Since size() returns size_t, 0 - 1 results in SIZE_MAX, causing the CHECK to pass for any non-negative seq_idx and leading to an out-of-bounds access at line 642. The condition should be restructured to avoid subtraction from an unsigned type.

Suggested change
CHECK(seq_idx < attention.host.q_seq_lens.size() - 1)
CHECK(seq_idx + 1 < static_cast<int32_t>(attention.host.q_seq_lens.size()))

inputs.skip_sampling_for_logits_only = skip_sampling_for_logits_only;
}

void copy_model_inputs_to(ForwardInput& inputs) const {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ForwardInput struct contains multiple member functions (e.g., copy_model_inputs_to, get_q_seq_len, onerec_params), which violates the repository style guide. Rule 160 states that structs must not have member functions and are for plain data aggregation only. If methods are required, ForwardInput should be defined as a class.

References
  1. Structs must not have member functions. If you need methods, use a class. Structs are for plain data aggregation only. (link)

Comment on lines +285 to +293
int batch_size = input.attention.host.q_seq_lens.size();
if (batch_size > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(batch_size);

for (int j = 0; j < batch_size; j++) {
int start = input_params.attention.host.kv_seq_lens[j] -
input_params.attention.host.q_seq_lens[j];
int end = input_params.attention.host.kv_seq_lens[j];
int start = input.attention.host.kv_seq_lens[j] -
input.attention.host.q_seq_lens[j];
int end = input.attention.host.kv_seq_lens[j];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Multiple usages of plain int for variables and loop counters violate the style guide requirement for fixed-width integers.

Suggested change
int batch_size = input.attention.host.q_seq_lens.size();
if (batch_size > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(batch_size);
for (int j = 0; j < batch_size; j++) {
int start = input_params.attention.host.kv_seq_lens[j] -
input_params.attention.host.q_seq_lens[j];
int end = input_params.attention.host.kv_seq_lens[j];
int start = input.attention.host.kv_seq_lens[j] -
input.attention.host.q_seq_lens[j];
int end = input.attention.host.kv_seq_lens[j];
const int32_t batch_size = static_cast<int32_t>(input.attention.host.q_seq_lens.size());
if (batch_size > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(batch_size);
for (int32_t j = 0; j < batch_size; ++j) {
int32_t start = input.attention.host.kv_seq_lens[j] -
input.attention.host.q_seq_lens[j];
int32_t end = input.attention.host.kv_seq_lens[j];
References
  1. Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)

Comment on lines +122 to +128
int num_sequences = input.meta.num_sequences;
if (num_sequences > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(num_sequences);

for (int j = 0; j < num_sequences; j++) {
auto mask = attn_mask_.gen_append_mask(
input_params.attention.host.q_seq_lens[j],
input_params.attention.host.kv_seq_lens[j],
max_kv_seq,
cos_pos.dtype().toScalarType(),
cos_pos.device());
auto mask =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Plain int is used for sequence lengths and loop counters, which should be replaced with fixed-width integers per the style guide.

Suggested change
int num_sequences = input.meta.num_sequences;
if (num_sequences > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(num_sequences);
for (int j = 0; j < num_sequences; j++) {
auto mask = attn_mask_.gen_append_mask(
input_params.attention.host.q_seq_lens[j],
input_params.attention.host.kv_seq_lens[j],
max_kv_seq,
cos_pos.dtype().toScalarType(),
cos_pos.device());
auto mask =
int32_t max_kv_seq = input.meta.kv_max_seq_len;
int32_t num_sequences = input.meta.num_sequences;
if (num_sequences > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(num_sequences);
for (int32_t j = 0; j < num_sequences; ++j) {
References
  1. Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)

Comment on lines 102 to 103
torch::Tensor& cu_seq_len,
std::vector<int>& cu_seq_len_vec,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Function parameters should use fixed-width integers instead of plain int to adhere to the project's type system rules.

                        std::vector<int32_t>& cu_seq_len_vec,
                        int32_t node_id) {
References
  1. Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)

ModelInputParams& input_params_new =
const_cast<ModelInputParams&>(input_params);
torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu();
std::vector<int> cu_seqlens_vec(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Use std::vector<int32_t> instead of std::vector<int> to comply with the project's fixed-width integer requirement.

Suggested change
std::vector<int> cu_seqlens_vec(
std::vector<int32_t> cu_seqlens_vec(
References
  1. Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)

const torch::Tensor& v_cache,
uint32_t padded_num_token,
bool return_capture_input = false);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parameter name padded_num_token is singular here but plural (padded_num_tokens) in the implementation. It should be consistent with the implementation and other similar classes (like CudaGraphPersistentParam).

                                     uint32_t padded_num_tokens,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant