refactor: split forward inputs from model input params [3 / 3].#1469
refactor: split forward inputs from model input params [3 / 3].#1469RobbieLeung wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the model input handling by introducing a new ForwardInput struct and replacing the previous ModelInputParams and RawForwardInput structures. This change standardizes input data transport across the codebase, including distributed and shared memory communication. The review comments correctly identify that the newly introduced structs (ForwardInputBufferPlan, ForwardInput, AttentionInput) contain member functions, which violates the repository's style guide requiring structs to be used only for plain data aggregation. Additionally, one comment correctly points out a violation regarding the use of plain int instead of fixed-width integers, and another highlights a potential regression in memory efficiency in input_write due to unnecessary copying.
4494340 to
42fa2d0
Compare
- separate batch, attention, embedding, parallel and graph fields into structured forward input data. - update runtime workers, graph executors, model adapters and tests to consume the new input layout. - keep shared-memory serialization and speculative decode builders aligned with the refactored parameter flow.
42fa2d0 to
d5252b8
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the codebase by replacing the legacy ModelInputParams structure with a new ForwardInput structure across various models, layers, and test files. This significant change involves updating function signatures, member access patterns, and helper functions to integrate the new input structure throughout the codebase. I have no review feedback to provide.
| CHECK(seq_idx < attention.host.q_seq_lens.size()) << "seq_idx out of range"; | ||
| return attention.host.q_seq_lens[seq_idx]; | ||
| #else | ||
| CHECK(seq_idx < attention.host.q_seq_lens.size() - 1) |
There was a problem hiding this comment.
Potential unsigned integer underflow if attention.host.q_seq_lens.size() is 0. Since size() returns size_t, 0 - 1 results in SIZE_MAX, causing the CHECK to pass for any non-negative seq_idx and leading to an out-of-bounds access at line 642. The condition should be restructured to avoid subtraction from an unsigned type.
| CHECK(seq_idx < attention.host.q_seq_lens.size() - 1) | |
| CHECK(seq_idx + 1 < static_cast<int32_t>(attention.host.q_seq_lens.size())) |
| inputs.skip_sampling_for_logits_only = skip_sampling_for_logits_only; | ||
| } | ||
|
|
||
| void copy_model_inputs_to(ForwardInput& inputs) const { |
There was a problem hiding this comment.
The ForwardInput struct contains multiple member functions (e.g., copy_model_inputs_to, get_q_seq_len, onerec_params), which violates the repository style guide. Rule 160 states that structs must not have member functions and are for plain data aggregation only. If methods are required, ForwardInput should be defined as a class.
References
- Structs must not have member functions. If you need methods, use a class. Structs are for plain data aggregation only. (link)
| int batch_size = input.attention.host.q_seq_lens.size(); | ||
| if (batch_size > 0) { | ||
| std::vector<torch::Tensor> req_mask_vec; | ||
| req_mask_vec.reserve(batch_size); | ||
|
|
||
| for (int j = 0; j < batch_size; j++) { | ||
| int start = input_params.attention.host.kv_seq_lens[j] - | ||
| input_params.attention.host.q_seq_lens[j]; | ||
| int end = input_params.attention.host.kv_seq_lens[j]; | ||
| int start = input.attention.host.kv_seq_lens[j] - | ||
| input.attention.host.q_seq_lens[j]; | ||
| int end = input.attention.host.kv_seq_lens[j]; |
There was a problem hiding this comment.
Multiple usages of plain int for variables and loop counters violate the style guide requirement for fixed-width integers.
| int batch_size = input.attention.host.q_seq_lens.size(); | |
| if (batch_size > 0) { | |
| std::vector<torch::Tensor> req_mask_vec; | |
| req_mask_vec.reserve(batch_size); | |
| for (int j = 0; j < batch_size; j++) { | |
| int start = input_params.attention.host.kv_seq_lens[j] - | |
| input_params.attention.host.q_seq_lens[j]; | |
| int end = input_params.attention.host.kv_seq_lens[j]; | |
| int start = input.attention.host.kv_seq_lens[j] - | |
| input.attention.host.q_seq_lens[j]; | |
| int end = input.attention.host.kv_seq_lens[j]; | |
| const int32_t batch_size = static_cast<int32_t>(input.attention.host.q_seq_lens.size()); | |
| if (batch_size > 0) { | |
| std::vector<torch::Tensor> req_mask_vec; | |
| req_mask_vec.reserve(batch_size); | |
| for (int32_t j = 0; j < batch_size; ++j) { | |
| int32_t start = input.attention.host.kv_seq_lens[j] - | |
| input.attention.host.q_seq_lens[j]; | |
| int32_t end = input.attention.host.kv_seq_lens[j]; |
References
- Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)
| int num_sequences = input.meta.num_sequences; | ||
| if (num_sequences > 0) { | ||
| std::vector<torch::Tensor> req_mask_vec; | ||
| req_mask_vec.reserve(num_sequences); | ||
|
|
||
| for (int j = 0; j < num_sequences; j++) { | ||
| auto mask = attn_mask_.gen_append_mask( | ||
| input_params.attention.host.q_seq_lens[j], | ||
| input_params.attention.host.kv_seq_lens[j], | ||
| max_kv_seq, | ||
| cos_pos.dtype().toScalarType(), | ||
| cos_pos.device()); | ||
| auto mask = |
There was a problem hiding this comment.
Plain int is used for sequence lengths and loop counters, which should be replaced with fixed-width integers per the style guide.
| int num_sequences = input.meta.num_sequences; | |
| if (num_sequences > 0) { | |
| std::vector<torch::Tensor> req_mask_vec; | |
| req_mask_vec.reserve(num_sequences); | |
| for (int j = 0; j < num_sequences; j++) { | |
| auto mask = attn_mask_.gen_append_mask( | |
| input_params.attention.host.q_seq_lens[j], | |
| input_params.attention.host.kv_seq_lens[j], | |
| max_kv_seq, | |
| cos_pos.dtype().toScalarType(), | |
| cos_pos.device()); | |
| auto mask = | |
| int32_t max_kv_seq = input.meta.kv_max_seq_len; | |
| int32_t num_sequences = input.meta.num_sequences; | |
| if (num_sequences > 0) { | |
| std::vector<torch::Tensor> req_mask_vec; | |
| req_mask_vec.reserve(num_sequences); | |
| for (int32_t j = 0; j < num_sequences; ++j) { |
References
- Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)
| torch::Tensor& cu_seq_len, | ||
| std::vector<int>& cu_seq_len_vec, |
There was a problem hiding this comment.
Function parameters should use fixed-width integers instead of plain int to adhere to the project's type system rules.
std::vector<int32_t>& cu_seq_len_vec,
int32_t node_id) {References
- Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)
| ModelInputParams& input_params_new = | ||
| const_cast<ModelInputParams&>(input_params); | ||
| torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); | ||
| std::vector<int> cu_seqlens_vec( |
There was a problem hiding this comment.
Use std::vector<int32_t> instead of std::vector<int> to comply with the project's fixed-width integer requirement.
| std::vector<int> cu_seqlens_vec( | |
| std::vector<int32_t> cu_seqlens_vec( |
References
- Use fixed-width integers (int32_t, int64_t) instead of plain int. (link)
| const torch::Tensor& v_cache, | ||
| uint32_t padded_num_token, | ||
| bool return_capture_input = false); | ||
|
|
Summary
This PR refactors the forward input path by separating the old
ModelInputParamspayload into structured input groups and moving execution-facing state intoForwardInput.The goal is to make the runtime input flow clearer and more maintainable across batch builders, workers, executors, model adapters, and shared-memory transport, without changing the intended inference behavior.
What Changed
1. Split forward input data into structured groups
Refactored the previous flat input layout into explicit components, including:
ForwardInputis now the main execution payload and carries both the tensor inputs (token_ids,positions) and the structured runtime/model input state.2. Move execution code to the new
ForwardInputinterfaceUpdated the runtime execution path to consume
ForwardInputdirectly, including:This removes the previous dependency on passing a separate
ModelInputParamsobject through the execution stack.3. Align attention metadata construction with the refactor
Refactored attention metadata building to consume explicit input sections instead of the old monolithic params object.
This keeps the metadata path consistent for:
4. Keep serialization and decode builders in sync
Updated the shared-memory forward input serialization/deserialization flow and speculative decode input builders so they match the new structured input layout.
This includes:
5. Update tests and call sites
Adjusted affected tests and call sites to use the new input layout, especially around:
Motivation
The previous input flow mixed execution tensors and model/runtime parameters in a single broad structure, which made the data contract hard to follow and harder to evolve safely.
This refactor improves:
Scope
This change is primarily a structural refactor.
It is intended to preserve the existing runtime behavior while making the forward input pipeline easier to extend and reason about.
Validation
clang-formatchecks passed during commit