From d5252b8957df1aad1cd7cbbefd231b7407ee02c8 Mon Sep 17 00:00:00 2001 From: liangzhiwei20 Date: Mon, 18 May 2026 13:55:00 +0800 Subject: [PATCH] refactor: split forward inputs from model input params. - separate batch, attention, embedding, parallel and graph fields into structured forward input data. - update runtime workers, graph executors, model adapters and tests to consume the new input layout. - keep shared-memory serialization and speculative decode builders aligned with the refactored parameter flow. --- tests/core/framework/batch/batch_test.cpp | 125 +++---- tests/core/framework/hf_model_loader_test.cpp | 10 +- .../mlu/deepseek_v2_decoder_layer_test.cpp | 55 +-- .../mlu/deepseek_v2_sparse_moe_block_test.cpp | 68 ++-- tests/core/layers/mlu/dp_utils_test.cpp | 12 +- .../mlu/qwen2_vision_attention_test.cpp | 6 +- .../core/runtime/acl_graph_executor_test.cpp | 128 +++---- .../core/runtime/cuda_graph_executor_test.cpp | 151 +++++---- .../core/runtime/mlu_graph_executor_test.cpp | 303 +++++------------ .../core/runtime/spec_input_builder_test.cpp | 24 +- .../scheduler/continuous_scheduler_test.cpp | 7 +- xllm/core/distributed_runtime/dit_engine.cpp | 2 +- xllm/core/distributed_runtime/llm_engine.cpp | 29 +- xllm/core/distributed_runtime/rec_engine.cpp | 21 +- xllm/core/distributed_runtime/vlm_engine.cpp | 21 +- .../framework/batch/batch_input_builder.cpp | 4 +- .../batch/onerec_batch_input_builder.cpp | 2 +- .../onerec_xattention_batch_input_builder.cpp | 13 +- .../rec_multi_round_batch_input_builder.cpp | 2 +- .../hierarchy_kv_cache_transfer.cpp | 16 +- .../hierarchy_kv_cache_transfer.h | 4 +- xllm/core/framework/model/causal_lm.h | 18 +- xllm/core/framework/model/causal_vlm.h | 21 +- xllm/core/framework/model/mm_embedding_vlm.h | 12 +- .../core/framework/model/model_input_params.h | 185 ---------- xllm/core/framework/model/rec_causal_lm.h | 8 +- xllm/core/layers/common/attention_metadata.h | 2 +- .../common/attention_metadata_builder.cpp | 164 +++++---- .../common/attention_metadata_builder.h | 28 +- xllm/core/layers/common/dp_utils.cpp | 21 +- xllm/core/layers/common/dp_utils.h | 6 +- xllm/core/layers/common/fused_moe.cpp | 2 +- xllm/core/layers/common/fused_moe.h | 2 +- .../layers/common/oxygen_vision_attention.cpp | 3 +- .../layers/common/oxygen_vision_attention.h | 3 +- .../layers/common/qwen2_vision_attention.cpp | 3 +- .../layers/common/qwen2_vision_attention.h | 3 +- xllm/core/layers/cuda/fused_moe.cpp | 2 +- xllm/core/layers/cuda/fused_moe.h | 2 +- xllm/core/layers/ilu/fused_moe.cpp | 10 +- xllm/core/layers/ilu/fused_moe.h | 2 +- .../mlu/deepseek_v2_decoder_layer_impl.cpp | 31 +- .../mlu/deepseek_v2_decoder_layer_impl.h | 10 +- .../mlu/deepseek_v2_sparse_moe_block.cpp | 28 +- .../layers/mlu/deepseek_v2_sparse_moe_block.h | 8 +- xllm/core/layers/mlu/fused_moe.cpp | 4 +- xllm/core/layers/mlu/fused_moe.h | 2 +- .../core/layers/mlu/qwen3_5_decoder_layer.cpp | 18 +- xllm/core/layers/mlu/qwen3_5_decoder_layer.h | 6 +- xllm/core/layers/musa/attention.cpp | 20 +- xllm/core/layers/musa/musa_layer_base.h | 4 +- .../musa/musa_qwen3_decoder_layer_impl.cpp | 5 +- .../musa/musa_qwen3_decoder_layer_impl.h | 7 +- .../npu_deepseek_v2_decoder_layer_impl.cpp | 29 +- .../npu/npu_deepseek_v2_decoder_layer_impl.h | 6 +- .../npu_deepseek_v32_decoder_layer_impl.cpp | 35 +- .../npu/npu_deepseek_v32_decoder_layer_impl.h | 6 +- .../npu/npu_eagle3_decoder_layer_impl.cpp | 11 +- .../npu/npu_eagle3_decoder_layer_impl.h | 6 +- .../npu/npu_glm4_decoder_layer_impl.cpp | 19 +- .../layers/npu/npu_glm4_decoder_layer_impl.h | 6 +- .../layers/npu/npu_glm4_moe_decoder_layer.cpp | 36 +- .../layers/npu/npu_glm4_moe_decoder_layer.h | 8 +- .../npu/npu_glm4_moe_lite_decoder_layer.cpp | 36 +- .../npu/npu_glm4_moe_lite_decoder_layer.h | 6 +- .../npu_glm4_vision_encoder_layer_impl.cpp | 12 +- .../npu/npu_glm4_vision_encoder_layer_impl.h | 2 - .../npu/npu_llama_decoder_layer_impl.cpp | 19 +- .../layers/npu/npu_llama_decoder_layer_impl.h | 6 +- .../npu/npu_onerec_block_layer_impl.cpp | 35 +- .../layers/npu/npu_onerec_block_layer_impl.h | 15 +- .../npu/npu_qwen2_decoder_layer_impl.cpp | 19 +- .../layers/npu/npu_qwen2_decoder_layer_impl.h | 6 +- .../npu_qwen2_vision_encoder_layer_impl.cpp | 12 +- .../npu/npu_qwen2_vision_encoder_layer_impl.h | 4 +- ...pu_qwen2dot5_vision_encoder_layer_impl.cpp | 12 +- .../npu_qwen2dot5_vision_encoder_layer_impl.h | 2 - .../npu/npu_qwen3_decoder_layer_impl.cpp | 11 +- .../layers/npu/npu_qwen3_decoder_layer_impl.h | 6 +- .../npu/npu_qwen3_moe_decoder_layer_impl.cpp | 11 +- .../npu/npu_qwen3_moe_decoder_layer_impl.h | 6 +- .../npu_qwen3_vision_encoder_layer_impl.cpp | 12 +- .../npu/npu_qwen3_vision_encoder_layer_impl.h | 2 - xllm/core/layers/npu_torch/fused_moe.cpp | 6 +- xllm/core/layers/npu_torch/fused_moe.h | 6 +- .../npu_torch/qwen3_gated_delta_net_base.cpp | 23 +- .../npu_torch/qwen3_gated_delta_net_base.h | 9 +- .../qwen3_next_hybrid_decoder_layer_base.cpp | 8 +- .../qwen3_next_hybrid_decoder_layer_base.h | 7 +- xllm/core/layers/onerec_block_layer.h | 5 +- xllm/core/layers/oxygen_vision_layer.cpp | 11 +- xllm/core/layers/oxygen_vision_layer.h | 1 - xllm/core/layers/qwen2_5_vision_layer.cpp | 11 +- xllm/core/layers/qwen2_5_vision_layer.h | 1 - xllm/core/layers/qwen2_decoder_layer.cpp | 2 +- xllm/core/layers/qwen2_decoder_layer.h | 5 +- xllm/core/layers/qwen3_moe_decoder_layer.cpp | 20 +- xllm/core/layers/qwen3_moe_decoder_layer.h | 8 +- xllm/core/runtime/acl_graph_executor_impl.cpp | 250 ++++++-------- xllm/core/runtime/acl_graph_executor_impl.h | 39 +-- xllm/core/runtime/base_executor_impl.cpp | 8 +- xllm/core/runtime/base_executor_impl.h | 6 +- .../core/runtime/cuda_graph_executor_impl.cpp | 320 ++++++++---------- xllm/core/runtime/cuda_graph_executor_impl.h | 43 +-- xllm/core/runtime/dit_worker_impl.cpp | 2 +- xllm/core/runtime/embed_vlm_worker_impl.cpp | 11 +- xllm/core/runtime/embed_worker_impl.cpp | 9 +- xllm/core/runtime/executor.cpp | 8 +- xllm/core/runtime/executor.h | 9 +- xllm/core/runtime/executor_impl.h | 10 +- xllm/core/runtime/forward_params.h | 294 ++++++++++++---- .../runtime/forward_shared_memory_manager.cpp | 7 +- xllm/core/runtime/llm_worker_impl.cpp | 12 +- xllm/core/runtime/mlu_graph_executor_impl.cpp | 123 ++++--- xllm/core/runtime/mlu_graph_executor_impl.h | 22 +- .../core/runtime/mm_embed_vlm_worker_impl.cpp | 10 +- xllm/core/runtime/mtp_worker_impl.cpp | 68 ++-- xllm/core/runtime/rec_worker_impl.cpp | 273 +++++++-------- xllm/core/runtime/spec_input_builder.cpp | 34 +- xllm/core/runtime/spec_input_builder.h | 9 +- xllm/core/runtime/speculative_worker_impl.cpp | 14 +- xllm/core/runtime/suffix_worker_impl.cpp | 15 +- xllm/core/runtime/vlm_executor_impl.cpp | 23 +- xllm/core/runtime/vlm_executor_impl.h | 8 +- xllm/core/runtime/vlm_worker_impl.cpp | 3 +- xllm/core/runtime/worker_impl.cpp | 137 ++++---- xllm/core/runtime/worker_impl.h | 9 +- xllm/models/dit/pipeline_longcat_image.h | 20 +- xllm/models/dit/pipeline_longcat_image_edit.h | 22 +- xllm/models/llm/deepseek_v2.h | 23 +- xllm/models/llm/deepseek_v32.h | 26 +- xllm/models/llm/llm_model_base.h | 30 +- xllm/models/llm/mtp_model_base.h | 31 +- xllm/models/llm/musa/qwen3.h | 55 +-- xllm/models/llm/npu/deepseek_v2.h | 35 +- xllm/models/llm/npu/deepseek_v32.h | 33 +- xllm/models/llm/npu/glm4.h | 42 +-- xllm/models/llm/npu/glm4_moe.h | 56 ++- xllm/models/llm/npu/glm4_moe_lite.h | 42 +-- xllm/models/llm/npu/glm5_moe.h | 21 +- xllm/models/llm/npu/joyai_llm_flash.h | 23 +- xllm/models/llm/npu/llama.h | 39 ++- xllm/models/llm/npu/llm_model_base.h | 66 ++-- xllm/models/llm/npu/mtp_model_base.h | 53 ++- xllm/models/llm/npu/oxygen.h | 48 +-- xllm/models/llm/npu/qwen3.h | 48 +-- xllm/models/llm/npu/qwen3_eagle3.h | 52 ++- xllm/models/llm/npu/qwen3_moe.h | 53 ++- xllm/models/llm/oxygen.h | 45 +-- xllm/models/llm/qwen3.h | 83 +++-- xllm/models/llm/qwen3_5_mtp.h | 26 +- xllm/models/llm/qwen3_moe.h | 36 +- xllm/models/llm/qwen3_next_hybrid_base.h | 44 +-- xllm/models/rec/npu/onerec.h | 19 +- xllm/models/rec/npu/onerec_npu_impl.h | 65 ++-- xllm/models/rec/onerec.h | 10 +- xllm/models/rec/rec_model_base.h | 9 +- xllm/models/vlm/npu/glm4v.h | 52 +-- xllm/models/vlm/npu/glm4v_moe.h | 27 +- xllm/models/vlm/npu/minicpmv.h | 21 +- xllm/models/vlm/npu/oxygen_vlm.h | 27 +- xllm/models/vlm/npu/qwen2_5_vl.h | 51 +-- xllm/models/vlm/npu/qwen2_5_vl_mm_embedding.h | 19 +- xllm/models/vlm/npu/qwen2_vl.h | 48 +-- xllm/models/vlm/npu/qwen3_vl.h | 62 ++-- xllm/models/vlm/npu/qwen3_vl_mm_embedding.h | 23 +- xllm/models/vlm/npu/qwen3_vl_moe.h | 35 +- xllm/models/vlm/oxygen_vlm.h | 42 +-- xllm/models/vlm/qwen2_5_vl.h | 51 ++- xllm/models/vlm/qwen2_vl.h | 38 +-- xllm/models/vlm/qwen3_5.h | 42 +-- xllm/models/vlm/qwen3_vl.h | 15 +- xllm/models/vlm/qwen3_vl_base.h | 38 +-- xllm/models/vlm/qwen3_vl_moe.h | 35 +- 174 files changed, 2624 insertions(+), 3095 deletions(-) diff --git a/tests/core/framework/batch/batch_test.cpp b/tests/core/framework/batch/batch_test.cpp index 59505f3e04..4cd81dd964 100644 --- a/tests/core/framework/batch/batch_test.cpp +++ b/tests/core/framework/batch/batch_test.cpp @@ -324,13 +324,13 @@ TEST(BatchTest, Basic) { EXPECT_TRUE(equal(forward_input.positions, expected_pos)); // check the input parameters - const ModelInputParams& input_params = forward_input.input_params; - EXPECT_TRUE(input_params.meta.batch_forward_type.is_mixed()); - EXPECT_EQ(input_params.meta.num_sequences, 4); - EXPECT_EQ(input_params.meta.q_max_seq_len, 9); - EXPECT_EQ(input_params.meta.kv_max_seq_len, 16); - EXPECT_EQ(input_params.embedding.embedding_ids, std::vector({-1, -1, -1})); - EXPECT_EQ(input_params.embedding.linear_state_ids, + EXPECT_TRUE(forward_input.meta.batch_forward_type.is_mixed()); + EXPECT_EQ(forward_input.meta.num_sequences, 4); + EXPECT_EQ(forward_input.meta.q_max_seq_len, 9); + EXPECT_EQ(forward_input.meta.kv_max_seq_len, 16); + EXPECT_EQ(forward_input.embedding.embedding_ids, + std::vector({-1, -1, -1})); + EXPECT_EQ(forward_input.embedding.linear_state_ids, std::vector({-1, -1, -1, -1})); #if defined(USE_NPU) @@ -338,7 +338,7 @@ TEST(BatchTest, Basic) { #else const std::vector q_seq_lens = {0, 9, 10, 11, 15}; #endif - EXPECT_TRUE(equal(input_params.attention.device.q_seq_lens, q_seq_lens)); + EXPECT_TRUE(equal(forward_input.attention.device.q_seq_lens, q_seq_lens)); // seq4's kv_seq_len = q_len + num_cached_tokens (q_len<=max_allowed_tokens) #if defined(USE_NPU) @@ -346,7 +346,7 @@ TEST(BatchTest, Basic) { #else const std::vector kv_seq_lens = {0, 9, 17, 33, 41}; #endif - EXPECT_TRUE(equal(input_params.attention.device.kv_seq_lens, kv_seq_lens)); + EXPECT_TRUE(equal(forward_input.attention.device.kv_seq_lens, kv_seq_lens)); const std::vector new_cache_slots = { /*seq1*/ 4, 5, 6, 7, 8, 9, 10, 11, 12, @@ -354,17 +354,18 @@ TEST(BatchTest, Basic) { /*seq3*/ 47, /*seq4*/ 56,57,58,59 }; - EXPECT_TRUE(equal(input_params.attention.device.new_cache_slots, new_cache_slots)); + EXPECT_TRUE(equal(forward_input.attention.device.new_cache_slots, + new_cache_slots)); const std::vector block_tables = { /*seq1*/ 1, 2, 3, 0, 0, /*seq2*/ 4, 5, 6, 7, 0, /*seq3*/ 8, 9, 10, 11, 12, /*seq4*/ 13, 14, 15, 0, 0}; - EXPECT_TRUE(equal(input_params.attention.device.block_tables, block_tables)); + EXPECT_TRUE(equal(forward_input.attention.device.block_tables, block_tables)); // const std::vector last_token_idxes = {8, 9, 10}; - // EXPECT_TRUE(equal(input_params.last_token_idxes, last_token_idxes)); + // EXPECT_TRUE(equal(forward_input.last_token_idxes, last_token_idxes)); const auto& sampling_params = forward_input.sampling_params; const std::vector unique_ids = { @@ -578,7 +579,7 @@ TEST(BatchTest, ForwardInputPreservesTransferInfoAndBatchId) { (std::vector{1, 2})); EXPECT_EQ(input.transfer_kv_infos[0].remote_blocks_ids, (std::vector{100, 101})); - EXPECT_EQ(input.input_params.meta.batch_id, batch_id); + EXPECT_EQ(input.meta.batch_id, batch_id); } TEST(BatchTest, ForwardInputPackedRoundTripPreservesTransportFields) { @@ -650,7 +651,7 @@ TEST(BatchTest, ForwardInputPackedRoundTripPreservesTransportFields) { ForwardInput round_trip; reader_manager.input_read(round_trip, torch::Device(torch::kCPU)); - EXPECT_EQ(round_trip.input_params.meta.batch_id, batch_id); + EXPECT_EQ(round_trip.meta.batch_id, batch_id); EXPECT_TRUE(equal(round_trip.token_ids, std::vector({1, 2, 3, 4}))); ASSERT_EQ(round_trip.transfer_kv_infos.size(), 1u); EXPECT_EQ(round_trip.transfer_kv_infos[0].local_blocks_ids, @@ -709,18 +710,18 @@ TEST(BatchTest, ForwardInputBlockCopyKernelFieldsMatchExpectedLayout) { forward_builder.build_forward_input(/*num_decoding_tokens=*/1, /*min_decoding_batch_size=*/0); - EXPECT_TRUE(equal(forward_input.input_params.block_copy.src_block_indices, + EXPECT_TRUE(equal(forward_input.block_copy.src_block_indices, std::vector({7, 8}))); - EXPECT_TRUE(equal(forward_input.input_params.block_copy.dst_block_indices, + EXPECT_TRUE(equal(forward_input.block_copy.dst_block_indices, std::vector({10, 11, 12}))); - EXPECT_TRUE(equal(forward_input.input_params.block_copy.cum_sum, - std::vector({2, 3}))); + EXPECT_TRUE( + equal(forward_input.block_copy.cum_sum, std::vector({2, 3}))); #if defined(USE_CUDA) - EXPECT_EQ(forward_input.input_params.block_copy.swap_blocks.size(), + EXPECT_EQ(forward_input.block_copy.swap_blocks.size(), forward_swap_blocks.size()); #else - EXPECT_TRUE(forward_input.input_params.block_copy.swap_blocks.empty()); + EXPECT_TRUE(forward_input.block_copy.swap_blocks.empty()); #endif FLAGS_enable_block_copy_kernel = old_enable_block_copy_kernel; @@ -780,13 +781,13 @@ TEST(BatchTest, ForwardInputCpPartitionMatchesExpectedLayout) { equal(cp_forward_input.positions, std::vector({0, 1, 6, 7}))); EXPECT_TRUE(equal(cp_forward_input.sampling_params.selected_token_idxes, std::vector({3}))); - EXPECT_EQ(cp_forward_input.input_params.meta.q_max_seq_len, 4); - EXPECT_EQ(cp_forward_input.input_params.meta.kv_max_seq_len, 4); + EXPECT_EQ(cp_forward_input.meta.q_max_seq_len, 4); + EXPECT_EQ(cp_forward_input.meta.kv_max_seq_len, 4); const std::vector& q_seq_lens = - cp_forward_input.input_params.attention.host.q_seq_lens; + cp_forward_input.attention.host.q_seq_lens; const std::vector& kv_seq_lens = - cp_forward_input.input_params.attention.host.kv_seq_lens; + cp_forward_input.attention.host.kv_seq_lens; EXPECT_TRUE((q_seq_lens == std::vector({4}) || q_seq_lens == std::vector({0, 4}))); EXPECT_TRUE((kv_seq_lens == std::vector({4}) || @@ -910,13 +911,11 @@ TEST(BatchTest, SampleRequestKeepsThreadedForwardBuilderOffsetsStable) { expected_selected_token_idxes)); EXPECT_TRUE( equal(forward_input.sampling_params.sample_idxes, expected_sample_idxes)); - ASSERT_EQ(forward_input.input_params.embedding.embedding_ids.size(), - sequences.size()); - ASSERT_EQ(forward_input.input_params.embedding.linear_state_ids.size(), - sequences.size()); - EXPECT_EQ(forward_input.input_params.embedding.embedding_ids, + ASSERT_EQ(forward_input.embedding.embedding_ids.size(), sequences.size()); + ASSERT_EQ(forward_input.embedding.linear_state_ids.size(), sequences.size()); + EXPECT_EQ(forward_input.embedding.embedding_ids, std::vector({-1, -1})); - EXPECT_EQ(forward_input.input_params.embedding.linear_state_ids, + EXPECT_EQ(forward_input.embedding.linear_state_ids, std::vector({-1, -1})); } @@ -967,11 +966,10 @@ TEST(BatchTest, DecodeMinBatchSizeDoesNotPadTransportState) { builder.build_forward_input(/*num_decoding_tokens=*/1, /*min_decoding_batch_size=*/3); - EXPECT_EQ(forward_input.input_params.meta.num_sequences, 1); - EXPECT_EQ(forward_input.input_params.embedding.linear_state_ids, - std::vector({-1})); - EXPECT_EQ(forward_input.input_params.embedding.embedding_ids, + EXPECT_EQ(forward_input.meta.num_sequences, 1); + EXPECT_EQ(forward_input.embedding.linear_state_ids, std::vector({-1})); + EXPECT_EQ(forward_input.embedding.embedding_ids, std::vector({-1})); } TEST(BatchTest, DecodeSingleBlockIdsStaySplitInTransportButShareSlotValue) { @@ -1026,15 +1024,13 @@ TEST(BatchTest, DecodeSingleBlockIdsStaySplitInTransportButShareSlotValue) { builder.build_forward_input(/*num_decoding_tokens=*/1, /*min_decoding_batch_size=*/0); - ASSERT_EQ(forward_input.input_params.embedding.embedding_ids.size(), 1u); - ASSERT_EQ(forward_input.input_params.embedding.linear_state_ids.size(), 1u); - EXPECT_EQ(forward_input.input_params.embedding.embedding_ids[0], - expected_slot_id); - EXPECT_EQ(forward_input.input_params.embedding.linear_state_ids[0], - expected_slot_id); + ASSERT_EQ(forward_input.embedding.embedding_ids.size(), 1u); + ASSERT_EQ(forward_input.embedding.linear_state_ids.size(), 1u); + EXPECT_EQ(forward_input.embedding.embedding_ids[0], expected_slot_id); + EXPECT_EQ(forward_input.embedding.linear_state_ids[0], expected_slot_id); } -TEST(BatchTest, SharedMemoryRoundTripPreservesLinearStateIds) { +TEST(BatchTest, SharedMemoryRoundTripPreservesAndDefaultsLinearStateIds) { ForwardInput forward_input; auto int_options = torch::TensorOptions() .dtype(torch::kInt) @@ -1046,30 +1042,30 @@ TEST(BatchTest, SharedMemoryRoundTripPreservesLinearStateIds) { forward_input.positions = torch::tensor(std::vector({0, 0}), int_options); forward_input.positions_host = forward_input.positions; - forward_input.input_params.meta.batch_forward_type = BatchForwardType::DECODE; - forward_input.input_params.meta.num_sequences = 2; - forward_input.input_params.meta.kv_max_seq_len = 1; - forward_input.input_params.meta.q_max_seq_len = 1; - forward_input.input_params.attention.host.kv_seq_lens = {1, 1}; - forward_input.input_params.attention.host.q_seq_lens = {1, 1}; - forward_input.input_params.attention.host.q_cu_seq_lens = {1, 2}; - forward_input.input_params.attention.host.kv_cache_tokens_nums = {0, 0}; - forward_input.input_params.attention.host.new_cache_slots = {0, 0}; - forward_input.input_params.attention.device.kv_seq_lens = + forward_input.meta.batch_forward_type = BatchForwardType::DECODE; + forward_input.meta.num_sequences = 2; + forward_input.meta.kv_max_seq_len = 1; + forward_input.meta.q_max_seq_len = 1; + forward_input.attention.host.kv_seq_lens = {1, 1}; + forward_input.attention.host.q_seq_lens = {1, 1}; + forward_input.attention.host.q_cu_seq_lens = {1, 2}; + forward_input.attention.host.kv_cache_tokens_nums = {0, 0}; + forward_input.attention.host.new_cache_slots = {0, 0}; + forward_input.attention.device.kv_seq_lens = torch::tensor(std::vector({1, 1}), int_options); - forward_input.input_params.attention.device.q_seq_lens = + forward_input.attention.device.q_seq_lens = torch::tensor(std::vector({1, 1}), int_options); - forward_input.input_params.attention.device.q_cu_seq_lens = + forward_input.attention.device.q_cu_seq_lens = torch::tensor(std::vector({1, 2}), int_options); - forward_input.input_params.attention.device.kv_cache_tokens_nums = + forward_input.attention.device.kv_cache_tokens_nums = torch::tensor(std::vector({0, 0}), int_options); - forward_input.input_params.attention.device.new_cache_slots = + forward_input.attention.device.new_cache_slots = torch::tensor(std::vector({0, 0}), int_options); - forward_input.input_params.attention.device.block_tables = create_2d_tensor( + forward_input.attention.device.block_tables = create_2d_tensor( std::vector>{{0}, {0}}, torch::kInt); - forward_input.input_params.attention.host.block_tables = - forward_input.input_params.attention.device.block_tables; - forward_input.input_params.embedding.linear_state_ids = {4, 6}; + forward_input.attention.host.block_tables = + forward_input.attention.device.block_tables; + forward_input.embedding.linear_state_ids = {4, 6}; bool is_creator = false; auto shm_name = @@ -1086,8 +1082,15 @@ TEST(BatchTest, SharedMemoryRoundTripPreservesLinearStateIds) { ForwardInput from_shm; reader_manager.input_read(from_shm, torch::Device(torch::kCPU)); - EXPECT_EQ(from_shm.input_params.embedding.linear_state_ids, - std::vector({4, 6})); + EXPECT_EQ(from_shm.embedding.linear_state_ids, std::vector({4, 6})); + + forward_input.embedding.linear_state_ids.clear(); + ASSERT_TRUE(writer_manager.input_write(forward_input)); + + ForwardInput legacy_from_shm; + reader_manager.input_read(legacy_from_shm, torch::Device(torch::kCPU)); + EXPECT_EQ(legacy_from_shm.embedding.linear_state_ids, + std::vector({-1, -1})); } TEST(BatchTest, SampleRequestProcessesAllMatchedRawOutputs) { diff --git a/tests/core/framework/hf_model_loader_test.cpp b/tests/core/framework/hf_model_loader_test.cpp index 67ebc16dba..c03238677e 100644 --- a/tests/core/framework/hf_model_loader_test.cpp +++ b/tests/core/framework/hf_model_loader_test.cpp @@ -41,14 +41,10 @@ class DummyRecCausalLM final : public RecCausalLM { explicit DummyRecCausalLM(const torch::TensorOptions& options) : options_(options) {} - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - UNUSED_PARAMETER(tokens); - UNUSED_PARAMETER(positions); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + UNUSED_PARAMETER(input); UNUSED_PARAMETER(kv_caches); - UNUSED_PARAMETER(parameters); return ModelOutput(); } diff --git a/tests/core/layers/mlu/deepseek_v2_decoder_layer_test.cpp b/tests/core/layers/mlu/deepseek_v2_decoder_layer_test.cpp index 916037650b..3d88d3ffdf 100644 --- a/tests/core/layers/mlu/deepseek_v2_decoder_layer_test.cpp +++ b/tests/core/layers/mlu/deepseek_v2_decoder_layer_test.cpp @@ -39,6 +39,7 @@ limitations under the License. #include "layers/common/attention_metadata_builder.h" #include "layers/mlu/tests_utils.h" #include "platform/device.h" +#include "runtime/forward_params.h" namespace xllm { namespace layer { @@ -64,7 +65,7 @@ class DeepseekV2DecoderLayerTestPeer { DeepseekV2DecoderLayerImpl& decoder, torch::Tensor x, const torch::Tensor& residual, - const ModelInputParams& input_params, + const ForwardInput& input_params, DeepseekV2AttentionImpl::PostAttnLayout attn_layout, bool need_dp_gather, bool enable_moe_all2all) { @@ -74,11 +75,10 @@ class DeepseekV2DecoderLayerTestPeer { return decoder.build_post_attn_carrier(x, residual, attn_layout); } - static torch::Tensor restore_ffn_output( - DeepseekV2DecoderLayerImpl& decoder, - torch::Tensor x, - const Carrier& carrier, - const ModelInputParams& input_params) { + static torch::Tensor restore_ffn_output(DeepseekV2DecoderLayerImpl& decoder, + torch::Tensor x, + const Carrier& carrier, + const ForwardInput& input_params) { (void)input_params; return decoder.restore_ffn_output(x, carrier); } @@ -107,8 +107,9 @@ class DeepseekV2DecoderLayerTestPeer { static torch::Tensor run_mlp(DeepseekV2DecoderLayerImpl& decoder, torch::Tensor x, - const ModelInputParams& input_params) { - return decoder.run_mlp(std::move(x), input_params); + BatchForwardType batch_forward_type) { + return decoder.run_mlp(std::move(x), + decoder.can_sp_chunk(batch_forward_type)); } static int64_t sp_ffn_chunk(DeepseekV2DecoderLayerImpl& decoder) { @@ -736,8 +737,8 @@ class DeepseekV2DecoderLayerTest : public ::testing::Test { decoder->set_sequence_parallel_context(&sp_ctx_); } - ModelInputParams build_prefill_params(int64_t batch_size, int64_t seq_len) { - ModelInputParams input_params; + ForwardInput build_prefill_params(int64_t batch_size, int64_t seq_len) { + ForwardInput input_params; input_params.meta.batch_forward_type = BatchForwardType::PREFILL; input_params.meta.num_sequences = batch_size; input_params.meta.q_max_seq_len = seq_len; @@ -773,7 +774,8 @@ class DeepseekV2DecoderLayerTest : public ::testing::Test { .device(options_.device()))); } - return input_params.to(options_.device()); + return input_params.to(options_.device(), + c10::typeMetaToScalarType(options_.dtype())); } KVCache build_indexed_cache(torch::Tensor key_cache, @@ -1046,7 +1048,7 @@ TEST_P(DeepseekV2DecoderCarrierTest, const auto& tc = GetParam(); auto* tp_pg_raw = init_env(tc); auto decoder = make_decoder(/*layer_id=*/0); - ModelInputParams input_params; + ForwardInput input_params; input_params.parallel.dp_global_token_nums = tc.dp_global_token_nums; set_tp_full_tokens(tc); @@ -1090,7 +1092,7 @@ TEST_F(DeepseekV2DecoderLayerTest, BuildPostAttnCarrierPackedLocal) { torch::full({2, model_args_.hidden_size()}, 5.0f, hidden_opts()); sp_pg_->set_allgather_outputs({expected_local_norm, remote_norm}); - ModelInputParams input_params; + ForwardInput input_params; auto carrier = DeepseekV2DecoderLayerTestPeer::build_post_attn_carrier( *decoder, attn_out, @@ -1110,7 +1112,7 @@ TEST_F(DeepseekV2DecoderLayerTest, BuildPostAttnCarrierPackedLocal) { TEST_F(DeepseekV2DecoderLayerTest, RestoreFfnOutputReplicated) { auto decoder = make_decoder(/*layer_id=*/0); - ModelInputParams input_params; + ForwardInput input_params; auto attn_out = torch::tensor( {{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -1156,7 +1158,7 @@ TEST_F(DeepseekV2DecoderLayerTest, RestoreFfnOutputPackedLocal) { sp_pg_->set_allgather_outputs( std::vector{expected_local_norm, remote_norm}); - ModelInputParams input_params; + ForwardInput input_params; auto carrier = DeepseekV2DecoderLayerTestPeer::build_post_attn_carrier( *decoder, attn_out, @@ -1333,7 +1335,7 @@ TEST_F(DeepseekV2DecoderLayerTest, ForwardMixedDpMoEReturnsLocalSlice) { 2, torch::TensorOptions().dtype(torch::kInt32).device(options_.device())); - ModelInputParams input_params; + ForwardInput input_params; input_params.meta.batch_forward_type = BatchForwardType::PREFILL; input_params.meta.num_sequences = 2; input_params.meta.q_max_seq_len = 1; @@ -1355,10 +1357,16 @@ TEST_F(DeepseekV2DecoderLayerTest, ForwardMixedDpMoEReturnsLocalSlice) { torch::TensorOptions().dtype(torch::kInt32).device(options_.device())); input_params.parallel.dp_global_token_nums = {2, 1}; input_params.parallel.dp_is_decode = {0, 0}; - input_params = input_params.to(options_.device()); + input_params = input_params.to(options_.device(), + c10::typeMetaToScalarType(options_.dtype())); auto attn_metadata = - AttentionMetadataBuilder::build(input_params, /*enable_mla=*/true); + AttentionMetadataBuilder::build(input_params.meta, + input_params.attention, + input_params.graph, + input_params.llmrec_params(), + input_params.enable_cuda_graph, + /*enable_mla=*/true); auto k_cache = torch::zeros( {2048, 1, 1, model_args_.qk_rope_head_dim() + model_args_.kv_lora_rank()}, options_); @@ -1433,7 +1441,7 @@ TEST_F(DeepseekV2DecoderLayerTest, DenseMlpChunkMatchesDirectPrefill) { auto expected = DeepseekV2DecoderLayerTestPeer::mlp(*decoder)->forward(hidden_states); auto actual = DeepseekV2DecoderLayerTestPeer::run_mlp( - *decoder, hidden_states, input_params); + *decoder, hidden_states, input_params.meta.batch_forward_type); sync_dev(); test::verify_tensor_close(actual, expected, 1e-3, 1e-4); @@ -1454,7 +1462,7 @@ TEST_F(DeepseekV2DecoderLayerTest, DenseMlpChunkMatchesDirectChunkedPrefill) { auto expected = DeepseekV2DecoderLayerTestPeer::mlp(*decoder)->forward(hidden_states); auto actual = DeepseekV2DecoderLayerTestPeer::run_mlp( - *decoder, hidden_states, input_params); + *decoder, hidden_states, input_params.meta.batch_forward_type); sync_dev(); test::verify_tensor_close(actual, expected, 1e-3, 1e-4); @@ -1478,7 +1486,12 @@ TEST_P(DeepseekV2DecoderLayerParamTest, torch::TensorOptions().dtype(torch::kInt32).device(options_.device())); auto input_params = build_prefill_params(kBatchSize, kSeqLen); auto attn_metadata = - AttentionMetadataBuilder::build(input_params, /*enable_mla=*/true); + AttentionMetadataBuilder::build(input_params.meta, + input_params.attention, + input_params.graph, + input_params.llmrec_params(), + input_params.enable_cuda_graph, + /*enable_mla=*/true); auto kv_cache = build_cache(block_num, /*block_size=*/1); std::optional residual = std::nullopt; diff --git a/tests/core/layers/mlu/deepseek_v2_sparse_moe_block_test.cpp b/tests/core/layers/mlu/deepseek_v2_sparse_moe_block_test.cpp index 6cf782f161..2755ff56cc 100644 --- a/tests/core/layers/mlu/deepseek_v2_sparse_moe_block_test.cpp +++ b/tests/core/layers/mlu/deepseek_v2_sparse_moe_block_test.cpp @@ -257,17 +257,17 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PlanExecEnablesAll2AllOnlyForDecode) { auto block = create_block(); DeepseekV2SparseMoEBlockTestPeer::set_enable_deep_ep(*block, true); - ModelInputParams decode_params; - decode_params.parallel.dp_global_token_nums = {1, 1}; - decode_params.parallel.dp_is_decode = {1, 1}; - auto decode_cfg = block->plan_exec(decode_params); + ParallelInput decode_input; + decode_input.dp_global_token_nums = {1, 1}; + decode_input.dp_is_decode = {1, 1}; + auto decode_cfg = block->plan_exec(decode_input); EXPECT_TRUE(decode_cfg.enable_all2all); EXPECT_FALSE(decode_cfg.need_dp_gather); - ModelInputParams mixed_params; - mixed_params.parallel.dp_global_token_nums = {2, 1}; - mixed_params.parallel.dp_is_decode = {0, 1}; - auto mixed_cfg = block->plan_exec(mixed_params); + ParallelInput mixed_input; + mixed_input.dp_global_token_nums = {2, 1}; + mixed_input.dp_is_decode = {0, 1}; + auto mixed_cfg = block->plan_exec(mixed_input); EXPECT_FALSE(mixed_cfg.enable_all2all); EXPECT_FALSE(mixed_cfg.need_dp_gather); } @@ -276,11 +276,11 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PlanExecSetsDpGatherWhenAll2AllOff) { set_tp_dp_ctx(/*world_size=*/4, /*dp_size=*/2, /*tp_size=*/2, /*ep_size=*/4); auto block = create_block(); - ModelInputParams input_params; - input_params.parallel.dp_global_token_nums = {3, 1}; - input_params.parallel.dp_is_decode = {0, 0}; + ParallelInput parallel_input; + parallel_input.dp_global_token_nums = {3, 1}; + parallel_input.dp_is_decode = {0, 0}; - auto cfg = block->plan_exec(input_params); + auto cfg = block->plan_exec(parallel_input); EXPECT_FALSE(cfg.enable_all2all); EXPECT_TRUE(cfg.need_dp_gather); } @@ -289,8 +289,8 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PrepInDpGatherBuildsLocalSkip) { set_tp_dp_ctx(/*world_size=*/4, /*dp_size=*/2, /*tp_size=*/2, /*ep_size=*/4); auto block = create_block(); - ModelInputParams input_params; - input_params.parallel.dp_global_token_nums = {3, 1}; + ParallelInput parallel_input; + parallel_input.dp_global_token_nums = {3, 1}; auto attn_out = mat(/*rows=*/4, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); auto residual = @@ -302,8 +302,8 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PrepInDpGatherBuildsLocalSkip) { auto prep = block->prep_in(attn_out, residual, - input_params, - block->plan_exec(input_params), + parallel_input, + block->plan_exec(parallel_input), DeepseekV2AttentionImpl::PostAttnLayout::kTpShard); EXPECT_TRUE(prep.need_dp_gather); @@ -319,8 +319,8 @@ TEST_F(DeepseekV2SparseMoEBlockTest, GatherInDpGatherRebuildsGlobalTokens) { set_tp_dp_ctx(/*world_size=*/4, /*dp_size=*/2, /*tp_size=*/2, /*ep_size=*/4); auto block = create_block(); - ModelInputParams input_params; - input_params.parallel.dp_global_token_nums = {3, 1}; + ParallelInput parallel_input; + parallel_input.dp_global_token_nums = {3, 1}; DeepseekV2SparseMoEBlockImpl::PrepOut prep; prep.ffn_in = mat(/*rows=*/2, {11.0f, 22.0f, 33.0f, 44.0f}); prep.need_dp_gather = true; @@ -329,7 +329,7 @@ TEST_F(DeepseekV2SparseMoEBlockTest, GatherInDpGatherRebuildsGlobalTokens) { auto dp1_tp1 = torch::zeros_like(dp1_tp0); global_pg_->set_allgather_outputs({prep.ffn_in, dp0_tp1, dp1_tp0, dp1_tp1}); - auto gathered = block->gather_in(prep, input_params); + auto gathered = block->gather_in(prep, parallel_input); test::verify_tensor_close( gathered, @@ -342,16 +342,16 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PrepInAll2AllPadsTpShardInput) { auto block = create_block(); DeepseekV2SparseMoEBlockTestPeer::set_enable_deep_ep(*block, true); - ModelInputParams input_params; - input_params.parallel.dp_global_token_nums = {1, 1}; - input_params.parallel.dp_is_decode = {1, 1}; + ParallelInput parallel_input; + parallel_input.dp_global_token_nums = {1, 1}; + parallel_input.dp_is_decode = {1, 1}; auto attn_out = mat(/*rows=*/3, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); auto residual = mat(/*rows=*/3, {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}); auto prep = block->prep_in(attn_out, residual, - input_params, - block->plan_exec(input_params), + parallel_input, + block->plan_exec(parallel_input), DeepseekV2AttentionImpl::PostAttnLayout::kTpShard); EXPECT_FALSE(prep.need_dp_gather); @@ -368,10 +368,10 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PrepInUsesProvidedExecCfg) { set_tp_dp_ctx(/*world_size=*/4, /*dp_size=*/2, /*tp_size=*/2, /*ep_size=*/4); auto block = create_block(); - ModelInputParams input_params; - input_params.parallel.dp_global_token_nums = {3, 1}; - input_params.parallel.dp_is_decode = {0, 0}; - auto planned_cfg = block->plan_exec(input_params); + ParallelInput parallel_input; + parallel_input.dp_global_token_nums = {3, 1}; + parallel_input.dp_is_decode = {0, 0}; + auto planned_cfg = block->plan_exec(parallel_input); EXPECT_FALSE(planned_cfg.enable_all2all); EXPECT_TRUE(planned_cfg.need_dp_gather); @@ -383,7 +383,7 @@ TEST_F(DeepseekV2SparseMoEBlockTest, PrepInUsesProvidedExecCfg) { auto prep = block->prep_in(attn_out, residual, - input_params, + parallel_input, forced_cfg, DeepseekV2AttentionImpl::PostAttnLayout::kTpShard); @@ -397,7 +397,7 @@ TEST_F(DeepseekV2SparseMoEBlockTest, MergeOutTpPadGathersAndUnpads) { set_tp_ctx(/*world_size=*/2, /*ep_size=*/2); auto block = create_block(); - ModelInputParams input_params; + ParallelInput parallel_input; DeepseekV2SparseMoEBlockImpl::PrepOut prep; prep.need_tp_pad = true; prep.pad_info = {.original_tokens = 3, .padded_tokens = 4, .active = true}; @@ -406,7 +406,7 @@ TEST_F(DeepseekV2SparseMoEBlockTest, MergeOutTpPadGathersAndUnpads) { prep.skip_local = shard0; tp_pg_->set_allgather_outputs({shard0, shard1}); - auto merged = block->merge_out(shard0, prep, input_params); + auto merged = block->merge_out(shard0, prep, parallel_input); test::verify_tensor_close( merged, mat(/*rows=*/3, {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f})); @@ -416,8 +416,8 @@ TEST_F(DeepseekV2SparseMoEBlockTest, MergeOutDpGatherSlicesLocalTokens) { set_tp_dp_ctx(/*world_size=*/4, /*dp_size=*/2, /*tp_size=*/2, /*ep_size=*/4); auto block = create_block(); - ModelInputParams input_params; - input_params.parallel.dp_global_token_nums = {3, 1}; + ParallelInput parallel_input; + parallel_input.dp_global_token_nums = {3, 1}; DeepseekV2SparseMoEBlockImpl::PrepOut prep; prep.skip_local = mat(/*rows=*/3, {11.0f, 22.0f, 33.0f, 44.0f, 55.0f, 66.0f}); prep.need_dp_gather = true; @@ -425,7 +425,7 @@ TEST_F(DeepseekV2SparseMoEBlockTest, MergeOutDpGatherSlicesLocalTokens) { mat(/*rows=*/4, {101.0f, 102.0f, 103.0f, 104.0f, 105.0f, 106.0f, 107.0f, 108.0f}); - auto merged = block->merge_out(ffn_out, prep, input_params); + auto merged = block->merge_out(ffn_out, prep, parallel_input); test::verify_tensor_close( merged, diff --git a/tests/core/layers/mlu/dp_utils_test.cpp b/tests/core/layers/mlu/dp_utils_test.cpp index abe6707df8..aff955b209 100644 --- a/tests/core/layers/mlu/dp_utils_test.cpp +++ b/tests/core/layers/mlu/dp_utils_test.cpp @@ -329,13 +329,13 @@ TEST(DpUtilsTest, UnpadTokensRestoresOriginalLength) { } TEST(DpUtilsTest, AllDpRanksAreDecodeNeedsEveryRankDecode) { - ModelInputParams decode_params; - decode_params.parallel.dp_is_decode = {1, 1, 1}; - EXPECT_TRUE(all_dp_ranks_are_decode(decode_params)); + ParallelInput decode_input; + decode_input.dp_is_decode = {1, 1, 1}; + EXPECT_TRUE(all_dp_ranks_are_decode(decode_input)); - ModelInputParams mixed_params; - mixed_params.parallel.dp_is_decode = {1, 0, 1}; - EXPECT_FALSE(all_dp_ranks_are_decode(mixed_params)); + ParallelInput mixed_input; + mixed_input.dp_is_decode = {1, 0, 1}; + EXPECT_FALSE(all_dp_ranks_are_decode(mixed_input)); } } // namespace test diff --git a/tests/core/layers/mlu/qwen2_vision_attention_test.cpp b/tests/core/layers/mlu/qwen2_vision_attention_test.cpp index 95e1ca3882..9b3cbfea4e 100644 --- a/tests/core/layers/mlu/qwen2_vision_attention_test.cpp +++ b/tests/core/layers/mlu/qwen2_vision_attention_test.cpp @@ -114,10 +114,8 @@ TEST_F(Qwen2VisionAttentionTest, ForwardTest) { torch::kBFloat16, options_.device()); - // Create ModelInputParams - ModelInputParams params; auto output = vision_attention->forward( - hidden_states, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec, params); + hidden_states, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec); xllm::Device device(options_.device()); device.synchronize_default_stream(); @@ -138,7 +136,7 @@ TEST_F(Qwen2VisionAttentionTest, ForwardTest) { test::verify_precision(test_output.unsqueeze(0), expected_values, 1e-4, 1e-5); auto output2 = vision_attention->forward( - hidden_states, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec, params); + hidden_states, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec); device.synchronize_default_stream(); ASSERT_TRUE(torch::allclose(output.flatten().to(torch::kFloat32), output2.flatten().to(torch::kFloat32), diff --git a/tests/core/runtime/acl_graph_executor_test.cpp b/tests/core/runtime/acl_graph_executor_test.cpp index e7efc5aa6e..3e86458ebe 100644 --- a/tests/core/runtime/acl_graph_executor_test.cpp +++ b/tests/core/runtime/acl_graph_executor_test.cpp @@ -145,13 +145,13 @@ class SimpleCausalLM : public CausalLM { torch::Tensor forward_impl(const torch::Tensor& tokens, const torch::Tensor& positions, std::vector& kv_caches, - const ModelInputParams& params) { + const ForwardInput& input) { // Simple computation: token embedding + position embedding + linear layer // This creates temporary tensors that NPUGraph mempool will manage LOG(INFO) << "SimpleCausalLM forward_impl, tokens: " << tokens.sizes() << ", positions: " << positions.sizes() << ", kv_caches: " << kv_caches.size() - << ", params: " << params.meta.num_sequences; + << ", params: " << input.meta.num_sequences; const int64_t num_tokens = tokens.size(0); const int64_t hidden_size = args_.hidden_size(); @@ -168,30 +168,29 @@ class SimpleCausalLM : public CausalLM { // Apply linear layer auto output = linear_->forward(combined); - // Add some computation using other params to make it more realistic - // if (params.attention.device.kv_seq_lens.defined()) { + // Add some computation using other input fields to make it more realistic + // if (input.attention.device.kv_seq_lens.defined()) { // // Use kv_seq_lens in computation - // auto kv_lens_sum = torch::sum(params.attention.device.kv_seq_lens); + // auto kv_lens_sum = torch::sum(input.attention.device.kv_seq_lens); // output = output + kv_lens_sum * kv_scale_; // } - // if (params.attention.device.q_seq_lens.defined()) { + // if (input.attention.device.q_seq_lens.defined()) { // // Use q_seq_lens in computation - // auto q_lens_sum = torch::sum(params.attention.device.q_seq_lens); + // auto q_lens_sum = torch::sum(input.attention.device.q_seq_lens); // output = output + q_lens_sum * q_scale_; // } - if (params.attention.device.new_cache_slots.defined()) { + if (input.attention.device.new_cache_slots.defined()) { // Use new_cache_slots in computation - auto cache_slots_sum = - torch::sum(params.attention.device.new_cache_slots); + auto cache_slots_sum = torch::sum(input.attention.device.new_cache_slots); output = output + cache_slots_sum * cache_scale_; } - if (params.attention.device.block_tables.defined() && !kv_caches.empty()) { + if (input.attention.device.block_tables.defined() && !kv_caches.empty()) { // Use block_tables to do embedding lookup from kv_cache - Rec multi-round // computation Calculate max_seq_len from actual seq_len tensor - auto max_seq_len = torch::max(params.attention.device.kv_seq_lens); + auto max_seq_len = torch::max(input.attention.device.kv_seq_lens); // Calculate max_block_nums_per_seq auto max_block_nums_per_seq = torch::ceil(max_seq_len / block_size_); @@ -201,14 +200,14 @@ class SimpleCausalLM : public CausalLM { first_full_attention_cache(kv_caches).get_k_cache(); // Create col_indices and mask - int64_t block_table_len = params.attention.device.block_tables.size(1); + int64_t block_table_len = input.attention.device.block_tables.size(1); auto col_indices = torch::arange( block_table_len, torch::dtype(torch::kInt64).device(device_)); auto mask = col_indices < (max_block_nums_per_seq - scalar_one_); // Directly compute embedding auto kv_embeddings = torch::embedding( - kv_cache_tensor, params.attention.device.block_tables); + kv_cache_tensor, input.attention.device.block_tables); // Apply mask and sum auto kv_embeddings_masked = kv_embeddings * mask.view({1, -1, 1}); @@ -220,11 +219,10 @@ class SimpleCausalLM : public CausalLM { } // Adapter method to match CausalLM base class interface - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - auto hidden_states = forward_impl(tokens, positions, kv_caches, parameters); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + auto hidden_states = + forward_impl(input.token_ids, input.positions, kv_caches, input); return ModelOutput(hidden_states); } @@ -431,33 +429,23 @@ TEST_F(AclGraphExecutorTest, GraphExecutorVsEagerExecution) { << std::endl; std::cout << "forward_input.positions: " << forward_input.positions << std::endl; - std::cout << "forward_input.input_params.attention.device.q_seq_lens: " - << forward_input.input_params.attention.device.q_seq_lens - << std::endl; - std::cout << "forward_input.input_params.attention.device.kv_seq_lens: " - << forward_input.input_params.attention.device.kv_seq_lens - << std::endl; - std::cout << "forward_input.input_params.attention.device.new_cache_slots: " - << forward_input.input_params.attention.device.new_cache_slots - << std::endl; - std::cout << "forward_input.input_params.attention.device.block_tables: " - << forward_input.input_params.attention.device.block_tables - << std::endl; + std::cout << "forward_input.attention.device.q_seq_lens: " + << forward_input.attention.device.q_seq_lens << std::endl; + std::cout << "forward_input.attention.device.kv_seq_lens: " + << forward_input.attention.device.kv_seq_lens << std::endl; + std::cout << "forward_input.attention.device.new_cache_slots: " + << forward_input.attention.device.new_cache_slots << std::endl; + std::cout << "forward_input.attention.device.block_tables: " + << forward_input.attention.device.block_tables << std::endl; // Test eager execution (direct model forward) - auto eager_model_output = model_->forward({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto eager_model_output = model_->forward(forward_input, kv_caches_); auto eager_output = eager_model_output.hidden_states; // Create ACL graph executor auto graph_executor = std::make_unique<::xllm::npu::AclGraphExecutorImpl>( model_.get(), model_args_, *device_, options_); // Test graph execution with NPUGraph mempool optimization - auto graph_model_output = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto graph_model_output = graph_executor->run(forward_input, kv_caches_); auto graph_output = graph_model_output.hidden_states; // Compare outputs - should be identical EXPECT_TRUE( @@ -483,16 +471,10 @@ TEST_F(AclGraphExecutorTest, GraphReplayConsistency) { model_.get(), model_args_, *device_, options_); // First execution (should create graph with NPUGraph mempool) - auto output1 = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto output1 = graph_executor->run(forward_input, kv_caches_); // Second execution (should replay graph using mempool-managed tensors) - auto output2 = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto output2 = graph_executor->run(forward_input, kv_caches_); // Compare outputs - should be identical EXPECT_TRUE(torch::allclose(output1.hidden_states, @@ -551,10 +533,7 @@ TEST_F(AclGraphExecutorTest, DifferentBatchSizes) { model_.get(), model_args_, *device_, options_); // Test graph execution - auto output = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto output = graph_executor->run(forward_input, kv_caches_); // Verify output shape EXPECT_EQ(output.hidden_states.size(0), @@ -579,20 +558,14 @@ TEST_F(AclGraphExecutorTest, AclGraphExecutorVsBaseExecutorImpl) { auto npu_executor = std::make_unique( model_.get(), model_args_, *device_, options_); - auto npu_model_output = npu_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto npu_model_output = npu_executor->run(forward_input, kv_caches_); auto npu_output = npu_model_output.hidden_states; // Test ACL Graph Executor with NPUGraph mempool optimization auto graph_executor = std::make_unique<::xllm::npu::AclGraphExecutorImpl>( model_.get(), model_args_, *device_, options_); - auto graph_model_output = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto graph_model_output = graph_executor->run(forward_input, kv_caches_); auto graph_output = graph_model_output.hidden_states; // Compare outputs - should be identical @@ -628,24 +601,15 @@ TEST_F(AclGraphExecutorTest, AclGraphExecutorVsBaseExecutorImplMultipleRuns) { const int num_runs = 3; for (int i = 0; i < num_runs; ++i) { // Direct model forward call (baseline) - auto direct_model_output = model_->forward({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto direct_model_output = model_->forward(forward_input, kv_caches_); auto direct_output = direct_model_output.hidden_states; // NPU Executor run - auto npu_model_output = npu_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto npu_model_output = npu_executor->run(forward_input, kv_caches_); auto npu_output = npu_model_output.hidden_states; // ACL Graph Executor run with NPUGraph mempool - auto graph_model_output = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto graph_model_output = graph_executor->run(forward_input, kv_caches_); auto graph_output = graph_model_output.hidden_states; // Compare direct model output with NPU Executor output @@ -688,13 +652,12 @@ TEST_F(AclGraphExecutorTest, BatchInputCarriesLinearStateIds) { auto forward_input = batch->prepare_forward_input( options_.num_decoding_tokens(), 0, model_args_); - ASSERT_EQ(forward_input.input_params.meta.num_sequences, 1); - ASSERT_EQ(forward_input.input_params.embedding.linear_state_ids.size(), 1); - EXPECT_EQ(forward_input.input_params.embedding.linear_state_ids[0], - expected_linear_state_id); - ASSERT_EQ(forward_input.input_params.embedding.embedding_ids.size(), 1); - EXPECT_EQ(forward_input.input_params.embedding.embedding_ids[0], + ASSERT_EQ(forward_input.meta.num_sequences, 1); + ASSERT_EQ(forward_input.embedding.linear_state_ids.size(), 1); + EXPECT_EQ(forward_input.embedding.linear_state_ids[0], expected_linear_state_id); + ASSERT_EQ(forward_input.embedding.embedding_ids.size(), 1); + EXPECT_EQ(forward_input.embedding.embedding_ids[0], expected_linear_state_id); } TEST(AclGraphExecutorHybridTest, KvCacheSupportsLinearOnlyLayers) { @@ -744,16 +707,11 @@ TEST_F(AclGraphExecutorTest, GraphExecutorUsesFirstFullAttentionKvCache) { LinearAttentionKVCacheTensors{conv_cache, ssm_cache}); hybrid_kv_caches.emplace_back(KVCacheTensors{full_k, full_v}); - auto eager_model_output = model_->forward({forward_input.token_ids}, - {forward_input.positions}, - hybrid_kv_caches, - {forward_input.input_params}); + auto eager_model_output = model_->forward(forward_input, hybrid_kv_caches); auto graph_executor = std::make_unique<::xllm::npu::AclGraphExecutorImpl>( model_.get(), model_args_, *device_, options_); - auto graph_model_output = graph_executor->run({forward_input.token_ids}, - {forward_input.positions}, - hybrid_kv_caches, - {forward_input.input_params}); + auto graph_model_output = + graph_executor->run(forward_input, hybrid_kv_caches); EXPECT_TRUE(torch::allclose(eager_model_output.hidden_states, graph_model_output.hidden_states, diff --git a/tests/core/runtime/cuda_graph_executor_test.cpp b/tests/core/runtime/cuda_graph_executor_test.cpp index cbf4fbe55f..f5984709ba 100644 --- a/tests/core/runtime/cuda_graph_executor_test.cpp +++ b/tests/core/runtime/cuda_graph_executor_test.cpp @@ -28,7 +28,6 @@ limitations under the License. #include "core/framework/batch/batch_forward_type.h" #include "core/framework/model/causal_lm.h" #include "core/framework/model/model_args.h" -#include "core/framework/model/model_input_params.h" #include "core/layers/cuda/attention.h" #include "core/layers/cuda/flashinfer_workspace.h" #include "core/platform/device.h" @@ -62,6 +61,15 @@ std::vector MakeSingleKvCaches(torch::Tensor k_cache, return kv_caches; } +ForwardInput MakeForwardInput(const torch::Tensor& tokens, + const torch::Tensor& positions, + const ForwardInput& params) { + ForwardInput input = params; + input.token_ids = tokens; + input.positions = positions; + return input; +} + class CudaGraphExecutorTestEnvironment : public ::testing::Environment { public: void SetUp() override { @@ -153,20 +161,25 @@ class FakeAttnCausalLM final : public CausalLM { /*sliding_window=*/-1); } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; (void)positions; CHECK(!kv_caches.empty()); // Use the executor-provided metadata when available (CUDA graph mode). layer::AttentionMetadata attn_meta; - if (params.attn_metadata) { - attn_meta = *params.attn_metadata; + if (input.attn_metadata) { + attn_meta = *input.attn_metadata; } else { attn_meta = - layer::AttentionMetadataBuilder::build(params, /*enable_mla=*/false); + layer::AttentionMetadataBuilder::build(input.meta, + input.attention, + input.graph, + input.llmrec_params(), + input.enable_cuda_graph, + /*enable_mla=*/false); } CHECK(attn_meta.plan_info) << "attn_meta.plan_info must be set"; attn_meta.plan_info->layer_id = 0; @@ -214,8 +227,8 @@ class FakeAttnCausalLM final : public CausalLM { std::unique_ptr attn_; }; -ModelInputParams MakeDecodeParams(const torch::Device& device) { - ModelInputParams p; +ForwardInput MakeDecodeParams(const torch::Device& device) { + ForwardInput p; p.meta.batch_forward_type = BatchForwardType::DECODE; p.meta.num_sequences = 1; p.meta.kv_max_seq_len = 4; @@ -241,11 +254,11 @@ ModelInputParams MakeDecodeParams(const torch::Device& device) { return p; } -ModelInputParams MakePrefillParams(const torch::Device& device, - int32_t num_tokens) { +ForwardInput MakePrefillParams(const torch::Device& device, + int32_t num_tokens) { CHECK_GT(num_tokens, 0); - ModelInputParams p; + ForwardInput p; p.meta.batch_forward_type = BatchForwardType::PREFILL; p.meta.num_sequences = 1; p.meta.kv_max_seq_len = num_tokens; @@ -326,9 +339,8 @@ runtime::Options make_test_runtime_options(int64_t max_seqs_per_batch) { return options; } -ModelInputParams make_multi_sequence_decode_params( - const torch::Device& device) { - ModelInputParams p; +ForwardInput make_multi_sequence_decode_params(const torch::Device& device) { + ForwardInput p; p.meta.batch_forward_type = BatchForwardType::DECODE; p.meta.num_sequences = 2; p.meta.kv_max_seq_len = 9; @@ -366,21 +378,21 @@ TEST(CudaGraphExecutorTest, DecodeMetadataFastPathUpdatesPersistentBuffers) { torch::TensorOptions().dtype(torch::kInt32).device(device); torch::Tensor tokens = torch::tensor({10, 11}, iopt); torch::Tensor positions = torch::tensor({20, 21}, iopt); - ModelInputParams params = make_multi_sequence_decode_params(device); + ForwardInput input = make_multi_sequence_decode_params(device); + input.token_ids = tokens; + input.positions = positions; std::vector kv = MakeKvCaches(device, /*num_pages=*/16, /*page_size=*/1, /*num_kv_heads=*/1, /*head_dim=*/128); - std::optional updated = - persistent.update(tokens, + std::optional updated = + persistent.update(input, kv[0].get_k_cache(), kv[0].get_v_cache(), - positions, - params, /*padded_num_tokens=*/4, - /*return_capture_params=*/true); + /*return_capture_input=*/true); ASSERT_TRUE(updated.has_value()); ASSERT_TRUE(updated->attn_metadata); @@ -429,7 +441,7 @@ TEST(CudaGraphExecutorTest, DecodeMetadataFastPathUpdatesPersistentBuffers) { torch::equal(updated->attn_metadata->qo_indptr.value().cpu(), torch::tensor({0, 1, 2}, torch::dtype(torch::kInt32)))); EXPECT_TRUE(torch::equal(updated->attn_metadata->block_table.cpu(), - params.attention.device.block_tables.cpu())); + input.attention.device.block_tables.cpu())); } TEST(CudaGraphExecutorTest, DecodeMetadataFastPathUpdatesLinearStateIndices) { @@ -449,23 +461,23 @@ TEST(CudaGraphExecutorTest, DecodeMetadataFastPathUpdatesLinearStateIndices) { torch::TensorOptions().dtype(torch::kInt32).device(device); torch::Tensor tokens = torch::tensor({10, 11}, iopt); torch::Tensor positions = torch::tensor({20, 21}, iopt); - ModelInputParams params = make_multi_sequence_decode_params(device); - params.embedding.linear_state_ids = {8, 6}; - params.embedding.linear_state_indices = torch::tensor({8, 6}, iopt); + ForwardInput input = make_multi_sequence_decode_params(device); + input.token_ids = tokens; + input.positions = positions; + input.embedding.linear_state_ids = {8, 6}; + input.embedding.linear_state_indices = torch::tensor({8, 6}, iopt); std::vector kv = MakeKvCaches(device, /*num_pages=*/16, /*page_size=*/1, /*num_kv_heads=*/1, /*head_dim=*/128); - std::optional updated = - persistent.update(tokens, + std::optional updated = + persistent.update(input, kv[0].get_k_cache(), kv[0].get_v_cache(), - positions, - params, /*padded_num_tokens=*/4, - /*return_capture_params=*/true); + /*return_capture_input=*/true); ASSERT_TRUE(updated.has_value()); EXPECT_TRUE(torch::equal( @@ -495,14 +507,17 @@ TEST(CudaGraphExecutorTest, DecodeMetadataFastPathFallbackMatchesLegacyPath) { torch::TensorOptions().dtype(torch::kInt32).device(device); torch::Tensor tokens = torch::tensor({10, 11}, iopt); torch::Tensor positions = torch::tensor({20, 21}, iopt); - ModelInputParams fast_params = make_multi_sequence_decode_params(device); - ModelInputParams fallback_params = make_multi_sequence_decode_params(device); + ForwardInput fast_input = make_multi_sequence_decode_params(device); + ForwardInput fallback_input = make_multi_sequence_decode_params(device); + fast_input.token_ids = tokens; + fast_input.positions = positions; + fallback_input.token_ids = tokens; + fallback_input.positions = positions; torch::Tensor new_cache_slots_base = torch::tensor({5, 99, 7, 88}, iopt).view({2, 2}); - fallback_params.attention.device.new_cache_slots = + fallback_input.attention.device.new_cache_slots = new_cache_slots_base.select(1, 0); - ASSERT_FALSE( - fallback_params.attention.device.new_cache_slots.is_contiguous()); + ASSERT_FALSE(fallback_input.attention.device.new_cache_slots.is_contiguous()); std::vector kv = MakeKvCaches(device, /*num_pages=*/16, @@ -510,22 +525,18 @@ TEST(CudaGraphExecutorTest, DecodeMetadataFastPathFallbackMatchesLegacyPath) { /*num_kv_heads=*/1, /*head_dim=*/128); - std::optional fast_updated = - fast_path_persistent.update(tokens, + std::optional fast_updated = + fast_path_persistent.update(fast_input, kv[0].get_k_cache(), kv[0].get_v_cache(), - positions, - fast_params, /*padded_num_tokens=*/4, - /*return_capture_params=*/true); - std::optional fallback_updated = - fallback_persistent.update(tokens, + /*return_capture_input=*/true); + std::optional fallback_updated = + fallback_persistent.update(fallback_input, kv[0].get_k_cache(), kv[0].get_v_cache(), - positions, - fallback_params, /*padded_num_tokens=*/4, - /*return_capture_params=*/true); + /*return_capture_input=*/true); ASSERT_TRUE(fast_updated.has_value()); ASSERT_TRUE(fallback_updated.has_value()); @@ -609,17 +620,20 @@ TEST(CudaGraphExecutorTest, BatchDecodeCaptureAndReplay) { /*num_kv_heads=*/1, /*head_dim=*/128); auto eager_out = - model->forward(tokens, positions, kv, params).hidden_states.clone(); + model->forward(MakeForwardInput(tokens, positions, params), kv) + .hidden_states.clone(); torch::cuda::synchronize(); // LOG(INFO) << "eager_out: " << eager_out; // Graph capture (first run) + replay (second run). - auto out1 = graph_exec->run(tokens, positions, kv, params).hidden_states; + auto out1 = graph_exec->run(MakeForwardInput(tokens, positions, params), kv) + .hidden_states; torch::cuda::synchronize(); // LOG(INFO) << "out1: " << out1; EXPECT_TRUE(torch::allclose(out1, eager_out, /*rtol=*/1e-3, /*atol=*/1e-3)) << "graph capture output should match eager output"; - auto out2 = graph_exec->run(tokens, positions, kv, params).hidden_states; + auto out2 = graph_exec->run(MakeForwardInput(tokens, positions, params), kv) + .hidden_states; torch::cuda::synchronize(); EXPECT_TRUE(torch::allclose(out2, eager_out, /*rtol=*/1e-3, /*atol=*/1e-3)) << "graph replay output should match eager output"; @@ -672,11 +686,14 @@ TEST(CudaGraphExecutorTest, /*head_dim=*/128); auto eager_out = - model->forward(tokens, positions, kv, params).hidden_states.clone(); + model->forward(MakeForwardInput(tokens, positions, params), kv) + .hidden_states.clone(); torch::cuda::synchronize(); - auto out1 = graph_exec->run(tokens, positions, kv, params).hidden_states; + auto out1 = graph_exec->run(MakeForwardInput(tokens, positions, params), kv) + .hidden_states; torch::cuda::synchronize(); - auto out2 = graph_exec->run(tokens, positions, kv, params).hidden_states; + auto out2 = graph_exec->run(MakeForwardInput(tokens, positions, params), kv) + .hidden_states; torch::cuda::synchronize(); EXPECT_TRUE(torch::allclose(out1, eager_out, /*rtol=*/1e-3, /*atol=*/1e-3)); @@ -746,16 +763,21 @@ TEST(CudaGraphExecutorTest, PrefillPiecewiseCaptureAndReplay) { kv[0].get_v_cache().clone()); auto eager_out = - model->forward(tokens, positions, kv_eager, params).hidden_states.clone(); + model->forward(MakeForwardInput(tokens, positions, params), kv_eager) + .hidden_states.clone(); torch::cuda::synchronize(); auto out1 = - graph_exec->run(tokens, positions, kv_graph_first, params).hidden_states; + graph_exec + ->run(MakeForwardInput(tokens, positions, params), kv_graph_first) + .hidden_states; out1 = out1.clone(); torch::cuda::synchronize(); auto out2 = - graph_exec->run(tokens, positions, kv_graph_second, params).hidden_states; + graph_exec + ->run(MakeForwardInput(tokens, positions, params), kv_graph_second) + .hidden_states; out2 = out2.clone(); torch::cuda::synchronize(); @@ -846,12 +868,14 @@ TEST(CudaGraphExecutorTest, CompareMqa2v1AndMqa8v1) { auto kv_graph = MakeSingleKvCaches(kv[0].get_k_cache().clone(), kv[0].get_v_cache().clone()); - auto eager_out = model->forward(tokens, positions, kv_eager, params) - .hidden_states.clone(); + auto eager_out = + model->forward(MakeForwardInput(tokens, positions, params), kv_eager) + .hidden_states.clone(); torch::cuda::synchronize(); auto graph_out = - graph_exec->run(tokens, positions, kv_graph, params).hidden_states; + graph_exec->run(MakeForwardInput(tokens, positions, params), kv_graph) + .hidden_states; graph_out = graph_out.clone(); torch::cuda::synchronize(); @@ -966,7 +990,8 @@ TEST(CudaGraphExecutorTest, GraphVmmPoolMemoryReuseAcrossMultiShape) { auto positions = all_positions.slice(/*dim=*/0, /*start=*/0, /*end=*/num_tokens); auto params = MakePrefillParams(device, num_tokens); - auto out = exec->run(tokens, positions, kv, params).hidden_states; + auto out = exec->run(MakeForwardInput(tokens, positions, params), kv) + .hidden_states; (void)out; torch::cuda::synchronize(); memory_usage_bytes.push_back(exec->get_graph_memory_usage_bytes()); @@ -1071,10 +1096,12 @@ TEST(CudaGraphExecutorTest, GraphVmmPoolEnabledPrefillCorrectness) { auto kv_graph = MakeSingleKvCaches(kv_base[0].get_k_cache().clone(), kv_base[0].get_v_cache().clone()); - auto eager_out = model->forward(tokens, positions, kv_eager, params) - .hidden_states.clone(); + auto eager_out = + model->forward(MakeForwardInput(tokens, positions, params), kv_eager) + .hidden_states.clone(); auto graph_out = - exec->run(tokens, positions, kv_graph, params).hidden_states.clone(); + exec->run(MakeForwardInput(tokens, positions, params), kv_graph) + .hidden_states.clone(); torch::cuda::synchronize(); EXPECT_TRUE(torch::isfinite(graph_out).all().item()) diff --git a/tests/core/runtime/mlu_graph_executor_test.cpp b/tests/core/runtime/mlu_graph_executor_test.cpp index f21cf413df..7856e41d36 100644 --- a/tests/core/runtime/mlu_graph_executor_test.cpp +++ b/tests/core/runtime/mlu_graph_executor_test.cpp @@ -38,17 +38,14 @@ class MockCausalLM : public CausalLM { weight_ = register_parameter("weight", weight, false); } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override { - (void)tokens; - (void)positions; + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { (void)kv_caches; + const torch::Tensor& tokens = input.token_ids; ++forward_cnt_; last_tokens_size_ = tokens.size(0); - last_dp_token_nums_ = params.parallel.dp_global_token_nums; - auto hidden_states = params.embedding.input_embedding.matmul(weight_); + last_dp_token_nums_ = input.parallel.dp_global_token_nums; + auto hidden_states = input.embedding.input_embedding.matmul(weight_); if (return_aux_hidden_states_) { auto aux_hidden_states = hidden_states + 1; return ModelOutput(hidden_states, torch::Tensor(), aux_hidden_states); @@ -131,7 +128,9 @@ class MluGraphExecutorTest : public ::testing::Test { auto input_embedding = torch::randn({batch_size, model_args_.hidden_size()}, tensor_options_) * 0.1; - ModelInputParams input_params; + ForwardInput input_params; + input_params.token_ids = token_ids; + input_params.positions = positions; input_params.meta.batch_forward_type = BatchForwardType::DECODE; input_params.meta.num_sequences = batch_size; input_params.meta.kv_max_seq_len = 1; @@ -147,11 +146,7 @@ class MluGraphExecutorTest : public ::testing::Test { input_params.embedding.input_embedding = input_embedding; kv_caches_.resize(batch_size); - ForwardInput input; - input.token_ids = token_ids; - input.positions = positions; - input.input_params = input_params; - return input; + return input_params; } void rebuild_impl() { @@ -177,22 +172,13 @@ TEST_F(MluGraphExecutorTest, DifferentBatchSizes) { const std::vector batch_sizes = {1, 3, 13, 21, 65}; for (auto batch_size : batch_sizes) { auto forward_input = prepare_inputs(batch_size, 1); - auto eager_model_output = base_impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto eager_model_output = base_impl_->run(forward_input, kv_caches_); auto eager_output = eager_model_output.hidden_states; - auto graph_model_output = impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto graph_model_output = impl_->run(forward_input, kv_caches_); auto graph_output = graph_model_output.hidden_states; - auto replay_model_output = impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto replay_model_output = impl_->run(forward_input, kv_caches_); auto replay_output = replay_model_output.hidden_states; CHECK_EQ(eager_output.sizes(), graph_output.sizes()); @@ -209,16 +195,10 @@ TEST_F(MluGraphExecutorTest, MluGraphExecutorVsBaseExecutorImplMultipleRuns) { int32_t batch_size = 5; int32_t seed = 42; auto forward_input = prepare_inputs(batch_size, seed); - auto eager_model_output = base_impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto eager_model_output = base_impl_->run(forward_input, kv_caches_); auto eager_output = eager_model_output.hidden_states; - auto graph_model_output = impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto graph_model_output = impl_->run(forward_input, kv_caches_); auto graph_output = graph_model_output.hidden_states; CHECK_EQ(eager_output.sizes(), graph_output.sizes()); @@ -230,35 +210,27 @@ TEST_F(MluGraphExecutorTest, MluGraphExecutorVsBaseExecutorImplMultipleRuns) { const int num_runs = 5; auto base_forward_input = prepare_inputs(batch_size + 1, seed); auto replay_forward_input = prepare_inputs(batch_size + 1, seed); - EXPECT_TRUE(torch::allclose( - base_forward_input.input_params.embedding.input_embedding, - replay_forward_input.input_params.embedding.input_embedding, - 1e-5, - 1e-6)); + EXPECT_TRUE(torch::allclose(base_forward_input.embedding.input_embedding, + replay_forward_input.embedding.input_embedding, + 1e-5, + 1e-6)); for (int i = 0; i < num_runs; ++i) { - auto base_model_output = base_impl_->run({base_forward_input.token_ids}, - {base_forward_input.positions}, - kv_caches_, - {base_forward_input.input_params}); + auto base_model_output = base_impl_->run(base_forward_input, kv_caches_); auto base_output = base_model_output.hidden_states; - auto replay_model_output = impl_->run({replay_forward_input.token_ids}, - {replay_forward_input.positions}, - kv_caches_, - {replay_forward_input.input_params}); + auto replay_model_output = impl_->run(replay_forward_input, kv_caches_); auto replay_output = replay_model_output.hidden_states; - base_forward_input.input_params.embedding.input_embedding = base_output; - replay_forward_input.input_params.embedding.input_embedding = replay_output; + base_forward_input.embedding.input_embedding = base_output; + replay_forward_input.embedding.input_embedding = replay_output; CHECK_EQ(base_output.sizes(), replay_output.sizes()); } torch_mlu::synchronize(); - EXPECT_TRUE(torch::allclose( - base_forward_input.input_params.embedding.input_embedding, - replay_forward_input.input_params.embedding.input_embedding, - 1e-5, - 1e-6)); + EXPECT_TRUE(torch::allclose(base_forward_input.embedding.input_embedding, + replay_forward_input.embedding.input_embedding, + 1e-5, + 1e-6)); } TEST_F(MluGraphExecutorTest, DraftDecodeFallsBackToEager) { @@ -269,24 +241,11 @@ TEST_F(MluGraphExecutorTest, DraftDecodeFallsBackToEager) { const uint64_t seed = 7; auto forward_input = prepare_inputs(batch_size, seed); - auto eager_model_output = base_impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + auto eager_model_output = base_impl_->run(forward_input, kv_caches_); auto eager_output = eager_model_output.hidden_states; - auto first_impl_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_impl_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + auto first_impl_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_impl_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(eager_output, first_impl_output, 1e-5, 1e-6)); @@ -305,10 +264,7 @@ TEST_F(MluGraphExecutorTest, DraftEagerDoesNotExposeAuxWhenDisabled) { const uint64_t seed = 17; auto forward_input = prepare_inputs(batch_size, seed); - ModelOutput output = impl_->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}); + ModelOutput output = impl_->run(forward_input, kv_caches_); EXPECT_FALSE(output.aux_hidden_states.defined()); EXPECT_EQ(model_->forward_cnt(), 1); @@ -322,18 +278,8 @@ TEST_F(MluGraphExecutorTest, TargetDecodeCapturesThenReplays) { const uint64_t seed = 11; auto forward_input = prepare_inputs(batch_size, seed); - auto first_impl_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_impl_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + auto first_impl_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_impl_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE( @@ -348,28 +294,15 @@ TEST_F(MluGraphExecutorTest, PrefillThenDecodeCapturesAndReplays) { const int32_t batch_size = 5; const uint64_t prefill_seed = 23; auto prefill_input = prepare_inputs(batch_size, prefill_seed); - prefill_input.input_params.meta.batch_forward_type = - BatchForwardType::PREFILL; + prefill_input.meta.batch_forward_type = BatchForwardType::PREFILL; - ModelOutput prefill_output = impl_->run({prefill_input.token_ids}, - {prefill_input.positions}, - kv_caches_, - {prefill_input.input_params}); + ModelOutput prefill_output = impl_->run(prefill_input, kv_caches_); const uint64_t decode_seed = 29; auto decode_input = prepare_inputs(batch_size, decode_seed); - auto first_decode_output = impl_ - ->run({decode_input.token_ids}, - {decode_input.positions}, - kv_caches_, - {decode_input.input_params}) - .hidden_states; - auto second_decode_output = impl_ - ->run({decode_input.token_ids}, - {decode_input.positions}, - kv_caches_, - {decode_input.input_params}) - .hidden_states; + auto first_decode_output = impl_->run(decode_input, kv_caches_).hidden_states; + auto second_decode_output = + impl_->run(decode_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(prefill_output.hidden_states.defined()); @@ -385,21 +318,11 @@ TEST_F(MluGraphExecutorTest, EqualDpDecodePadsToTpGraphSize) { rebuild_impl(); auto forward_input = prepare_inputs(/*batch_size=*/2, /*seed=*/61); - forward_input.input_params.parallel.dp_global_token_nums = {2, 2}; - forward_input.input_params.parallel.dp_is_decode = {1, 1}; - - auto first_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + forward_input.parallel.dp_global_token_nums = {2, 2}; + forward_input.parallel.dp_is_decode = {1, 1}; + + auto first_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_output, second_output, 1e-5, 1e-6)); @@ -415,21 +338,11 @@ TEST_F(MluGraphExecutorTest, UnevenDpDecodePadsToTpGraphSize) { rebuild_impl(); auto forward_input = prepare_inputs(/*batch_size=*/2, /*seed=*/67); - forward_input.input_params.parallel.dp_global_token_nums = {1, 2}; - forward_input.input_params.parallel.dp_is_decode = {1, 1}; - - auto first_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + forward_input.parallel.dp_global_token_nums = {1, 2}; + forward_input.parallel.dp_is_decode = {1, 1}; + + auto first_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_output, second_output, 1e-5, 1e-6)); @@ -445,21 +358,11 @@ TEST_F(MluGraphExecutorTest, MtpSeqLensCapacityUsesSpecFactor) { rebuild_impl(); auto forward_input = prepare_inputs(/*batch_size=*/4, /*seed=*/71); - forward_input.input_params.parallel.dp_global_token_nums = {4, 4}; - forward_input.input_params.parallel.dp_is_decode = {1, 1}; - - auto first_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + forward_input.parallel.dp_global_token_nums = {4, 4}; + forward_input.parallel.dp_is_decode = {1, 1}; + + auto first_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_output, second_output, 1e-5, 1e-6)); @@ -472,22 +375,12 @@ TEST_F(MluGraphExecutorTest, DpDummyFallsBackToEager) { const int32_t batch_size = 5; const uint64_t seed = 31; auto forward_input = prepare_inputs(batch_size, seed); - forward_input.input_params.parallel.dp_global_token_nums = {batch_size, 0}; - forward_input.input_params.parallel.dp_is_decode = {1, 0}; + forward_input.parallel.dp_global_token_nums = {batch_size, 0}; + forward_input.parallel.dp_is_decode = {1, 0}; const int32_t start_cnt = model_->forward_cnt(); - auto first_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + auto first_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_output, second_output, 1e-5, 1e-6)); @@ -500,24 +393,13 @@ TEST_F(MluGraphExecutorTest, DpUnevenDecodeFallsBackToEager) { const int32_t batch_size = 5; auto forward_input = prepare_inputs(batch_size, 43); - forward_input.input_params.parallel.dp_global_token_nums = {batch_size, - batch_size - 1}; - forward_input.input_params.parallel.dp_is_decode = {1, 1}; - forward_input.input_params.meta.q_max_seq_len = 2; + forward_input.parallel.dp_global_token_nums = {batch_size, batch_size - 1}; + forward_input.parallel.dp_is_decode = {1, 1}; + forward_input.meta.q_max_seq_len = 2; const int32_t start_cnt = model_->forward_cnt(); - auto first_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; - auto second_output = impl_ - ->run({forward_input.token_ids}, - {forward_input.positions}, - kv_caches_, - {forward_input.input_params}) - .hidden_states; + auto first_output = impl_->run(forward_input, kv_caches_).hidden_states; + auto second_output = impl_->run(forward_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_output, second_output, 1e-5, 1e-6)); @@ -530,32 +412,18 @@ TEST_F(MluGraphExecutorTest, DpDummyDoesNotPoisonGraphCache) { const int32_t batch_size = 5; auto dummy_input = prepare_inputs(batch_size, 37); - dummy_input.input_params.parallel.dp_global_token_nums = {batch_size, 0}; - dummy_input.input_params.parallel.dp_is_decode = {1, 0}; + dummy_input.parallel.dp_global_token_nums = {batch_size, 0}; + dummy_input.parallel.dp_is_decode = {1, 0}; const int32_t start_cnt = model_->forward_cnt(); - impl_->run({dummy_input.token_ids}, - {dummy_input.positions}, - kv_caches_, - {dummy_input.input_params}); + impl_->run(dummy_input, kv_caches_); auto decode_input = prepare_inputs(batch_size, 41); - decode_input.input_params.parallel.dp_global_token_nums = {batch_size, - batch_size}; - decode_input.input_params.parallel.dp_is_decode = {1, 1}; - - auto first_decode = impl_ - ->run({decode_input.token_ids}, - {decode_input.positions}, - kv_caches_, - {decode_input.input_params}) - .hidden_states; - auto second_decode = impl_ - ->run({decode_input.token_ids}, - {decode_input.positions}, - kv_caches_, - {decode_input.input_params}) - .hidden_states; + decode_input.parallel.dp_global_token_nums = {batch_size, batch_size}; + decode_input.parallel.dp_is_decode = {1, 1}; + + auto first_decode = impl_->run(decode_input, kv_caches_).hidden_states; + auto second_decode = impl_->run(decode_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_decode, second_decode, 1e-5, 1e-6)); @@ -568,34 +436,19 @@ TEST_F(MluGraphExecutorTest, DpUnevenDecodeDoesNotPoisonGraphCache) { const int32_t batch_size = 5; auto uneven_input = prepare_inputs(batch_size, 47); - uneven_input.input_params.parallel.dp_global_token_nums = {batch_size, - batch_size - 1}; - uneven_input.input_params.parallel.dp_is_decode = {1, 1}; - uneven_input.input_params.meta.q_max_seq_len = 2; + uneven_input.parallel.dp_global_token_nums = {batch_size, batch_size - 1}; + uneven_input.parallel.dp_is_decode = {1, 1}; + uneven_input.meta.q_max_seq_len = 2; const int32_t start_cnt = model_->forward_cnt(); - impl_->run({uneven_input.token_ids}, - {uneven_input.positions}, - kv_caches_, - {uneven_input.input_params}); + impl_->run(uneven_input, kv_caches_); auto decode_input = prepare_inputs(batch_size, 53); - decode_input.input_params.parallel.dp_global_token_nums = {batch_size, - batch_size}; - decode_input.input_params.parallel.dp_is_decode = {1, 1}; - - auto first_decode = impl_ - ->run({decode_input.token_ids}, - {decode_input.positions}, - kv_caches_, - {decode_input.input_params}) - .hidden_states; - auto second_decode = impl_ - ->run({decode_input.token_ids}, - {decode_input.positions}, - kv_caches_, - {decode_input.input_params}) - .hidden_states; + decode_input.parallel.dp_global_token_nums = {batch_size, batch_size}; + decode_input.parallel.dp_is_decode = {1, 1}; + + auto first_decode = impl_->run(decode_input, kv_caches_).hidden_states; + auto second_decode = impl_->run(decode_input, kv_caches_).hidden_states; torch_mlu::synchronize(); EXPECT_TRUE(torch::allclose(first_decode, second_decode, 1e-5, 1e-6)); diff --git a/tests/core/runtime/spec_input_builder_test.cpp b/tests/core/runtime/spec_input_builder_test.cpp index 549df38966..c5a857f697 100644 --- a/tests/core/runtime/spec_input_builder_test.cpp +++ b/tests/core/runtime/spec_input_builder_test.cpp @@ -58,17 +58,16 @@ ForwardInput make_forward_input(const torch::Tensor& token_ids, const torch::Tensor& block_tables, const std::vector& kv_seq_lens) { ForwardInput input; - input.input_params.meta.num_sequences = - static_cast(positions.numel()); + input.meta.num_sequences = static_cast(positions.numel()); input.token_ids_host = token_ids; input.positions_host = positions; - input.input_params.attention.host.block_tables = block_tables; - input.input_params.attention.host.kv_seq_lens = kv_seq_lens; + input.attention.host.block_tables = block_tables; + input.attention.host.kv_seq_lens = kv_seq_lens; return input; } TEST(SpecDecodeInputBuilderTest, DraftInputsSingleRowPerSeq) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 2; std::vector kv_seq_lens = to_layout_seq_lens({5, 9}); @@ -95,7 +94,7 @@ TEST(SpecDecodeInputBuilderTest, DraftInputsSingleRowPerSeq) { } TEST(SpecDecodeInputBuilderTest, ValidateInputsNonAtbExpansion) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 2; const int32_t num_speculative_tokens = 2; const int32_t num_val_tokens = num_speculative_tokens + 1; @@ -186,7 +185,7 @@ TEST(SpecDecodeInputBuilderTest, AppendDecodeRowUsesInputBlockTableLayout) { } TEST(SpecDecodeInputBuilderTest, ValidateRowsStartFromCorrectedCurrentView) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 2; std::vector token_ids = {31, 41}; std::vector positions = {6, 9}; @@ -243,7 +242,7 @@ TEST(SpecDecodeInputBuilderTest, ValidateInputsAtbChunkedPrefillShape) { } TEST(SpecDecodeInputBuilderTest, FirstDecodeInputsFixAndNonFixMix) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 2; std::vector kv_seq_lens = to_layout_seq_lens({6, 9}); @@ -286,7 +285,7 @@ TEST(SpecDecodeInputBuilderTest, FirstDecodeInputsFixAndNonFixMix) { } TEST(SpecDecodeInputBuilderTest, AppendDecodeRowWithInputTokenSource) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 2; std::vector kv_seq_lens = to_layout_seq_lens({5, 9}); @@ -356,7 +355,7 @@ TEST(SpecDecodeInputBuilderTest, ResolveTokenWithPositionOffset) { } TEST(SpecDecodeInputBuilderTest, AppendDecodeRowFromLastStep) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 2; std::vector kv_seq_lens = to_layout_seq_lens({6, 9}); @@ -395,12 +394,13 @@ TEST(SpecDecodeInputBuilderTest, AppendDecodeRowFromLastStep) { } TEST(SpecDecodeInputBuilderTest, QCuSeqLensConsistency) { - ModelInputParams params; + ForwardInput params; params.meta.num_sequences = 3; params.attention.host.q_seq_lens = to_layout_seq_lens({1, 2, 3}); params.attention.host.q_cu_seq_lens = {1, 3, 6}; - torch::Tensor q_cu_seq_lens = build_q_cu_seq_lens_tensor(params); + torch::Tensor q_cu_seq_lens = + build_q_cu_seq_lens_tensor(params.attention.host); EXPECT_EQ(tensor_to_vec_int32(q_cu_seq_lens), std::vector({1, 3, 6})); } diff --git a/tests/core/scheduler/continuous_scheduler_test.cpp b/tests/core/scheduler/continuous_scheduler_test.cpp index b0c177d143..c3bebaa5ed 100644 --- a/tests/core/scheduler/continuous_scheduler_test.cpp +++ b/tests/core/scheduler/continuous_scheduler_test.cpp @@ -325,10 +325,9 @@ TEST(ContinuousSchedulerFactoryTest, const auto forward_input = batches[0].prepare_forward_input(1, 0, ModelArgs()); - EXPECT_TRUE( - forward_input.input_params.meta.batch_forward_type.is_chunked_prefill()); - EXPECT_FALSE(forward_input.input_params.meta.batch_forward_type.is_mixed()); - EXPECT_EQ(forward_input.input_params.meta.num_sequences, 1); + EXPECT_TRUE(forward_input.meta.batch_forward_type.is_chunked_prefill()); + EXPECT_FALSE(forward_input.meta.batch_forward_type.is_mixed()); + EXPECT_EQ(forward_input.meta.num_sequences, 1); EXPECT_EQ(batches[0].get_allowed_max_tokens()[0], opt.max_tokens_per_chunk_for_prefill()); } diff --git a/xllm/core/distributed_runtime/dit_engine.cpp b/xllm/core/distributed_runtime/dit_engine.cpp index aae14ca033..554496bf74 100644 --- a/xllm/core/distributed_runtime/dit_engine.cpp +++ b/xllm/core/distributed_runtime/dit_engine.cpp @@ -112,7 +112,7 @@ DiTForwardOutput DiTEngine::step(std::vector& batches) { Timer timer; auto dit_forward_input = batches[0].prepare_forward_input(); ForwardInput forward_input; - forward_input.input_params.dit_forward_input = dit_forward_input; + forward_input.dit_forward_input = dit_forward_input; COUNTER_ADD(prepare_input_latency_seconds, timer.elapsed_seconds()); std::vector>> futures; diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index ea83009d9e..346fb7d2b6 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -1001,8 +1001,7 @@ ForwardOutput LLMEngine::step(std::vector& batch) { if (cp_size_ > 1) { for (int32_t dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - if (!forward_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_prefill()) { + if (!forward_inputs[dp_rank].meta.batch_forward_type.is_prefill()) { continue; } auto& inputs_per_cp = cp_partitioned_inputs[dp_rank]; @@ -1019,8 +1018,7 @@ ForwardOutput LLMEngine::step(std::vector& batch) { const int32_t dp_rank = worker_rank / dp_local_size_; const ForwardInput* input_to_send = &forward_inputs[dp_rank]; if (cp_size_ > 1 && - forward_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_prefill()) { + forward_inputs[dp_rank].meta.batch_forward_type.is_prefill()) { const int32_t local_rank_in_dp_group = worker_rank % dp_local_size_; const int32_t cp_rank = local_rank_in_dp_group / dp_local_tp_size_; CHECK_GE(cp_rank, 0); @@ -1187,17 +1185,14 @@ std::vector LLMEngine::prepare_inputs(std::vector& batch) { dp_global_token_nums[dp_rank] = static_cast(batched_inputs[dp_rank].host_token_ids().numel()); if (batch_forward_type.is_empty() && - !batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_empty()) { - batch_forward_type = - batched_inputs[dp_rank].input_params.meta.batch_forward_type; + !batched_inputs[dp_rank].meta.batch_forward_type.is_empty()) { + batch_forward_type = batched_inputs[dp_rank].meta.batch_forward_type; if (batch_forward_type.is_chunked_prefill()) { batch_forward_type = BatchForwardType::PREFILL; } } - dp_is_decode[dp_rank] = - batch_forward_type.is_decode() && - batched_inputs[dp_rank].input_params.meta.q_max_seq_len == 1; + dp_is_decode[dp_rank] = batch_forward_type.is_decode() && + batched_inputs[dp_rank].meta.q_max_seq_len == 1; } // eplb related @@ -1208,16 +1203,14 @@ std::vector LLMEngine::prepare_inputs(std::vector& batch) { // update dp_global_token_nums and batch_forward_type for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - batched_inputs[dp_rank].input_params.parallel.dp_global_token_nums = + batched_inputs[dp_rank].parallel.dp_global_token_nums = dp_global_token_nums; - batched_inputs[dp_rank].input_params.parallel.dp_is_decode = dp_is_decode; + batched_inputs[dp_rank].parallel.dp_is_decode = dp_is_decode; if (FLAGS_enable_eplb) { - batched_inputs[dp_rank].input_params.expert.eplb_info = eplb_info; + batched_inputs[dp_rank].expert.eplb_info = eplb_info; } - if (batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_empty()) { - batched_inputs[dp_rank].input_params.meta.batch_forward_type = - batch_forward_type; + if (batched_inputs[dp_rank].meta.batch_forward_type.is_empty()) { + batched_inputs[dp_rank].meta.batch_forward_type = batch_forward_type; } } diff --git a/xllm/core/distributed_runtime/rec_engine.cpp b/xllm/core/distributed_runtime/rec_engine.cpp index 3bed606537..617c308f15 100644 --- a/xllm/core/distributed_runtime/rec_engine.cpp +++ b/xllm/core/distributed_runtime/rec_engine.cpp @@ -371,24 +371,19 @@ std::vector RecEngine::LlmRecEnginePipeline::prepare_inputs( dp_global_token_nums[dp_rank] = static_cast(batched_inputs[dp_rank].host_token_ids().numel()); if (batch_forward_type.is_empty() && - !batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_empty()) { - batch_forward_type = - batched_inputs[dp_rank].input_params.meta.batch_forward_type; + !batched_inputs[dp_rank].meta.batch_forward_type.is_empty()) { + batch_forward_type = batched_inputs[dp_rank].meta.batch_forward_type; } - dp_is_decode[dp_rank] = - batch_forward_type.is_decode() && - batched_inputs[dp_rank].input_params.meta.q_max_seq_len == 1; + dp_is_decode[dp_rank] = batch_forward_type.is_decode() && + batched_inputs[dp_rank].meta.q_max_seq_len == 1; } for (int32_t dp_rank = 0; dp_rank < engine_.dp_size_; ++dp_rank) { - batched_inputs[dp_rank].input_params.parallel.dp_global_token_nums = + batched_inputs[dp_rank].parallel.dp_global_token_nums = dp_global_token_nums; - batched_inputs[dp_rank].input_params.parallel.dp_is_decode = dp_is_decode; - if (batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_empty()) { - batched_inputs[dp_rank].input_params.meta.batch_forward_type = - batch_forward_type; + batched_inputs[dp_rank].parallel.dp_is_decode = dp_is_decode; + if (batched_inputs[dp_rank].meta.batch_forward_type.is_empty()) { + batched_inputs[dp_rank].meta.batch_forward_type = batch_forward_type; } } diff --git a/xllm/core/distributed_runtime/vlm_engine.cpp b/xllm/core/distributed_runtime/vlm_engine.cpp index b951b2d76c..4df69d7810 100644 --- a/xllm/core/distributed_runtime/vlm_engine.cpp +++ b/xllm/core/distributed_runtime/vlm_engine.cpp @@ -442,26 +442,21 @@ std::vector VLMEngine::prepare_inputs(std::vector& batch) { dp_global_token_nums[dp_rank] = static_cast(batched_inputs[dp_rank].host_token_ids().numel()); if (batch_forward_type.is_empty() && - !batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_empty()) { - batch_forward_type = - batched_inputs[dp_rank].input_params.meta.batch_forward_type; + !batched_inputs[dp_rank].meta.batch_forward_type.is_empty()) { + batch_forward_type = batched_inputs[dp_rank].meta.batch_forward_type; } dp_is_decode[dp_rank] = - batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_decode() && - batched_inputs[dp_rank].input_params.meta.q_max_seq_len == 1; + batched_inputs[dp_rank].meta.batch_forward_type.is_decode() && + batched_inputs[dp_rank].meta.q_max_seq_len == 1; } // update dp_global_token_nums and batch_forward_type for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - batched_inputs[dp_rank].input_params.parallel.dp_global_token_nums = + batched_inputs[dp_rank].parallel.dp_global_token_nums = dp_global_token_nums; - if (batched_inputs[dp_rank] - .input_params.meta.batch_forward_type.is_empty()) { - batched_inputs[dp_rank].input_params.parallel.dp_is_decode = dp_is_decode; - batched_inputs[dp_rank].input_params.meta.batch_forward_type = - batch_forward_type; + if (batched_inputs[dp_rank].meta.batch_forward_type.is_empty()) { + batched_inputs[dp_rank].parallel.dp_is_decode = dp_is_decode; + batched_inputs[dp_rank].meta.batch_forward_type = batch_forward_type; } } diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 8bc54062a9..e01b05ea84 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -687,7 +687,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { } forward_input.positions_host = forward_input.positions; - auto& input_params = forward_input.input_params; + ForwardInput& input_params = forward_input; input_params.meta.batch_forward_type = state_.batch_forward_type; input_params.meta.num_sequences = state_.block_tables_vec.size(); input_params.meta.kv_max_seq_len = state_.max_seq_len; @@ -768,7 +768,7 @@ void BatchInputBuilder::process_swap_block_infos(ForwardInput& forward_input) { return; } - auto& input_params = forward_input.input_params; + auto& input_params = forward_input; auto& swap_blocks = *swap_block_transfer_infos_; if (FLAGS_enable_block_copy_kernel) { std::sort(swap_blocks.begin(), diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.cpp b/xllm/core/framework/batch/onerec_batch_input_builder.cpp index 6cb26916da..6d6d4e4171 100644 --- a/xllm/core/framework/batch/onerec_batch_input_builder.cpp +++ b/xllm/core/framework/batch/onerec_batch_input_builder.cpp @@ -431,7 +431,7 @@ ForwardInput OneRecBatchInputBuilder::build_rec_forward_input( // ========== High-performance ForwardInput construction ========== ForwardInput forward_input; - auto& input_params = forward_input.input_params; + ForwardInput& input_params = forward_input; auto& onerec_params = input_params.mutable_onerec_params(); auto& cache_data = perf_cache.cache_data; diff --git a/xllm/core/framework/batch/onerec_xattention_batch_input_builder.cpp b/xllm/core/framework/batch/onerec_xattention_batch_input_builder.cpp index 1ad9f897c1..0add0e5ce3 100644 --- a/xllm/core/framework/batch/onerec_xattention_batch_input_builder.cpp +++ b/xllm/core/framework/batch/onerec_xattention_batch_input_builder.cpp @@ -39,14 +39,14 @@ ForwardInput OneRecXAttentionBatchInputBuilder::build_rec_forward_input( uint32_t min_decoding_batch_size) { auto input = OneRecBatchInputBuilder::build_rec_forward_input( num_decoding_tokens, min_decoding_batch_size); - if (const auto* onerec = input.input_params.onerec_params()) { + if (const auto* onerec = input.onerec_params()) { OneRecModelInputParams legacy_params = *onerec; - auto& xattn_params = input.input_params.mutable_onerec_xattention_params(); + auto& xattn_params = input.mutable_onerec_xattention_params(); static_cast(xattn_params) = std::move(legacy_params); } - input.input_params.meta.batch_forward_type = BatchForwardType::PREFILL; + input.meta.batch_forward_type = BatchForwardType::PREFILL; if (sequence_groups_.empty()) { return input; @@ -149,13 +149,12 @@ ForwardInput OneRecXAttentionBatchInputBuilder::build_rec_forward_input( if (!block_tables_vec.empty()) { util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); - input.input_params.attention.device.block_tables = + input.attention.device.block_tables = create_2d_tensor(block_tables_vec, torch::kInt); - input.input_params.attention.host.block_tables = - input.input_params.attention.device.block_tables; + input.attention.host.block_tables = input.attention.device.block_tables; } if (!new_cache_slots_vec.empty()) { - input.input_params.attention.device.new_cache_slots = + input.attention.device.new_cache_slots = torch::tensor(new_cache_slots_vec, torch::kInt); } diff --git a/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp b/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp index fcaa4e2ca1..d4a087ca60 100644 --- a/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp +++ b/xllm/core/framework/batch/rec_multi_round_batch_input_builder.cpp @@ -356,7 +356,7 @@ ForwardInput RecMultiRoundBatchInputBuilder::state_to_forward_input() { } forward_input.positions_host = forward_input.positions; - auto& input_params = forward_input.input_params; + ForwardInput& input_params = forward_input; input_params.meta.batch_forward_type = state.batch_forward_type; input_params.meta.num_sequences = state.block_tables_vec.size(); input_params.meta.kv_max_seq_len = state.max_seq_len; diff --git a/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp index 5a115caf69..56c36f4de2 100644 --- a/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp +++ b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.cpp @@ -23,6 +23,7 @@ limitations under the License. #include #include "framework/kv_cache_transfer/kv_cache_store.h" +#include "runtime/forward_params.h" namespace xllm { constexpr uint64_t MBUF_SIZE = 128 * 1024 * 1024; @@ -122,18 +123,17 @@ uint32_t HierarchyKVCacheTransfer::transfer_kv_blocks( return 0; } -void HierarchyKVCacheTransfer::set_layer_synchronizer( - ModelInputParams& params) { +void HierarchyKVCacheTransfer::set_layer_synchronizer(ForwardInput& input) { #if defined(USE_NPU) { std::lock_guard lock(mutex_); - if (layer_wise_load_synchronizer_.count(params.meta.batch_id) != 0) { - params.parallel.layer_wise_load_synchronizer = - layer_wise_load_synchronizer_[params.meta.batch_id]; - layer_wise_load_synchronizer_.erase(params.meta.batch_id); + if (layer_wise_load_synchronizer_.count(input.meta.batch_id) != 0) { + input.parallel.layer_wise_load_synchronizer = + layer_wise_load_synchronizer_[input.meta.batch_id]; + layer_wise_load_synchronizer_.erase(input.meta.batch_id); uint32_t event_cnt = - params.parallel.layer_wise_load_synchronizer->get_event_size(); - params.parallel.layers_per_bacth_copy = + input.parallel.layer_wise_load_synchronizer->get_event_size(); + input.parallel.layers_per_bacth_copy = (options_.layers() + event_cnt - 1) / event_cnt; } } diff --git a/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h index 3852ba341e..6dba6663aa 100644 --- a/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h +++ b/xllm/core/framework/kv_cache_transfer/hierarchy_kv_cache_transfer.h @@ -34,6 +34,8 @@ limitations under the License. #include "util/threadpool.h" namespace xllm { +struct ForwardInput; + class HierarchyKVCacheTransfer { public: struct Options { @@ -60,7 +62,7 @@ class HierarchyKVCacheTransfer { uint32_t transfer_kv_blocks(const uint64_t batch_id, Slice& block_transfer_info); - void set_layer_synchronizer(ModelInputParams& params); + void set_layer_synchronizer(ForwardInput& input); private: void create_page_aligned_host_cache(); diff --git a/xllm/core/framework/model/causal_lm.h b/xllm/core/framework/model/causal_lm.h index 3aa7a00cae..a4624aa0b3 100644 --- a/xllm/core/framework/model/causal_lm.h +++ b/xllm/core/framework/model/causal_lm.h @@ -45,17 +45,15 @@ limitations under the License. namespace xllm { +struct ForwardInput; + class CausalLM : public torch::nn::Module { public: ~CausalLM() override = default; - // tokens: [num_tokens] - // positions: [num_tokens] // returns: [num_tokens, hidden_size] - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) = 0; + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) = 0; // hidden_states: [num_tokens, hidden_size] // seleted_idxes: [num_tokens] @@ -141,11 +139,9 @@ class CausalLMImpl : public CausalLM { CausalLMImpl(Model model, const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - return model_->forward(tokens, positions, kv_caches, parameters); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + return model_->forward(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/core/framework/model/causal_vlm.h b/xllm/core/framework/model/causal_vlm.h index f418c78466..89bbeb3b4f 100644 --- a/xllm/core/framework/model/causal_vlm.h +++ b/xllm/core/framework/model/causal_vlm.h @@ -24,6 +24,7 @@ limitations under the License. #include "core/framework/kv_cache/kv_cache.h" #include "core/framework/quant_args.h" #include "core/framework/state_dict/state_dict.h" +#include "core/runtime/forward_params.h" #include "model_args.h" #include "model_input_params.h" @@ -32,10 +33,10 @@ namespace xllm { class CausalVLM : public CausalLM { public: ~CausalVLM() override = default; - virtual MMDict encode(const ModelInputParams& parameters) = 0; + virtual MMDict encode(const ForwardInput& input) = 0; virtual torch::Tensor get_input_embeddings( const torch::Tensor& input_ids, - const ModelInputParams& input_params) = 0; + const MultiModalInput& multimodal) = 0; }; template @@ -44,21 +45,19 @@ class CausalVLMImpl : public CausalVLM { CausalVLMImpl(Model model, const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - MMDict encode(const ModelInputParams& parameters) override { - return model_->get_multimodal_embeddings(parameters); + MMDict encode(const ForwardInput& input) override { + return model_->get_multimodal_embeddings(input); } torch::Tensor get_input_embeddings( const torch::Tensor& input_ids, - const ModelInputParams& input_params) override { - return model_->get_input_embeddings(input_ids, input_params); + const MultiModalInput& multimodal) override { + return model_->get_input_embeddings(input_ids, multimodal); } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - return model_->forward(tokens, positions, kv_caches, parameters); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + return model_->forward(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/core/framework/model/mm_embedding_vlm.h b/xllm/core/framework/model/mm_embedding_vlm.h index b93a272eb0..12f67c6f87 100755 --- a/xllm/core/framework/model/mm_embedding_vlm.h +++ b/xllm/core/framework/model/mm_embedding_vlm.h @@ -41,8 +41,8 @@ class MMEmbeddingVLMImpl : public MMEmbeddingVLM { MMEmbeddingVLMImpl(Model model, const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - virtual MMDict encode(const ModelInputParams& input_params) override { - return model_->encode(input_params); + virtual MMDict encode(const ForwardInput& input) override { + return model_->encode(input); }; virtual torch::Tensor logits(const torch::Tensor& hidden_states, @@ -52,14 +52,12 @@ class MMEmbeddingVLMImpl : public MMEmbeddingVLM { virtual torch::Tensor get_input_embeddings( const torch::Tensor& input_ids, - const ModelInputParams& input_params) override { + const MultiModalInput& multimodal) override { return torch::Tensor{}; } - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { return ModelOutput(); } virtual void prepare_expert_weight(int32_t layer_id, diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index d9dc89bf3d..814b34302e 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -42,9 +41,6 @@ limitations under the License. #include "util/tensor_helper.h" namespace xllm { -namespace layer { -struct AttentionMetadata; -} // namespace layer struct OneRecModelInputParams { enum class RecStage { @@ -842,185 +838,4 @@ struct GraphInput { } }; -struct ModelInputParams { - ModelInputParams to(const torch::Device& device) const { - ModelInputParams params; - params.meta = meta; - params.attention = attention.to(device); - params.embedding = embedding.to(device); - params.block_copy = block_copy.to(device); - params.multimodal = multimodal.to(device); - params.parallel = parallel.to(device); - params.expert = expert.to(device); - params.graph = graph.to(device); - params.dit_forward_input = dit_forward_input.to(device); - - // rec_params device conversion for both OneRec and LLM-Rec variants - if (const auto* onerec_xattn = onerec_xattention_params()) { - params.rec_params = onerec_xattn->to(device); - } else if (const auto* onerec = onerec_params()) { - params.rec_params = onerec->to(device); - } else if (const auto* llmrec = llmrec_params()) { - params.rec_params = llmrec->to(device); - } - - return params; - } - - void print() const { - LOG(INFO) << "ModelInputParams: batch_forward_type is " - << meta.batch_forward_type.to_string() << " , num_sequences is " - << meta.num_sequences << " , kv_max_seq_len is " - << meta.kv_max_seq_len << " , q_max_seq_len is " - << meta.q_max_seq_len; - LOG(INFO) << "ModelInputParams: attention.host.kv_seq_lens is " - << attention.host.kv_seq_lens; - LOG(INFO) << "ModelInputParams: attention.host.q_seq_lens is " - << attention.host.q_seq_lens; - LOG(INFO) << "ModelInputParams: batch_forward_type is " - << meta.batch_forward_type.to_string(); - print_tensor( - attention.device.kv_seq_lens, "ModelInputParams: kv_seq_lens", 4); - print_tensor( - attention.device.q_seq_lens, "ModelInputParams: q_seq_lens", 4); - print_tensor( - attention.device.q_cu_seq_lens, "ModelInputParams: q_cu_seq_lens", 4); - print_tensor(attention.device.new_cache_slots, - "ModelInputParams: new_cache_slots", - 4); - print_tensor( - attention.device.block_tables, "ModelInputParams: block_tables", 4); - LOG(INFO) << "ModelInputParams: dp_global_token_nums is " - << parallel.dp_global_token_nums - << ", dp_is_decode: " << parallel.dp_is_decode; - - if (const auto* onerec_xattn = onerec_xattention_params()) { - LOG(INFO) << "ModelInputParams: has onerec_xattention_params"; - onerec_xattn->print(); - } else if (const auto* onerec = onerec_params()) { - LOG(INFO) << "ModelInputParams: has onerec_params"; - onerec->print(); - } else if (const auto* llmrec = llmrec_params()) { - LOG(INFO) << "ModelInputParams: has llm_rec_multi_round_params" - << ", beam_width=" << llmrec->beam_width - << ", total_round=" << llmrec->total_round; - } - } - - int32_t get_q_seq_len(int32_t seq_idx) const { -#if defined(USE_NPU) - CHECK(seq_idx < attention.host.q_seq_lens.size()) << "seq_idx out of range"; - return attention.host.q_seq_lens[seq_idx]; -#else - CHECK(seq_idx < attention.host.q_seq_lens.size() - 1) - << "seq_idx out of range"; - return attention.host.q_seq_lens[seq_idx + 1] - - attention.host.q_seq_lens[seq_idx]; -#endif - } - - bool synchronize_layer(uint32_t layer_idx) const { -#if defined(USE_NPU) - if (parallel.layer_wise_load_synchronizer != nullptr && - layer_idx % parallel.layers_per_bacth_copy == 0) { - if (!parallel.layer_wise_load_synchronizer->synchronize_layer( - layer_idx / parallel.layers_per_bacth_copy)) { - return false; - } - } -#else - (void)layer_idx; -#endif - return true; - } - - bool record_layer(uint32_t layer_idx, const torch::Device& device) const { -#if defined(USE_MLU) - if (parallel.layer_synchronizer != nullptr) { - return parallel.layer_synchronizer->record_current(layer_idx, - device.index()); - } -#else - (void)layer_idx; - (void)device; -#endif - return true; - } - - BatchInputMeta meta; - AttentionInput attention; - ModelEmbeddingInput embedding; - ParallelInput parallel; - BlockCopyInput block_copy; - MultiModalInput multimodal; - ExpertInput expert; - GraphInput graph; - - RecModelInputParams rec_params; - - // dit input data - DiTForwardInput dit_forward_input; - - const OneRecModelInputParams* onerec_params() const { - if (const auto* params = std::get_if(&rec_params)) { - return params; - } - if (const auto* params = std::get_if(&rec_params)) { - return static_cast(params); - } - return nullptr; - } - - bool has_onerec_params() const { return onerec_params() != nullptr; } - - OneRecModelInputParams& mutable_onerec_params() { - if (auto* params = std::get_if(&rec_params)) { - return *params; - } - if (auto* params = std::get_if(&rec_params)) { - return static_cast(*params); - } - if (!has_onerec_params()) { - rec_params.emplace(); - } - return std::get(rec_params); - } - - const OneRecXAttentionParams* onerec_xattention_params() const { - return std::get_if(&rec_params); - } - - bool has_onerec_xattention_params() const { - return onerec_xattention_params() != nullptr; - } - - OneRecXAttentionParams& mutable_onerec_xattention_params() { - if (!has_onerec_xattention_params()) { - rec_params.emplace(); - } - return std::get(rec_params); - } - - // Accessors for LLM Rec multi-round params inside rec_params variant - const LlmRecMultiRoundParams* llmrec_params() const { - return std::get_if(&rec_params); - } - - bool has_llmrec_params() const { return llmrec_params() != nullptr; } - - LlmRecMultiRoundParams& mutable_llmrec_params() { - if (!has_llmrec_params()) { - rec_params.emplace(); - } - return std::get(rec_params); - } - - // Optional attention metadata, built by executor - // Using shared_ptr with forward declaration to avoid circular dependency - std::shared_ptr attn_metadata; - - // Flag for CUDA graph capture mode - bool enable_cuda_graph = false; -}; - } // namespace xllm diff --git a/xllm/core/framework/model/rec_causal_lm.h b/xllm/core/framework/model/rec_causal_lm.h index 0a56677f8d..0df153c5c6 100644 --- a/xllm/core/framework/model/rec_causal_lm.h +++ b/xllm/core/framework/model/rec_causal_lm.h @@ -30,11 +30,9 @@ class RecCausalLMImpl : public RecCausalLM { RecCausalLMImpl(Model model, const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& parameters) override { - return model_->forward(tokens, positions, kv_caches, parameters); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + return model_->forward(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/core/layers/common/attention_metadata.h b/xllm/core/layers/common/attention_metadata.h index ecaadcbfdc..857ba2d432 100644 --- a/xllm/core/layers/common/attention_metadata.h +++ b/xllm/core/layers/common/attention_metadata.h @@ -65,7 +65,7 @@ struct XAttentionTwoStageDecodeCache { // reused by all layers. This avoids redundant computation and memory allocation // for metadata that is identical across layers (e.g., sequence lengths, paged // KV cache indices, plan_info). Use -// AttentionMetadataBuilder to build instances from ModelInputParams. +// AttentionMetadataBuilder builds instances from narrow attention inputs. struct AttentionMetadata { torch::Tensor q_cu_seq_lens; torch::Tensor kv_cu_seq_lens; diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index b78b2b66b5..e8807d53af 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -29,37 +29,40 @@ namespace xllm::layer { namespace { AttentionMetadata build_attention_metadata( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, bool enable_mla, const std::string& compute_dtype, const std::optional& attn_mask) { // MLA mode still affects which shared tensors must be materialized for // attention execution, but the flag itself is no longer carried in metadata. AttentionMetadata attn_metadata; - attn_metadata.q_cu_seq_lens = params.attention.device.q_seq_lens; - attn_metadata.kv_cu_seq_lens = params.attention.device.kv_seq_lens; - attn_metadata.max_query_len = params.meta.q_max_seq_len; - attn_metadata.max_seq_len = params.meta.kv_max_seq_len; - if (!params.attention.host.kv_seq_lens.empty()) { + attn_metadata.q_cu_seq_lens = attention.device.q_seq_lens; + attn_metadata.kv_cu_seq_lens = attention.device.kv_seq_lens; + attn_metadata.max_query_len = meta.q_max_seq_len; + attn_metadata.max_seq_len = meta.kv_max_seq_len; + if (!attention.host.kv_seq_lens.empty()) { const bool is_cu_seq_lens = - params.attention.host.kv_seq_lens.size() == - static_cast(params.meta.num_sequences + 1) && - params.attention.host.kv_seq_lens.front() == 0; + attention.host.kv_seq_lens.size() == + static_cast(meta.num_sequences + 1) && + attention.host.kv_seq_lens.front() == 0; attn_metadata.total_kv_len = - is_cu_seq_lens - ? params.attention.host.kv_seq_lens.back() - : std::accumulate(params.attention.host.kv_seq_lens.begin(), - params.attention.host.kv_seq_lens.end(), - int64_t{0}); + is_cu_seq_lens ? attention.host.kv_seq_lens.back() + : std::accumulate(attention.host.kv_seq_lens.begin(), + attention.host.kv_seq_lens.end(), + int64_t{0}); } - attn_metadata.slot_mapping = params.attention.device.new_cache_slots; + attn_metadata.slot_mapping = attention.device.new_cache_slots; attn_metadata.compute_dtype = compute_dtype; // for flashinfer - attn_metadata.paged_kv_indptr = params.attention.device.paged_kv_indptr; - attn_metadata.paged_kv_indices = params.attention.device.paged_kv_indices; + attn_metadata.paged_kv_indptr = attention.device.paged_kv_indptr; + attn_metadata.paged_kv_indices = attention.device.paged_kv_indices; attn_metadata.paged_kv_last_page_len = - params.attention.device.paged_kv_last_page_len; + attention.device.paged_kv_last_page_len; #if defined(USE_CUDA) || defined(USE_MUSA) attn_metadata.plan_info = std::make_shared(); attn_metadata.shared_plan_info = std::make_shared(); @@ -71,8 +74,8 @@ AttentionMetadata build_attention_metadata( // graph_buffer.attn_mask (e.g. Qwen2_5_VL sets graph_buffer.attn_mask for // LongCat text encoding) std::optional mask_to_use = attn_mask; - if (!mask_to_use.has_value() && params.graph.attn_mask.defined()) { - mask_to_use = params.graph.attn_mask; + if (!mask_to_use.has_value() && graph.attn_mask.defined()) { + mask_to_use = graph.attn_mask; } if (mask_to_use.has_value()) { attn_metadata.attn_mask = mask_to_use.value(); @@ -84,57 +87,56 @@ AttentionMetadata build_attention_metadata( // - FLAGS_enable_graph must be enabled // - Must be decode phase (not prefill) // - tiling_data must be available - bool is_decode = !params.meta.batch_forward_type.is_prefill() && - !params.meta.batch_forward_type.is_mixed() && - !params.meta.batch_forward_type.is_chunked_prefill(); + bool is_decode = !meta.batch_forward_type.is_prefill() && + !meta.batch_forward_type.is_mixed() && + !meta.batch_forward_type.is_chunked_prefill(); bool use_acl_graph = - FLAGS_enable_graph && is_decode && params.graph.tiling_data.defined(); + FLAGS_enable_graph && is_decode && graph.tiling_data.defined(); if (use_acl_graph) { // ACL graph mode: use CustomPagedAttention with tiling_data on device - attn_metadata.paged_attention_tiling_data = params.graph.tiling_data; + attn_metadata.paged_attention_tiling_data = graph.tiling_data; } // Provide host seq_lens for NPU kernels (required by CustomPagedAttention). - if (!params.attention.host.kv_seq_lens.empty()) { + if (!attention.host.kv_seq_lens.empty()) { attn_metadata.kv_seq_lens_host = - torch::tensor(params.attention.host.kv_seq_lens, torch::kInt); + torch::tensor(attention.host.kv_seq_lens, torch::kInt); } #endif attn_metadata.is_chunked_prefill = - params.meta.batch_forward_type.is_mixed() || - params.meta.batch_forward_type.is_chunked_prefill(); - attn_metadata.is_prefill = params.meta.batch_forward_type.is_prefill(); + meta.batch_forward_type.is_mixed() || + meta.batch_forward_type.is_chunked_prefill(); + attn_metadata.is_prefill = meta.batch_forward_type.is_prefill(); // enable_mla is for DeepSeekv32 on mlu device if (!attn_metadata.is_prefill || enable_mla) { - attn_metadata.block_table = params.attention.device.block_tables; + attn_metadata.block_table = attention.device.block_tables; #if !defined(USE_NPU) && !defined(USE_CUDA) attn_metadata.kv_seq_lens = - torch::diff(params.attention.device.kv_seq_lens); // kv seqlens + torch::diff(attention.device.kv_seq_lens); // kv seqlens attn_metadata.q_seq_lens = - torch::diff(params.attention.device.q_seq_lens); // q seqlens + torch::diff(attention.device.q_seq_lens); // q seqlens #endif } #if defined(USE_NPU) // NPU path uses per-sequence lengths (not cumulative), so no diff. // Ensure per-sequence lengths are available for NPU kernels in all phases. - if (params.attention.device.kv_seq_lens.defined()) { - attn_metadata.kv_seq_lens = params.attention.device.kv_seq_lens; + if (attention.device.kv_seq_lens.defined()) { + attn_metadata.kv_seq_lens = attention.device.kv_seq_lens; } - if (params.attention.device.q_seq_lens.defined()) { - attn_metadata.q_seq_lens = params.attention.device.q_seq_lens; - CHECK(params.attention.device.q_cu_seq_lens.defined()) + if (attention.device.q_seq_lens.defined()) { + attn_metadata.q_seq_lens = attention.device.q_seq_lens; + CHECK(attention.device.q_cu_seq_lens.defined()) << "q_cu_seq_lens must be provided by upstream"; - auto zero = - torch::zeros({1}, params.attention.device.q_cu_seq_lens.options()); + auto zero = torch::zeros({1}, attention.device.q_cu_seq_lens.options()); attn_metadata.q_cu_seq_lens = - torch::cat({zero, params.attention.device.q_cu_seq_lens}, 0); + torch::cat({zero, attention.device.q_cu_seq_lens}, 0); } #endif - attn_metadata.is_dummy = (params.meta.q_max_seq_len == 0); + attn_metadata.is_dummy = (meta.q_max_seq_len == 0); if (attn_metadata.is_dummy) { attn_metadata.slot_mapping = - torch::tensor({1}, params.attention.device.new_cache_slots.options()); + torch::tensor({1}, attention.device.new_cache_slots.options()); } // Set is_causal: true for prefill (causal attention), false for decode @@ -142,8 +144,7 @@ AttentionMetadata build_attention_metadata( attn_metadata.is_causal = attn_metadata.is_prefill || attn_metadata.is_chunked_prefill; - // Copy enable_cuda_graph flag from params - attn_metadata.enable_cuda_graph = params.enable_cuda_graph; + attn_metadata.enable_cuda_graph = enable_cuda_graph; #if defined(USE_CUDA) || defined(USE_MUSA) if (attn_metadata.is_causal && !attn_metadata.enable_cuda_graph) { @@ -152,16 +153,15 @@ AttentionMetadata build_attention_metadata( #endif #if defined(USE_ILU) - attn_metadata.block_table = params.attention.device.block_tables; + attn_metadata.block_table = attention.device.block_tables; #endif // TODO: set use_tensor_core from options. // for xattention - if (params.has_llmrec_params()) { - const auto& llmrec_params = *params.llmrec_params(); - if (llmrec_params.current_round_tensor.defined() && - llmrec_params.current_round_tensor.numel() > 0) { - attn_metadata.step_tensor = llmrec_params.current_round_tensor; + if (llmrec_params != nullptr) { + if (llmrec_params->current_round_tensor.defined() && + llmrec_params->current_round_tensor.numel() > 0) { + attn_metadata.step_tensor = llmrec_params->current_round_tensor; } if (!FLAGS_enable_xattention_one_stage) { @@ -170,35 +170,36 @@ AttentionMetadata build_attention_metadata( XAttentionTwoStageDecodeCache{}); auto& cache = attn_metadata.xattention_two_stage_decode_cache.value(); - cache.shared_lse = llmrec_params.two_stage_shared_lse; - cache.shared_o = llmrec_params.two_stage_shared_o; - cache.unshared_lse = llmrec_params.two_stage_unshared_lse; - cache.unshared_o = llmrec_params.two_stage_unshared_o; - cache.q_cu_seq_lens_shared = llmrec_params.two_stage_q_cu_seq_lens_shared; - cache.qo_indptr_expanded = llmrec_params.two_stage_qo_indptr_expanded; + cache.shared_lse = llmrec_params->two_stage_shared_lse; + cache.shared_o = llmrec_params->two_stage_shared_o; + cache.unshared_lse = llmrec_params->two_stage_unshared_lse; + cache.unshared_o = llmrec_params->two_stage_unshared_o; + cache.q_cu_seq_lens_shared = + llmrec_params->two_stage_q_cu_seq_lens_shared; + cache.qo_indptr_expanded = llmrec_params->two_stage_qo_indptr_expanded; cache.paged_kv_indptr_expanded = - llmrec_params.two_stage_paged_kv_indptr_expanded; + llmrec_params->two_stage_paged_kv_indptr_expanded; cache.paged_kv_indices_expanded = - llmrec_params.two_stage_paged_kv_indices_expanded; + llmrec_params->two_stage_paged_kv_indices_expanded; cache.paged_kv_last_page_len_expanded = - llmrec_params.two_stage_paged_kv_last_page_len_expanded; + llmrec_params->two_stage_paged_kv_last_page_len_expanded; if (cache.q_cu_seq_lens_shared.defined()) { cache.cached_batch_size = static_cast(cache.q_cu_seq_lens_shared.numel()) - 1; } - cache.cached_beam_size = llmrec_params.beam_width; - if (!llmrec_params.unshared_k_caches.empty()) { + cache.cached_beam_size = llmrec_params->beam_width; + if (!llmrec_params->unshared_k_caches.empty()) { cache.cached_max_decode_step = - static_cast(llmrec_params.unshared_k_caches[0].size(2)); + static_cast(llmrec_params->unshared_k_caches[0].size(2)); } if (cache.shared_o.defined() && cache.shared_o.dim() == 3) { cache.cached_num_heads = static_cast(cache.shared_o.size(1)); cache.cached_head_size = static_cast(cache.shared_o.size(2)); } - if (llmrec_params.current_round_tensor.defined() && - llmrec_params.current_round_tensor.numel() > 0) { - cache.cached_step = llmrec_params.current_round_tensor.item(); + if (llmrec_params->current_round_tensor.defined() && + llmrec_params->current_round_tensor.numel() > 0) { + cache.cached_step = llmrec_params->current_round_tensor.item(); } #endif } @@ -210,19 +211,40 @@ AttentionMetadata build_attention_metadata( } // namespace AttentionMetadata AttentionMetadataBuilder::build( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, bool enable_mla, const std::optional& attn_mask) { - return AttentionMetadataBuilder::build( - params, enable_mla, "float", attn_mask); + return AttentionMetadataBuilder::build(meta, + attention, + graph, + llmrec_params, + enable_cuda_graph, + enable_mla, + "float", + attn_mask); } AttentionMetadata AttentionMetadataBuilder::build( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, bool enable_mla, const std::string& compute_dtype, const std::optional& attn_mask) { - return build_attention_metadata(params, enable_mla, compute_dtype, attn_mask); + return build_attention_metadata(meta, + attention, + graph, + llmrec_params, + enable_cuda_graph, + enable_mla, + compute_dtype, + attn_mask); } } // namespace xllm::layer diff --git a/xllm/core/layers/common/attention_metadata_builder.h b/xllm/core/layers/common/attention_metadata_builder.h index 7f1810fae0..1b0c315ef7 100644 --- a/xllm/core/layers/common/attention_metadata_builder.h +++ b/xllm/core/layers/common/attention_metadata_builder.h @@ -21,28 +21,38 @@ limitations under the License. #include namespace xllm { -struct ModelArgs; -struct ModelInputParams; +struct AttentionInput; +struct BatchInputMeta; +struct GraphInput; +struct LlmRecMultiRoundParams; namespace layer { struct AttentionMetadata; // Builder class for AttentionMetadata to avoid circular dependency. -// This class handles building AttentionMetadata from ModelInputParams, -// allowing attention_metadata.h to not depend on model_input_params.h. +// This class handles building AttentionMetadata without making +// attention_metadata.h depend on runtime input structures. class AttentionMetadataBuilder { public: - // Build AttentionMetadata from ModelInputParams with default compute_dtype - // ("float"). + // Build AttentionMetadata from the narrow inputs consumed by attention. static AttentionMetadata build( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, bool enable_mla, const std::optional& attn_mask = {}); - // Build AttentionMetadata from ModelInputParams with specified compute_dtype. + // Build AttentionMetadata from the narrow inputs consumed by attention with + // specified compute_dtype. static AttentionMetadata build( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, bool enable_mla, const std::string& compute_dtype, const std::optional& attn_mask = {}); diff --git a/xllm/core/layers/common/dp_utils.cpp b/xllm/core/layers/common/dp_utils.cpp index 05810f2645..196b329ebe 100644 --- a/xllm/core/layers/common/dp_utils.cpp +++ b/xllm/core/layers/common/dp_utils.cpp @@ -181,25 +181,24 @@ bool need_dp_moe_gather(const ParallelArgs& args, bool enable_moe_all2all) { } torch::Tensor gather_dp_tokens(const torch::Tensor& input, - const ModelInputParams& params, + const ParallelInput& parallel, const ParallelArgs& args) { if (args.dp_size() <= 1) { return input; } - return parallel_state::gather(input, - args.dp_local_process_group_, - params.parallel.dp_global_token_nums); + return parallel_state::gather( + input, args.dp_local_process_group_, parallel.dp_global_token_nums); } torch::Tensor get_dp_local_slice(const torch::Tensor& input, - const ModelInputParams& params, + const ParallelInput& parallel, const ParallelArgs& args) { if (args.dp_size() <= 1) { return input; } - const auto& dp_tokens = params.parallel.dp_global_token_nums; + const auto& dp_tokens = parallel.dp_global_token_nums; const int64_t dp_rank = args.dp_local_process_group_->rank(); int64_t start = 0; @@ -211,13 +210,13 @@ torch::Tensor get_dp_local_slice(const torch::Tensor& input, return input.slice(0, start, end); } -bool all_dp_ranks_are_decode(const ModelInputParams& params) { - if (params.parallel.dp_is_decode.empty()) { - return params.parallel.dp_global_token_nums.size() <= 1; +bool all_dp_ranks_are_decode(const ParallelInput& parallel) { + if (parallel.dp_is_decode.empty()) { + return parallel.dp_global_token_nums.size() <= 1; } - return std::all_of(params.parallel.dp_is_decode.begin(), - params.parallel.dp_is_decode.end(), + return std::all_of(parallel.dp_is_decode.begin(), + parallel.dp_is_decode.end(), [](int32_t val) { return val == 1; }); } diff --git a/xllm/core/layers/common/dp_utils.h b/xllm/core/layers/common/dp_utils.h index 2d8a40da4b..91c450f422 100644 --- a/xllm/core/layers/common/dp_utils.h +++ b/xllm/core/layers/common/dp_utils.h @@ -61,17 +61,17 @@ bool need_dp_moe_gather(const ParallelArgs& args, bool enable_moe_all2all); // gather tokens from all dp ranks before moe torch::Tensor gather_dp_tokens(const torch::Tensor& input, - const ModelInputParams& params, + const ParallelInput& parallel, const ParallelArgs& args); // given a tensor containing data from all DP ranks, // returns a slice containing only the tokens for the current DP rank torch::Tensor get_dp_local_slice(const torch::Tensor& input, - const ModelInputParams& params, + const ParallelInput& parallel, const ParallelArgs& args); // check if all dp ranks are decode -bool all_dp_ranks_are_decode(const ModelInputParams& params); +bool all_dp_ranks_are_decode(const ParallelInput& parallel); } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/fused_moe.cpp b/xllm/core/layers/common/fused_moe.cpp index b91dc08ec6..6a54c2c552 100644 --- a/xllm/core/layers/common/fused_moe.cpp +++ b/xllm/core/layers/common/fused_moe.cpp @@ -41,7 +41,7 @@ torch::Tensor FusedMoEImpl::forward_experts( } torch::Tensor FusedMoEImpl::forward(const torch::Tensor& /*hidden_states*/, - const ModelInputParams& /*input_params*/) { + const ParallelInput& /*parallel_input*/) { NOT_IMPLEMENTED_WITH_MSG( "FusedMoE is not supported for this backend. Please use CUDA, MLU or " "ILU backend for MoE models."); diff --git a/xllm/core/layers/common/fused_moe.h b/xllm/core/layers/common/fused_moe.h index 6e148c1551..cd040f0aec 100644 --- a/xllm/core/layers/common/fused_moe.h +++ b/xllm/core/layers/common/fused_moe.h @@ -45,7 +45,7 @@ class FusedMoEImpl : public torch::nn::Module { const torch::Tensor& router_logits, bool enable_all2all_communication); torch::Tensor forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params); + const ParallelInput& parallel_input); void load_state_dict(const StateDict& state_dict); }; TORCH_MODULE(FusedMoE); diff --git a/xllm/core/layers/common/oxygen_vision_attention.cpp b/xllm/core/layers/common/oxygen_vision_attention.cpp index f49f3f2e7b..ef1bf8ccfe 100644 --- a/xllm/core/layers/common/oxygen_vision_attention.cpp +++ b/xllm/core/layers/common/oxygen_vision_attention.cpp @@ -31,8 +31,7 @@ torch::Tensor OxygenVisionAttentionImpl::forward( torch::Tensor& m_cos_pos, torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, - std::vector& cu_seq_len_vec, - ModelInputParams& params) { + std::vector& cu_seq_len_vec) { // 1. qkv projection auto qkv = qkv_proj_->forward(hidden_states); // 2. split qkv diff --git a/xllm/core/layers/common/oxygen_vision_attention.h b/xllm/core/layers/common/oxygen_vision_attention.h index cde16aaa06..fcd89a4522 100644 --- a/xllm/core/layers/common/oxygen_vision_attention.h +++ b/xllm/core/layers/common/oxygen_vision_attention.h @@ -33,8 +33,7 @@ class OxygenVisionAttentionImpl : public Qwen2VisionAttentionImpl { torch::Tensor& m_cos_pos, torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, - std::vector& cu_seq_len_vec, - ModelInputParams& input_params) override; + std::vector& cu_seq_len_vec) override; }; TORCH_MODULE(OxygenVisionAttention); diff --git a/xllm/core/layers/common/qwen2_vision_attention.cpp b/xllm/core/layers/common/qwen2_vision_attention.cpp index d2474801d5..69f3223649 100644 --- a/xllm/core/layers/common/qwen2_vision_attention.cpp +++ b/xllm/core/layers/common/qwen2_vision_attention.cpp @@ -148,8 +148,7 @@ torch::Tensor Qwen2VisionAttentionImpl::forward( torch::Tensor& m_cos_pos, torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, - std::vector& cu_seq_len_vec, - ModelInputParams& params) { + std::vector& cu_seq_len_vec) { // 1. qkv projection auto qkv = qkv_proj_->forward(hidden_states); // 2. split qkv diff --git a/xllm/core/layers/common/qwen2_vision_attention.h b/xllm/core/layers/common/qwen2_vision_attention.h index b59e75b621..0bbff8d84a 100644 --- a/xllm/core/layers/common/qwen2_vision_attention.h +++ b/xllm/core/layers/common/qwen2_vision_attention.h @@ -37,8 +37,7 @@ class Qwen2VisionAttentionImpl : public torch::nn::Module { torch::Tensor& m_cos_pos, torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, - std::vector& cu_seq_len_vec, - ModelInputParams& input_params); + std::vector& cu_seq_len_vec); void load_state_dict(const StateDict& state_dict); diff --git a/xllm/core/layers/cuda/fused_moe.cpp b/xllm/core/layers/cuda/fused_moe.cpp index ba965262f1..e157606910 100644 --- a/xllm/core/layers/cuda/fused_moe.cpp +++ b/xllm/core/layers/cuda/fused_moe.cpp @@ -118,7 +118,7 @@ torch::Tensor FusedMoEImpl::forward_experts(const torch::Tensor& hidden_states, } torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, - const ModelInputParams& /*input_params*/) { + const ParallelInput& /*parallel_input*/) { torch::Tensor router_logits = gate_->forward(hidden_states); return forward_experts(hidden_states, router_logits); } diff --git a/xllm/core/layers/cuda/fused_moe.h b/xllm/core/layers/cuda/fused_moe.h index b5722e351f..fa131a53c1 100644 --- a/xllm/core/layers/cuda/fused_moe.h +++ b/xllm/core/layers/cuda/fused_moe.h @@ -45,7 +45,7 @@ class FusedMoEImpl : public torch::nn::Module { torch::Tensor forward_experts(const torch::Tensor& hidden_states, torch::Tensor router_logits); torch::Tensor forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params); + const ParallelInput& parallel_input); void load_state_dict(const StateDict& state_dict); private: diff --git a/xllm/core/layers/ilu/fused_moe.cpp b/xllm/core/layers/ilu/fused_moe.cpp index a0cda6bad7..00721f3d50 100644 --- a/xllm/core/layers/ilu/fused_moe.cpp +++ b/xllm/core/layers/ilu/fused_moe.cpp @@ -711,11 +711,11 @@ torch::Tensor FusedMoEImpl::forward_experts(const torch::Tensor& hidden_states, } torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params) { + const ParallelInput& parallel_input) { // we only support all2all communication for decode stage for now bool enable_all2all_communication = - enable_deep_ep_ && std::all_of(input_params.parallel.dp_is_decode.begin(), - input_params.parallel.dp_is_decode.end(), + enable_deep_ep_ && std::all_of(parallel_input.dp_is_decode.begin(), + parallel_input.dp_is_decode.end(), [](int32_t val) { return val == 1; }); bool is_dp_ep_parallel = @@ -730,7 +730,7 @@ torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, if (need_gather_and_slice) { input = parallel_state::gather(input, parallel_args_.dp_local_process_group_, - input_params.parallel.dp_global_token_nums); + parallel_input.dp_global_token_nums); } // MoE Gate auto router_logits = gate_(input); @@ -740,7 +740,7 @@ torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, forward_experts(input, router_logits, enable_all2all_communication); if (need_gather_and_slice) { - output = get_dp_local_slice(output, input_params, parallel_args_); + output = get_dp_local_slice(output, parallel_input, parallel_args_); } return output; diff --git a/xllm/core/layers/ilu/fused_moe.h b/xllm/core/layers/ilu/fused_moe.h index 3e477064ae..42866abc9e 100644 --- a/xllm/core/layers/ilu/fused_moe.h +++ b/xllm/core/layers/ilu/fused_moe.h @@ -46,7 +46,7 @@ class FusedMoEImpl : public torch::nn::Module { const torch::Tensor& router_logits, bool enable_all2all_communication); torch::Tensor forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params); + const ParallelInput& parallel_input); void load_state_dict(const StateDict& state_dict); private: diff --git a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp index 172955c49f..5fd83a6553 100644 --- a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.cpp @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "runtime/forward_params.h" + namespace xllm { namespace layer { @@ -159,7 +161,7 @@ DeepseekV2DecoderLayerImpl::MoeInputPrepResult DeepseekV2DecoderLayerImpl::prepare_moe_inputs( torch::Tensor x, const torch::Tensor& residual, - const ModelInputParams& input_params, + const ParallelInput& parallel_input, DeepseekV2AttentionImpl::PostAttnLayout attn_layout) { MoeInputPrepResult result; if (!sparse_moe_) { @@ -169,7 +171,7 @@ DeepseekV2DecoderLayerImpl::prepare_moe_inputs( return result; } - result.exec_cfg = sparse_moe_->plan_exec(input_params); + result.exec_cfg = sparse_moe_->plan_exec(parallel_input); result.use_sp_moe_overlap = attn_layout == DeepseekV2AttentionImpl::PostAttnLayout::kPackedLocal && !result.exec_cfg->enable_all2all && !result.exec_cfg->need_dp_gather && @@ -194,7 +196,7 @@ DeepseekV2DecoderLayerImpl::prepare_moe_inputs( if (result.exec_cfg->enable_all2all || result.exec_cfg->need_dp_gather) { result.moe_prep = sparse_moe_->prep_in(std::move(x), residual, - input_params, + parallel_input, result.exec_cfg.value(), attn_layout); result.ffn_in = result.moe_prep->ffn_in; @@ -233,9 +235,9 @@ bool DeepseekV2DecoderLayerImpl::can_keep_local_output( } bool DeepseekV2DecoderLayerImpl::can_sp_chunk( - const ModelInputParams& input_params) const { + BatchForwardType batch_forward_type) const { return sequence_parallel_context_ != nullptr && sp_ffn_chunk_size_ > 0 && - input_params.meta.batch_forward_type.no_decode(); + batch_forward_type.no_decode(); } torch::Tensor DeepseekV2DecoderLayerImpl::comm_out( @@ -255,10 +257,9 @@ torch::Tensor DeepseekV2DecoderLayerImpl::comm_out( return v32_sp::slice_local_packed(x, *sequence_parallel_context_); } -torch::Tensor DeepseekV2DecoderLayerImpl::run_mlp( - torch::Tensor x, - const ModelInputParams& input_params) { - if (!can_sp_chunk(input_params) || !x.defined() || x.dim() == 0 || +torch::Tensor DeepseekV2DecoderLayerImpl::run_mlp(torch::Tensor x, + bool can_use_sp_chunk) { + if (!can_use_sp_chunk || !x.defined() || x.dim() == 0 || x.size(0) <= sp_ffn_chunk_size_) { return mlp_(x); } @@ -299,7 +300,7 @@ torch::Tensor DeepseekV2DecoderLayerImpl::forward( torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput& input) { // Pre-attention norm residual = x; x = std::get<0>(input_norm_->forward(x)); @@ -311,7 +312,7 @@ torch::Tensor DeepseekV2DecoderLayerImpl::forward( sequence_parallel_context_ != nullptr && attention_->can_use_sp(); const auto attn_layout = attention_->post_attn_layout(use_sp_output); auto prep = prepare_moe_inputs( - std::move(x), residual.value(), input_params, attn_layout); + std::move(x), residual.value(), input.parallel, attn_layout); auto& carrier = prep.carrier; auto& moe_prep = prep.moe_prep; auto& exec_cfg = prep.exec_cfg; @@ -327,13 +328,13 @@ torch::Tensor DeepseekV2DecoderLayerImpl::forward( } } if (moe_prep.has_value()) { - x = sparse_moe_->gather_in(*moe_prep, input_params); + x = sparse_moe_->gather_in(*moe_prep, input.parallel); } // MLP forward bool keep_local_output = false; const int64_t sp_chunk_size = - can_sp_chunk(input_params) ? sp_ffn_chunk_size_ : -1; + can_sp_chunk(input.meta.batch_forward_type) ? sp_ffn_chunk_size_ : -1; if (sparse_moe_) { auto can_keep_local = [&](ProcessGroup* pg) { return carrier.has_value() && can_keep_local_output(*carrier, pg); @@ -369,7 +370,7 @@ torch::Tensor DeepseekV2DecoderLayerImpl::forward( } else { keep_local_output = can_keep_local_output(*carrier, parallel_args_.tp_group_); - x = run_mlp(std::move(x), input_params); + x = run_mlp(std::move(x), sp_chunk_size > 0); x = keep_local_output ? comm_out(x, *carrier, parallel_args_.tp_group_) : reduce_out(x, parallel_args_.tp_group_); } @@ -381,7 +382,7 @@ torch::Tensor DeepseekV2DecoderLayerImpl::forward( x = x + skip_src; } else if (moe_prep.has_value() && (moe_prep->need_dp_gather || moe_prep->need_tp_pad)) { - x = sparse_moe_->merge_out(x, *moe_prep, input_params); + x = sparse_moe_->merge_out(x, *moe_prep, input.parallel); } else { x = restore_ffn_output(x, *carrier); } diff --git a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h index 90317e5cf8..b3f8d53d50 100644 --- a/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/mlu/deepseek_v2_decoder_layer_impl.h @@ -36,6 +36,8 @@ limitations under the License. #include "layers/mlu/deepseek_v32_sp_context.h" namespace xllm { +struct ForwardInput; + namespace layer { class DeepseekV2DecoderLayerTestPeer; @@ -64,7 +66,7 @@ class DeepseekV2DecoderLayerImpl : public torch::nn::Module { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params); + const ForwardInput& input); private: enum class PostAttnMode { @@ -96,16 +98,16 @@ class DeepseekV2DecoderLayerImpl : public torch::nn::Module { MoeInputPrepResult prepare_moe_inputs( torch::Tensor x, const torch::Tensor& residual, - const ModelInputParams& input_params, + const ParallelInput& parallel_input, DeepseekV2AttentionImpl::PostAttnLayout attn_layout); bool can_keep_local_output(const PostAttnCarrier& carrier, ProcessGroup* pg) const; - bool can_sp_chunk(const ModelInputParams& input_params) const; + bool can_sp_chunk(BatchForwardType batch_forward_type) const; torch::Tensor comm_out(torch::Tensor x, const PostAttnCarrier& carrier, ProcessGroup* pg) const; - torch::Tensor run_mlp(torch::Tensor x, const ModelInputParams& input_params); + torch::Tensor run_mlp(torch::Tensor x, bool can_use_sp_chunk); torch::Tensor restore_ffn_output(torch::Tensor x, const PostAttnCarrier& carrier); torch::Tensor reduce_out(torch::Tensor x, ProcessGroup* pg) const; diff --git a/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.cpp b/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.cpp index 62db991f94..0746e86963 100644 --- a/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.cpp +++ b/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.cpp @@ -58,9 +58,10 @@ void DeepseekV2SparseMoEBlockImpl::verify_loaded_weights() const { } DeepseekV2SparseMoEBlockImpl::ExecCfg DeepseekV2SparseMoEBlockImpl::plan_exec( - const ModelInputParams& input_params) const { + const ParallelInput& parallel_input) const { ExecCfg cfg; - cfg.enable_all2all = enable_deep_ep_ && all_dp_ranks_are_decode(input_params); + cfg.enable_all2all = + enable_deep_ep_ && all_dp_ranks_are_decode(parallel_input); cfg.need_dp_gather = need_dp_moe_gather(parallel_args_, cfg.enable_all2all); return cfg; } @@ -68,7 +69,7 @@ DeepseekV2SparseMoEBlockImpl::ExecCfg DeepseekV2SparseMoEBlockImpl::plan_exec( DeepseekV2SparseMoEBlockImpl::PrepOut DeepseekV2SparseMoEBlockImpl::prep_in( torch::Tensor x, const torch::Tensor& residual, - const ModelInputParams& input_params, + const ParallelInput& parallel_input, const ExecCfg& exec, DeepseekV2AttentionImpl::PostAttnLayout attn_layout) const { PrepOut prep; @@ -90,8 +91,7 @@ DeepseekV2SparseMoEBlockImpl::PrepOut DeepseekV2SparseMoEBlockImpl::prep_in( auto shard = shard_attn_out( x, residual, - get_dp_gather_tokens(input_params.parallel.dp_global_token_nums, - parallel_args_), + get_dp_gather_tokens(parallel_input.dp_global_token_nums, parallel_args_), attn_layout); prep.ffn_in = shard.first; prep.pad_info = shard.second; @@ -107,34 +107,32 @@ DeepseekV2SparseMoEBlockImpl::PrepOut DeepseekV2SparseMoEBlockImpl::prep_in( << "dp gather prep requires dp_local_process_group_"; const int64_t dp_rank = parallel_args_.dp_local_process_group_->rank(); CHECK_GE(dp_rank, 0) << "invalid dp rank " << dp_rank; - CHECK_LT( - dp_rank, - static_cast(input_params.parallel.dp_global_token_nums.size())) + CHECK_LT(dp_rank, + static_cast(parallel_input.dp_global_token_nums.size())) << "dp rank " << dp_rank << " exceeds dp_global_token_nums size " - << input_params.parallel.dp_global_token_nums.size(); - const int64_t local_token_num = - input_params.parallel.dp_global_token_nums[dp_rank]; + << parallel_input.dp_global_token_nums.size(); + const int64_t local_token_num = parallel_input.dp_global_token_nums[dp_rank]; prep.skip_local = local_tokens.slice(0, 0, local_token_num); return prep; } torch::Tensor DeepseekV2SparseMoEBlockImpl::gather_in( const PrepOut& prep, - const ModelInputParams& input_params) const { + const ParallelInput& parallel_input) const { if (!prep.need_dp_gather) { return prep.ffn_in; } return gather_global_tokens( - prep.ffn_in, input_params.parallel.dp_global_token_nums, parallel_args_); + prep.ffn_in, parallel_input.dp_global_token_nums, parallel_args_); } torch::Tensor DeepseekV2SparseMoEBlockImpl::merge_out( torch::Tensor x, const PrepOut& prep, - const ModelInputParams& input_params) const { + const ParallelInput& parallel_input) const { if (prep.need_dp_gather) { - x = get_dp_local_slice(x, input_params, parallel_args_); + x = get_dp_local_slice(x, parallel_input, parallel_args_); return x + prep.skip_local; } diff --git a/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.h b/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.h index 99637536dc..e8e581387a 100644 --- a/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.h +++ b/xllm/core/layers/mlu/deepseek_v2_sparse_moe_block.h @@ -73,17 +73,17 @@ class DeepseekV2SparseMoEBlockImpl : public torch::nn::Module { void load_state_dict(const StateDict& state_dict); void verify_loaded_weights() const; - ExecCfg plan_exec(const ModelInputParams& input_params) const; + ExecCfg plan_exec(const ParallelInput& parallel_input) const; PrepOut prep_in(torch::Tensor x, const torch::Tensor& residual, - const ModelInputParams& input_params, + const ParallelInput& parallel_input, const ExecCfg& exec, DeepseekV2AttentionImpl::PostAttnLayout attn_layout) const; torch::Tensor gather_in(const PrepOut& prep, - const ModelInputParams& input_params) const; + const ParallelInput& parallel_input) const; torch::Tensor merge_out(torch::Tensor x, const PrepOut& prep, - const ModelInputParams& input_params) const; + const ParallelInput& parallel_input) const; bool has_shared() const; ForwardResult forward(torch::Tensor x, diff --git a/xllm/core/layers/mlu/fused_moe.cpp b/xllm/core/layers/mlu/fused_moe.cpp index d7bd5a8984..c122ff1cc9 100644 --- a/xllm/core/layers/mlu/fused_moe.cpp +++ b/xllm/core/layers/mlu/fused_moe.cpp @@ -513,10 +513,10 @@ torch::Tensor FusedMoEImpl::forward_experts( } torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params) { + const ParallelInput& parallel_input) { // we only support all2all communication for decode stage for now bool enable_all2all_communication = - enable_deep_ep_ && all_dp_ranks_are_decode(input_params); + enable_deep_ep_ && all_dp_ranks_are_decode(parallel_input); return forward_experts( hidden_states, enable_all2all_communication, std::nullopt); } diff --git a/xllm/core/layers/mlu/fused_moe.h b/xllm/core/layers/mlu/fused_moe.h index f0dc5449be..c81d676963 100644 --- a/xllm/core/layers/mlu/fused_moe.h +++ b/xllm/core/layers/mlu/fused_moe.h @@ -57,7 +57,7 @@ class FusedMoEImpl : public torch::nn::Module { const std::optional& route_info = std::nullopt); torch::Tensor forward_shared(const torch::Tensor& hidden_states); torch::Tensor forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params); + const ParallelInput& parallel_input); virtual void load_state_dict(const StateDict& state_dict); void verify_loaded_weights() const; bool has_shared() const { return static_cast(shared_experts_); } diff --git a/xllm/core/layers/mlu/qwen3_5_decoder_layer.cpp b/xllm/core/layers/mlu/qwen3_5_decoder_layer.cpp index 1a6021e0d6..eee4076fb4 100644 --- a/xllm/core/layers/mlu/qwen3_5_decoder_layer.cpp +++ b/xllm/core/layers/mlu/qwen3_5_decoder_layer.cpp @@ -19,13 +19,13 @@ limitations under the License. #include "common/global_flags.h" #include "layers/common/dp_utils.h" +#include "runtime/forward_params.h" namespace xllm { namespace layer { namespace { -bool use_moe_all2all(bool enable_deep_ep, - const ModelInputParams& input_params) { - return enable_deep_ep && all_dp_ranks_are_decode(input_params); +bool use_moe_all2all(bool enable_deep_ep, const ParallelInput& parallel_input) { + return enable_deep_ep && all_dp_ranks_are_decode(parallel_input); } bool is_moe_layer(const ModelArgs& model_args, int32_t layer_id) { @@ -130,13 +130,13 @@ void Qwen3_5DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { torch::Tensor Qwen3_5DecoderLayerImpl::run_moe( torch::Tensor x, - const ModelInputParams& input_params) { + const ParallelInput& parallel_input) { const bool enable_moe_all2all = - use_moe_all2all(enable_deep_ep_, input_params); + use_moe_all2all(enable_deep_ep_, parallel_input); if (need_dp_moe_gather(parallel_args_, enable_moe_all2all)) { - x = gather_dp_tokens(x, input_params, parallel_args_); + x = gather_dp_tokens(x, parallel_input, parallel_args_); x = moe_mlp_->forward_experts(x, enable_moe_all2all); - return get_dp_local_slice(x, input_params, parallel_args_); + return get_dp_local_slice(x, parallel_input, parallel_args_); } return moe_mlp_->forward_experts(x, enable_moe_all2all); } @@ -164,7 +164,7 @@ torch::Tensor Qwen3_5DecoderLayerImpl::forward( torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput& input_params) { // Pre-attention norm std::tie(x, residual) = apply_norm(input_norm_, x, residual); @@ -181,7 +181,7 @@ torch::Tensor Qwen3_5DecoderLayerImpl::forward( // MLP/MoE if (moe_mlp_) { - x = run_moe(x, input_params); + x = run_moe(x, input_params.parallel); } else { x = mlp_->forward(x); } diff --git a/xllm/core/layers/mlu/qwen3_5_decoder_layer.h b/xllm/core/layers/mlu/qwen3_5_decoder_layer.h index efe56d1bb4..5f761369c3 100644 --- a/xllm/core/layers/mlu/qwen3_5_decoder_layer.h +++ b/xllm/core/layers/mlu/qwen3_5_decoder_layer.h @@ -32,6 +32,8 @@ limitations under the License. #include "layers/mlu/qwen3_5_fused_moe.h" namespace xllm { +struct ForwardInput; + namespace layer { class Qwen3_5DecoderLayerImpl final : public torch::nn::Module { @@ -45,7 +47,7 @@ class Qwen3_5DecoderLayerImpl final : public torch::nn::Module { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params); + const ForwardInput& input_params); private: std::tuple> apply_norm( @@ -53,7 +55,7 @@ class Qwen3_5DecoderLayerImpl final : public torch::nn::Module { torch::Tensor& input, std::optional& residual); - torch::Tensor run_moe(torch::Tensor x, const ModelInputParams& input_params); + torch::Tensor run_moe(torch::Tensor x, const ParallelInput& parallel_input); std::string layer_type_; Qwen3_5Attention full_attention_{nullptr}; diff --git a/xllm/core/layers/musa/attention.cpp b/xllm/core/layers/musa/attention.cpp index 31393f3bbf..85f14720c2 100644 --- a/xllm/core/layers/musa/attention.cpp +++ b/xllm/core/layers/musa/attention.cpp @@ -50,21 +50,21 @@ AttentionImpl::AttentionImpl(int64_t num_heads, torch::Tensor AttentionImpl::forward(torch::Tensor& input, ForwardParams& fwd_params) { auto&& cache = fwd_params.kv_cache; - auto& input_params = const_cast(fwd_params.input_params); + const auto& attention = fwd_params.attention; - auto musa_attn_meta = xllm_musa::AttnMetaData::build( - input_params.attention.host.q_seq_lens, - input_params.attention.host.kv_seq_lens, - num_heads_, - num_kv_heads_, - head_dim_, - input_params.attention.device.new_cache_slots, - 64); + auto musa_attn_meta = + xllm_musa::AttnMetaData::build(attention.host.q_seq_lens, + attention.host.kv_seq_lens, + num_heads_, + num_kv_heads_, + head_dim_, + attention.device.new_cache_slots, + 64); return xllm_musa::QWen3Attn(input, cache.get_k_cache(), cache.get_v_cache(), - input_params.attention.device.block_tables, + attention.device.block_tables, fwd_params.attn_meta.mrope_cos, fwd_params.positions, weights_, diff --git a/xllm/core/layers/musa/musa_layer_base.h b/xllm/core/layers/musa/musa_layer_base.h index 4f42fdfd05..345b817515 100644 --- a/xllm/core/layers/musa/musa_layer_base.h +++ b/xllm/core/layers/musa/musa_layer_base.h @@ -31,7 +31,7 @@ struct ForwardParams { torch::Tensor& positions; AttentionMetadata const& attn_meta; KVCache& kv_cache; - ModelInputParams const& input_params; + AttentionInput const& attention; }; class MUSALayerBaseImpl : public torch::nn::Module { @@ -66,4 +66,4 @@ class MUSALayerBaseImpl : public torch::nn::Module { }; TORCH_MODULE(MUSALayerBase); -} // namespace xllm::layer \ No newline at end of file +} // namespace xllm::layer diff --git a/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.cpp index f197e4b2ee..5422a2382a 100644 --- a/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "attention.h" #include "layers/common/rotary_embedding.h" #include "musa_mlp.h" +#include "runtime/forward_params.h" namespace xllm::layer { @@ -53,13 +54,13 @@ torch::Tensor Qwen3DecoderLayerImpl::forward( torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput& input) { // torch::Tensor k_cache = kv_cache.get_k_cache(); // k_cache = k_cache.view({-1, k_cache.size(1) * 8, k_cache.size(2)}); // torch::Tensor v_cache = kv_cache.get_v_cache(); // v_cache = v_cache.view({-1, v_cache.size(1) * 8, v_cache.size(2)}); - ForwardParams f{positions, attn_metadata, kv_cache, input_params}; + ForwardParams f{positions, attn_metadata, kv_cache, input.attention}; for (auto&& mod : layers_) { x = mod->forward(x, f); diff --git a/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.h b/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.h index dbb1eae89b..e1058bb7b3 100644 --- a/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/musa/musa_qwen3_decoder_layer_impl.h @@ -18,7 +18,6 @@ limitations under the License. #include #include "framework/kv_cache/kv_cache.h" -#include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" #include "framework/state_dict/utils.h" @@ -26,6 +25,8 @@ limitations under the License. #include "musa_layer_base.h" namespace xllm { +struct ForwardInput; + namespace layer { class Qwen3DecoderLayerImpl : public torch::nn::Module { public: @@ -40,7 +41,7 @@ class Qwen3DecoderLayerImpl : public torch::nn::Module { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params); + const ForwardInput& input); private: std::vector> layers_; @@ -49,4 +50,4 @@ class Qwen3DecoderLayerImpl : public torch::nn::Module { TORCH_MODULE(Qwen3DecoderLayer); } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 6fd08e444f..ef75f2b49c 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "common/global_flags.h" #include "layers/common/rotary_embedding_util.h" #include "loader/deepseek_v2_decoder_loader.h" +#include "runtime/forward_params.h" namespace xllm { namespace layer { @@ -758,36 +759,28 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - ModelInputParams& input_params_new = - const_cast(input_params); // all micro batches are in same prefill/decode stage, - if (input_params_new.meta.batch_forward_type.is_chunked_prefill()) { + if (input.meta.batch_forward_type.is_chunked_prefill()) { build_node_variant_pack(prefill_node_prefixcache_, x, cos_pos, sin_pos, attn_mask, kv_cache, - input_params_new, + input, true); st = execute_node(prefill_node_prefixcache_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute is_chunked_prefill layer fail, error code: " << st; - } else if (input_params_new.meta.batch_forward_type.is_prefill()) { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params_new, - true); + } else if (input.meta.batch_forward_type.is_prefill()) { + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "excute prefill layer fail, error code: " << st; @@ -804,7 +797,7 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( sin_pos, /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params_new, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) @@ -816,7 +809,7 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( sin_pos, /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params_new, + input, false); st = execute_node(decode_mla_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) @@ -833,12 +826,12 @@ void NpuDeepseekV2DecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill) { internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); // final_hidden_states_ = torch::zeros_like(x); int32_t input_idx = 0; - auto& dp_ep_padding = input_params.parallel.dp_ep_padding_data; + DpEpPaddingData dp_ep_padding = input_params.parallel.dp_ep_padding_data; if (dp_size_ <= 1 && ep_size_ <= 1) { dp_ep_padding.set_placeholder(tensor_placeholder_); } diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index 241cc33bfb..025b043c20 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -32,6 +32,8 @@ limitations under the License. #include "xllm_atb_layers/models/deepseekv2/layer/decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class ExpertBuffer { @@ -124,7 +126,7 @@ class NpuDeepseekV2DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -199,7 +201,7 @@ class NpuDeepseekV2DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill); torch::Tensor block_tables_placeholder_; diff --git a/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.cpp index 51d06b3e50..1cac9ff6fe 100644 --- a/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.cpp @@ -28,6 +28,7 @@ limitations under the License. #include "framework/parallel_state/npu_cp_prepare.h" #include "layers/common/rotary_embedding_util.h" #include "loader/deepseek_v32_decoder_loader.h" +#include "runtime/forward_params.h" namespace xllm { namespace layer { @@ -775,34 +776,20 @@ torch::Tensor NpuDeepseekV32DecoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - ModelInputParams& input_params_new = - const_cast(input_params); - if (input_params_new.meta.batch_forward_type.is_decode()) { - build_node_variant_pack(decode_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params_new, - false); + if (input.meta.batch_forward_type.is_decode()) { + build_node_variant_pack( + decode_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, false); st = execute_node(decode_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "execute prefill layer fail, error code: " << st; } else { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params_new, - true); + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << "execute prefill layer fail, error code: " << st; @@ -817,14 +804,14 @@ void NpuDeepseekV32DecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill) { internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); // final_hidden_states_ = torch::zeros_like(x); int32_t input_idx = 0; - auto& dp_ep_padding = input_params.parallel.dp_ep_padding_data; - auto& cp_ep_padding = input_params.parallel.cp_ep_padding_data; - auto& cp_inputs = input_params.parallel.cp_prefill_inputs; + DpEpPaddingData dp_ep_padding = input_params.parallel.dp_ep_padding_data; + CpEpPaddingData cp_ep_padding = input_params.parallel.cp_ep_padding_data; + const auto& cp_inputs = input_params.parallel.cp_prefill_inputs; const bool use_cp_ep_padding = (cp_size_ > 1 && is_prefill); if (dp_size_ <= 1 && ep_size_ <= 1 || cp_size_ > 1) { diff --git a/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.h index 26bb4edae2..604a0ec901 100644 --- a/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.h @@ -33,6 +33,8 @@ limitations under the License. #include "xllm_atb_layers/models/deepseekv2/layer/decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuDeepseekV32DecoderLayerImpl : public BaseLayer { @@ -57,7 +59,7 @@ class NpuDeepseekV32DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -125,7 +127,7 @@ class NpuDeepseekV32DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill); torch::Tensor block_tables_placeholder_; diff --git a/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.cpp index 34bf2120fa..700485cfd4 100644 --- a/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.cpp @@ -22,6 +22,7 @@ limitations under the License. #include #include "common/global_flags.h" +#include "runtime/forward_params.h" // #include "attn_mask.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" @@ -221,12 +222,12 @@ torch::Tensor NpuEagle3DecoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); build_node_variant_pack(prefill_node_, hidden_states, @@ -235,7 +236,7 @@ torch::Tensor NpuEagle3DecoderLayerImpl::forward( sin_pos, attn_mask, kv_cache, - input_params, + input, true); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); @@ -249,7 +250,7 @@ torch::Tensor NpuEagle3DecoderLayerImpl::forward( sin_pos, decode_attn_mask_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -267,7 +268,7 @@ void NpuEagle3DecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(hidden_states); internal_tensors_extra_ = diff --git a/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.h index c9a22a81a3..b4cdd3fd2c 100644 --- a/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_eagle3_decoder_layer_impl.h @@ -42,6 +42,8 @@ limitations under the License. #include "xllm_atb_layers/models/eagle3/layer/decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuEagle3DecoderLayerImpl : public BaseLayer { @@ -60,7 +62,7 @@ class NpuEagle3DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -79,7 +81,7 @@ class NpuEagle3DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill); void param_from_args(atb_speed::eagle3::DecoderLayerParam& param, diff --git a/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp index 0cae6d9984..14bc41181d 100644 --- a/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_glm4_decoder_layer_impl.cpp @@ -20,6 +20,7 @@ limitations under the License. #include "common/global_flags.h" #include "loader/glm4_decoder_loader.h" +#include "runtime/forward_params.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUException.h" @@ -159,20 +160,14 @@ torch::Tensor NpuGlm4DecoderLayerImpl::forward(torch::Tensor& x, torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - true); + if (!input.meta.batch_forward_type.is_decode()) { + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -184,7 +179,7 @@ torch::Tensor NpuGlm4DecoderLayerImpl::forward(torch::Tensor& x, sin_pos, decode_attn_mask_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -201,7 +196,7 @@ void NpuGlm4DecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); // std::cout<<"node.variantPack.inTensors.size:"<* event_flag = nullptr, int node_id = 0); @@ -74,7 +76,7 @@ class NpuGlm4DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill); void initialize_quantization_parameters( diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index d9f56984f7..b01fb3e333 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "npu_glm4_moe_decoder_layer.h" #include "common/global_flags.h" +#include "runtime/forward_params.h" DECLARE_string(rank_tablefile); DECLARE_string(communication_backend); DECLARE_int32(expert_parallel_degree); @@ -337,26 +338,19 @@ int64_t NpuGlm4MoeDecoderImpl::init_node(atb_speed::Model::Node& node, return atb::NO_ERROR; } -torch::Tensor NpuGlm4MoeDecoderImpl::forward( - torch::Tensor& x, - torch::Tensor& cos_pos, - torch::Tensor& sin_pos, - torch::Tensor& attn_mask, - KVCache& kv_cache, - const ModelInputParams& input_params, - aclrtEvent* event, - std::atomic* event_flag, - int node_id) { +torch::Tensor NpuGlm4MoeDecoderImpl::forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + const ForwardInput& input, + aclrtEvent* event, + std::atomic* event_flag, + int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - true); + if (!input.meta.batch_forward_type.is_decode()) { + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << " excute prefill layer fail, error code: " << st; @@ -367,7 +361,7 @@ torch::Tensor NpuGlm4MoeDecoderImpl::forward( sin_pos, /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -384,7 +378,7 @@ void NpuGlm4MoeDecoderImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill) { internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); auto& dp_ep_padding = input_params.parallel.dp_ep_padding_data; diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h index a40ce4b89e..38c9d52c3d 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h @@ -30,6 +30,8 @@ limitations under the License. #include "xllm_atb_layers/models/glm/layer/moe_decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuGlm4MoeDecoderImpl : public BaseLayer { @@ -48,7 +50,7 @@ class NpuGlm4MoeDecoderImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -98,7 +100,7 @@ class NpuGlm4MoeDecoderImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill); std::string model_name_; @@ -144,4 +146,4 @@ std::vector get_dtp_inputs(torch::Tensor token_size_per_dp_group, int32_t rank, at::Device device); } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp index df7bf8cb31..13e1c75fdc 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "common/global_flags.h" #include "layers/common/rotary_embedding_util.h" +#include "runtime/forward_params.h" DECLARE_string(rank_tablefile); DECLARE_string(communication_backend); DECLARE_int32(expert_parallel_degree); @@ -402,26 +403,19 @@ int64_t NpuGlm4MoeDecoderLiteImpl::init_node( return atb::NO_ERROR; } -torch::Tensor NpuGlm4MoeDecoderLiteImpl::forward( - torch::Tensor& x, - torch::Tensor& cos_pos, - torch::Tensor& sin_pos, - torch::Tensor& attn_mask, - KVCache& kv_cache, - const ModelInputParams& input_params, - aclrtEvent* event, - std::atomic* event_flag, - int node_id) { +torch::Tensor NpuGlm4MoeDecoderLiteImpl::forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + const ForwardInput& input, + aclrtEvent* event, + std::atomic* event_flag, + int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - true); + if (!input.meta.batch_forward_type.is_decode()) { + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << " excute prefill layer fail, error code: " << st; @@ -432,7 +426,7 @@ torch::Tensor NpuGlm4MoeDecoderLiteImpl::forward( sin_pos, /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -449,7 +443,7 @@ void NpuGlm4MoeDecoderLiteImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill) { internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); auto& dp_ep_padding = input_params.parallel.dp_ep_padding_data; diff --git a/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.h b/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.h index 5a5aeb1ef2..7bd115dbde 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.h +++ b/xllm/core/layers/npu/npu_glm4_moe_lite_decoder_layer.h @@ -32,6 +32,8 @@ limitations under the License. #include "xllm_atb_layers/models/glm/layer/moe_decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuGlm4MoeDecoderLiteImpl : public BaseLayer { @@ -50,7 +52,7 @@ class NpuGlm4MoeDecoderLiteImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -102,7 +104,7 @@ class NpuGlm4MoeDecoderLiteImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill); std::string model_name_; diff --git a/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp index dd83ffbbdb..c50e4af8ef 100644 --- a/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.cpp @@ -125,20 +125,13 @@ torch::Tensor NpuGlm4VisionEncoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id, aclrtEvent* event, std::atomic* event_flag) { atb::Status st; - build_node_variant_pack(encode_node_, - x, - cos_pos, - sin_pos, - cu_seqlen, - cu_seqlen_vec, - input_params, - true); + build_node_variant_pack( + encode_node_, x, cos_pos, sin_pos, cu_seqlen, cu_seqlen_vec, true); // mstxRangeEnd(id); st = execute_node(encode_node_, node_id); LOG_IF(FATAL, st != 0) << model_name_ @@ -153,7 +146,6 @@ void NpuGlm4VisionEncoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); diff --git a/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h index ad2bce5ded..a992c2f3e1 100644 --- a/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_glm4_vision_encoder_layer_impl.h @@ -59,7 +59,6 @@ class NpuGlm4VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id = 0, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr); @@ -69,7 +68,6 @@ class NpuGlm4VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill); void param_from_args(atb_speed::glm::VisionEncoderLayerParam& param, diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp index 2ff2a84b30..cc171da3f3 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "common/global_flags.h" #include "core/layers/common/attention_mask.h" #include "loader/llama_decoder_loader.h" +#include "runtime/forward_params.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUException.h" @@ -163,19 +164,13 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x, torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - true); + if (!input.meta.batch_forward_type.is_decode()) { + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id); LOG_IF(FATAL, st != 0) << model_name_ @@ -187,7 +182,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x, sin_pos, decode_attn_mask_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000); LOG_IF(FATAL, st != 0) << model_name_ @@ -204,7 +199,7 @@ void NpuLlamaDecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h index 25de49f776..72e854997a 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h @@ -41,6 +41,8 @@ limitations under the License. #include "xllm_atb_layers/models/llama/layer/decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuLlamaDecoderLayerImpl : public BaseLayer { @@ -58,7 +60,7 @@ class NpuLlamaDecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, int node_id = 0); private: @@ -68,7 +70,7 @@ class NpuLlamaDecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill); int64_t init_node(atb_speed::Model::Node& node, diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp index 030e69e171..d2e625223b 100644 --- a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp @@ -24,6 +24,8 @@ limitations under the License. #include "common/global_flags.h" #include "core/util/rec_model_utils.h" +#include "runtime/forward_params.h" + namespace xllm { namespace layer { namespace { @@ -59,7 +61,7 @@ torch::Tensor PrepareOneRecAttentionMask(const at::Tensor& attn_mask, return EnsureNdFormat(result); } -int64_t ResolveOneRecBatchSize(const ModelInputParams& input_params) { +int64_t ResolveOneRecBatchSize(const ForwardInput& input_params) { const auto* onerec_params = input_params.onerec_xattention_params() != nullptr ? static_cast( input_params.onerec_xattention_params()) @@ -715,10 +717,7 @@ void NpuOneRecBlockLayerImpl::param_from_args( atb_speed::onerec::BlockLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, - bool is_prefill, - const ModelInputParams* input_params) { - (void)input_params; - + bool is_prefill) { param.isFA = false; param.isPrefill = is_prefill; param.isBF16 = args.dtype() == "bfloat16"; @@ -1301,13 +1300,13 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( torch::Tensor& x, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, torch::Tensor* encoder_output, int32_t node_id, aclrtEvent* event, std::atomic* event_flag, const torch::Tensor& expert_array) { - const auto* onerec_params = input_params.onerec_params(); + const auto* onerec_params = input.onerec_params(); CHECK(onerec_params != nullptr) << "OneRec requires rec_params."; const bool is_prefill = @@ -1354,7 +1353,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( x, attn_mask, kv_cache, - input_params, + input, true, is_first_prefill, is_first_prefill ? encoder_output : nullptr, @@ -1366,7 +1365,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( x, attn_mask, kv_cache, - input_params, + input, true, is_first_prefill, is_first_prefill ? encoder_output : nullptr, @@ -1378,7 +1377,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( x, attn_mask, kv_cache, - input_params, + input, true, true, encoder_output, @@ -1390,7 +1389,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( x, attn_mask, kv_cache, - input_params, + input, true, true, encoder_output, @@ -1401,7 +1400,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( << model_name_ << " execute prefill layer fail, error code: " << st; } else { build_encoder_node_variant_pack( - prefill_node_, x, attn_mask, input_params, true, node_id); + prefill_node_, x, attn_mask, input, true, node_id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -1417,7 +1416,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( x, attn_mask, kv_cache, - input_params, + input, false, false, encoder_output, @@ -1428,7 +1427,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( x, attn_mask, kv_cache, - input_params, + input, false, false, encoder_output, @@ -1446,7 +1445,7 @@ void NpuOneRecBlockLayerImpl::build_encoder_node_variant_pack( atb_speed::Model::Node& node, torch::Tensor& x, at::Tensor& attn_mask, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill, int32_t layer_id) { (void)is_prefill; @@ -1510,7 +1509,7 @@ void NpuOneRecBlockLayerImpl::build_decoder_moe_node_variant_pack( torch::Tensor& x, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill, bool is_first_prefill, torch::Tensor* encoder_output, @@ -1571,7 +1570,7 @@ int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( torch::Tensor& x, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, const atb_speed::onerec::BlockLayerParam& param, bool is_first_prefill, torch::Tensor* encoder_output, @@ -1799,7 +1798,7 @@ void NpuOneRecBlockLayerImpl::build_decoder_node_variant_pack( torch::Tensor& x, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill, bool is_first_prefill, torch::Tensor* encoder_output, diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h index a52bcdec03..a5d930b80a 100644 --- a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h @@ -36,6 +36,8 @@ limitations under the License. #include "xllm_atb_layers/operations/fusion/utils.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuOneRecBlockLayerImpl final : public BaseLayer { @@ -57,7 +59,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { torch::Tensor forward(torch::Tensor& x, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, torch::Tensor* encoder_output = nullptr, int32_t node_id = 0, aclrtEvent* event = nullptr, @@ -68,13 +70,12 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { void param_from_args(atb_speed::onerec::BlockLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, - bool is_prefill, - const ModelInputParams* input_params = nullptr); + bool is_prefill); void build_encoder_node_variant_pack(atb_speed::Model::Node& node, torch::Tensor& x, at::Tensor& attn_mask, - ModelInputParams& input_params, + ForwardInput& input, bool is_prefill, int32_t layer_id = 0); @@ -82,7 +83,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { torch::Tensor& x, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, bool is_prefill, bool is_first_prefill, torch::Tensor* encoder_output = nullptr, @@ -93,7 +94,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { torch::Tensor& x, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, bool is_prefill, bool is_first_prefill, torch::Tensor* encoder_output = nullptr, @@ -110,7 +111,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { torch::Tensor& x, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, const atb_speed::onerec::BlockLayerParam& param, bool is_first_prefill, torch::Tensor* encoder_output = nullptr, diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index 35dccedaae..16f52a9815 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -22,6 +22,7 @@ limitations under the License. #include #include "common/global_flags.h" +#include "runtime/forward_params.h" // #include "attn_mask.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" @@ -252,21 +253,15 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward(torch::Tensor& x, torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); - build_node_variant_pack(prefill_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - true); + build_node_variant_pack( + prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, input, true); // mstxRangeEnd(id); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -278,7 +273,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward(torch::Tensor& x, sin_pos, decode_attn_mask_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -295,7 +290,7 @@ void NpuQwen2DecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h index 4c240d83a1..97ebf71f96 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h @@ -42,6 +42,8 @@ limitations under the License. #include "xllm_atb_layers/models/qwen2/layer/decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuQwen2DecoderLayerImpl : public BaseLayer { @@ -59,7 +61,7 @@ class NpuQwen2DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -73,7 +75,7 @@ class NpuQwen2DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill); void param_from_args(atb_speed::qwen::DecoderLayerParam& param, diff --git a/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp index f8f0c533f0..4551af511a 100644 --- a/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.cpp @@ -125,20 +125,13 @@ torch::Tensor NpuQwen2VisionEncoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id, aclrtEvent* event, std::atomic* event_flag) { atb::Status st; - build_node_variant_pack(encode_node_, - x, - cos_pos, - sin_pos, - cu_seqlen, - cu_seqlen_vec, - input_params, - true); + build_node_variant_pack( + encode_node_, x, cos_pos, sin_pos, cu_seqlen, cu_seqlen_vec, true); // mstxRangeEnd(id); st = execute_node(encode_node_, node_id); LOG_IF(FATAL, st != 0) << model_name_ @@ -153,7 +146,6 @@ void NpuQwen2VisionEncoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); diff --git a/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h index 9e3bbcf659..d4981d6547 100644 --- a/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h @@ -62,7 +62,6 @@ class NpuQwen2VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id = 0, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr); @@ -74,7 +73,6 @@ class NpuQwen2VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill); void get_weights_col_packed_qkv(); @@ -123,4 +121,4 @@ class NpuQwen2VisionEncoderLayerImpl : public BaseLayer { TORCH_MODULE(NpuQwen2VisionEncoderLayer); } // namespace layer -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp index d8a48e9a7d..a0f0b26e39 100644 --- a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp @@ -123,20 +123,13 @@ torch::Tensor NpuQwen2dot5VisionEncoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id, aclrtEvent* event, std::atomic* event_flag) { atb::Status st; - build_node_variant_pack(encode_node_, - x, - cos_pos, - sin_pos, - cu_seqlen, - cu_seqlen_vec, - input_params, - true); + build_node_variant_pack( + encode_node_, x, cos_pos, sin_pos, cu_seqlen, cu_seqlen_vec, true); // mstxRangeEnd(id); st = execute_node(encode_node_, node_id); LOG_IF(FATAL, st != 0) << model_name_ @@ -151,7 +144,6 @@ void NpuQwen2dot5VisionEncoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); diff --git a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h index 4e9fdc47b6..8546efe7b0 100644 --- a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h @@ -58,7 +58,6 @@ class NpuQwen2dot5VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id = 0, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr); @@ -70,7 +69,6 @@ class NpuQwen2dot5VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill); void param_from_args(atb_speed::qwen::VisionEncoderLayerParam& param, diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index 5e29e3cf62..0400cd6fc2 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -21,6 +21,7 @@ limitations under the License. #include #include "common/global_flags.h" +#include "runtime/forward_params.h" #include "util/rec_model_utils.h" // #include "attn_mask.h" @@ -225,19 +226,19 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x, torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, sin_pos, attn_mask, kv_cache, - input_params, + input, /*is_prefill=*/true, node_id); // mstxRangeEnd(id); @@ -251,7 +252,7 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x, sin_pos, decode_attn_mask_, kv_cache, - input_params, + input, /*is_prefill=*/false, node_id); st = execute_node(decode_node_, node_id + 1000, event, event_flag); @@ -269,7 +270,7 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, at::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill, int node_id) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h index 82a0686dbc..b274f97321 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h @@ -41,6 +41,8 @@ limitations under the License. #include "xllm_atb_layers/core/include/atb_speed/utils/model_factory.h" #include "xllm_atb_layers/models/qwen3/layer/decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuQwen3DecoderLayerImpl : public BaseLayer { @@ -56,7 +58,7 @@ class NpuQwen3DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -73,7 +75,7 @@ class NpuQwen3DecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input_params, bool is_prefill, int node_id); diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index ca7f17ab99..16479306d2 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -20,6 +20,7 @@ limitations under the License. #include #include "common/global_flags.h" +#include "runtime/forward_params.h" namespace xllm { namespace layer { @@ -312,12 +313,12 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag, int node_id) { atb::Status st; - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, residual, @@ -325,7 +326,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( sin_pos, attn_mask, kv_cache, - input_params, + input, true); st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -338,7 +339,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( sin_pos, /*attn_mask*/ tensor_placeholder_, kv_cache, - input_params, + input, false); st = execute_node(decode_node_, node_id + 1000, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ @@ -356,7 +357,7 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill) { internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); int32_t input_idx = 0; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h index 0f2638004a..6e052c1894 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h @@ -35,6 +35,8 @@ limitations under the License. #include "xllm_atb_layers/models/qwen3/layer/moe_decoder_layer.h" namespace xllm { +struct ForwardInput; + namespace layer { class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { @@ -54,7 +56,7 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr, int node_id = 0); @@ -104,7 +106,7 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input_params, bool is_prefill); torch::Tensor block_tables_placeholder_; diff --git a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp index c9169f14e6..31f4777830 100644 --- a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp @@ -125,20 +125,13 @@ torch::Tensor NpuQwen3VisionEncoderLayerImpl::forward( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id, aclrtEvent* event, std::atomic* event_flag) { atb::Status st; - build_node_variant_pack(encode_node_, - x, - cos_pos, - sin_pos, - cu_seqlen, - cu_seqlen_vec, - input_params, - true); + build_node_variant_pack( + encode_node_, x, cos_pos, sin_pos, cu_seqlen, cu_seqlen_vec, true); // mstxRangeEnd(id); st = execute_node(encode_node_, node_id); LOG_IF(FATAL, st != 0) << model_name_ @@ -153,7 +146,6 @@ void NpuQwen3VisionEncoderLayerImpl::build_node_variant_pack( torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); diff --git a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h index aad778cf08..879fd92897 100755 --- a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h @@ -63,7 +63,6 @@ class NpuQwen3VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, int node_id = 0, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr); @@ -75,7 +74,6 @@ class NpuQwen3VisionEncoderLayerImpl : public BaseLayer { torch::Tensor& sin_pos, torch::Tensor& cu_seqlen, std::vector& cu_seqlen_vec, - ModelInputParams& input_params, bool is_prefill); void get_weights_col_packed_qkv(); diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index 18a250914b..c8456aec80 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -917,13 +917,13 @@ torch::Tensor FusedMoEImpl::forward_expert( } torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params) { + const ParallelInput& parallel_input) { auto input = hidden_states; bool need_slice = false; if (parallel_args_.dp_size() > 1 && parallel_args_.ep_size() > 1) { input = parallel_state::gather(input, parallel_args_.dp_local_process_group_, - input_params.parallel.dp_global_token_nums); + parallel_input.dp_global_token_nums); need_slice = true; } @@ -942,7 +942,7 @@ torch::Tensor FusedMoEImpl::forward(const torch::Tensor& hidden_states, auto output = forward_expert(input, router_logits, shared_output); if (need_slice) { - const auto& dp_tokens = input_params.parallel.dp_global_token_nums; + const auto& dp_tokens = parallel_input.dp_global_token_nums; const int64_t dp_rank = parallel_args_.dp_local_process_group_->rank(); auto start = std::accumulate(dp_tokens.begin(), dp_tokens.begin() + dp_rank, 0); diff --git a/xllm/core/layers/npu_torch/fused_moe.h b/xllm/core/layers/npu_torch/fused_moe.h index 4194791b2b..6c82a8d19e 100644 --- a/xllm/core/layers/npu_torch/fused_moe.h +++ b/xllm/core/layers/npu_torch/fused_moe.h @@ -31,6 +31,8 @@ limitations under the License. #include "layers/common/linear.h" namespace xllm { +struct ForwardInput; + namespace layer { class FusedMoEImpl : public torch::nn::Module { @@ -50,9 +52,9 @@ class FusedMoEImpl : public torch::nn::Module { const torch::Tensor& hidden_states, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, - const ModelInputParams& input_params); + const ForwardInput& input_params); torch::Tensor forward(const torch::Tensor& hidden_states, - const ModelInputParams& input_params); + const ParallelInput& parallel_input); void load_state_dict(const StateDict& state_dict); private: diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index a9c278fd31..af6b712ef1 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -17,6 +17,7 @@ limitations under the License. #include +#include "runtime/forward_params.h" #include "xllm/core/kernels/ops_api.h" namespace xllm { @@ -326,7 +327,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( const torch::Tensor& hidden_states, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput& input) { auto [qkvz_padded, ba_padded] = project_padded_inputs(hidden_states, attn_metadata); int64_t batch_size = qkvz_padded.size(0); @@ -358,21 +359,21 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( torch::Tensor g, beta, core_attn_out, last_recurrent_state; auto device = mixed_qkv.device(); auto conv_weight = conv1d_->weight(); - auto linear_state_indices = get_linear_state_indices(input_params, device); + auto linear_state_indices = get_linear_state_indices(input.embedding, device); if (attn_metadata.is_prefill) { torch::IntArrayRef num_accepted_tokens_opt; torch::Tensor conv_weight_T = conv_weight.transpose(0, 1).contiguous(); std::vector linear_state_indices_vec( - input_params.embedding.linear_state_ids.begin(), - input_params.embedding.linear_state_ids.end()); + input.embedding.linear_state_ids.begin(), + input.embedding.linear_state_ids.end()); mixed_qkv = xllm::kernel::causal_conv1d( mixed_qkv, conv_weight_T, conv_cache, std::optional(), // bias (no bias for qwen3) - torch::IntArrayRef(input_params.parallel.query_start_loc), + torch::IntArrayRef(input.parallel.query_start_loc), torch::IntArrayRef(linear_state_indices_vec), - torch::IntArrayRef(input_params.parallel.has_initial_state), + torch::IntArrayRef(input.parallel.has_initial_state), num_accepted_tokens_opt, 1, // activation_mode -1, // pad_slot_id @@ -504,15 +505,15 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::reshape_qkvz_unpad( } torch::Tensor Qwen3GatedDeltaNetBaseImpl::get_linear_state_indices( - const ModelInputParams& input_params, + const ModelEmbeddingInput& embedding_input, const torch::Device& device) const { - CHECK(!input_params.embedding.linear_state_ids.empty()) + CHECK(!embedding_input.linear_state_ids.empty()) << "linear_state_ids must be populated for gated delta net"; - if (input_params.embedding.linear_state_indices.defined()) { - return input_params.embedding.linear_state_indices; + if (embedding_input.linear_state_indices.defined()) { + return embedding_input.linear_state_indices; } return torch::tensor( - input_params.embedding.linear_state_ids, + embedding_input.linear_state_ids, torch::TensorOptions().dtype(torch::kInt).device(device)); } diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h index 2994f329a3..705aadc9c0 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.h @@ -32,6 +32,8 @@ limitations under the License. #include "layers/common/rms_norm_gated.h" namespace xllm { +struct ForwardInput; + namespace layer { class Qwen3GatedDeltaNetBaseImpl : public torch::nn::Module { @@ -48,7 +50,7 @@ class Qwen3GatedDeltaNetBaseImpl : public torch::nn::Module { torch::Tensor forward(const torch::Tensor& hidden_states, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params); + const ForwardInput& input); protected: virtual std::pair project_padded_inputs( @@ -62,8 +64,9 @@ class Qwen3GatedDeltaNetBaseImpl : public torch::nn::Module { const torch::Tensor& qkvz) const; torch::Tensor reshape_qkvz_unpad(const AttentionMetadata& attn_metadata, const torch::Tensor& padded_qkvz) const; - torch::Tensor get_linear_state_indices(const ModelInputParams& input_params, - const torch::Device& device) const; + torch::Tensor get_linear_state_indices( + const ModelEmbeddingInput& embedding_input, + const torch::Device& device) const; std::tuple process_mixed_qkv( torch::Tensor& mixed_qkv) const; diff --git a/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.cpp b/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.cpp index 904561d626..878ea6c4de 100644 --- a/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.cpp @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "runtime/forward_params.h" + namespace xllm { namespace layer { @@ -112,7 +114,7 @@ torch::Tensor Qwen3HybridDecoderLayerImplBase::forward( torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, const torch::Tensor& mrope_cos_sin) { // Pre-attention norm if (!residual.has_value()) { @@ -127,7 +129,7 @@ torch::Tensor Qwen3HybridDecoderLayerImplBase::forward( x = attention_->forward( positions, x, attn_metadata, kv_cache, mrope_cos_sin); } else { - x = linear_attention_->forward(x, attn_metadata, kv_cache, input_params); + x = linear_attention_->forward(x, attn_metadata, kv_cache, input); } // Post-attention norm @@ -135,7 +137,7 @@ torch::Tensor Qwen3HybridDecoderLayerImplBase::forward( // MLP forward if (moe_mlp_) { - x = moe_mlp_(x, input_params); + x = moe_mlp_(x, input.parallel); } else { x = mlp_(x); } diff --git a/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.h b/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.h index 098f8f2ec5..5078644853 100644 --- a/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.h +++ b/xllm/core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.h @@ -21,7 +21,6 @@ limitations under the License. #include #include "framework/kv_cache/kv_cache.h" -#include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" #include "layers/common/dense_mlp.h" @@ -31,6 +30,8 @@ limitations under the License. #include "layers/npu_torch/qwen3_next_attention.h" namespace xllm { +struct ForwardInput; + namespace layer { class Qwen3HybridDecoderLayerModule : public torch::nn::Module { @@ -42,7 +43,7 @@ class Qwen3HybridDecoderLayerModule : public torch::nn::Module { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, const torch::Tensor& mrope_cos_sin = {}) = 0; virtual torch::Tensor build_mrope_cos_sin( const torch::Tensor& positions) const { @@ -69,7 +70,7 @@ class Qwen3HybridDecoderLayerImplBase : public Qwen3HybridDecoderLayerModule { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, const torch::Tensor& mrope_cos_sin = {}) override; torch::Tensor build_mrope_cos_sin( diff --git a/xllm/core/layers/onerec_block_layer.h b/xllm/core/layers/onerec_block_layer.h index e0c836fb5f..cb3351cfad 100644 --- a/xllm/core/layers/onerec_block_layer.h +++ b/xllm/core/layers/onerec_block_layer.h @@ -21,7 +21,6 @@ limitations under the License. #include #include "framework/kv_cache/kv_cache.h" -#include "framework/model/model_input_params.h" #if defined(USE_NPU) #include @@ -30,6 +29,8 @@ using aclrtEvent = void*; #endif namespace xllm { +struct ForwardInput; + namespace layer { class OneRecBlockLayer : public torch::nn::Module { @@ -39,7 +40,7 @@ class OneRecBlockLayer : public torch::nn::Module { virtual torch::Tensor forward(torch::Tensor& hidden_states, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, torch::Tensor* encoder_output = nullptr, int32_t node_id = 0, aclrtEvent* event = nullptr, diff --git a/xllm/core/layers/oxygen_vision_layer.cpp b/xllm/core/layers/oxygen_vision_layer.cpp index 09d62ba4a0..978cd04c72 100644 --- a/xllm/core/layers/oxygen_vision_layer.cpp +++ b/xllm/core/layers/oxygen_vision_layer.cpp @@ -54,15 +54,12 @@ torch::Tensor OxygenVisionLayerImpl::forward( torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id) { auto norm_output1 = std::get<0>(norm1_(hidden_states)); - auto output = hidden_states + attention_(norm_output1, - m_cos_pos, - m_sin_pos, - cu_seq_len, - cu_seq_len_vec, - input_params); + auto output = + hidden_states + + attention_( + norm_output1, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec); auto norm_output2 = std::get<0>(norm2_(output)); output = output + mlp_(norm_output2); return output; diff --git a/xllm/core/layers/oxygen_vision_layer.h b/xllm/core/layers/oxygen_vision_layer.h index 4b81a27189..a1392696e5 100644 --- a/xllm/core/layers/oxygen_vision_layer.h +++ b/xllm/core/layers/oxygen_vision_layer.h @@ -41,7 +41,6 @@ class OxygenVisionLayerImpl : public torch::nn::Module { torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id); private: diff --git a/xllm/core/layers/qwen2_5_vision_layer.cpp b/xllm/core/layers/qwen2_5_vision_layer.cpp index 6895b9078b..93146acea9 100644 --- a/xllm/core/layers/qwen2_5_vision_layer.cpp +++ b/xllm/core/layers/qwen2_5_vision_layer.cpp @@ -62,15 +62,12 @@ torch::Tensor Qwen2_5_VisionLayerImpl::forward( torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id) { auto norm_output1 = std::get<0>(norm1_(hidden_states)); - auto output = hidden_states + attention_(norm_output1, - m_cos_pos, - m_sin_pos, - cu_seq_len, - cu_seq_len_vec, - input_params); + auto output = + hidden_states + + attention_( + norm_output1, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec); auto norm_output2 = std::get<0>(norm2_(output)); output = output + mlp_(norm_output2); return output; diff --git a/xllm/core/layers/qwen2_5_vision_layer.h b/xllm/core/layers/qwen2_5_vision_layer.h index 39bae59eb1..887a3d6c98 100644 --- a/xllm/core/layers/qwen2_5_vision_layer.h +++ b/xllm/core/layers/qwen2_5_vision_layer.h @@ -44,7 +44,6 @@ class Qwen2_5_VisionLayerImpl : public torch::nn::Module { torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id); protected: diff --git a/xllm/core/layers/qwen2_decoder_layer.cpp b/xllm/core/layers/qwen2_decoder_layer.cpp index db5f8395a5..babf81f598 100644 --- a/xllm/core/layers/qwen2_decoder_layer.cpp +++ b/xllm/core/layers/qwen2_decoder_layer.cpp @@ -90,7 +90,7 @@ torch::Tensor Qwen2DecoderLayerImpl::forward( torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput&) { auto pre_fp8_scale = attention_->get_fp8_input_scale(); auto post_fp8_scale = mlp_->get_fp8_input_scale(); diff --git a/xllm/core/layers/qwen2_decoder_layer.h b/xllm/core/layers/qwen2_decoder_layer.h index 19ed4b6015..59745028df 100644 --- a/xllm/core/layers/qwen2_decoder_layer.h +++ b/xllm/core/layers/qwen2_decoder_layer.h @@ -25,12 +25,13 @@ limitations under the License. #include "common/rms_norm.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_args.h" -#include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/parallel_state/parallel_args.h" #include "framework/state_dict/state_dict.h" namespace xllm { +struct ForwardInput; + namespace layer { class Qwen2DecoderLayerImpl : public torch::nn::Module { @@ -45,7 +46,7 @@ class Qwen2DecoderLayerImpl : public torch::nn::Module { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params); + const ForwardInput& input); private: Qwen2Attention attention_{nullptr}; diff --git a/xllm/core/layers/qwen3_moe_decoder_layer.cpp b/xllm/core/layers/qwen3_moe_decoder_layer.cpp index 6fa3f1aef3..541b7bc9f0 100644 --- a/xllm/core/layers/qwen3_moe_decoder_layer.cpp +++ b/xllm/core/layers/qwen3_moe_decoder_layer.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "common/global_flags.h" #include "layers/common/dp_utils.h" +#include "runtime/forward_params.h" namespace xllm { namespace layer { @@ -26,9 +27,8 @@ namespace layer { namespace { #if defined(USE_MLU) -bool use_moe_all2all(bool enable_deep_ep, - const ModelInputParams& input_params) { - return enable_deep_ep && all_dp_ranks_are_decode(input_params); +bool use_moe_all2all(bool enable_deep_ep, const ParallelInput& parallel_input) { + return enable_deep_ep && all_dp_ranks_are_decode(parallel_input); } #endif bool is_moe_layer(const ModelArgs& model_args, int32_t layer_id) { @@ -98,18 +98,18 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, torch::Tensor Qwen3MoeDecoderLayerImpl::run_moe( torch::Tensor x, - const ModelInputParams& input_params) { + const ParallelInput& parallel_input) { #if defined(USE_MLU) const bool enable_moe_all2all = - use_moe_all2all(enable_deep_ep_, input_params); + use_moe_all2all(enable_deep_ep_, parallel_input); if (need_dp_moe_gather(parallel_args_, enable_moe_all2all)) { - x = gather_dp_tokens(x, input_params, parallel_args_); + x = gather_dp_tokens(x, parallel_input, parallel_args_); x = moe_mlp_->forward_experts(x, enable_moe_all2all); - return get_dp_local_slice(x, input_params, parallel_args_); + return get_dp_local_slice(x, parallel_input, parallel_args_); } return moe_mlp_->forward_experts(x, enable_moe_all2all); #else - return moe_mlp_(x, input_params); + return moe_mlp_(x, parallel_input); #endif } @@ -132,7 +132,7 @@ torch::Tensor Qwen3MoeDecoderLayerImpl::forward( torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput& input) { // Pre-attention norm if (!residual.has_value()) { residual = x; @@ -149,7 +149,7 @@ torch::Tensor Qwen3MoeDecoderLayerImpl::forward( // MLP forward if (moe_mlp_) { - x = run_moe(x, input_params); + x = run_moe(x, input.parallel); } else { x = mlp_(x); } diff --git a/xllm/core/layers/qwen3_moe_decoder_layer.h b/xllm/core/layers/qwen3_moe_decoder_layer.h index 549939124b..6bcaab2a77 100644 --- a/xllm/core/layers/qwen3_moe_decoder_layer.h +++ b/xllm/core/layers/qwen3_moe_decoder_layer.h @@ -33,12 +33,14 @@ limitations under the License. #include "common/rms_norm.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_args.h" -#include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/parallel_state/parallel_args.h" #include "framework/state_dict/state_dict.h" namespace xllm { +struct ForwardInput; +struct ParallelInput; + namespace layer { class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { @@ -53,10 +55,10 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { torch::Tensor& positions, const AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params); + const ForwardInput& input); private: - torch::Tensor run_moe(torch::Tensor x, const ModelInputParams& input_params); + torch::Tensor run_moe(torch::Tensor x, const ParallelInput& parallel_input); Qwen2Attention attention_{nullptr}; DenseMLP mlp_{nullptr}; diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index dd0428bc62..d32ae64aeb 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -197,18 +197,17 @@ void GraphPersistentParam::set_aux_hidden_states(const torch::Tensor& value) { } } -std::optional GraphPersistentParam::update( - const torch::Tensor& tokens, +std::optional GraphPersistentParam::update( + const ForwardInput& input, const torch::Tensor& k_cache, const torch::Tensor& v_cache, - const torch::Tensor& positions, - const ModelInputParams& params, uint32_t padded_num_tokens, - bool return_capture_params) { - CHECK_GT(padded_num_tokens, 0) - << "padded_num_tokens must be > 0 when return_capture_params is true"; + bool return_capture_input) { + CHECK_GT(padded_num_tokens, 0) << "padded_num_tokens must be > 0"; + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; const uint32_t actual_num_tokens = tokens.size(0); - const int64_t actual_batch_size = params.meta.num_sequences; + const int64_t actual_batch_size = input.meta.num_sequences; // Copy data from input parameters to persistent graph tensors persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) @@ -224,22 +223,22 @@ std::optional GraphPersistentParam::update( .copy_(positions, /*non_blocking=*/true); } q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.attention.device.q_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.q_seq_lens, /*non_blocking=*/true); kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.attention.device.kv_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.kv_seq_lens, /*non_blocking=*/true); persistent_new_cache_slots_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) - .copy_(params.attention.device.new_cache_slots, /*non_blocking=*/true); - if (!params.embedding.linear_state_ids.empty()) { - if (params.embedding.linear_state_indices.defined()) { + .copy_(input.attention.device.new_cache_slots, /*non_blocking=*/true); + if (!input.embedding.linear_state_ids.empty()) { + if (input.embedding.linear_state_indices.defined()) { persistent_linear_state_indices_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.embedding.linear_state_indices, /*non_blocking=*/true); + .copy_(input.embedding.linear_state_indices, /*non_blocking=*/true); } else { persistent_linear_state_indices_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(torch::tensor(params.embedding.linear_state_ids, torch::kInt) + .copy_(torch::tensor(input.embedding.linear_state_ids, torch::kInt) .to(device_), /*non_blocking=*/true); } @@ -247,16 +246,16 @@ std::optional GraphPersistentParam::update( // Copy block table data const int64_t actual_block_table_len = - params.attention.device.block_tables.size(1); + input.attention.device.block_tables.size(1); auto slice_persistent_block_tables = persistent_block_tables_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) .slice(/*dim=*/1, /*start=*/0, /*end=*/actual_block_table_len); - slice_persistent_block_tables.copy_(params.attention.device.block_tables, + slice_persistent_block_tables.copy_(input.attention.device.block_tables, /*non_blocking=*/true); // Update persistent embedding from input_embedding if available - const auto& embedding = params.embedding.input_embedding; + const auto& embedding = input.embedding.input_embedding; if (embedding.defined()) { const int64_t embedding_tokens = embedding.size(0); @@ -275,9 +274,9 @@ std::optional GraphPersistentParam::update( .slice(/*dim=*/0, /*start=*/0, /*end=*/embedding_tokens) .copy_(embedding, /*non_blocking=*/true); } - // Update q_cu_seq_lens only if params.attention.device.q_cu_seq_lens is + // Update q_cu_seq_lens only if input.attention.device.q_cu_seq_lens is // defined - if (params.attention.device.q_cu_seq_lens.defined()) { + if (input.attention.device.q_cu_seq_lens.defined()) { // Lazy initialization: if q_cu_seq_lens_ is not initialized, initialize it if (q_cu_seq_lens_.numel() == 0) { const int64_t max_seqs_per_batch = get_decode_graph_capacity(options_); @@ -286,12 +285,12 @@ std::optional GraphPersistentParam::update( } // Copy data q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.attention.device.q_cu_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.q_cu_seq_lens, /*non_blocking=*/true); } // Update attention mask only if needed if (need_update_attn_mask_) { - update_attention_mask(params); + update_attention_mask(input); } if (tiling_data_.numel() > 0) { @@ -300,68 +299,71 @@ std::optional GraphPersistentParam::update( if (need_update_attention_plan_ && k_cache.defined() && v_cache.defined() && k_cache.numel() > 0 && v_cache.numel() > 0) { plan_paged_attention_tiling( - tokens, k_cache, v_cache, persistent_block_tables_, params, stream); + tokens, k_cache, v_cache, persistent_block_tables_, input, stream); } } - // Return ModelInputParams with persistent buffer references if requested - if (return_capture_params) { - std::optional params_for_capture = - std::make_optional(params); - // Set persistent buffers in params_for_capture - params_for_capture->attention.device.kv_seq_lens = + // Return ForwardInput with persistent buffer references if requested + if (return_capture_input) { + std::optional input_for_capture = input; + input_for_capture->token_ids = persistent_tokens(padded_num_tokens); + input_for_capture->positions = persistent_positions(padded_num_tokens); + // Set persistent buffers in input_for_capture + input_for_capture->attention.device.kv_seq_lens = kv_seq_lens(padded_num_tokens); - params_for_capture->attention.device.q_seq_lens = + input_for_capture->attention.device.q_seq_lens = q_seq_lens(padded_num_tokens); - params_for_capture->attention.host.kv_seq_lens.resize(padded_num_tokens); - params_for_capture->attention.host.q_seq_lens.resize(padded_num_tokens); - // Copy actual values from original params - for (int i = 0; i < actual_batch_size; i++) { - params_for_capture->attention.host.kv_seq_lens[i] = - params.attention.host.kv_seq_lens[i]; - params_for_capture->attention.host.q_seq_lens[i] = - params.attention.host.q_seq_lens[i]; + input_for_capture->attention.host.kv_seq_lens.resize(padded_num_tokens); + input_for_capture->attention.host.q_seq_lens.resize(padded_num_tokens); + // Copy actual values from original input + for (int64_t i = 0; i < actual_batch_size; ++i) { + input_for_capture->attention.host.kv_seq_lens[i] = + input.attention.host.kv_seq_lens[i]; + input_for_capture->attention.host.q_seq_lens[i] = + input.attention.host.q_seq_lens[i]; } // Fill padded positions with default values - for (int i = actual_batch_size; i < padded_num_tokens; i++) { - params_for_capture->attention.host.kv_seq_lens[i] = 1; - params_for_capture->attention.host.q_seq_lens[i] = 1; + for (uint32_t i = static_cast(actual_batch_size); + i < padded_num_tokens; + ++i) { + input_for_capture->attention.host.kv_seq_lens[i] = 1; + input_for_capture->attention.host.q_seq_lens[i] = 1; } - params_for_capture->meta.num_sequences = padded_num_tokens; - params_for_capture->meta.batch_forward_type = BatchForwardType::DECODE; - params_for_capture->attention.device.new_cache_slots = + input_for_capture->meta.num_sequences = padded_num_tokens; + input_for_capture->meta.batch_forward_type = BatchForwardType::DECODE; + input_for_capture->attention.device.new_cache_slots = persistent_new_cache_slots(padded_num_tokens); - params_for_capture->attention.device.block_tables = + input_for_capture->attention.device.block_tables = persistent_block_tables(padded_num_tokens); - if (!params.embedding.linear_state_ids.empty()) { - params_for_capture->embedding.linear_state_ids = - params.embedding.linear_state_ids; + if (!input.embedding.linear_state_ids.empty()) { + input_for_capture->embedding.linear_state_ids = + input.embedding.linear_state_ids; const int32_t padding_linear_state_id = 0; - params_for_capture->embedding.linear_state_ids.resize( + input_for_capture->embedding.linear_state_ids.resize( padded_num_tokens, padding_linear_state_id); - params_for_capture->embedding.linear_state_indices = + input_for_capture->embedding.linear_state_indices = persistent_linear_state_indices(padded_num_tokens); } // Only set attn_mask if need_update_attn_mask_ is true if (need_update_attn_mask_) { - params_for_capture->graph.attn_mask = persistent_mask(padded_num_tokens); + input_for_capture->graph.attn_mask = persistent_mask(padded_num_tokens); } - params_for_capture->graph.tiling_data = tiling_data(); + input_for_capture->graph.tiling_data = tiling_data(); // Set persistent embedding if available - if (params.embedding.input_embedding.defined()) { - params_for_capture->embedding.input_embedding = + if (input.embedding.input_embedding.defined()) { + input_for_capture->embedding.input_embedding = persistent_embedding(padded_num_tokens); } // Set q_cu_seq_lens if available - if (params.attention.device.q_cu_seq_lens.defined()) { - params_for_capture->attention.device.q_cu_seq_lens = + if (input.attention.device.q_cu_seq_lens.defined()) { + input_for_capture->attention.device.q_cu_seq_lens = q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size); } - return params_for_capture; + return input_for_capture; } return std::nullopt; } @@ -612,7 +614,7 @@ void GraphPersistentParam::plan_paged_attention_tiling( const torch::Tensor& k_cache, const torch::Tensor& v_cache, const torch::Tensor& block_tables, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtStream stream) { // Convert torch tensors to atb tensors atb::Tensor atb_k_cache = atb_speed::Utils::AtTensor2Tensor(k_cache); @@ -620,10 +622,10 @@ void GraphPersistentParam::plan_paged_attention_tiling( atb::Tensor atb_block_tables = atb_speed::Utils::AtTensor2Tensor(block_tables); // Get context_lens from input_params.attention.device.kv_seq_lens - atb::Tensor atb_context_lens = atb_speed::Utils::AtTensor2Tensor( - input_params.attention.device.kv_seq_lens); + atb::Tensor atb_context_lens = + atb_speed::Utils::AtTensor2Tensor(input.attention.device.kv_seq_lens); atb_context_lens.hostData = - const_cast(input_params.attention.host.kv_seq_lens.data()); + const_cast(input.attention.host.kv_seq_lens.data()); atb::Tensor atb_tiling_data = atb_speed::Utils::AtTensor2Tensor(tiling_data_); atb_tiling_data.desc.dtype = ACL_UINT32; @@ -704,14 +706,13 @@ void GraphPersistentParam::plan_paged_attention_tiling( CHECK_EQ(acl_status, ACL_SUCCESS) << "Failed to copy tiling buffer to device"; } -void GraphPersistentParam::update_attention_mask( - const ModelInputParams& input_params) { +void GraphPersistentParam::update_attention_mask(const ForwardInput& input) { torch::Dtype dtype = util::parse_dtype(args_.dtype(), device_); // update persistent_mask_ in-place - const int64_t batch_size = input_params.attention.device.kv_seq_lens.size(0); - const int64_t max_seq_len = input_params.meta.kv_max_seq_len > 0 - ? input_params.meta.kv_max_seq_len + const int64_t batch_size = input.attention.device.kv_seq_lens.size(0); + const int64_t max_seq_len = input.meta.kv_max_seq_len > 0 + ? input.meta.kv_max_seq_len : args_.max_position_embeddings(); // persistent_mask_ is already initialized in constructor @@ -721,20 +722,19 @@ void GraphPersistentParam::update_attention_mask( << persistent_mask_.size(1) << ")"; // Check if q_max_seq_len > 1 (prefill mode, not decode mode) - bool chunked_prefill = input_params.meta.q_max_seq_len > 1; + bool chunked_prefill = input.meta.q_max_seq_len > 1; // Calculate num_tokens: in chunked mode, sum of all q_len; in decode mode, // batch_size int64_t num_tokens = batch_size; // Default for decode mode if (chunked_prefill) { - CHECK_EQ(input_params.attention.host.q_seq_lens.size(), batch_size) - << "q_seq_lens_vec size (" - << input_params.attention.host.q_seq_lens.size() << ") != batch_size (" - << batch_size << ")"; - num_tokens = std::accumulate( - input_params.attention.host.q_seq_lens.begin(), - input_params.attention.host.q_seq_lens.begin() + batch_size, - int64_t(0)); + CHECK_EQ(input.attention.host.q_seq_lens.size(), batch_size) + << "q_seq_lens_vec size (" << input.attention.host.q_seq_lens.size() + << ") != batch_size (" << batch_size << ")"; + num_tokens = + std::accumulate(input.attention.host.q_seq_lens.begin(), + input.attention.host.q_seq_lens.begin() + batch_size, + int64_t(0)); } // Check if num_tokens is within bounds @@ -762,15 +762,14 @@ void GraphPersistentParam::update_attention_mask( // q_len // Check if kv_seq_lens_vec is available - CHECK_EQ(input_params.attention.host.kv_seq_lens.size(), batch_size) - << "kv_seq_lens_vec size (" - << input_params.attention.host.kv_seq_lens.size() << ") != batch_size (" - << batch_size << ")"; + CHECK_EQ(input.attention.host.kv_seq_lens.size(), batch_size) + << "kv_seq_lens_vec size (" << input.attention.host.kv_seq_lens.size() + << ") != batch_size (" << batch_size << ")"; int64_t offset = 0; for (int64_t i = 0; i < batch_size; i++) { - const int32_t q_len = input_params.attention.host.q_seq_lens[i]; - const int32_t kv_len = input_params.attention.host.kv_seq_lens[i]; + const int32_t q_len = input.attention.host.q_seq_lens[i]; + const int32_t kv_len = input.attention.host.kv_seq_lens[i]; // For chunked mode, slice out q_len rows for this sequence // mask_slice is [num_tokens, max_seq_len] @@ -805,7 +804,7 @@ void GraphPersistentParam::update_attention_mask( .expand({batch_size, max_seq_len}); auto context_lens_expanded = - input_params.attention.device.kv_seq_lens.to(torch::kInt32) + input.attention.device.kv_seq_lens.to(torch::kInt32) .unsqueeze(1) .expand({batch_size, max_seq_len}); @@ -817,13 +816,12 @@ void GraphPersistentParam::update_attention_mask( bool AclGraph::capture(CausalLM* model, const ModelArgs& args, const runtime::Options& options, - const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, std::vector& kv_cache, uint32_t bucket_num_tokens) { // Save bucket num_tokens for this graph instance num_tokens_ = bucket_num_tokens; + const torch::Tensor& tokens = input.token_ids; // Get actual num_tokens from tokens tensor // const uint32_t actual_num_tokens = tokens.size(0); @@ -846,18 +844,15 @@ bool AclGraph::capture(CausalLM* model, const uint32_t actual_num_tokens = tokens.size(0); CHECK_GE(num_tokens_, actual_num_tokens) << "num_tokens_ >= actual_num_tokens"; - auto graph_params = persistent_param_.update(tokens, - k_cache, - v_cache, - positions, - params, - num_tokens_, - /*return_capture_params=*/true); - - // Use the returned ModelInputParams for graph capture - CHECK(graph_params.has_value()) - << "update() should return ModelInputParams when " - "return_capture_params=true"; + auto graph_input = persistent_param_.update(input, + k_cache, + v_cache, + num_tokens_, + /*return_capture_input=*/true); + + // Use the returned ForwardInput for graph capture + CHECK(graph_input.has_value()) + << "update() should return ForwardInput when return_capture_input=true"; // Synchronize stream to ensure all data is copied to graph persistent buffers aclrtSynchronizeStream(stream); @@ -891,11 +886,7 @@ bool AclGraph::capture(CausalLM* model, graph_.capture_begin( {0, 0}, aclmdlRICaptureMode::ACL_MODEL_RI_CAPTURE_MODE_THREAD_LOCAL); // Execute forward pass - NPUGraph mempool manages temporary tensors - auto forward_result = - model->forward({persistent_param_.persistent_tokens(num_tokens_)}, - {persistent_param_.persistent_positions(num_tokens_)}, - kv_cache, - {graph_params.value()}); + auto forward_result = model->forward(graph_input.value(), kv_cache); // Store result in persistent buffer owned by NPUGraph mempool persistent_param_.set_hidden_states(forward_result.hidden_states); @@ -931,10 +922,9 @@ void AclGraph::initialize_capture_stream(c10::DeviceIndex device_index) { << ", device_index: " << device_index; } -ModelOutput AclGraph::replay(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_cache, - const ModelInputParams& params) { +ModelOutput AclGraph::replay(const ForwardInput& input, + std::vector& kv_cache) { + const torch::Tensor& tokens = input.token_ids; const uint32_t actual_num_tokens = tokens.size(0); CHECK_LE(actual_num_tokens, num_tokens_) << "num_tokens mismatch: expected <= " << num_tokens_ << ", got " @@ -946,13 +936,11 @@ ModelOutput AclGraph::replay(const torch::Tensor& tokens, // be updated when Full Attention layers are involved, which is determined // by k_cache being valid and non-empty auto [k_cache, v_cache] = find_attention_plan_kv_cache(kv_cache); - persistent_param_.update(tokens, + persistent_param_.update(input, k_cache, v_cache, - positions, - params, num_tokens_, - /*return_capture_params=*/false); + /*return_capture_input=*/false); // Replay captured graph - NPUGraph mempool reuses temporary tensors // Get current NPU stream from libtorch NPU API @@ -991,18 +979,13 @@ ForwardInput AclGraphExecutorImpl::prepare_inputs(Batch& batch) { // tokens: [num_decode_tokens] // positions: [num_decode_tokens] token pos in the sequence // returns: [num_decode_tokens, hidden_size] -ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { +ModelOutput AclGraphExecutorImpl::run(const ForwardInput& input, + std::vector& kv_caches) { // no mirco batch in decode phase - const torch::Tensor& tokens_tensor = tokens; - const torch::Tensor& positions_tensor = positions; - const ModelInputParams& params_single = params; - const bool in_decoding_phase = - params_single.meta.batch_forward_type.is_decode(); + const torch::Tensor& tokens_tensor = input.token_ids; + const bool in_decoding_phase = input.meta.batch_forward_type.is_decode(); VLOG(50) << "in_decoding_phase: " << in_decoding_phase - << " q_max_seq_len: " << params_single.meta.q_max_seq_len + << " q_max_seq_len: " << input.meta.q_max_seq_len << " n_layers: " << args_.n_layers(); // If not in decode phase, use eager mode directly without acl graph // TODO: fix mtp model support. @@ -1010,19 +993,17 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, VLOG(kGraphExecutorLogVerboseLevel) << "AclGraphExecutorImpl::run() in eager mode"; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } // Only use acl graph in decode phase for performance optimization // Get actual num_tokens from tokens shape const uint32_t n_tokens = tokens_tensor.size(/*dim=*/0); - const uint32_t actual_batch_size = n_tokens / options_.num_decoding_tokens(); const uint32_t bucket_num_tokens = get_bucket_num_tokens(n_tokens); // Check if conditions are suitable for graph execution (replay or capture) const auto max_seq_len = args_.max_position_embeddings(); - const bool seq_len_supported = - params_single.meta.kv_max_seq_len <= max_seq_len; + const bool seq_len_supported = input.meta.kv_max_seq_len <= max_seq_len; // Combined condition for graph capture support // ACL graph executor only supports single tensor inputs (no micro-batching) @@ -1032,11 +1013,11 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, if (!capture_supported) { LOG_FIRST_N(WARNING, 1) << "Falling back to eager mode because kv_max_seq_len (" - << params_single.meta.kv_max_seq_len << ") > max_seq_len (" - << max_seq_len << "). This message is logged only once. " + << input.meta.kv_max_seq_len << ") > max_seq_len (" << max_seq_len + << "). This message is logged only once. " << "Monitor counter 'num_model_execution_total_eager' for frequency."; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } // Check if captured graph exists for this bucket num_tokens @@ -1045,8 +1026,7 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, // Replay the existing graph VLOG(kGraphExecutorLogVerboseLevel) << "AclGraphExecutorImpl::run() in replay mode"; - auto result = it->second->replay( - tokens_tensor, positions_tensor, kv_caches, params_single); + auto result = it->second->replay(input, kv_caches); // Handle aux_hidden_states based on options if (options_.enable_graph_aux_hidden_states()) { auto aux_hidden_states = persistent_param_->aux_hidden_states(n_tokens); @@ -1062,14 +1042,8 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, auto graph = std::make_unique(*persistent_param_, device_.index()); VLOG(kGraphExecutorLogVerboseLevel) << "AclGraphExecutorImpl::run() in capture mode"; - bool capture_success = graph->capture(model_, - args_, - options_, - tokens_tensor, - positions_tensor, - params_single, - kv_caches, - bucket_num_tokens); + bool capture_success = graph->capture( + model_, args_, options_, input, kv_caches, bucket_num_tokens); if (capture_success) { LOG(INFO) << "Lazy capturing ACL graph for bucket num_tokens: " @@ -1096,7 +1070,7 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens, LOG(ERROR) << "Failed to capture ACL graph for bucket num_tokens: " << bucket_num_tokens; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } void AclGraph::print_graph_tensors() const { diff --git a/xllm/core/runtime/acl_graph_executor_impl.h b/xllm/core/runtime/acl_graph_executor_impl.h index 0e8c1f84f1..3d20970aff 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.h +++ b/xllm/core/runtime/acl_graph_executor_impl.h @@ -26,7 +26,7 @@ limitations under the License. #include "core/common/macros.h" #include "core/framework/kv_cache/kv_cache.h" #include "core/framework/model/causal_lm.h" -#include "core/framework/model/model_input_params.h" +#include "core/runtime/forward_params.h" #include "executor_impl.h" #include "executor_impl_factory.h" #include "options.h" @@ -65,18 +65,16 @@ class GraphPersistentParam { ~GraphPersistentParam(); // Update persistent tensors with new input data - // If return_capture_params is true, returns a ModelInputParams with + // If return_capture_input is true, returns a ForwardInput with // persistent buffer references. padded_num_tokens must be > 0 when - // return_capture_params is true, used for build new ModelInputParams for - // capture. If return_capture_params is false, only updates persistent buffers + // return_capture_input is true, used for building new ForwardInput for + // capture. If return_capture_input is false, only updates persistent buffers // and returns std::nullopt. - std::optional update(const torch::Tensor& tokens, - const torch::Tensor& k_cache, - const torch::Tensor& v_cache, - const torch::Tensor& positions, - const ModelInputParams& params, - uint32_t padded_num_token, - bool return_capture_params = false); + std::optional update(const ForwardInput& input, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + uint32_t padded_num_token, + bool return_capture_input = false); // Getter methods for persistent tensors torch::Tensor persistent_tokens(uint32_t actual_tokens = 0) const { @@ -183,14 +181,14 @@ class GraphPersistentParam { void initialize_paged_attention_plan_context(const torch::Device& device); // Update attention mask efficiently from input parameters - void update_attention_mask(const ModelInputParams& input_params); + void update_attention_mask(const ForwardInput& input); // Update paged attention tiling based on input parameters void plan_paged_attention_tiling(const torch::Tensor& tokens, const torch::Tensor& k_cache, const torch::Tensor& v_cache, const torch::Tensor& block_tables, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtStream stream); const ModelArgs& args_; @@ -258,17 +256,12 @@ class AclGraph { bool capture(CausalLM* model, const ModelArgs& args, const runtime::Options& options, - const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, std::vector& kv_cache, uint32_t bucket_num_tokens); // Replay captured graph with new input data - ModelOutput replay(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_cache, - const ModelInputParams& params); + ModelOutput replay(const ForwardInput& input, std::vector& kv_cache); // Get the hidden states from the last capture torch::Tensor get_hidden_states(uint32_t actual_num_tokens = 0) const { @@ -309,10 +302,8 @@ class AclGraphExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; // Execute model with graph optimization for decode phase - ModelOutput run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override; + ModelOutput run(const ForwardInput& input, + std::vector& kv_caches) override; static std::optional> find_first_full_attention_cache(const std::vector& kv_caches); diff --git a/xllm/core/runtime/base_executor_impl.cpp b/xllm/core/runtime/base_executor_impl.cpp index 37d3e78245..e32aa952b7 100644 --- a/xllm/core/runtime/base_executor_impl.cpp +++ b/xllm/core/runtime/base_executor_impl.cpp @@ -32,12 +32,10 @@ ForwardInput BaseExecutorImpl::prepare_inputs(Batch& batch) { options_.num_decoding_tokens(), 0, args_, options_.cp_size()); } -ModelOutput BaseExecutorImpl::run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { +ModelOutput BaseExecutorImpl::run(const ForwardInput& input, + std::vector& kv_caches) { COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } } // namespace xllm diff --git a/xllm/core/runtime/base_executor_impl.h b/xllm/core/runtime/base_executor_impl.h index 980121cff1..3420308df2 100644 --- a/xllm/core/runtime/base_executor_impl.h +++ b/xllm/core/runtime/base_executor_impl.h @@ -41,10 +41,8 @@ class BaseExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; - ModelOutput run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override; + ModelOutput run(const ForwardInput& input, + std::vector& kv_caches) override; private: // not own diff --git a/xllm/core/runtime/cuda_graph_executor_impl.cpp b/xllm/core/runtime/cuda_graph_executor_impl.cpp index 30f1298090..e57ecb1e25 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.cpp +++ b/xllm/core/runtime/cuda_graph_executor_impl.cpp @@ -192,51 +192,45 @@ CudaGraphPersistentParam::CudaGraphPersistentParam( } bool CudaGraphPersistentParam::can_use_llm_decode_fast_path( - const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params) const { - if (!params.meta.batch_forward_type.is_decode() || - is_rec_multi_round_mode() || params.has_llmrec_params() || - params.embedding.input_embedding.defined()) { + const ForwardInput& input) const { + if (!input.meta.batch_forward_type.is_decode() || is_rec_multi_round_mode() || + input.has_llmrec_params() || input.embedding.input_embedding.defined()) { return false; } - return is_cuda_contiguous_int32_tensor(tokens) && - is_cuda_contiguous_int32_tensor(positions) && + return is_cuda_contiguous_int32_tensor(input.token_ids) && + is_cuda_contiguous_int32_tensor(input.positions) && is_cuda_contiguous_int32_tensor( - params.attention.device.new_cache_slots) && - is_cuda_contiguous_int32_tensor(params.attention.device.kv_seq_lens) && + input.attention.device.new_cache_slots) && + is_cuda_contiguous_int32_tensor(input.attention.device.kv_seq_lens) && is_cuda_contiguous_int32_tensor( - params.attention.device.paged_kv_indptr) && + input.attention.device.paged_kv_indptr) && is_cuda_contiguous_int32_tensor( - params.attention.device.paged_kv_indices) && + input.attention.device.paged_kv_indices) && is_cuda_contiguous_int32_tensor( - params.attention.device.paged_kv_last_page_len); + input.attention.device.paged_kv_last_page_len); } void CudaGraphPersistentParam::update_llm_decode_metadata_fast_path( - const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, uint32_t padded_num_tokens, int64_t actual_batch_size, int64_t actual_num_tokens) { CHECK_GE(actual_batch_size, 0) << "actual_batch_size must be >= 0"; CHECK_GE(actual_num_tokens, 0) << "actual_num_tokens must be >= 0"; const int64_t actual_indices_size = - params.attention.device.paged_kv_indices.size(0); + input.attention.device.paged_kv_indices.size(0); xllm::kernel::cuda::LlmDecodeMetadataUpdateParams update_params{ - .src_tokens = tokens.data_ptr(), - .src_positions = positions.data_ptr(), + .src_tokens = input.token_ids.data_ptr(), + .src_positions = input.positions.data_ptr(), .src_new_cache_slots = - params.attention.device.new_cache_slots.data_ptr(), - .src_kv_seq_lens = - params.attention.device.kv_seq_lens.data_ptr(), + input.attention.device.new_cache_slots.data_ptr(), + .src_kv_seq_lens = input.attention.device.kv_seq_lens.data_ptr(), .src_paged_kv_indptr = - params.attention.device.paged_kv_indptr.data_ptr(), + input.attention.device.paged_kv_indptr.data_ptr(), .src_paged_kv_indices = - params.attention.device.paged_kv_indices.data_ptr(), + input.attention.device.paged_kv_indices.data_ptr(), .src_paged_kv_last_page_len = - params.attention.device.paged_kv_last_page_len.data_ptr(), + input.attention.device.paged_kv_last_page_len.data_ptr(), .dst_tokens = persistent_tokens_.data_ptr(), .dst_positions = persistent_positions_.data_ptr(), .dst_new_cache_slots = persistent_new_cache_slots_.data_ptr(), @@ -308,65 +302,65 @@ size_t CudaGraphPersistentParam::get_persistent_tensor_bytes() const { return total; } -std::optional CudaGraphPersistentParam::update( - const torch::Tensor& tokens, +std::optional CudaGraphPersistentParam::update( + const ForwardInput& input, const torch::Tensor& k_cache, const torch::Tensor& v_cache, - const torch::Tensor& positions, - const ModelInputParams& params, uint32_t padded_num_tokens, - bool return_capture_params) { - std::optional params_for_capture; - if (return_capture_params) { + bool return_capture_input) { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; + std::optional input_for_capture; + if (return_capture_input) { CHECK_GT(padded_num_tokens, 0) - << "padded_num_tokens must be > 0 when return_capture_params is true"; - params_for_capture = std::make_optional(params); + << "padded_num_tokens must be > 0 when return_capture_input is true"; + input_for_capture = input; + input_for_capture->token_ids = persistent_tokens(padded_num_tokens); + input_for_capture->positions = persistent_positions(padded_num_tokens); } // Build attn_metadata with original model_input_params. So we can set actual // batch size in plan_info. std::shared_ptr attn_metadata = std::make_shared( - layer::AttentionMetadataBuilder::build(params, args_.enable_mla())); + layer::AttentionMetadataBuilder::build(input.meta, + input.attention, + input.graph, + input.llmrec_params(), + input.enable_cuda_graph, + args_.enable_mla())); CHECK(attn_metadata) << "attn_metadata should not be null"; attn_metadata->enable_cuda_graph = true; - auto build_capture_params_if_needed = - [&]() -> std::optional { - if (!return_capture_params) { + auto build_capture_input_if_needed = [&]() -> std::optional { + if (!return_capture_input) { return std::nullopt; } - CHECK(params_for_capture.has_value()) - << "params_for_capture should be initialized when " - "return_capture_params " - "is true"; - if (params.embedding.input_embedding.defined()) { - params_for_capture->embedding.input_embedding = + CHECK(input_for_capture.has_value()) + << "input_for_capture should be initialized when " + "return_capture_input is true"; + if (input.embedding.input_embedding.defined()) { + input_for_capture->embedding.input_embedding = persistent_embedding(padded_num_tokens); } - if (!params.embedding.linear_state_ids.empty()) { - params_for_capture->embedding.linear_state_ids = - params.embedding.linear_state_ids; - params_for_capture->embedding.linear_state_indices = - persistent_linear_state_indices(params.meta.num_sequences); + if (!input.embedding.linear_state_ids.empty()) { + input_for_capture->embedding.linear_state_ids = + input.embedding.linear_state_ids; + input_for_capture->embedding.linear_state_indices = + persistent_linear_state_indices(input.meta.num_sequences); } - params_for_capture->attn_metadata = attn_metadata; - return params_for_capture; + input_for_capture->attn_metadata = attn_metadata; + return input_for_capture; }; const uint32_t actual_num_tokens = tokens.size(0); - const int64_t actual_batch_size = params.meta.num_sequences; - const bool use_llm_decode_fast_path = - can_use_llm_decode_fast_path(tokens, positions, params); + const int64_t actual_batch_size = input.meta.num_sequences; + const bool use_llm_decode_fast_path = can_use_llm_decode_fast_path(input); // Copy data from input parameters to persistent graph tensors if (use_llm_decode_fast_path) { VLOG(kGraphExecutorLogVerboseLevel) << "use fast path for LLM decode metadata update"; - update_llm_decode_metadata_fast_path(tokens, - positions, - params, - padded_num_tokens, - actual_batch_size, - actual_num_tokens); + update_llm_decode_metadata_fast_path( + input, padded_num_tokens, actual_batch_size, actual_num_tokens); } else { VLOG(kGraphExecutorLogVerboseLevel) << "copy_ tokens: src shape=" << tokens.sizes() << ", dst slice shape=[" @@ -398,25 +392,26 @@ std::optional CudaGraphPersistentParam::update( // kv_seq_lens is kv_cu_seq_lens in GPU Model. VLOG(kGraphExecutorLogVerboseLevel) << "copy_ q_seq_lens: src shape=" - << params.attention.device.q_seq_lens.sizes() << ", dst slice shape=[" + << input.attention.device.q_seq_lens.sizes() << ", dst slice shape=[" << actual_batch_size + 1 << "]"; q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size + 1) - .copy_(params.attention.device.q_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.q_seq_lens, /*non_blocking=*/true); VLOG(kGraphExecutorLogVerboseLevel) << "copy_ kv_seq_lens: src shape=" - << params.attention.device.kv_seq_lens.sizes() << ", dst slice shape=[" + << input.attention.device.kv_seq_lens.sizes() << ", dst slice shape=[" << actual_batch_size + 1 << "]"; kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size + 1) - .copy_(params.attention.device.kv_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.kv_seq_lens, /*non_blocking=*/true); VLOG(kGraphExecutorLogVerboseLevel) << "copy_ new_cache_slots: src shape=" - << params.attention.device.new_cache_slots.sizes() + << input.attention.device.new_cache_slots.sizes() << ", dst slice shape=[" << actual_num_tokens << "]"; persistent_new_cache_slots_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) - .copy_(params.attention.device.new_cache_slots, /*non_blocking=*/true); + .copy_(input.attention.device.new_cache_slots, + /*non_blocking=*/true); if (padded_num_tokens > actual_num_tokens) { persistent_new_cache_slots_ .slice(/*dim=*/0, @@ -438,16 +433,15 @@ std::optional CudaGraphPersistentParam::update( persistent_new_cache_slots(slot_mapping_tokens); } - if (!is_rec_multi_round_mode() && - !params.embedding.linear_state_ids.empty()) { - if (params.embedding.linear_state_indices.defined()) { + if (!is_rec_multi_round_mode() && !input.embedding.linear_state_ids.empty()) { + if (input.embedding.linear_state_indices.defined()) { persistent_linear_state_indices_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.embedding.linear_state_indices, /*non_blocking=*/true); + .copy_(input.embedding.linear_state_indices, /*non_blocking=*/true); } else { persistent_linear_state_indices_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(torch::tensor(params.embedding.linear_state_ids, torch::kInt) + .copy_(torch::tensor(input.embedding.linear_state_ids, torch::kInt) .to(device_), /*non_blocking=*/true); } @@ -457,10 +451,10 @@ std::optional CudaGraphPersistentParam::update( // expanded to batch_size * beam_width rows while num_sequences still tracks // the logical request count. Use the tensor's real row count here. const int64_t actual_block_table_batch = - is_rec_multi_round_mode() ? params.attention.device.block_tables.size(0) + is_rec_multi_round_mode() ? input.attention.device.block_tables.size(0) : actual_batch_size; const int64_t actual_block_table_len = - params.attention.device.block_tables.size(1); + input.attention.device.block_tables.size(1); torch::Tensor slice_persistent_block_tables = persistent_block_tables_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_block_table_batch) @@ -468,16 +462,16 @@ std::optional CudaGraphPersistentParam::update( VLOG(kGraphExecutorLogVerboseLevel) << "copy_ block_tables: src shape=" - << params.attention.device.block_tables.sizes() + << input.attention.device.block_tables.sizes() << ", dst slice shape=" << slice_persistent_block_tables.sizes(); - slice_persistent_block_tables.copy_(params.attention.device.block_tables, + slice_persistent_block_tables.copy_(input.attention.device.block_tables, /*non_blocking=*/true); if (!attn_metadata->is_prefill || args_.enable_mla()) { attn_metadata->block_table = slice_persistent_block_tables; } // Update persistent embedding from input_embedding if available - const auto& embedding = params.embedding.input_embedding; + const auto& embedding = input.embedding.input_embedding; if (embedding.defined()) { const int64_t embedding_tokens = embedding.size(0); @@ -502,7 +496,7 @@ std::optional CudaGraphPersistentParam::update( } const bool is_decode_with_llmrec = - params.meta.batch_forward_type.is_decode() && params.has_llmrec_params(); + input.meta.batch_forward_type.is_decode() && input.has_llmrec_params(); const bool use_two_stage_decode = !FLAGS_enable_xattention_one_stage && is_decode_with_llmrec; const int32_t head_dim = args_.head_dim(); @@ -536,19 +530,19 @@ std::optional CudaGraphPersistentParam::update( bool use_tensor_core = xllm::kernel::cuda::should_use_tensor_core(dtype, n_heads, n_kv_heads); if (use_two_stage_decode) { - if (params.attention.device.q_seq_lens.defined() && - params.attention.device.q_seq_lens.numel() > 0) { - const int64_t q_numel = params.attention.device.q_seq_lens.numel(); + if (input.attention.device.q_seq_lens.defined() && + input.attention.device.q_seq_lens.numel() > 0) { + const int64_t q_numel = input.attention.device.q_seq_lens.numel(); q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_numel) - .copy_(params.attention.device.q_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.q_seq_lens, /*non_blocking=*/true); attn_metadata->q_cu_seq_lens = q_seq_lens(/*actual_batch_size=*/q_numel); } - if (params.attention.device.kv_seq_lens.defined() && - params.attention.device.kv_seq_lens.numel() > 0) { - const int64_t kv_numel = params.attention.device.kv_seq_lens.numel(); + if (input.attention.device.kv_seq_lens.defined() && + input.attention.device.kv_seq_lens.numel() > 0) { + const int64_t kv_numel = input.attention.device.kv_seq_lens.numel(); kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_numel) - .copy_(params.attention.device.kv_seq_lens, /*non_blocking=*/true); + .copy_(input.attention.device.kv_seq_lens, /*non_blocking=*/true); attn_metadata->kv_cu_seq_lens = kv_seq_lens(/*actual_batch_size=*/kv_numel); if (kv_numel > 1) { @@ -562,7 +556,7 @@ std::optional CudaGraphPersistentParam::update( attn_metadata->qo_indptr = torch::Tensor(); // Update plan_info if attn_metadata exists and enable_cuda_graph is true. - const auto& llmrec_params = *params.llmrec_params(); + const auto& llmrec_params = *input.llmrec_params(); auto cache = attn_metadata->xattention_two_stage_decode_cache.value(); CHECK(cache.q_cu_seq_lens_shared.defined()) << "q_cu_seq_lens_shared must be initialized in rec worker"; @@ -672,7 +666,7 @@ std::optional CudaGraphPersistentParam::update( /*causal=*/false, use_tensor_core, /*is_shared_stage_plan*/ false); - return build_capture_params_if_needed(); + return build_capture_input_if_needed(); } if (use_llm_decode_fast_path) { const uint32_t slot_mapping_tokens = @@ -693,15 +687,15 @@ std::optional CudaGraphPersistentParam::update( attn_metadata->qo_indptr = persistent_decode_qo_indptr(static_cast(actual_batch_size)); } else { - CHECK(params.attention.device.paged_kv_indptr.defined()) + CHECK(input.attention.device.paged_kv_indptr.defined()) << "paged_kv_indptr should not be null"; VLOG(kGraphExecutorLogVerboseLevel) << "copy_ paged_kv_indptr: src shape=" - << params.attention.device.paged_kv_indptr.sizes() + << input.attention.device.paged_kv_indptr.sizes() << ", dst slice shape=[" << (actual_batch_size + 1) << "]"; if (VLOG_IS_ON(kGraphExecutorLogVerboseLevel)) { torch::Tensor paged_kv_indptr_cpu = - params.attention.device.paged_kv_indptr.to(torch::kCPU); + input.attention.device.paged_kv_indptr.to(torch::kCPU); VLOG(kGraphExecutorLogVerboseLevel) << "copy_ paged_kv_indptr: src values=" << paged_kv_indptr_cpu; } @@ -709,31 +703,32 @@ std::optional CudaGraphPersistentParam::update( .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size + 1) - .copy_(params.attention.device.paged_kv_indptr, /*non_blocking=*/true); - CHECK(params.attention.device.paged_kv_indices.defined()) + .copy_(input.attention.device.paged_kv_indptr, /*non_blocking=*/true); + CHECK(input.attention.device.paged_kv_indices.defined()) << "paged_kv_indices should not be null"; const int64_t actual_indices_size = - params.attention.device.paged_kv_indices.size(0); + input.attention.device.paged_kv_indices.size(0); VLOG(kGraphExecutorLogVerboseLevel) << "copy_ paged_kv_indices: src shape=" - << params.attention.device.paged_kv_indices.sizes() + << input.attention.device.paged_kv_indices.sizes() << ", dst slice shape=[" << actual_indices_size << "]"; persistent_paged_kv_indices_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_indices_size) - .copy_(params.attention.device.paged_kv_indices, /*non_blocking=*/true); - CHECK(params.attention.device.paged_kv_last_page_len.defined()) + .copy_(input.attention.device.paged_kv_indices, + /*non_blocking=*/true); + CHECK(input.attention.device.paged_kv_last_page_len.defined()) << "paged_kv_last_page_len should not be null"; VLOG(kGraphExecutorLogVerboseLevel) << "copy_ paged_kv_last_page_len: src shape=" - << params.attention.device.paged_kv_last_page_len.sizes() + << input.attention.device.paged_kv_last_page_len.sizes() << ", dst slice shape=[" << actual_batch_size << "]"; persistent_paged_kv_last_page_len_ .slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) - .copy_(params.attention.device.paged_kv_last_page_len, + .copy_(input.attention.device.paged_kv_last_page_len, /*non_blocking=*/true); attn_metadata->kv_seq_lens = torch::diff(kv_seq_lens(/*actual_batch_size=*/actual_batch_size + 1)); @@ -819,22 +814,21 @@ std::optional CudaGraphPersistentParam::update( : 0); } - // Return ModelInputParams with persistent buffer references if requested - return build_capture_params_if_needed(); + // Return ForwardInput with persistent buffer references if requested + return build_capture_input_if_needed(); } // CudaGraph implementation bool CudaGraph::capture(CausalLM* model, const ModelArgs& args, const runtime::Options& options, - const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, std::vector& kv_cache, uint32_t bucket_num_tokens, const at::cuda::MempoolId_t& pool, TorchMemPool* pool_ptr) { padded_num_tokens_ = bucket_num_tokens; + const torch::Tensor& tokens = input.token_ids; const uint32_t actual_num_tokens = tokens.size(0); CHECK_GE(padded_num_tokens_, actual_num_tokens) << "bucket_num_tokens >= actual_num_tokens"; @@ -850,7 +844,7 @@ bool CudaGraph::capture(CausalLM* model, device_index_); capture_lock_guard.emplace(capture_lock); } - // Use the returned ModelInputParams for graph capture + // Use the returned ForwardInput for graph capture // Always use capture stream for plan/update + capture + forward. at::cuda::CUDAStream original_stream = at::cuda::getCurrentCUDAStream(device_index_); @@ -872,19 +866,16 @@ bool CudaGraph::capture(CausalLM* model, << "CUDA graph capture requires at least one full-attention KV cache"; const torch::Tensor& k_cache = full_attention_cache->first; const torch::Tensor& v_cache = full_attention_cache->second; - auto graph_params_opt = - persistent_param_.update(tokens, + auto graph_input_opt = + persistent_param_.update(input, k_cache, v_cache, - positions, - params, padded_num_tokens_, - /*return_capture_params=*/true); + /*return_capture_input=*/true); - // Use the returned ModelInputParams for graph capture - CHECK(graph_params_opt.has_value()) - << "update() should return ModelInputParams when " - "return_capture_params=true"; + // Use the returned ForwardInput for graph capture + CHECK(graph_input_opt.has_value()) + << "update() should return ForwardInput when return_capture_input=true"; LOG(INFO) << "CUDA graph capture begin, bucket_num_tokens: " << bucket_num_tokens << ", actual_num_tokens: " << actual_num_tokens @@ -895,10 +886,8 @@ bool CudaGraph::capture(CausalLM* model, // Warmup: execute forward once without capture to initialize cuBLAS handles // and other CUDA resources. This is necessary because these resources // cannot be created during CUDA graph capture mode. - model->forward(persistent_param_.persistent_tokens(padded_num_tokens_), - persistent_param_.persistent_positions(padded_num_tokens_), - kv_cache, - graph_params_opt.value()); + ForwardInput warmup_input = graph_input_opt.value(); + model->forward(warmup_input, kv_cache); // MemPoolContext has been deprecated in torch >= 2.8 #if TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <= 7 @@ -914,11 +903,8 @@ bool CudaGraph::capture(CausalLM* model, GlobalCaptureInstance::get_instance().begin_capture(pool); // Execute forward pass - attention operations will be captured separately - auto forward_result = model->forward( - persistent_param_.persistent_tokens(padded_num_tokens_), - persistent_param_.persistent_positions(padded_num_tokens_), - kv_cache, - graph_params_opt.value()); + ForwardInput graph_input = graph_input_opt.value(); + auto forward_result = model->forward(graph_input, kv_cache); // Store result in persistent buffer persistent_param_.set_hidden_states(forward_result.hidden_states); @@ -966,11 +952,8 @@ bool CudaGraph::capture(CausalLM* model, graph_.capture_begin(pool, cudaStreamCaptureModeThreadLocal); // Execute forward pass - CUDA graph will capture this - auto forward_result = model->forward( - persistent_param_.persistent_tokens(padded_num_tokens_), - persistent_param_.persistent_positions(padded_num_tokens_), - kv_cache, - graph_params_opt.value()); + ForwardInput graph_input = graph_input_opt.value(); + auto forward_result = model->forward(graph_input, kv_cache); // Store result in persistent buffer persistent_param_.set_hidden_states(forward_result.hidden_states); @@ -997,10 +980,9 @@ bool CudaGraph::capture(CausalLM* model, return true; } -ModelOutput CudaGraph::replay(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_cache, - const ModelInputParams& params) { +ModelOutput CudaGraph::replay(const ForwardInput& input, + std::vector& kv_cache) { + const torch::Tensor& tokens = input.token_ids; const uint32_t actual_num_tokens = tokens.size(0); CHECK_LE(actual_num_tokens, padded_num_tokens_) << "num_tokens mismatch: expected <= " << padded_num_tokens_ << ", got " @@ -1029,50 +1011,44 @@ ModelOutput CudaGraph::replay(const torch::Tensor& tokens, if (is_piecewise_) { // Piecewise replay mode (for prefill) // Need to get updated params with attn_metadata for attention replay - auto updated_params_opt = - persistent_param_.update(tokens, + auto updated_input_opt = + persistent_param_.update(input, k_cache, v_cache, - positions, - params, padded_num_tokens_, - /*return_capture_params=*/true); - CHECK(updated_params_opt.has_value()) - << "update() should return ModelInputParams for piecewise replay"; + /*return_capture_input=*/true); + CHECK(updated_input_opt.has_value()) + << "update() should return ForwardInput for piecewise replay"; - const auto& updated_params = updated_params_opt.value(); + const auto& updated_input = updated_input_opt.value(); CHECK(piecewise_graph_.num_runners() > 0) << "Piecewise graph must have attention runners"; - CHECK(updated_params.attn_metadata) + CHECK(updated_input.attn_metadata) << "attn_metadata is required for piecewise replay"; - CHECK(updated_params.attn_metadata->plan_info) + CHECK(updated_input.attn_metadata->plan_info) << "plan_info is required for piecewise replay"; VLOG(kGraphExecutorLogVerboseLevel) << "CudaGraph::replay() piecewise replay with uri=" - << updated_params.attn_metadata->plan_info->uri - << ", plan_info.defined=" - << updated_params.attn_metadata->plan_info->plan_info.defined(); + << updated_input.attn_metadata->plan_info->uri << ", plan_info.defined=" + << updated_input.attn_metadata->plan_info->plan_info.defined(); // Build AttentionReplayParams from updated attn_metadata ::xllm::kernel::cuda::AttentionReplayParams replay_params; replay_params.actual_num_tokens = actual_num_tokens; - replay_params.plan_info = - updated_params.attn_metadata->plan_info->plan_info; - replay_params.q_cu_seq_lens = updated_params.attn_metadata->q_cu_seq_lens; - replay_params.kv_cu_seq_lens = updated_params.attn_metadata->kv_cu_seq_lens; + replay_params.plan_info = updated_input.attn_metadata->plan_info->plan_info; + replay_params.q_cu_seq_lens = updated_input.attn_metadata->q_cu_seq_lens; + replay_params.kv_cu_seq_lens = updated_input.attn_metadata->kv_cu_seq_lens; // Replay piecewise graphs and attention runners piecewise_graph_.replay(replay_params); } else { // Normal replay mode (for decode) - persistent_param_.update(tokens, + persistent_param_.update(input, k_cache, v_cache, - positions, - params, padded_num_tokens_, - /*return_capture_params=*/false); + /*return_capture_input=*/false); graph_.replay(); } @@ -1375,12 +1351,11 @@ ModelOutput CudaGraphExecutorImpl::attach_aux_hidden_states_if_needed( return ModelOutput(hidden_states); } -ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { - const bool is_prefill = params.meta.batch_forward_type.is_prefill(); - const bool is_decode = params.meta.batch_forward_type.is_decode(); +ModelOutput CudaGraphExecutorImpl::run(const ForwardInput& input, + std::vector& kv_caches) { + const torch::Tensor& tokens = input.token_ids; + const bool is_prefill = input.meta.batch_forward_type.is_prefill(); + const bool is_decode = input.meta.batch_forward_type.is_decode(); // Get actual num_tokens from tokens shape const uint32_t n_tokens = tokens.size(/*dim=*/0); @@ -1398,7 +1373,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, << " exceeds max_tokens_for_graph_mode (" << max_tokens_for_graph_mode_ << "), falling back to eager mode"; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } // Check if piecewise graph exists for this bucket @@ -1407,7 +1382,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, // Replay existing piecewise graph VLOG(kGraphExecutorLogVerboseLevel) << "CudaGraphExecutorImpl::run() in prefill piecewise replay mode"; - auto result = it->second->replay(tokens, positions, kv_caches, params); + auto result = it->second->replay(input, kv_caches); return attach_aux_hidden_states_if_needed(result.hidden_states, n_tokens); } @@ -1432,9 +1407,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, bool capture_success = graph->capture(model_, args_, options_, - tokens, - positions, - params, + input, kv_caches, bucket_num_tokens, mem_pool, @@ -1452,8 +1425,8 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, // Run replay after capture so first request uses same execution path as // subsequent requests. - auto result = prefill_graphs_[bucket_num_tokens]->replay( - tokens, positions, kv_caches, params); + auto result = + prefill_graphs_[bucket_num_tokens]->replay(input, kv_caches); return attach_aux_hidden_states_if_needed(result.hidden_states, n_tokens); } @@ -1470,21 +1443,21 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, // Prefill without piecewise graph: use eager mode if (is_prefill) { COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } // Decode phase with full graph if (is_decode) { // Check if conditions are suitable for graph execution (replay or capture) const auto max_seq_len = args_.max_position_embeddings(); - const bool seq_len_supported = params.meta.kv_max_seq_len <= max_seq_len; + const bool seq_len_supported = input.meta.kv_max_seq_len <= max_seq_len; // Early return if conditions are not suitable for graph operations if (!seq_len_supported) { LOG(WARNING) << "Not suitable for CUDA graph operations, falling back to " "eager mode."; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } // Check if captured graph exists for this bucket num_tokens @@ -1493,7 +1466,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, // Replay the existing graph VLOG(kGraphExecutorLogVerboseLevel) << "CudaGraphExecutorImpl::run() in decode replay mode"; - auto result = it->second->replay(tokens, positions, kv_caches, params); + auto result = it->second->replay(input, kv_caches); return attach_aux_hidden_states_if_needed(result.hidden_states, n_tokens); } @@ -1517,9 +1490,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, bool capture_success = graph->capture(model_, args_, options_, - tokens, - positions, - params, + input, kv_caches, bucket_num_tokens, mem_pool, @@ -1537,8 +1508,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, // Run replay after capture so first request uses same execution path as // subsequent requests. - auto result = graphs_[bucket_num_tokens]->replay( - tokens, positions, kv_caches, params); + auto result = graphs_[bucket_num_tokens]->replay(input, kv_caches); return attach_aux_hidden_states_if_needed(result.hidden_states, n_tokens); } @@ -1555,7 +1525,7 @@ ModelOutput CudaGraphExecutorImpl::run(const torch::Tensor& tokens, LOG(ERROR) << "Failed to capture CUDA graph for bucket num_tokens: " << bucket_num_tokens; COUNTER_INC(num_model_execution_total_eager); - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(input, kv_caches); } // bucket will be [1, 2, 4, 8, 16, 32, 48, 64, ..., max_seqs_per_batch] diff --git a/xllm/core/runtime/cuda_graph_executor_impl.h b/xllm/core/runtime/cuda_graph_executor_impl.h index b4eb5e7089..03cacfdfbd 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.h +++ b/xllm/core/runtime/cuda_graph_executor_impl.h @@ -35,9 +35,9 @@ limitations under the License. #include "core/common/macros.h" #include "core/framework/kv_cache/kv_cache.h" #include "core/framework/model/causal_lm.h" -#include "core/framework/model/model_input_params.h" #include "core/kernels/cuda/llm_decode_metadata_update.h" #include "core/kernels/cuda/piecewise_graphs.h" +#include "core/runtime/forward_params.h" #include "executor_impl.h" #include "executor_impl_factory.h" #include "options.h" @@ -62,18 +62,16 @@ class CudaGraphPersistentParam { ~CudaGraphPersistentParam() = default; // Update persistent tensors with new input data - // If return_capture_params is true, returns a ModelInputParams with + // If return_capture_input is true, returns a ForwardInput with // persistent buffer references. padded_num_tokens must be > 0 when - // return_capture_params is true, used for build new ModelInputParams for - // capture. If return_capture_params is false, only updates persistent buffers + // return_capture_input is true, used for building new ForwardInput for + // capture. If return_capture_input is false, only updates persistent buffers // and returns std::nullopt. - std::optional update(const torch::Tensor& tokens, - const torch::Tensor& k_cache, - const torch::Tensor& v_cache, - const torch::Tensor& positions, - const ModelInputParams& params, - uint32_t padded_num_tokens = 0, - bool return_capture_params = false); + std::optional update(const ForwardInput& input, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + uint32_t padded_num_tokens = 0, + bool return_capture_input = false); // Getter methods for persistent tensors torch::Tensor persistent_tokens(uint32_t actual_tokens) const { @@ -198,12 +196,8 @@ class CudaGraphPersistentParam { } private: - bool can_use_llm_decode_fast_path(const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params) const; - void update_llm_decode_metadata_fast_path(const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params, + bool can_use_llm_decode_fast_path(const ForwardInput& input) const; + void update_llm_decode_metadata_fast_path(const ForwardInput& input, uint32_t padded_num_tokens, int64_t actual_batch_size, int64_t actual_num_tokens); @@ -253,19 +247,14 @@ class CudaGraph { bool capture(CausalLM* model, const ModelArgs& args, const runtime::Options& options, - const torch::Tensor& tokens, - const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, std::vector& kv_cache, uint32_t bucket_num_tokens, const at::cuda::MempoolId_t& pool, TorchMemPool* pool_ptr = nullptr); // Replay captured graph with new input data - ModelOutput replay(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_cache, - const ModelInputParams& params); + ModelOutput replay(const ForwardInput& input, std::vector& kv_cache); // Get the hidden states from the last capture torch::Tensor get_hidden_states(uint32_t actual_num_tokens) const { @@ -307,10 +296,8 @@ class CudaGraphExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; // Execute model with graph optimization for decode phase - ModelOutput run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override; + ModelOutput run(const ForwardInput& input, + std::vector& kv_caches) override; // Return current graph executor memory usage in bytes (including persistent // parameters). Exposed for tests and diagnostics. diff --git a/xllm/core/runtime/dit_worker_impl.cpp b/xllm/core/runtime/dit_worker_impl.cpp index f38bde7688..082f352eb9 100644 --- a/xllm/core/runtime/dit_worker_impl.cpp +++ b/xllm/core/runtime/dit_worker_impl.cpp @@ -155,7 +155,7 @@ std::optional DiTWorkerImpl::step(const ForwardInput& inputs) { torch::DeviceGuard device_guard(device_); Timer timer; auto output = dit_model_executor_->forward( - inputs.input_params.dit_forward_input.to(device_, dtype_)); + inputs.dit_forward_input.to(device_, dtype_)); auto ret = device_.synchronize_default_stream(); COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); diff --git a/xllm/core/runtime/embed_vlm_worker_impl.cpp b/xllm/core/runtime/embed_vlm_worker_impl.cpp index 21a8db94ae..f7666becb2 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.cpp +++ b/xllm/core/runtime/embed_vlm_worker_impl.cpp @@ -61,14 +61,11 @@ std::optional EmbedVLMWorkerImpl::step( // TODO to adapt multi stream parallel later, just use [0] temporarily // all tensors should be on the same device as model - auto flatten_tokens = input.token_ids.to(device_); - auto flatten_positions = input.positions.to(device_); - auto params = input.input_params.to(device_); - auto sampling_params = input.sampling_params.to(device_, dtype_); + ForwardInput model_input = input.to(device_, dtype_); + auto sampling_params = model_input.sampling_params; // call model executor forward to get hidden states - auto model_output = model_executor_->forward( - flatten_tokens, flatten_positions, kv_caches_, params); + auto model_output = model_executor_->forward(model_input, kv_caches_); auto hidden_states = model_output.hidden_states; ret = device_.synchronize_default_stream(); COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); @@ -88,7 +85,7 @@ std::optional EmbedVLMWorkerImpl::step( // split full embeddings and add them to mm_embeddings // so that the user could receive embeddings of images and texts if (FLAGS_enable_return_mm_full_embeddings) { - auto q_seq_len_vec = input.input_params.attention.host.q_seq_lens; + auto q_seq_len_vec = input.attention.host.q_seq_lens; sample_output.mm_embeddings.reserve(q_seq_len_vec.size()); int32_t token_start_idx = 0; for (auto seq_len : q_seq_len_vec) { diff --git a/xllm/core/runtime/embed_worker_impl.cpp b/xllm/core/runtime/embed_worker_impl.cpp index bc137ba842..584db2d039 100644 --- a/xllm/core/runtime/embed_worker_impl.cpp +++ b/xllm/core/runtime/embed_worker_impl.cpp @@ -60,14 +60,11 @@ std::optional EmbedWorkerImpl::step(const ForwardInput& input) { // TODO to adapt multi stream parallel later, just use [0] temporarily // all tensors should be on the same device as model - auto flatten_tokens = input.token_ids.to(device_); - auto flatten_positions = input.positions.to(device_); - auto params = input.input_params.to(device_); - auto sampling_params = input.sampling_params.to(device_, dtype_); + ForwardInput model_input = input.to(device_, dtype_); + auto sampling_params = model_input.sampling_params; // call model executor forward to get hidden states - auto model_output = model_executor_->forward( - flatten_tokens, flatten_positions, kv_caches_, params); + auto model_output = model_executor_->forward(model_input, kv_caches_); auto hidden_states = model_output.hidden_states; COUNTER_ADD(execution_latency_seconds_model, timer.elapsed_seconds()); diff --git a/xllm/core/runtime/executor.cpp b/xllm/core/runtime/executor.cpp index fe0831bb1c..5f20a10c2a 100644 --- a/xllm/core/runtime/executor.cpp +++ b/xllm/core/runtime/executor.cpp @@ -36,11 +36,9 @@ ForwardInput Executor::prepare_inputs(Batch& batch) { return impl_->prepare_inputs(batch); } -ModelOutput Executor::forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { - return impl_->run(tokens, positions, kv_caches, params); +ModelOutput Executor::forward(const ForwardInput& input, + std::vector& kv_caches) { + return impl_->run(input, kv_caches); } } // namespace xllm diff --git a/xllm/core/runtime/executor.h b/xllm/core/runtime/executor.h index cdd16a7649..2f3d73aaaa 100644 --- a/xllm/core/runtime/executor.h +++ b/xllm/core/runtime/executor.h @@ -38,13 +38,8 @@ class Executor final { ForwardInput prepare_inputs(Batch& batch); - // tokens: vector size is dp_size, each element is [num_tokens/dp_size] - // positions: vector size is dp_size, each element is [num_tokens/dp_size] - // token pos in the sequence returns: ModelOutput - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches); private: std::unique_ptr impl_; diff --git a/xllm/core/runtime/executor_impl.h b/xllm/core/runtime/executor_impl.h index ba0b1972e8..fa3058227e 100644 --- a/xllm/core/runtime/executor_impl.h +++ b/xllm/core/runtime/executor_impl.h @@ -27,6 +27,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/model/model_output.h" #include "options.h" +#include "runtime/forward_params.h" namespace xllm { @@ -36,13 +37,8 @@ class ExecutorImpl { virtual ForwardInput prepare_inputs(Batch& batch) = 0; - // tokens: vector size is dp_size, each element is [num_tokens/dp_size] - // positions: vector size is dp_size, each element is [num_tokens/dp_size] - // token pos in the sequence returns: ModelOutput - virtual ModelOutput run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) = 0; + virtual ModelOutput run(const ForwardInput& input, + std::vector& kv_caches) = 0; }; } // namespace xllm diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 8bf822d876..a151520793 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -39,6 +40,10 @@ limitations under the License. namespace xllm { +namespace layer { +struct AttentionMetadata; +} // namespace layer + struct ForwardInput; namespace detail { @@ -201,16 +206,16 @@ inline torch::Tensor normalize_positions_for_device( return positions; } -inline bool has_contiguous_input_buffer_exclusions( - const ModelInputParams& params) { +template +inline bool has_contiguous_input_buffer_exclusions(const Input& params) { return params.multimodal.mm_data.valid() || params.has_onerec_params() || params.has_llmrec_params() || params.dit_forward_input.valid() || params.multimodal.visual_pos_masks.defined() || !params.multimodal.deep_stacks.empty(); } -inline void clear_contiguous_input_buffer_tensor_targets( - ModelInputParams& params) { +template +inline void clear_contiguous_input_buffer_tensor_targets(Input& params) { params.embedding.input_embedding = torch::Tensor(); params.embedding.linear_state_indices = torch::Tensor(); params.block_copy.src_block_indices = torch::Tensor(); @@ -251,8 +256,9 @@ inline bool add_attention_to_plan(const AttentionInput& source, &target.device.ring_cache_seqlen); } -inline bool add_model_tensors_to_plan(const ModelInputParams& source, - ModelInputParams& target, +template +inline bool add_model_tensors_to_plan(const Input& source, + Input& target, ForwardInputBufferPlan& plan) { return plan.add(source.embedding.input_embedding, &target.embedding.input_embedding) && @@ -447,7 +453,7 @@ struct ForwardInput { inputs.token_ids = safe_to(source_token_ids, device, true); inputs.positions = detail::normalize_positions_for_device( safe_to(source_positions, device, true)); - inputs.input_params = input_params.to(device); + copy_model_inputs_to_device(device, inputs); inputs.sampling_params = sampling_params.to(device, dtype); inputs.decoder_sampling_params = decoder_sampling_params.to(device, dtype); copy_metadata_to(inputs); @@ -463,14 +469,13 @@ struct ForwardInput { copy_metadata_to(inputs); set_host_views(inputs); - const ModelInputParams& source_params = input_params; if (missing_required_host_views(inputs) || - detail::has_contiguous_input_buffer_exclusions(source_params)) { + detail::has_contiguous_input_buffer_exclusions(*this)) { return false; } - inputs.input_params = source_params; - detail::clear_contiguous_input_buffer_tensor_targets(inputs.input_params); + copy_model_inputs_to(inputs); + detail::clear_contiguous_input_buffer_tensor_targets(inputs); inputs.sampling_params = sampling_params; inputs.decoder_sampling_params = decoder_sampling_params; @@ -484,10 +489,8 @@ struct ForwardInput { return false; } - if (!detail::add_attention_to_plan( - source_params.attention, inputs.input_params.attention, plan) || - !detail::add_model_tensors_to_plan( - source_params, inputs.input_params, plan)) { + if (!detail::add_attention_to_plan(attention, inputs.attention, plan) || + !detail::add_model_tensors_to_plan(*this, inputs, plan)) { return false; } @@ -519,6 +522,43 @@ struct ForwardInput { inputs.skip_sampling_for_logits_only = skip_sampling_for_logits_only; } + void copy_model_inputs_to(ForwardInput& inputs) const { + inputs.meta = meta; + inputs.attention = attention; + inputs.embedding = embedding; + inputs.parallel = parallel; + inputs.block_copy = block_copy; + inputs.multimodal = multimodal; + inputs.expert = expert; + inputs.graph = graph; + inputs.rec_params = rec_params; + inputs.dit_forward_input = dit_forward_input; + inputs.attn_metadata = attn_metadata; + inputs.enable_cuda_graph = enable_cuda_graph; + } + + void copy_model_inputs_to_device(const torch::Device& device, + ForwardInput& inputs) const { + inputs.meta = meta; + inputs.attention = attention.to(device); + inputs.embedding = embedding.to(device); + inputs.block_copy = block_copy.to(device); + inputs.multimodal = multimodal.to(device); + inputs.parallel = parallel.to(device); + inputs.expert = expert.to(device); + inputs.graph = graph.to(device); + inputs.dit_forward_input = dit_forward_input.to(device); + + // rec_params device conversion for both OneRec and LLM-Rec variants + if (const auto* onerec_xattn = onerec_xattention_params()) { + inputs.rec_params = onerec_xattn->to(device); + } else if (const auto* onerec = onerec_params()) { + inputs.rec_params = onerec->to(device); + } else if (const auto* llmrec = llmrec_params()) { + inputs.rec_params = llmrec->to(device); + } + } + void set_host_views(ForwardInput& inputs) const { inputs.token_ids_host = token_ids_host.defined() ? token_ids_host : cpu_view(token_ids); @@ -549,7 +589,37 @@ struct ForwardInput { void print() const { LOG(INFO) << " token_ids: " << token_ids << std::endl; LOG(INFO) << " positions: " << positions << std::endl; - input_params.print(); + LOG(INFO) << "ForwardInput: batch_forward_type is " + << meta.batch_forward_type.to_string() << " , num_sequences is " + << meta.num_sequences << " , kv_max_seq_len is " + << meta.kv_max_seq_len << " , q_max_seq_len is " + << meta.q_max_seq_len; + LOG(INFO) << "ForwardInput: attention.host.kv_seq_lens is " + << attention.host.kv_seq_lens; + LOG(INFO) << "ForwardInput: attention.host.q_seq_lens is " + << attention.host.q_seq_lens; + LOG(INFO) << "ForwardInput: batch_forward_type is " + << meta.batch_forward_type.to_string(); + print_tensor(attention.device.kv_seq_lens, "ForwardInput: kv_seq_lens", 4); + print_tensor(attention.device.q_seq_lens, "ForwardInput: q_seq_lens", 4); + print_tensor( + attention.device.q_cu_seq_lens, "ForwardInput: q_cu_seq_lens", 4); + print_tensor( + attention.device.new_cache_slots, "ForwardInput: new_cache_slots", 4); + print_tensor( + attention.device.block_tables, "ForwardInput: block_tables", 4); + LOG(INFO) << "ForwardInput: dp_global_token_nums is " + << parallel.dp_global_token_nums + << ", dp_is_decode: " << parallel.dp_is_decode; + + if (const auto* onerec = onerec_params()) { + LOG(INFO) << "ForwardInput: has onerec_params"; + onerec->print(); + } else if (const auto* llmrec = llmrec_params()) { + LOG(INFO) << "ForwardInput: has llm_rec_multi_round_params" + << ", beam_width=" << llmrec->beam_width + << ", total_round=" << llmrec->total_round; + } LOG(INFO) << " params.selected_token_idxes " << sampling_params.selected_token_idxes; LOG(INFO) << " params.sample_idxes " << sampling_params.sample_idxes; @@ -562,13 +632,124 @@ struct ForwardInput { bool has_step_meta() const { return step_decode.has_value(); } + int32_t get_q_seq_len(int32_t seq_idx) const { +#if defined(USE_NPU) + CHECK(seq_idx < attention.host.q_seq_lens.size()) << "seq_idx out of range"; + return attention.host.q_seq_lens[seq_idx]; +#else + CHECK(seq_idx < attention.host.q_seq_lens.size() - 1) + << "seq_idx out of range"; + return attention.host.q_seq_lens[seq_idx + 1] - + attention.host.q_seq_lens[seq_idx]; +#endif + } + + bool synchronize_layer(uint32_t layer_idx) const { +#if defined(USE_NPU) + if (parallel.layer_wise_load_synchronizer != nullptr && + layer_idx % parallel.layers_per_bacth_copy == 0) { + if (!parallel.layer_wise_load_synchronizer->synchronize_layer( + layer_idx / parallel.layers_per_bacth_copy)) { + return false; + } + } +#else + (void)layer_idx; +#endif + return true; + } + + bool record_layer(uint32_t layer_idx, const torch::Device& device) const { +#if defined(USE_MLU) + if (parallel.layer_synchronizer != nullptr) { + return parallel.layer_synchronizer->record_current(layer_idx, + device.index()); + } +#else + (void)layer_idx; + (void)device; +#endif + return true; + } + + const OneRecModelInputParams* onerec_params() const { + if (const auto* params = std::get_if(&rec_params)) { + return params; + } + if (const auto* params = std::get_if(&rec_params)) { + return static_cast(params); + } + return nullptr; + } + + bool has_onerec_params() const { return onerec_params() != nullptr; } + + OneRecModelInputParams& mutable_onerec_params() { + if (auto* params = std::get_if(&rec_params)) { + return *params; + } + if (auto* params = std::get_if(&rec_params)) { + return static_cast(*params); + } + rec_params.emplace(); + return std::get(rec_params); + } + + const OneRecXAttentionParams* onerec_xattention_params() const { + return std::get_if(&rec_params); + } + + bool has_onerec_xattention_params() const { + return onerec_xattention_params() != nullptr; + } + + OneRecXAttentionParams& mutable_onerec_xattention_params() { + if (!has_onerec_xattention_params()) { + rec_params.emplace(); + } + return std::get(rec_params); + } + + const LlmRecMultiRoundParams* llmrec_params() const { + return std::get_if(&rec_params); + } + + bool has_llmrec_params() const { return llmrec_params() != nullptr; } + + LlmRecMultiRoundParams& mutable_llmrec_params() { + if (!has_llmrec_params()) { + rec_params.emplace(); + } + return std::get(rec_params); + } + + BatchInputMeta meta; + AttentionInput attention; + ModelEmbeddingInput embedding; + ParallelInput parallel; + BlockCopyInput block_copy; + MultiModalInput multimodal; + ExpertInput expert; + GraphInput graph; + + RecModelInputParams rec_params; + + // dit input data + DiTForwardInput dit_forward_input; + + // Optional attention metadata, built by executor + // Using shared_ptr with forward declaration to avoid circular dependency + std::shared_ptr attn_metadata; + + // Flag for CUDA graph capture mode + bool enable_cuda_graph = false; + // flatten token ids torch::Tensor token_ids; // flatten positions torch::Tensor positions; torch::Tensor token_ids_host; torch::Tensor positions_host; - ModelInputParams input_params; SamplingParameters sampling_params; SamplingParameters decoder_sampling_params; @@ -596,14 +777,14 @@ inline ForwardInput cp_partition_forward_input(const ForwardInput& input, int32_t cp_rank, int32_t cp_size) { if (cp_size <= 1 || !input.host_token_ids().defined() || - !input.input_params.meta.batch_forward_type.is_prefill()) { + !input.meta.batch_forward_type.is_prefill()) { return input; } CHECK_GT(cp_size, 0); CHECK_GE(cp_rank, 0); CHECK_LT(cp_rank, cp_size); - CHECK_GT(input.input_params.meta.num_sequences, 0); + CHECK_GT(input.meta.num_sequences, 0); ForwardInput output = input; output.set_host_views(output); @@ -616,13 +797,11 @@ inline ForwardInput cp_partition_forward_input(const ForwardInput& input, return detail::tensor_to_vector(tensor); }; - const std::vector seq_lens = - host_vector_or_tensor(input.input_params.attention.host.kv_seq_lens, - input.input_params.attention.device.kv_seq_lens); - const std::vector q_seq_lens = - host_vector_or_tensor(input.input_params.attention.host.q_seq_lens, - input.input_params.attention.device.q_seq_lens); - const int32_t num_sequences = input.input_params.meta.num_sequences; + const std::vector seq_lens = host_vector_or_tensor( + input.attention.host.kv_seq_lens, input.attention.device.kv_seq_lens); + const std::vector q_seq_lens = host_vector_or_tensor( + input.attention.host.q_seq_lens, input.attention.device.q_seq_lens); + const int32_t num_sequences = input.meta.num_sequences; auto to_seq_lens = [&](const std::vector& lens) -> std::vector { @@ -716,39 +895,34 @@ inline ForwardInput cp_partition_forward_input(const ForwardInput& input, output.positions_host = detail::gather_tensor_by_indices(input.host_positions(), gather_indices); output.positions = output.positions_host; - output.input_params.attention.host.new_cache_slots = host_vector_or_tensor( - input.input_params.attention.host.new_cache_slots, - input.input_params.attention.device.new_cache_slots); - if (!output.input_params.attention.host.new_cache_slots.empty() && - output.input_params.attention.host.new_cache_slots.size() == + output.attention.host.new_cache_slots = + host_vector_or_tensor(input.attention.host.new_cache_slots, + input.attention.device.new_cache_slots); + if (!output.attention.host.new_cache_slots.empty() && + output.attention.host.new_cache_slots.size() == static_cast(token_num)) { std::vector cp_new_cache_slots; cp_new_cache_slots.reserve(gather_indices.size()); for (int64_t idx : gather_indices) { cp_new_cache_slots.push_back( - output.input_params.attention.host - .new_cache_slots[static_cast(idx)]); - } - output.input_params.attention.host.new_cache_slots = - std::move(cp_new_cache_slots); - output.input_params.attention.device.new_cache_slots = - detail::int_vector_to_cpu_tensor( - output.input_params.attention.host.new_cache_slots); - } - if (input.input_params.embedding.input_embedding.defined() && - input.input_params.embedding.input_embedding.size(0) == token_num) { - output.input_params.embedding.input_embedding = - detail::gather_tensor_by_indices_on_dim( - input.input_params.embedding.input_embedding, - gather_indices, - /*dim=*/0); + output.attention.host.new_cache_slots[static_cast(idx)]); + } + output.attention.host.new_cache_slots = std::move(cp_new_cache_slots); + output.attention.device.new_cache_slots = + detail::int_vector_to_cpu_tensor(output.attention.host.new_cache_slots); + } + if (input.embedding.input_embedding.defined() && + input.embedding.input_embedding.size(0) == token_num) { + output.embedding.input_embedding = + detail::gather_tensor_by_indices_on_dim(input.embedding.input_embedding, + gather_indices, + /*dim=*/0); } - if (input.input_params.attention.device.new_cache_slot_offsets.defined() && - input.input_params.attention.device.new_cache_slot_offsets.size(0) == - token_num) { - output.input_params.attention.device.new_cache_slot_offsets = + if (input.attention.device.new_cache_slot_offsets.defined() && + input.attention.device.new_cache_slot_offsets.size(0) == token_num) { + output.attention.device.new_cache_slot_offsets = detail::gather_tensor_by_indices_on_dim( - input.input_params.attention.device.new_cache_slot_offsets, + input.attention.device.new_cache_slot_offsets, gather_indices, /*dim=*/0); } @@ -777,17 +951,17 @@ inline ForwardInput cp_partition_forward_input(const ForwardInput& input, std::partial_sum( cp_q_lens.begin(), cp_q_lens.end(), cp_q_cu_seq_lens.begin()); - output.input_params.attention.host.q_seq_lens = cp_q_seq_lens; - output.input_params.attention.host.kv_seq_lens = cp_seq_lens; - output.input_params.attention.host.q_cu_seq_lens = cp_q_cu_seq_lens; - output.input_params.attention.device.q_seq_lens = + output.attention.host.q_seq_lens = cp_q_seq_lens; + output.attention.host.kv_seq_lens = cp_seq_lens; + output.attention.host.q_cu_seq_lens = cp_q_cu_seq_lens; + output.attention.device.q_seq_lens = detail::int_vector_to_cpu_tensor(cp_q_seq_lens); - output.input_params.attention.device.kv_seq_lens = + output.attention.device.kv_seq_lens = detail::int_vector_to_cpu_tensor(cp_seq_lens); - output.input_params.attention.device.q_cu_seq_lens = + output.attention.device.q_cu_seq_lens = detail::int_vector_to_cpu_tensor(cp_q_cu_seq_lens); - output.input_params.meta.q_max_seq_len = cp_global_max_seq_len; - output.input_params.meta.kv_max_seq_len = cp_global_max_seq_len; + output.meta.q_max_seq_len = cp_global_max_seq_len; + output.meta.kv_max_seq_len = cp_global_max_seq_len; if (input.sampling_params.selected_token_idxes.defined() && input.sampling_params.selected_token_idxes.numel() > 0) { diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 18fb256738..f14dc0c136 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -2056,7 +2056,7 @@ inline void deserialize_forward_input_payload( context, forward_input.positions, forward_input.positions_host, stream); // input_params - auto& input_params = forward_input.input_params; + ForwardInput& input_params = forward_input; int32_t batch_forward_type; read_data(context, batch_forward_type); input_params.meta.batch_forward_type = BatchForwardType(batch_forward_type); @@ -2142,7 +2142,7 @@ inline void deserialize_forward_input_payload( for (auto& transfer : forward_input.transfer_kv_infos) { read_transfer_kv_info(context, transfer); } - read_eplb_info(context, forward_input.input_params.expert.eplb_info); + read_eplb_info(context, forward_input.expert.eplb_info); read_tensor_and_vector(context, input_params.attention.device.new_cache_slots, @@ -2374,7 +2374,7 @@ inline void serialize_forward_input_sections( write_tensor(context, host_token_ids); write_tensor(context, host_positions); - const auto& input_params = input.input_params; + const auto& input_params = input; write_data(context.descriptor, input_params.meta.batch_forward_type.value()); write_data(context.descriptor, input_params.meta.num_sequences); write_data(context.descriptor, input_params.meta.kv_max_seq_len); @@ -2716,7 +2716,6 @@ void packed_proto_to_forward_input( packed_proto_to_forward_input_impl( packed_forward_input, forward_input, device, stream); } - ForwardSharedMemoryManager::ForwardSharedMemoryManager(const std::string& name, size_t size, bool& is_creator, diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 54b650887d..c7219d420b 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -115,8 +115,8 @@ std::optional LLMWorkerImpl::step_internal( context_.get_model_args().n_layers()); #endif #if defined(USE_NPU) || defined(USE_MLU) - const_cast(&(input.input_params)) - ->parallel.layer_synchronizer = layer_synchronizer; + const_cast(input).parallel.layer_synchronizer = + layer_synchronizer; futures.emplace_back( kv_cache_transfer_->push_kv_blocks_async(input.transfer_kv_infos, @@ -127,12 +127,11 @@ std::optional LLMWorkerImpl::step_internal( } if (FLAGS_enable_eplb) { - eplb_executor_->eplb_execute(input.input_params.expert.eplb_info); + eplb_executor_->eplb_execute(input.expert.eplb_info); } // call model executor forward to get hidden states - auto model_output = model_executor_->forward( - input.token_ids, input.positions, kv_caches_, input.input_params); + auto model_output = model_executor_->forward(input, kv_caches_); if (!model_output.hidden_states.defined()) { return std::nullopt; } @@ -210,8 +209,7 @@ std::optional LLMWorkerImpl::step_internal( } else { embeddings = model_output.hidden_states; } - if (!input.input_params.meta.batch_forward_type.is_decode() && - !is_spec_draft_) { + if (!input.meta.batch_forward_type.is_decode() && !is_spec_draft_) { output.sample_output.embeddings = embeddings; } else if (sampling_params.selected_token_idxes.defined()) { output.sample_output.embeddings = embeddings.index_select( diff --git a/xllm/core/runtime/mlu_graph_executor_impl.cpp b/xllm/core/runtime/mlu_graph_executor_impl.cpp index 14605a71a4..d5f81d0d8f 100644 --- a/xllm/core/runtime/mlu_graph_executor_impl.cpp +++ b/xllm/core/runtime/mlu_graph_executor_impl.cpp @@ -114,16 +114,16 @@ uint32_t get_tp_size(const xllm::runtime::Options& options) { } uint32_t get_graph_dp_tokens(uint32_t actual_tokens, - const xllm::ModelInputParams& params, + const xllm::ParallelInput& parallel, const xllm::runtime::Options& options) { - if (params.parallel.dp_global_token_nums.size() <= 1) { + if (parallel.dp_global_token_nums.size() <= 1) { return get_bucket_num_tokens(actual_tokens); } const auto max_token_num = - std::max_element(params.parallel.dp_global_token_nums.begin(), - params.parallel.dp_global_token_nums.end()); - CHECK(max_token_num != params.parallel.dp_global_token_nums.end()) + std::max_element(parallel.dp_global_token_nums.begin(), + parallel.dp_global_token_nums.end()); + CHECK(max_token_num != parallel.dp_global_token_nums.end()) << "dp_global_token_nums is empty"; uint32_t bucket_tokens = get_bucket_num_tokens(static_cast(*max_token_num)); @@ -138,52 +138,52 @@ int64_t get_seq_lens_capacity(const xllm::runtime::Options& options) { return max_seqs * seq_expand + 1; } -xllm::ModelInputParams make_graph_params(const xllm::ModelInputParams& params, - uint32_t padding_num_tokens) { - xllm::ModelInputParams graph_params = params; - if (params.parallel.dp_global_token_nums.size() > 1) { - graph_params.parallel.dp_global_token_nums = - std::vector(params.parallel.dp_global_token_nums.size(), +xllm::ForwardInput make_graph_input(const xllm::ForwardInput& input, + uint32_t padding_num_tokens) { + xllm::ForwardInput graph_input = input; + if (input.parallel.dp_global_token_nums.size() > 1) { + graph_input.parallel.dp_global_token_nums = + std::vector(input.parallel.dp_global_token_nums.size(), static_cast(padding_num_tokens)); } - return graph_params; + return graph_input; } RunMode get_run_mode(const xllm::runtime::Options& options, - const xllm::ModelInputParams& params) { + const xllm::BatchInputMeta& meta, + const xllm::ParallelInput& parallel) { if (options.is_draft_engine()) { return RunMode::kDraft; } - if (!params.meta.batch_forward_type.is_decode()) { + if (!meta.batch_forward_type.is_decode()) { return RunMode::kNonDecode; } - if (params.meta.q_max_seq_len == 0) { + if (meta.q_max_seq_len == 0) { return RunMode::kDummy; } - if (params.parallel.dp_global_token_nums.size() <= 1) { + if (parallel.dp_global_token_nums.size() <= 1) { return RunMode::kGraph; } - if (has_zero_tokens(params.parallel.dp_global_token_nums)) { + if (has_zero_tokens(parallel.dp_global_token_nums)) { return RunMode::kDummy; } - if (params.parallel.dp_is_decode.size() != - params.parallel.dp_global_token_nums.size()) { + if (parallel.dp_is_decode.size() != parallel.dp_global_token_nums.size()) { return RunMode::kBadDpMeta; } - if (std::find(params.parallel.dp_is_decode.begin(), - params.parallel.dp_is_decode.end(), - 0) != params.parallel.dp_is_decode.end()) { + if (std::find(parallel.dp_is_decode.begin(), + parallel.dp_is_decode.end(), + 0) != parallel.dp_is_decode.end()) { return RunMode::kMixedDp; } - if (!dp_tokens_equal(params.parallel.dp_global_token_nums)) { - if (params.meta.q_max_seq_len == 1) { + if (!dp_tokens_equal(parallel.dp_global_token_nums)) { + if (meta.q_max_seq_len == 1) { return RunMode::kPaddedDpGraph; } return RunMode::kUnevenDp; @@ -232,31 +232,31 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, kv_seq_lens_ = torch::zeros({max_seq_lens}, int_tensor_options); } -void GraphPersistentParam::init_params(const ModelInputParams& params, +void GraphPersistentParam::init_params(const ForwardInput& input, uint32_t padding_num_tokens, uint32_t padding_needed) { - params_ = params.to(tokens_.device()); - params_.attention.device.q_seq_lens = q_seq_lens_.slice( - 0, 0, params.attention.device.q_seq_lens.size(0) + padding_needed); - params_.attention.device.kv_seq_lens = kv_seq_lens_.slice( - 0, 0, params.attention.device.kv_seq_lens.size(0) + padding_needed); - params_.attention.device.new_cache_slots = + input_ = input.to(tokens_.device(), output_.scalar_type()); + input_.attention.device.q_seq_lens = q_seq_lens_.slice( + 0, 0, input.attention.device.q_seq_lens.size(0) + padding_needed); + input_.attention.device.kv_seq_lens = kv_seq_lens_.slice( + 0, 0, input.attention.device.kv_seq_lens.size(0) + padding_needed); + input_.attention.device.new_cache_slots = new_cache_slots_.slice(0, 0, padding_num_tokens); - params_.attention.device.block_tables = + input_.attention.device.block_tables = block_table_.slice(0, 0, padding_num_tokens); - if (params.embedding.input_embedding.defined()) { + if (input.embedding.input_embedding.defined()) { if (!input_embeds_.defined()) { input_embeds_ = torch::zeros_like(output_); } - params_.embedding.input_embedding = + input_.embedding.input_embedding = input_embeds_.slice(0, 0, padding_num_tokens); } } void GraphPersistentParam::update_input_buffer(const torch::Tensor& tokens, const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& params, uint32_t padding_needed) { // Copy data from input parameters to persistent graph tensors int32_t slice_dim = use_mrope_ ? 1 : 0; @@ -340,11 +340,12 @@ void MluGraph::capture(CausalLM* model, torch_mlu::mlu::MLUStreamGuard guard(torch_mlu::getStreamFromPool()); graph_ = torch_mlu::MLUGraph(); graph_.capture_begin(pool, cnrtQueueCaptureModeRelaxed); - auto forward_result = model->forward( - persistent_param_->tokens_.slice(0, 0, padding_num_tokens_), - persistent_param_->positions_.slice(slice_dim, 0, padding_num_tokens_), - kv_cache, - persistent_param_->params_); + ForwardInput graph_input = persistent_param_->input_; + graph_input.token_ids = + persistent_param_->tokens_.slice(0, 0, padding_num_tokens_); + graph_input.positions = + persistent_param_->positions_.slice(slice_dim, 0, padding_num_tokens_); + auto forward_result = model->forward(graph_input, kv_cache); persistent_param_->output_.slice(0, 0, forward_result.hidden_states.size(0)) .copy_(forward_result.hidden_states, true); // Only capture aux_hidden_states when enable_graph_aux_hidden_states is on @@ -381,7 +382,7 @@ ModelOutput MluGraph::replay() { void MluGraph::update_input_buffer(const torch::Tensor& tokens, const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& params, bool is_init) { uint32_t padding_needed = padding_num_tokens_ - tokens.size(0); if (is_init) { @@ -406,11 +407,9 @@ ForwardInput MluGraphExecutorImpl::prepare_inputs(Batch& batch) { options_.num_decoding_tokens(), 0, args_, options_.cp_size()); } -ModelOutput MluGraphExecutorImpl::run_eager(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { - RunMode run_mode = get_run_mode(options_, params); +ModelOutput MluGraphExecutorImpl::run_eager(const ForwardInput& input, + std::vector& kv_caches) { + RunMode run_mode = get_run_mode(options_, input.meta, input.parallel); if (run_mode == RunMode::kDraft) { LOG_FIRST_N(INFO, 1) << "MLU graph fallback to eager for draft worker"; } else if (run_mode == RunMode::kDummy) { @@ -427,7 +426,7 @@ ModelOutput MluGraphExecutorImpl::run_eager(const torch::Tensor& tokens, << "MLU graph fallback to eager because dp_is_decode is invalid"; } COUNTER_INC(num_model_execution_total_eager); - ModelOutput result = model_->forward(tokens, positions, kv_caches, params); + ModelOutput result = model_->forward(input, kv_caches); return make_graph_output(result.hidden_states, result.aux_hidden_states, options_.enable_graph_aux_hidden_states()); @@ -444,35 +443,35 @@ void MluGraphExecutorImpl::init_param_once() { // tokens: [num_decode_tokens] // positions: [num_decode_tokens] token pos in the sequence // returns: ModelOutput -ModelOutput MluGraphExecutorImpl::run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { - const RunMode run_mode = get_run_mode(options_, params); +ModelOutput MluGraphExecutorImpl::run(const ForwardInput& input, + std::vector& kv_caches) { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; + const RunMode run_mode = get_run_mode(options_, input.meta, input.parallel); if (!allow_graph(run_mode)) { - return run_eager(tokens, positions, kv_caches, params); + return run_eager(input, kv_caches); } init_param_once(); const uint32_t actual_tokens = static_cast(tokens.size(0)); const uint32_t graph_tokens = - get_graph_dp_tokens(actual_tokens, params, options_); + get_graph_dp_tokens(actual_tokens, input.parallel, options_); if (graph_tokens > kMaxGraphTokens) { LOG_FIRST_N(INFO, 1) << "MLU graph fallback to eager because graph bucket num_tokens " << graph_tokens << " exceeds limit " << kMaxGraphTokens; - return run_eager(tokens, positions, kv_caches, params); + return run_eager(input, kv_caches); } - const ModelInputParams graph_params = make_graph_params(params, graph_tokens); + const ForwardInput graph_input = make_graph_input(input, graph_tokens); - if (graph_params.parallel.dp_global_token_nums != - params.parallel.dp_global_token_nums) { + if (graph_input.parallel.dp_global_token_nums != + input.parallel.dp_global_token_nums) { LOG_FIRST_N(INFO, 4) << "MLU graph padded dp decode path: raw " << "dp_global_token_nums=" - << params.parallel.dp_global_token_nums + << input.parallel.dp_global_token_nums << ", graph dp_global_token_nums=" - << graph_params.parallel.dp_global_token_nums + << graph_input.parallel.dp_global_token_nums << ", tp_size=" << get_tp_size(options_) << ", graph_tokens=" << graph_tokens; } @@ -480,7 +479,7 @@ ModelOutput MluGraphExecutorImpl::run(const torch::Tensor& tokens, auto it = graphs_.find(graph_tokens); if (it != graphs_.end()) { MluGraph* cur_graph = it->second.get(); - cur_graph->update_input_buffer(tokens, positions, graph_params); + cur_graph->update_input_buffer(tokens, positions, graph_input); ModelOutput result = cur_graph->replay(); // Return only the actual num_tokens portion auto hidden_states = result.hidden_states.slice(0, 0, actual_tokens); @@ -497,7 +496,7 @@ ModelOutput MluGraphExecutorImpl::run(const torch::Tensor& tokens, std::unique_ptr graph = std::make_unique(persistent_param_.get(), graph_tokens); - graph->update_input_buffer(tokens, positions, graph_params, true); + graph->update_input_buffer(tokens, positions, graph_input, true); graph->capture(model_, kv_caches, pool_, options_); graphs_[graph_tokens] = std::move(graph); // Return the output from capture diff --git a/xllm/core/runtime/mlu_graph_executor_impl.h b/xllm/core/runtime/mlu_graph_executor_impl.h index 255cbf056a..3b21601d05 100644 --- a/xllm/core/runtime/mlu_graph_executor_impl.h +++ b/xllm/core/runtime/mlu_graph_executor_impl.h @@ -22,8 +22,8 @@ limitations under the License. #include "executor_impl_factory.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/causal_lm.h" -#include "framework/model/model_input_params.h" #include "options.h" +#include "runtime/forward_params.h" namespace xllm::mlu { // Helper class to hold persistent parameters for graph execution @@ -36,20 +36,20 @@ class GraphPersistentParam { ~GraphPersistentParam() = default; - void init_params(const ModelInputParams& params, + void init_params(const ForwardInput& input, uint32_t padding_num_tokens, uint32_t padding_needed); // Update persistent tensors with new input data void update_input_buffer(const torch::Tensor& tokens, const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, uint32_t padding_needed); // input tensors torch::Tensor tokens_; torch::Tensor positions_; - ModelInputParams params_; + ForwardInput input_; // mrope bool use_mrope_ = false; // output @@ -87,7 +87,7 @@ class MluGraph { ModelOutput replay(); void update_input_buffer(const torch::Tensor& tokens, const torch::Tensor& positions, - const ModelInputParams& params, + const ForwardInput& input, bool is_init = false); private: @@ -114,16 +114,12 @@ class MluGraphExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; // Execute model with graph optimization for decode phase - ModelOutput run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override; + ModelOutput run(const ForwardInput& input, + std::vector& kv_caches) override; private: - ModelOutput run_eager(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params); + ModelOutput run_eager(const ForwardInput& input, + std::vector& kv_caches); void init_param_once(); CausalLM* model_; // not owned diff --git a/xllm/core/runtime/mm_embed_vlm_worker_impl.cpp b/xllm/core/runtime/mm_embed_vlm_worker_impl.cpp index bd6cb18b2a..c4a572baa6 100644 --- a/xllm/core/runtime/mm_embed_vlm_worker_impl.cpp +++ b/xllm/core/runtime/mm_embed_vlm_worker_impl.cpp @@ -58,16 +58,14 @@ std::optional MMEmbedVLMWorkerImpl::step( // TODO remove language params in only vision model forward. // TODO to adapt multi stream parallel later, just use [0] temporarily // all tensors should be on the same device as model - auto flatten_tokens = input.token_ids.to(device_); - auto flatten_positions = input.positions.to(device_); - auto params = input.input_params.to(device_); - auto sampling_params = input.sampling_params.to(device_, dtype_); + ForwardInput model_input = input.to(device_, dtype_); + auto sampling_params = model_input.sampling_params; CHECK(input.sampling_params.is_embeddings) << "Only mm embedding is supported."; // call model executor forward to get hidden states MMEmbeddingVLM* em_model = dynamic_cast(model_.get()); - auto encode_output = em_model->encode(params); + auto encode_output = em_model->encode(model_input); const auto it = encode_output.find("image|embedding"); if (it == encode_output.end() || !std::holds_alternative>(it->second)) { @@ -93,4 +91,4 @@ std::optional MMEmbedVLMWorkerImpl::step( return output; } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index 76bc01b7ff..8a835e7a31 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -45,7 +45,7 @@ void set_token_ids_device_tensor(ForwardInput& input, const torch::TensorOptions& token_options) { CHECK(token_ids.defined()) << "draft token_ids must be defined"; torch::Tensor flat_token_ids = token_ids.flatten(); - CHECK_EQ(flat_token_ids.numel(), input.input_params.meta.num_sequences) + CHECK_EQ(flat_token_ids.numel(), input.meta.num_sequences) << "draft token count must match num_sequences"; input.device_tensors_ready = false; @@ -311,7 +311,7 @@ void MTPWorkerImpl::prepare_work_before_execute(const ForwardInput& input, std::optional MTPWorkerImpl::step_empty( const ForwardInput& input) { - if (!input.input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { auto output = impl_->step(input); auto draft_output = draft_impl_->step(input); (void)draft_output; @@ -319,8 +319,7 @@ std::optional MTPWorkerImpl::step_empty( return output; } else { ForwardInput new_input = input; - for (int32_t& token_num : - new_input.input_params.parallel.dp_global_token_nums) { + for (int32_t& token_num : new_input.parallel.dp_global_token_nums) { token_num *= 2; } auto draft_extend_future = draft_impl_->step_async(new_input); @@ -335,8 +334,7 @@ std::optional MTPWorkerImpl::step_empty( } new_input = input; - for (int32_t& token_num : - new_input.input_params.parallel.dp_global_token_nums) { + for (int32_t& token_num : new_input.parallel.dp_global_token_nums) { token_num *= options_.num_speculative_tokens() + 1; } auto future = impl_->step_async(new_input); @@ -363,7 +361,7 @@ std::optional MTPWorkerImpl::step_prefill( auto& embeddings = output.sample_output.embeddings; if (embeddings.defined()) { - prefill_input.input_params.embedding.input_embedding = embeddings.clone(); + prefill_input.embedding.input_embedding = embeddings.clone(); } if (output.sample_output.next_tokens.defined()) { replace_host_token_placeholders(prefill_input, @@ -382,8 +380,8 @@ std::optional MTPWorkerImpl::step_prefill( if (input.sampling_params.selected_token_idxes.defined()) { embedding_cache_->write_prefill_target_context( - input.input_params.embedding.embedding_ids, - input.input_params.embedding.request_ids, + input.embedding.embedding_ids, + input.embedding.request_ids, output.sample_output.next_tokens, embeddings, input.sampling_params.selected_token_idxes); @@ -400,7 +398,7 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, ForwardInput& prefill_input) { c10::StreamGuard stream_guard = prepare_stream_->set_stream_guard(); prefill_input = input.to(device_, dtype_); - auto& input_params = prefill_input.input_params; + ForwardInput& input_params = prefill_input; auto& extra_token_ids = input_params.embedding.extra_token_ids; const torch::Tensor& token_ids = input.token_ids_host; @@ -439,11 +437,9 @@ std::optional MTPWorkerImpl::step_decode( // Get decode state of last step std::vector last_states = - embedding_cache_->read_decode_states( - input.input_params.embedding.embedding_ids, - input.input_params.embedding.request_ids); - CHECK_EQ(last_states.size(), - input.input_params.embedding.embedding_ids.size()) + embedding_cache_->read_decode_states(input.embedding.embedding_ids, + input.embedding.request_ids); + CHECK_EQ(last_states.size(), input.embedding.embedding_ids.size()) << "decode target state count mismatch"; update_decode_step_input(input, last_states); prepare_draft_extend_inputs(input, last_states, current_draft_input); @@ -473,8 +469,7 @@ std::optional MTPWorkerImpl::step_decode( set_token_ids_device_tensor(current_draft_input, last_output.next_tokens, current_draft_input.token_ids.options()); - current_draft_input.input_params.embedding.input_embedding = - last_output.embeddings; + current_draft_input.embedding.input_embedding = last_output.embeddings; } COUNTER_ADD(speculative_execution_latency_seconds_draft, timer.elapsed_seconds()); @@ -553,14 +548,13 @@ void MTPWorkerImpl::write_target_context_to_cache( const SampleOutput& validate_output) { CHECK(embedding_cache_ != nullptr) << "embedding_cache_ must be initialized before target cache write"; - CHECK(!input.input_params.embedding.embedding_ids.empty()) + CHECK(!input.embedding.embedding_ids.empty()) << "target context cache write requires embedding ids"; - embedding_cache_->write_target_context( - input.input_params.embedding.embedding_ids, - input.input_params.embedding.request_ids, - validate_output.next_tokens, - validate_output.embeddings, - options_.num_speculative_tokens()); + embedding_cache_->write_target_context(input.embedding.embedding_ids, + input.embedding.request_ids, + validate_output.next_tokens, + validate_output.embeddings, + options_.num_speculative_tokens()); } void MTPWorkerImpl::record_validate_metrics( @@ -599,7 +593,7 @@ void MTPWorkerImpl::process_draft_sample_output(SampleOutput& sample_output) { void MTPWorkerImpl::update_decode_step_input( ForwardInput& input, const std::vector& last_states) const { - const int32_t num_sequences = input.input_params.meta.num_sequences; + const int32_t num_sequences = input.meta.num_sequences; CHECK_EQ(last_states.size(), static_cast(num_sequences)) << "decode context state count mismatch"; const bool enable_cache_correction = enable_schedule_overlap(); @@ -638,7 +632,7 @@ void MTPWorkerImpl::update_decode_step_input( use_cache_correction ? state.position_offset : 0; const int32_t current_position = input_positions[seq_id] + position_offset; const int32_t current_kv_len = specBuilder::calc_kv_len( - input.input_params.attention.host.kv_seq_lens, seq_id, position_offset); + input.attention.host.kv_seq_lens, seq_id, position_offset); CHECK_EQ(current_position + 1, current_kv_len) << "decode context position/kv_len mismatch, seq_id=" << seq_id @@ -654,7 +648,7 @@ void MTPWorkerImpl::update_decode_step_input( input.token_ids_host = make_cpu_int_tensor(token_ids_vec); input.positions_host = make_cpu_int_tensor(positions_vec); - input.input_params.attention.host.kv_seq_lens = std::move(kv_seq_lens_vec); + input.attention.host.kv_seq_lens = std::move(kv_seq_lens_vec); input.device_tensors_ready = false; } @@ -663,7 +657,7 @@ void MTPWorkerImpl::prepare_validate_inputs(const ForwardInput& input, c10::StreamGuard stream_guard = prepare_stream_->set_stream_guard(); validate_input = input; validate_input.device_tensors_ready = false; - auto& input_params = validate_input.input_params; + ForwardInput& input_params = validate_input; torch::TensorOptions token_options = validate_input.token_ids.options(); torch::TensorOptions position_options = validate_input.positions.options(); @@ -680,7 +674,7 @@ void MTPWorkerImpl::prepare_validate_inputs(const ForwardInput& input, Slice positions = { input.positions_host.data_ptr(), static_cast(input.positions_host.numel())}; - Slice kv_seq_lens = input.input_params.attention.host.kv_seq_lens; + Slice kv_seq_lens = input.attention.host.kv_seq_lens; specBuilder::DecodeBuildBuffers buf; buf.out_token_ids.reserve(total_num_val_tokens); buf.out_positions.reserve(total_num_val_tokens); @@ -742,7 +736,8 @@ void MTPWorkerImpl::prepare_validate_inputs(const ForwardInput& input, input_params.meta.batch_forward_type = BatchForwardType::CHUNKED_PREFILL; } if (FLAGS_enable_atb_spec_kernel) { - specBuilder::update_input_params(input_params, + specBuilder::update_input_params(input_params.meta, + input_params.attention, buf, num_val_tokens, std::move(atb_q_seq_lens_vec), @@ -750,7 +745,8 @@ void MTPWorkerImpl::prepare_validate_inputs(const ForwardInput& input, atb_kv_max_seq_len, std::move(atb_kv_seq_lens_vec)); } else { - specBuilder::update_input_params(input_params, + specBuilder::update_input_params(input_params.meta, + input_params.attention, buf, 1, std::move(buf.out_q_seq_lens), @@ -778,7 +774,7 @@ void MTPWorkerImpl::prepare_draft_extend_inputs( ForwardInput& extend_input) { extend_input = base_input; extend_input.device_tensors_ready = false; - auto& input_params = extend_input.input_params; + ForwardInput& input_params = extend_input; const int32_t num_sequences = input_params.meta.num_sequences; const bool dp_enabled = parallel_args_.dp_size() > 1; @@ -874,7 +870,8 @@ void MTPWorkerImpl::prepare_draft_extend_inputs( input_params.meta.num_sequences = static_cast(buf.out_positions.size()); input_params.meta.batch_forward_type = BatchForwardType::DECODE; - specBuilder::update_input_params(input_params, + specBuilder::update_input_params(input_params.meta, + input_params.attention, buf, 1, std::move(buf.out_q_seq_lens), @@ -916,7 +913,7 @@ void MTPWorkerImpl::prepare_draft_inputs(const ForwardInput& input, draft_input = input; draft_input.device_tensors_ready = false; - auto& input_params = draft_input.input_params; + ForwardInput& input_params = draft_input; const int32_t num_sequences = input_params.meta.num_sequences; const int32_t block_size = options_.block_size(); specBuilder::DecodeRowContext row_ctx = @@ -940,7 +937,8 @@ void MTPWorkerImpl::prepare_draft_inputs(const ForwardInput& input, torch::TensorOptions position_options = input.positions.options(); set_positions_tensor(draft_input, buf.out_positions, position_options); specBuilder::update_input_params( - input_params, + input_params.meta, + input_params.attention, buf, input_params.meta.q_max_seq_len, std::move(input_params.attention.host.q_seq_lens), diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp index ed3a57eaf6..ec181e61dd 100644 --- a/xllm/core/runtime/rec_worker_impl.cpp +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -344,26 +344,25 @@ void RecWorkerImpl::RecWorkPipeline::prepare_work_before_execute( #endif processed_inputs = inputs.to(runtime_.worker.device(), runtime_.worker.dtype()); - auto& input_params = processed_inputs.input_params; - runtime_.worker.apply_kv_block_swaps(input_params); + ForwardInput& input_params = processed_inputs; + runtime_.worker.apply_kv_block_swaps(input_params.block_copy); #if defined(USE_NPU) if (runtime_.context->get_model_args().enable_mla() && input_params.meta.batch_forward_type.is_chunked_prefill()) { - runtime_.worker.prepare_mla_prefixcache_inputs(input_params); + runtime_.worker.prepare_mla_prefixcache_inputs(input_params.attention); } if (!runtime_.context->get_parallel_args().mapping_data().empty() && (runtime_.context->get_parallel_args().dp_size() > 1 || runtime_.context->get_parallel_args().ep_size() > 1)) { - torch::Tensor token_size_per_dp_group = torch::tensor( - processed_inputs.input_params.parallel.dp_global_token_nums, - torch::TensorOptions() - .device(torch::kCPU) - .dtype(torch::kInt32) - .pinned_memory(true)); - bool is_prefill = - processed_inputs.input_params.meta.batch_forward_type.is_prefill(); + torch::Tensor token_size_per_dp_group = + torch::tensor(processed_inputs.parallel.dp_global_token_nums, + torch::TensorOptions() + .device(torch::kCPU) + .dtype(torch::kInt32) + .pinned_memory(true)); + bool is_prefill = processed_inputs.meta.batch_forward_type.is_prefill(); DpEpPadding dp_ep_padding( token_size_per_dp_group, runtime_.context->get_model_args().num_experts_per_tok(), @@ -371,8 +370,7 @@ void RecWorkerImpl::RecWorkPipeline::prepare_work_before_execute( runtime_.worker.device(), runtime_.worker.dtype(), is_prefill); - processed_inputs.input_params.parallel.dp_ep_padding_data = - dp_ep_padding.build(); + processed_inputs.parallel.dp_ep_padding_data = dp_ep_padding.build(); } #endif } @@ -400,8 +398,8 @@ std::optional RecWorkerImpl::RecWorkPipeline::step( runtime_.context->get_model_args().n_layers()); #endif #if defined(USE_NPU) || defined(USE_MLU) - const_cast(&(input.input_params)) - ->parallel.layer_synchronizer = layer_synchronizer; + const_cast(input).parallel.layer_synchronizer = + layer_synchronizer; futures.emplace_back( runtime_.worker.kv_cache_transfer_->push_kv_blocks_async( @@ -413,15 +411,13 @@ std::optional RecWorkerImpl::RecWorkPipeline::step( } if (FLAGS_enable_eplb) { - runtime_.eplb_executor->eplb_execute(input.input_params.expert.eplb_info); + runtime_.eplb_executor->eplb_execute(input.expert.eplb_info); } // temporarily use [0], will be adapted in next pr // call model executor forward to get hidden states - auto model_output = runtime_.executor->forward(input.token_ids, - input.positions, - runtime_.worker.kv_caches_, - input.input_params); + auto model_output = + runtime_.executor->forward(input, runtime_.worker.kv_caches_); if (!model_output.hidden_states.defined()) { return std::nullopt; } @@ -492,7 +488,7 @@ std::optional RecWorkerImpl::RecWorkPipeline::step( } if (runtime_.worker.options_.enable_speculative_decode()) { - if (!input.input_params.meta.batch_forward_type.is_decode() && + if (!input.meta.batch_forward_type.is_decode() && !runtime_.worker.is_spec_draft_) { output.sample_output.embeddings = model_output.hidden_states; } else if (sampling_params.selected_token_idxes.defined()) { @@ -578,7 +574,7 @@ void RecWorkerImpl::OneRecWorkPipeline::prepare_work_before_execute( ForwardInput& processed_inputs) { RecWorkPipeline::prepare_work_before_execute(inputs, processed_inputs); - auto& onerec_params = processed_inputs.input_params.mutable_onerec_params(); + auto& onerec_params = processed_inputs.mutable_onerec_params(); if (!onerec_params.decoder_context_embedding.defined()) { return; } @@ -636,9 +632,8 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( runtime_.worker.device_.set_device(); const auto& sampling_params = input.sampling_params; - const auto& input_params = input.input_params; - const auto* onerec_params = input_params.onerec_params(); + const auto* onerec_params = input.onerec_params(); CHECK(onerec_params != nullptr) << "OneRec requires rec_params."; const OneRecModelInputParams& rec_params = *onerec_params; @@ -660,14 +655,12 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( LOG(ERROR) << "OneRec prefill requires encoder context."; return std::nullopt; } - ModelInputParams decoder_params = input_params; - decoder_params.mutable_onerec_params().is_encoder_forward = false; - decoder_params.mutable_onerec_params().has_encoder_output = - rec_params.has_encoder_output; - auto model_output = runtime_.executor->forward(input.token_ids, - input.positions, - runtime_.worker.kv_caches_, - decoder_params); + ForwardInput decoder_input = input; + auto& decoder_onerec_params = decoder_input.mutable_onerec_params(); + decoder_onerec_params.is_encoder_forward = false; + decoder_onerec_params.has_encoder_output = rec_params.has_encoder_output; + auto model_output = + runtime_.executor->forward(decoder_input, runtime_.worker.kv_caches_); hidden_states = model_output.hidden_states; } else { const bool has_sparse_embedding = @@ -680,34 +673,31 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( return std::nullopt; } - ModelInputParams encoder_params = input_params; - auto& mutable_onerec_params = encoder_params.mutable_onerec_params(); - mutable_onerec_params.is_encoder_forward = true; - mutable_onerec_params.is_hybrid_mode = has_sparse_embedding; + ForwardInput encoder_input = input; + auto& encoder_onerec_params = encoder_input.mutable_onerec_params(); + encoder_onerec_params.is_encoder_forward = true; + encoder_onerec_params.is_hybrid_mode = has_sparse_embedding; torch::Tensor encoder_tokens; if (has_sparse_embedding) { encoder_tokens = rec_params.encoder_sparse_embedding; } else { - mutable_onerec_params.is_hybrid_mode = false; + encoder_onerec_params.is_hybrid_mode = false; encoder_tokens = rec_params.encoder_token_ids; } + encoder_input.token_ids = encoder_tokens; + encoder_input.positions = rec_params.encoder_positions; auto encoder_output = - runtime_.executor->forward(encoder_tokens, - rec_params.encoder_positions, - runtime_.worker.kv_caches_, - encoder_params); + runtime_.executor->forward(encoder_input, runtime_.worker.kv_caches_); - ModelInputParams decoder_params = input_params; - auto& decoder_onerec_params = decoder_params.mutable_onerec_params(); + ForwardInput decoder_input = input; + auto& decoder_onerec_params = decoder_input.mutable_onerec_params(); decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = encoder_output.hidden_states.defined(); - auto model_output = runtime_.executor->forward(input.token_ids, - input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = + runtime_.executor->forward(decoder_input, runtime_.worker.kv_caches_); hidden_states = model_output.hidden_states; } } else { @@ -715,14 +705,12 @@ std::optional RecWorkerImpl::OneRecWorkPipeline::step( LOG(ERROR) << "OneRec decode requires encoder context."; return std::nullopt; } - ModelInputParams decoder_params = input_params; - decoder_params.mutable_onerec_params().is_encoder_forward = false; - decoder_params.mutable_onerec_params().has_encoder_output = - rec_params.has_encoder_output; - auto model_output = runtime_.executor->forward(input.token_ids, - input.positions, - runtime_.worker.kv_caches_, - decoder_params); + ForwardInput decoder_input = input; + auto& decoder_onerec_params = decoder_input.mutable_onerec_params(); + decoder_onerec_params.is_encoder_forward = false; + decoder_onerec_params.has_encoder_output = rec_params.has_encoder_output; + auto model_output = + runtime_.executor->forward(decoder_input, runtime_.worker.kv_caches_); hidden_states = model_output.hidden_states; } @@ -1111,8 +1099,7 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( } #endif - auto& onerec_params = - processed_inputs.input_params.mutable_onerec_xattention_params(); + auto& onerec_params = processed_inputs.mutable_onerec_xattention_params(); const auto& args = runtime_.context->get_model_args(); const auto& parallel_args = runtime_.context->get_parallel_args(); const int64_t decoder_kv_heads = args.decoder_n_kv_heads().value_or( @@ -1121,9 +1108,7 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( decoder_kv_heads / std::max(parallel_args.world_size(), 1); const int64_t head_dim = args.decoder_head_dim(); const int64_t batch_size = std::max( - onerec_params.bs > 0 ? onerec_params.bs - : inputs.input_params.meta.num_sequences, - 1); + onerec_params.bs > 0 ? onerec_params.bs : inputs.meta.num_sequences, 1); int64_t shared_kv_tokens = 0; if (onerec_params.decoder_context_embedding.defined()) { const int64_t hidden_size = @@ -1135,8 +1120,7 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( } if (shared_kv_tokens <= 0) { shared_kv_tokens = - batch_size * - std::max(processed_inputs.input_params.meta.q_max_seq_len, 1); + batch_size * std::max(processed_inputs.meta.q_max_seq_len, 1); } const int32_t decoder_layers = static_cast(args.n_layers()); auto fp_options = torch::TensorOptions() @@ -1160,7 +1144,7 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( fp_options)); } prepare_unshared_kv_caches_for_input(inputs, onerec_params); - processed_inputs.input_params.attention.device.block_tables = + processed_inputs.attention.device.block_tables = torch::arange(batch_size, torch::TensorOptions() .dtype(torch::kInt32) @@ -1266,7 +1250,7 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( }; ForwardInput mutable_input = input; - CHECK(mutable_input.input_params.onerec_xattention_params() != nullptr) + CHECK(mutable_input.onerec_xattention_params() != nullptr) << "OneRec xattention pipeline requires onerec_xattention_params."; struct RoundResult { @@ -1280,7 +1264,7 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( int32_t current_step, const torch::Tensor& sequence_group, int32_t request_beam_width) -> std::optional { - auto* round_params = mutable_input.input_params.onerec_xattention_params(); + auto* round_params = mutable_input.onerec_xattention_params(); CHECK(round_params != nullptr) << "OneRec xattention pipeline requires onerec_xattention_params."; @@ -1395,17 +1379,14 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( LOG(ERROR) << "OneRec xattention prefill requires encoder context."; return std::nullopt; } - ModelInputParams decoder_params = mutable_input.input_params; + ForwardInput decoder_input = mutable_input; auto& decoder_onerec_params = - decoder_params.mutable_onerec_xattention_params(); + decoder_input.mutable_onerec_xattention_params(); decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = round_params->has_encoder_output; - auto model_output = - runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = runtime_.executor->forward( + decoder_input, runtime_.worker.kv_caches_); #if defined(USE_NPU) validate_selected_token_idxes_stage("decoder_forward"); #endif @@ -1423,40 +1404,34 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( return std::nullopt; } - ModelInputParams encoder_params = mutable_input.input_params; + ForwardInput encoder_input = mutable_input; auto& encoder_onerec_params = - encoder_params.mutable_onerec_xattention_params(); + encoder_input.mutable_onerec_xattention_params(); encoder_onerec_params.is_encoder_forward = true; encoder_onerec_params.is_hybrid_mode = has_sparse_embedding; - torch::Tensor encoder_tokens; if (has_sparse_embedding) { - encoder_tokens = round_params->encoder_sparse_embedding; + encoder_input.token_ids = round_params->encoder_sparse_embedding; } else { encoder_onerec_params.is_hybrid_mode = false; - encoder_tokens = round_params->encoder_token_ids; + encoder_input.token_ids = round_params->encoder_token_ids; } + encoder_input.positions = round_params->encoder_positions; - auto encoder_output = - runtime_.executor->forward(encoder_tokens, - round_params->encoder_positions, - runtime_.worker.kv_caches_, - encoder_params); + auto encoder_output = runtime_.executor->forward( + encoder_input, runtime_.worker.kv_caches_); #if defined(USE_NPU) validate_selected_token_idxes_stage("encoder_forward"); #endif - ModelInputParams decoder_params = mutable_input.input_params; + ForwardInput decoder_input = mutable_input; auto& decoder_onerec_params = - decoder_params.mutable_onerec_xattention_params(); + decoder_input.mutable_onerec_xattention_params(); decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = encoder_output.hidden_states.defined(); - auto model_output = - runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = runtime_.executor->forward( + decoder_input, runtime_.worker.kv_caches_); #if defined(USE_NPU) validate_selected_token_idxes_stage("decoder_forward"); #endif @@ -1467,16 +1442,14 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( LOG(ERROR) << "OneRec xattention decode requires encoder context."; return std::nullopt; } - ModelInputParams decoder_params = mutable_input.input_params; + ForwardInput decoder_input = mutable_input; auto& decoder_onerec_params = - decoder_params.mutable_onerec_xattention_params(); + decoder_input.mutable_onerec_xattention_params(); decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = round_params->has_encoder_output; - auto model_output = runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = + runtime_.executor->forward(decoder_input, runtime_.worker.kv_caches_); #if defined(USE_NPU) validate_selected_token_idxes_stage("decode_forward"); #endif @@ -1577,8 +1550,7 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( int32_t beam_width, const std::vector& decode_positions_vec, const torch::Tensor& sequence_group) { - auto& round_params = - mutable_input.input_params.mutable_onerec_xattention_params(); + auto& round_params = mutable_input.mutable_onerec_xattention_params(); const int32_t decode_step = std::max(round - 1, 0); round_params.rec_stage = OneRecModelInputParams::RecStage::DECODE; @@ -1624,11 +1596,10 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( torch::tensor(selected_token_idxes, int_options); mutable_input.decoder_sampling_params.num_return_sequences = mutable_input.sampling_params.num_return_sequences; - mutable_input.input_params.meta.batch_forward_type = - BatchForwardType::DECODE; - mutable_input.input_params.meta.num_sequences = batch_size * beam_width; - mutable_input.input_params.embedding.input_embedding = torch::Tensor(); - mutable_input.input_params.attn_metadata = nullptr; + mutable_input.meta.batch_forward_type = BatchForwardType::DECODE; + mutable_input.meta.num_sequences = batch_size * beam_width; + mutable_input.embedding.input_embedding = torch::Tensor(); + mutable_input.attn_metadata = nullptr; }; auto step_meta = mutable_input.step_meta(); @@ -1861,8 +1832,7 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( std::swap(beam_tensors.sequence_group, beam_tensors.out_seqgroup); std::swap(beam_tensors.acc_logprob, beam_tensors.out_log_probs); if (round > 0 && round < total_rounds - 1) { - auto& round_params = - mutable_input.input_params.mutable_onerec_xattention_params(); + auto& round_params = mutable_input.mutable_onerec_xattention_params(); execute_cache_select( beam_tensors.out_token_index, beam_tensors.out_beam_count_prefix_sums, @@ -2021,7 +1991,7 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline:: ForwardInput& processed_inputs) { auto device = runtime_.worker.device(); auto int_options = torch::TensorOptions().dtype(torch::kInt32).device(device); - auto& input_params = processed_inputs.input_params; + ForwardInput& input_params = processed_inputs; auto& llm_rec_params = input_params.mutable_llmrec_params(); const auto* step_meta = inputs.step_meta(); @@ -2208,10 +2178,8 @@ std::optional RecWorkerImpl::LlmRecMultiRoundPipeline::step( next_round_async_result); #endif - auto model_output = runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - mutable_input.input_params); + auto model_output = + runtime_.executor->forward(mutable_input, runtime_.worker.kv_caches_); if (!model_output.hidden_states.defined()) { return std::nullopt; } @@ -2357,9 +2325,9 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::execute_cache_select( auto block_table = torch::arange(batch_size, int32_options); const auto& unshared_k_caches = - input.input_params.mutable_llmrec_params().unshared_k_caches; + input.mutable_llmrec_params().unshared_k_caches; const auto& unshared_v_caches = - input.input_params.mutable_llmrec_params().unshared_v_caches; + input.mutable_llmrec_params().unshared_v_caches; xllm::kernel::npu::select_unshared_kv( /*beam_index=*/beam_index_local.reshape({-1}), @@ -2373,9 +2341,9 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::execute_cache_select( #elif defined(USE_CUDA) xllm::kernel::cuda::cache_select( beam_tensors.out_token_index, - input.input_params.mutable_llmrec_params().unshared_k_caches, - input.input_params.mutable_llmrec_params().unshared_v_caches, - input.input_params.attention.device.block_tables, + input.mutable_llmrec_params().unshared_k_caches, + input.mutable_llmrec_params().unshared_v_caches, + input.attention.device.block_tables, round - 1, beam_width, num_layers); @@ -2410,17 +2378,16 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_two_stage_round_input( #if defined(USE_NPU) // TODO: implement prepare_two_stage_round_input for NPU #elif defined(USE_CUDA) - auto& llm_rec_params = input.input_params.mutable_llmrec_params(); + auto& llm_rec_params = input.mutable_llmrec_params(); CHECK_EQ(FLAGS_enable_xattention_one_stage, false) << "prepare_two_stage_round_input should only be called when " "two-stage decode is enabled"; - input.input_params.attention.device.paged_kv_indices = torch::Tensor(); - input.input_params.attention.device.paged_kv_indptr = torch::Tensor(); - input.input_params.attention.device.paged_kv_last_page_len = torch::Tensor(); - input.input_params.meta.num_sequences = - llm_rec_params.batch_size * - std::max(llm_rec_params.beam_width, 1); + input.attention.device.paged_kv_indices = torch::Tensor(); + input.attention.device.paged_kv_indptr = torch::Tensor(); + input.attention.device.paged_kv_last_page_len = torch::Tensor(); + input.meta.num_sequences = llm_rec_params.batch_size * + std::max(llm_rec_params.beam_width, 1); // previous_step corresponds to the decode step that produced tokens for // this round. @@ -2443,8 +2410,8 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_two_stage_round_input( llm_rec_params.decode_positions_tensor_list[previous_step]; } - input.input_params.meta.batch_forward_type = BatchForwardType(2); - input.input_params.embedding.input_embedding = torch::Tensor(); + input.meta.batch_forward_type = BatchForwardType(2); + input.embedding.input_embedding = torch::Tensor(); cached_current_round_tensor_.fill_(previous_step); llm_rec_params.current_round_tensor = cached_current_round_tensor_; @@ -2497,11 +2464,10 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_two_stage_round_input( llm_rec_params.two_stage_paged_kv_indptr_expanded.copy_( paged_kv_indptr_values, /*non_blocking=*/true); - if (input.input_params.attention.device.block_tables.defined() && - input.input_params.attention.device.block_tables.numel() >= total_beam) { + if (input.attention.device.block_tables.defined() && + input.attention.device.block_tables.numel() >= total_beam) { llm_rec_params.two_stage_paged_kv_indices_expanded.copy_( - input.input_params.attention.device.block_tables.view({-1}).slice( - 0, 0, total_beam), + input.attention.device.block_tables.view({-1}).slice(0, 0, total_beam), /*non_blocking=*/true); } else { auto paged_kv_indices_values = torch::arange(total_beam, int_options); @@ -2511,7 +2477,7 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_two_stage_round_input( llm_rec_params.two_stage_paged_kv_last_page_len_expanded.fill_(previous_step + 1); - input.input_params.attn_metadata = nullptr; + input.attn_metadata = nullptr; #endif } @@ -2520,7 +2486,7 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_round_input_for_npu( int32_t round, const torch::Tensor& top_tokens, const BeamSearchTensors& beam_tensors) { - auto& llm_rec_params = input.input_params.mutable_llmrec_params(); + auto& llm_rec_params = input.mutable_llmrec_params(); CHECK(cached_current_round_tensor_.defined()); CHECK(cached_beam_width_tensor_.defined()); @@ -2528,7 +2494,7 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_round_input_for_npu( llm_rec_params.beam_width_tensor = cached_beam_width_tensor_; cached_current_round_tensor_.fill_(round); llm_rec_params.current_round_tensor = cached_current_round_tensor_; - input.input_params.attn_metadata = nullptr; + input.attn_metadata = nullptr; if (round > 0) { if (round == 1) { @@ -2547,8 +2513,8 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_round_input_for_npu( llm_rec_params.decode_positions_tensor_list[decode_step]; } - input.input_params.meta.batch_forward_type = BatchForwardType::DECODE; - input.input_params.embedding.input_embedding = torch::Tensor(); + input.meta.batch_forward_type = BatchForwardType::DECODE; + input.embedding.input_embedding = torch::Tensor(); } } @@ -2560,14 +2526,12 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_input_for_current_round( const BeamSearchTensors& beam_tensors) { #if defined(USE_CUDA) if (FLAGS_enable_xattention_one_stage) { - input.input_params.attention.device.paged_kv_indices = - results.paged_kv_indices; - input.input_params.attention.device.paged_kv_indptr = - results.paged_kv_indptr; - input.input_params.attention.device.paged_kv_last_page_len = + input.attention.device.paged_kv_indices = results.paged_kv_indices; + input.attention.device.paged_kv_indptr = results.paged_kv_indptr; + input.attention.device.paged_kv_last_page_len = results.paged_kv_last_page_len; - input.input_params.meta.num_sequences = - input.input_params.attention.device.paged_kv_last_page_len.numel(); + input.meta.num_sequences = + input.attention.device.paged_kv_last_page_len.numel(); } else { prepare_two_stage_round_input(input, round, top_tokens, beam_tensors); return; @@ -2586,7 +2550,7 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_input_for_current_round( input.token_ids = beam_tensors.out_token_ids.reshape({-1}); } - auto& llm_rec_params = input.input_params.mutable_llmrec_params(); + auto& llm_rec_params = input.mutable_llmrec_params(); if (!llm_rec_params.decode_positions_tensor_list.empty() && previous_step >= 0 && previous_step < static_cast( @@ -2595,11 +2559,11 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_input_for_current_round( llm_rec_params.decode_positions_tensor_list[previous_step]; } - input.input_params.meta.batch_forward_type = BatchForwardType(2); - input.input_params.embedding.input_embedding = torch::Tensor(); + input.meta.batch_forward_type = BatchForwardType(2); + input.embedding.input_embedding = torch::Tensor(); cached_current_round_tensor_.fill_(previous_step); llm_rec_params.current_round_tensor = cached_current_round_tensor_; - input.input_params.attn_metadata = nullptr; + input.attn_metadata = nullptr; } folly::SemiFuture< @@ -2741,12 +2705,12 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline:: // Phase B: schedule async computation for the next round, if any. if (round < total_rounds - 1) { - next_round_async_result = compute_next_round_input_async( - input.input_params.attention.device.kv_seq_lens, - round, - batch_size, - beam_width, - max_decode_step); + next_round_async_result = + compute_next_round_input_async(input.attention.device.kv_seq_lens, + round, + batch_size, + beam_width, + max_decode_step); } } @@ -2990,15 +2954,14 @@ void RecWorkerImpl::prepare_work_before_execute( } void RecWorkerImpl::prepare_multi_modal_data(ForwardInput& processed_inputs) { - if (!processed_inputs.input_params.multimodal.mm_data.valid()) { + if (!processed_inputs.multimodal.mm_data.valid()) { return; } torch::Tensor multi_modal_values; torch::Tensor multi_modal_indices; - const auto& processed_mm_data = - processed_inputs.input_params.multimodal.mm_data; + const auto& processed_mm_data = processed_inputs.multimodal.mm_data; if (auto res = processed_mm_data.get("MULTI_MODAL_VALUES")) { multi_modal_values = res.value(); } @@ -3026,8 +2989,7 @@ void RecWorkerImpl::prepare_multi_modal_data(ForwardInput& processed_inputs) { torch::indexing::Slice()}; input_tokens_embedding.index_put_(indices, multi_modal_values); - processed_inputs.input_params.embedding.input_embedding = - input_tokens_embedding; + processed_inputs.embedding.input_embedding = input_tokens_embedding; } std::optional RecWorkerImpl::step(const ForwardInput& input) { @@ -3056,8 +3018,7 @@ folly::SemiFuture> RecWorkerImpl::step_async( input_on_device); if (hierarchy_kv_cache_transfer_ != nullptr) { - hierarchy_kv_cache_transfer_->set_layer_synchronizer( - input_on_device.input_params); + hierarchy_kv_cache_transfer_->set_layer_synchronizer(input_on_device); } const auto output = work_pipelines_[index]->step(input_on_device); diff --git a/xllm/core/runtime/spec_input_builder.cpp b/xllm/core/runtime/spec_input_builder.cpp index a9e012c3b1..b94c1959a1 100644 --- a/xllm/core/runtime/spec_input_builder.cpp +++ b/xllm/core/runtime/spec_input_builder.cpp @@ -48,7 +48,7 @@ Slice get_positions(const ForwardInput& input) { } Slice get_kv_seq_lens(const ForwardInput& input) { - return input.input_params.attention.host.kv_seq_lens; + return input.attention.host.kv_seq_lens; } // Resolves a row token from either input token_ids[seq_id] or row.token_id. @@ -165,7 +165,7 @@ void update_kv_seq_lens_and_max(std::vector& kv_seq_lens_vec, DecodeRowContext make_decode_row_context(const ForwardInput& input) { DecodeRowContext ctx; - ctx.num_sequences = input.input_params.meta.num_sequences; + ctx.num_sequences = input.meta.num_sequences; CHECK_GE(ctx.num_sequences, 0) << "invalid num_sequences"; if (input.token_ids_host.defined()) { @@ -179,10 +179,9 @@ DecodeRowContext make_decode_row_context(const ForwardInput& input) { << ctx.positions.size() << ", num_sequences=" << ctx.num_sequences; ctx.kv_seq_lens = get_kv_seq_lens(input); - CHECK(input.input_params.attention.host.block_tables.defined()) + CHECK(input.attention.host.block_tables.defined()) << "host block_tables must be defined for decode row build"; - ctx.block_tables_owner = - input.input_params.attention.host.block_tables.contiguous(); + ctx.block_tables_owner = input.attention.host.block_tables.contiguous(); CHECK_EQ(ctx.block_tables_owner.dim(), 2) << "block_tables must be 2D, got " << ctx.block_tables_owner.sizes(); CHECK_LE(ctx.num_sequences, ctx.block_tables_owner.size(0)) @@ -303,16 +302,16 @@ void append_decode_row_from_last_step(const DecodeRowContext& ctx, append_decode_row(ctx, row, block_size, buf); } -torch::Tensor build_q_cu_seq_lens_tensor(const ModelInputParams& params, +torch::Tensor build_q_cu_seq_lens_tensor(const AttentionHostInput& attention, torch::Device device) { - CHECK_EQ(params.attention.host.q_seq_lens.empty(), - params.attention.host.q_cu_seq_lens.empty()) + CHECK_EQ(attention.q_seq_lens.empty(), attention.q_cu_seq_lens.empty()) << "q_seq_lens and q_cu_seq_lens must be provided together"; - return torch::tensor(params.attention.host.q_cu_seq_lens, + return torch::tensor(attention.q_cu_seq_lens, torch::dtype(torch::kInt).device(device)); } -void update_input_params(ModelInputParams& input_params, +void update_input_params(BatchInputMeta& meta, + AttentionInput& attention, DecodeBuildBuffers& buf, int32_t q_max_seq_len, std::vector q_seq_lens_vec, @@ -322,15 +321,14 @@ void update_input_params(ModelInputParams& input_params, bool update_block_tables) { CHECK_EQ(q_seq_lens_vec.empty(), q_cu_seq_lens_vec.empty()) << "q_seq_lens and q_cu_seq_lens must be provided together"; - input_params.meta.q_max_seq_len = q_max_seq_len; - input_params.attention.host.q_seq_lens = std::move(q_seq_lens_vec); - input_params.attention.host.q_cu_seq_lens = std::move(q_cu_seq_lens_vec); - input_params.meta.kv_max_seq_len = kv_max_seq_len; - input_params.attention.host.kv_seq_lens = std::move(kv_seq_lens_vec); - input_params.attention.host.new_cache_slots = - std::move(buf.out_new_cache_slots); + meta.q_max_seq_len = q_max_seq_len; + attention.host.q_seq_lens = std::move(q_seq_lens_vec); + attention.host.q_cu_seq_lens = std::move(q_cu_seq_lens_vec); + meta.kv_max_seq_len = kv_max_seq_len; + attention.host.kv_seq_lens = std::move(kv_seq_lens_vec); + attention.host.new_cache_slots = std::move(buf.out_new_cache_slots); if (update_block_tables) { - input_params.attention.host.block_tables = + attention.host.block_tables = create_flat_2d_tensor(buf.out_block_tables, buf.out_block_table_rows, buf.out_block_table_stride); diff --git a/xllm/core/runtime/spec_input_builder.h b/xllm/core/runtime/spec_input_builder.h index 64a1e909e2..f0dcf9c6f9 100644 --- a/xllm/core/runtime/spec_input_builder.h +++ b/xllm/core/runtime/spec_input_builder.h @@ -21,11 +21,11 @@ limitations under the License. #include #include +#include "framework/model/model_input_params.h" #include "util/slice.h" namespace xllm { -struct ModelInputParams; struct ForwardInput; namespace specBuilder { @@ -134,11 +134,12 @@ void update_kv_seq_lens_and_max(std::vector& kv_seq_lens_vec, int32_t& kv_max_seq_len); // Builds q_cu_seq_lens tensor from upstream-provided host values. -torch::Tensor build_q_cu_seq_lens_tensor(const ModelInputParams& params, +torch::Tensor build_q_cu_seq_lens_tensor(const AttentionHostInput& attention, torch::Device device = torch::kCPU); -// Updates common decode-side ModelInputParams fields from built buffers. -void update_input_params(ModelInputParams& input_params, +// Updates common decode-side model input fields from built buffers. +void update_input_params(BatchInputMeta& meta, + AttentionInput& attention, DecodeBuildBuffers& buf, int32_t q_max_seq_len, std::vector q_seq_lens_vec, diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 930946bf76..d88a111919 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -104,7 +104,7 @@ std::optional SpeculativeWorkerImpl::step( return step_empty(input); } - if (!input.input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { return step_prefill(input); } else { return step_decode(input); @@ -116,7 +116,7 @@ ForwardInput SpeculativeWorkerImpl::update_input_by_last_step_output( // only process decode batch, so prepare draft input here. ForwardInput& new_inputs = inputs; - auto& input_params = new_inputs.input_params; + ForwardInput& input_params = new_inputs; const int32_t num_sequences = input_params.meta.num_sequences; const int32_t block_size = options_.block_size(); @@ -205,7 +205,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( ForwardInput& validate_input) { validate_input = input.to(device_, dtype_); validate_input.device_tensors_ready = false; - auto& input_params = validate_input.input_params; + ForwardInput& input_params = validate_input; torch::TensorOptions token_options = validate_input.token_ids.options(); torch::TensorOptions position_options = validate_input.positions.options(); @@ -219,7 +219,7 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( Slice token_ids = tensor_slice(input.token_ids_host); Slice positions = tensor_slice(input.positions_host); - Slice kv_seq_lens = input.input_params.attention.host.kv_seq_lens; + Slice kv_seq_lens = input.attention.host.kv_seq_lens; specBuilder::DecodeBuildBuffers buf; buf.out_token_ids.reserve(total_num_val_tokens); buf.out_positions.reserve(total_num_val_tokens); @@ -288,7 +288,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( input_params.meta.batch_forward_type = BatchForwardType::CHUNKED_PREFILL; } if (FLAGS_enable_atb_spec_kernel) { - specBuilder::update_input_params(input_params, + specBuilder::update_input_params(input_params.meta, + input_params.attention, buf, num_val_tokens, std::move(atb_q_seq_lens_vec), @@ -296,7 +297,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( atb_kv_max_seq_len, std::move(atb_kv_seq_lens_vec)); } else { - specBuilder::update_input_params(input_params, + specBuilder::update_input_params(input_params.meta, + input_params.attention, buf, 1, std::move(buf.out_q_seq_lens), diff --git a/xllm/core/runtime/suffix_worker_impl.cpp b/xllm/core/runtime/suffix_worker_impl.cpp index 660c15c563..068e240180 100644 --- a/xllm/core/runtime/suffix_worker_impl.cpp +++ b/xllm/core/runtime/suffix_worker_impl.cpp @@ -77,13 +77,13 @@ SuffixWorkerImpl::SuffixWorkerImpl(const ParallelArgs& parallel_args, std::optional SuffixWorkerImpl::step_empty( const ForwardInput& input) { - if (!input.input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { auto output = impl_->step(input); output->sample_output.embeddings = torch::Tensor(); return output; } else { ForwardInput new_input = input; - for (auto& it : new_input.input_params.parallel.dp_global_token_nums) { + for (auto& it : new_input.parallel.dp_global_token_nums) { it *= options_.num_speculative_tokens() + 1; } @@ -103,9 +103,8 @@ std::optional SuffixWorkerImpl::step_prefill( COUNTER_ADD(speculative_execution_latency_seconds_target, timer.elapsed_seconds()); - const auto& input_params = input.input_params; - const int32_t num_sequences = input_params.meta.num_sequences; - const auto& request_ids = input_params.embedding.request_ids; + const int32_t num_sequences = input.meta.num_sequences; + const auto& request_ids = input.embedding.request_ids; if (suffix_cache_ != nullptr && request_ids.size() == static_cast(num_sequences)) { @@ -115,7 +114,7 @@ std::optional SuffixWorkerImpl::step_prefill( int32_t start_idx = 0; for (int32_t seq_id = 0; seq_id < num_sequences; ++seq_id) { - int32_t q_len = input_params.get_q_seq_len(seq_id); + int32_t q_len = input.get_q_seq_len(seq_id); Slice seq_tokens = tokens_ids_slice.slice(start_idx, start_idx + q_len); start_idx += q_len; @@ -174,9 +173,9 @@ std::optional SuffixWorkerImpl::step_prefill( std::optional SuffixWorkerImpl::step_decode( const ForwardInput& input) { const int32_t num_speculative_tokens = options_.num_speculative_tokens(); - const int32_t num_sequences = input.input_params.meta.num_sequences; + const int32_t num_sequences = input.meta.num_sequences; const int32_t num_val_tokens = num_speculative_tokens + 1; - const auto& request_ids = input.input_params.embedding.request_ids; + const auto& request_ids = input.embedding.request_ids; const bool has_request_ids = suffix_cache_ != nullptr && diff --git a/xllm/core/runtime/vlm_executor_impl.cpp b/xllm/core/runtime/vlm_executor_impl.cpp index ab17f324ae..d2dd2bf6e6 100644 --- a/xllm/core/runtime/vlm_executor_impl.cpp +++ b/xllm/core/runtime/vlm_executor_impl.cpp @@ -43,22 +43,21 @@ ForwardInput VlmExecutorImpl::prepare_inputs(Batch& batch) { options_.num_decoding_tokens(), 0, args_, options_.cp_size()); } -MMDict VlmExecutorImpl::encode(const ModelInputParams& params) { - return dynamic_cast(model_)->encode(params); +MMDict VlmExecutorImpl::encode(const ForwardInput& input) { + return dynamic_cast(model_)->encode(input); } -ModelOutput VlmExecutorImpl::run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) { +ModelOutput VlmExecutorImpl::run(const ForwardInput& input, + std::vector& kv_caches) { torch::NoGradGuard no_grad; - auto& mm_data = params.multimodal.mm_data; + ForwardInput model_input = input; + auto& mm_data = model_input.multimodal.mm_data; EncoderInputGatherVisitor input_gather; mm_data.foreach (input_gather); CHECK(input_gather.finish(mm_data)); mm_data.to(device_); - auto embedding = encode(params); + auto embedding = encode(model_input); EncoderOutputScatterVisitor scatter(embedding); mm_data.foreach (scatter); CHECK(scatter.finish()); @@ -67,14 +66,14 @@ ModelOutput VlmExecutorImpl::run(const torch::Tensor& tokens, mm_data.foreach (gather); CHECK(gather.finish(mm_data)); - params.embedding.input_embedding = - model_->get_input_embeddings(tokens, params); + model_input.embedding.input_embedding = model_->get_input_embeddings( + model_input.token_ids, model_input.multimodal); if (llm_executor_) { - return llm_executor_->run(tokens, positions, kv_caches, params); + return llm_executor_->run(model_input, kv_caches); } - return model_->forward(tokens, positions, kv_caches, params); + return model_->forward(model_input, kv_caches); } } // namespace xllm diff --git a/xllm/core/runtime/vlm_executor_impl.h b/xllm/core/runtime/vlm_executor_impl.h index 2a6044aae2..6d5fcdd5cb 100644 --- a/xllm/core/runtime/vlm_executor_impl.h +++ b/xllm/core/runtime/vlm_executor_impl.h @@ -43,12 +43,10 @@ class VlmExecutorImpl : public ExecutorImpl { ForwardInput prepare_inputs(Batch& batch) override; - ModelOutput run(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& params) override; + ModelOutput run(const ForwardInput& input, + std::vector& kv_caches) override; - virtual MMDict encode(const ModelInputParams& params); + virtual MMDict encode(const ForwardInput& input); private: // not own diff --git a/xllm/core/runtime/vlm_worker_impl.cpp b/xllm/core/runtime/vlm_worker_impl.cpp index d7b4d1c83b..5aa872feec 100755 --- a/xllm/core/runtime/vlm_worker_impl.cpp +++ b/xllm/core/runtime/vlm_worker_impl.cpp @@ -58,8 +58,7 @@ std::optional VLMWorkerImpl::step(const ForwardInput& input) { Timer timer; // TODO guojinrong, to adapt multi stream parallel later // call model executor forward to get hidden states - auto model_output = model_executor_->forward( - input.token_ids, input.positions, kv_caches_, input.input_params); + auto model_output = model_executor_->forward(input, kv_caches_); auto& sampling_params = input.sampling_params; torch::Tensor logits; if (sampling_params.selected_token_idxes.defined()) { diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index f4f0a0e545..86248c923f 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -119,7 +119,7 @@ class ScopedAtenLoadThreads { }; #if defined(USE_NPU) -void prepare_input_params_for_linear_attention(ModelInputParams& input_params) { +void prepare_input_params_for_linear_attention(ForwardInput& input_params) { int64_t batch_size = input_params.attention.device.block_tables.size(0); input_params.parallel.query_start_loc.resize(batch_size + 1, 0); int64_t max_seq_len = 0; @@ -469,69 +469,63 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, !enable_schedule_overlap() && options_.backend() == "llm"; auto prepare_input_on_current_stream = [&]() { processed_input = input.to(device_, dtype_); - auto& input_params = processed_input.input_params; + ForwardInput& input_params = processed_input; #if defined(USE_NPU) CpPrefillInputs tmp_cp_inputs; if (parallel_args_.cp_size() > 1 && - input.input_params.meta.batch_forward_type.is_prefill()) { - tmp_cp_inputs = prepare_cp_prefill_inputs( - parallel_args_.cp_size(), - input.token_ids, - input.positions, - input.input_params.attention.device.q_seq_lens); - processed_input.input_params.parallel.cp_prefill_inputs = - tmp_cp_inputs.to(device_); - CpEpPadding cp_ep_padding( - input.token_ids, - context_.get_model_args().num_experts_per_tok(), - context_.get_parallel_args().mapping_data(), - /*device=*/device_, - dtype_, - /*is_prefill=*/ - input.input_params.meta.batch_forward_type.is_prefill()); - processed_input.input_params.parallel.cp_ep_padding_data = - cp_ep_padding.build(); + input.meta.batch_forward_type.is_prefill()) { + tmp_cp_inputs = + prepare_cp_prefill_inputs(parallel_args_.cp_size(), + input.token_ids, + input.positions, + input.attention.device.q_seq_lens); + processed_input.parallel.cp_prefill_inputs = tmp_cp_inputs.to(device_); + CpEpPadding cp_ep_padding(input.token_ids, + context_.get_model_args().num_experts_per_tok(), + context_.get_parallel_args().mapping_data(), + /*device=*/device_, + dtype_, + /*is_prefill=*/ + input.meta.batch_forward_type.is_prefill()); + processed_input.parallel.cp_ep_padding_data = cp_ep_padding.build(); } #endif - apply_kv_block_swaps(input_params); + apply_kv_block_swaps(input_params.block_copy); #if defined(USE_NPU) if (context_.get_model_args().enable_mla() && input_params.meta.batch_forward_type.is_chunked_prefill()) { - prepare_mla_prefixcache_inputs(input_params); + prepare_mla_prefixcache_inputs(input_params.attention); } if (!context_.get_parallel_args().mapping_data().empty() && !(context_.get_parallel_args().cp_size() > 1) && (context_.get_parallel_args().dp_size() > 1 || context_.get_parallel_args().ep_size() > 1)) { - torch::Tensor token_size_per_dp_group = torch::tensor( - processed_input.input_params.parallel.dp_global_token_nums, - torch::TensorOptions() - .device(torch::kCPU) - .dtype(torch::kInt32) - .pinned_memory(true)); - bool is_prefill = - processed_input.input_params.meta.batch_forward_type.is_prefill(); + torch::Tensor token_size_per_dp_group = + torch::tensor(processed_input.parallel.dp_global_token_nums, + torch::TensorOptions() + .device(torch::kCPU) + .dtype(torch::kInt32) + .pinned_memory(true)); + bool is_prefill = processed_input.meta.batch_forward_type.is_prefill(); DpEpPadding dp_ep_padding(token_size_per_dp_group, context_.get_model_args().num_experts_per_tok(), context_.get_parallel_args().mapping_data(), device_, dtype_, is_prefill); - processed_input.input_params.parallel.dp_ep_padding_data = - dp_ep_padding.build(); + processed_input.parallel.dp_ep_padding_data = dp_ep_padding.build(); if (FLAGS_enable_eplb) { // expert_load_data_.fill_(0); - processed_input.input_params.expert.expert_load_data = - expert_load_data_; + processed_input.expert.expert_load_data = expert_load_data_; } } if (has_linear_attention_layers(context_.get_model_args())) { - prepare_input_params_for_linear_attention(processed_input.input_params); + prepare_input_params_for_linear_attention(processed_input); } #endif }; @@ -548,22 +542,21 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, } } -void WorkerImpl::apply_kv_block_swaps(const ModelInputParams& input_params) { +void WorkerImpl::apply_kv_block_swaps(const BlockCopyInput& block_copy) { #if defined(USE_CUDA) if (FLAGS_enable_block_copy_kernel && - can_use_cuda_block_copy_kernel(input_params)) { - execute_cuda_block_copy_kernel(input_params); + can_use_cuda_block_copy_kernel(block_copy)) { + execute_cuda_block_copy_kernel(block_copy); return; } #endif #if defined(USE_NPU) - if (input_params.block_copy.swap_blocks.size() == 0 || - FLAGS_enable_block_copy_kernel) { + if (block_copy.swap_blocks.size() == 0 || FLAGS_enable_block_copy_kernel) { return; } #elif defined(USE_CUDA) - if (input_params.block_copy.swap_blocks.size() == 0) { + if (block_copy.swap_blocks.size() == 0) { return; } #else @@ -572,10 +565,10 @@ void WorkerImpl::apply_kv_block_swaps(const ModelInputParams& input_params) { #if defined(USE_NPU) || defined(USE_CUDA) std::vector src_indices, dst_indices; - src_indices.reserve(input_params.block_copy.swap_blocks.size()); - dst_indices.reserve(input_params.block_copy.swap_blocks.size()); + src_indices.reserve(block_copy.swap_blocks.size()); + dst_indices.reserve(block_copy.swap_blocks.size()); - for (const auto& block : input_params.block_copy.swap_blocks) { + for (const auto& block : block_copy.swap_blocks) { src_indices.push_back(block.src_block_id); dst_indices.push_back(block.dst_block_id); } @@ -644,25 +637,25 @@ void WorkerImpl::refresh_cuda_block_copy_runtime_state() { } bool WorkerImpl::can_use_cuda_block_copy_kernel( - const ModelInputParams& input_params) const { + const BlockCopyInput& block_copy) const { return cuda_block_copy_runtime_state_.valid() && - input_params.block_copy.src_block_indices.defined() && - input_params.block_copy.dst_block_indices.defined() && - input_params.block_copy.cum_sum.defined() && - input_params.block_copy.src_block_indices.numel() > 0 && - input_params.block_copy.dst_block_indices.numel() > 0 && - input_params.block_copy.cum_sum.numel() > 0; + block_copy.src_block_indices.defined() && + block_copy.dst_block_indices.defined() && + block_copy.cum_sum.defined() && + block_copy.src_block_indices.numel() > 0 && + block_copy.dst_block_indices.numel() > 0 && + block_copy.cum_sum.numel() > 0; } void WorkerImpl::execute_cuda_block_copy_kernel( - const ModelInputParams& input_params) { + const BlockCopyInput& block_copy) { CHECK(!kv_caches_.empty()); xllm::kernel::cuda::block_copy( cuda_block_copy_runtime_state_.k_cache_ptrs_device, cuda_block_copy_runtime_state_.v_cache_ptrs_device, - input_params.block_copy.src_block_indices, - input_params.block_copy.dst_block_indices, - input_params.block_copy.cum_sum, + block_copy.src_block_indices, + block_copy.dst_block_indices, + block_copy.cum_sum, cuda_block_copy_runtime_state_.numel_per_block, kv_caches_.front().get_k_cache().scalar_type()); } @@ -680,7 +673,7 @@ folly::SemiFuture> WorkerImpl::step_async( input = std::move(input_on_device), promise = std::move(promise)]() mutable { if (hierarchy_kv_cache_transfer_ != nullptr) { - hierarchy_kv_cache_transfer_->set_layer_synchronizer(input.input_params); + hierarchy_kv_cache_transfer_->set_layer_synchronizer(input); } // run the model on the given input in working thread @@ -689,7 +682,7 @@ folly::SemiFuture> WorkerImpl::step_async( promise.setValue(output); } else { if (last_step_output_valid_ && input.token_ids.numel() > 0 && - input.input_params.meta.batch_forward_type.has_decode()) { + input.meta.batch_forward_type.has_decode()) { // replace step i model input with true output of step i-1 input = update_input_by_last_step_output(input); } @@ -1148,40 +1141,36 @@ void WorkerImpl::init_hierarchy_kv_cache_transfer() { transfer_options, device_, &kv_caches_); } } -void WorkerImpl::prepare_mla_prefixcache_inputs( - ModelInputParams& input_params) { - int32_t sum_prefix = - input_params.attention.device.kv_cache_tokens_nums.sum().item(); - input_params.attention.device.history_compressed_kv = +void WorkerImpl::prepare_mla_prefixcache_inputs(AttentionInput& attention) { + int32_t sum_prefix = attention.device.kv_cache_tokens_nums.sum().item(); + attention.device.history_compressed_kv = torch::empty({sum_prefix, context_.get_model_args().kv_lora_rank()}, torch::TensorOptions().dtype(dtype_).pinned_memory(true)) .to(device_); - input_params.attention.device.history_k_rope = + attention.device.history_k_rope = torch::empty({sum_prefix, context_.get_model_args().qk_rope_head_dim()}, torch::TensorOptions().dtype(dtype_).pinned_memory(true)) .to(device_); ; - input_params.attention.device.ring_cur_seqlen = - torch::stack({input_params.attention.device.q_seq_lens, - input_params.attention.device.q_seq_lens}) + attention.device.ring_cur_seqlen = + torch::stack({attention.device.q_seq_lens, attention.device.q_seq_lens}) .to(device_); - input_params.attention.device.ring_cache_seqlen = - torch::stack( - {input_params.attention.device.q_seq_lens, - input_params.attention.device.kv_cache_tokens_nums.to(device_)}) + attention.device.ring_cache_seqlen = + torch::stack({attention.device.q_seq_lens, + attention.device.kv_cache_tokens_nums.to(device_)}) .to(device_); torch::Tensor ring_cur_seqlen_host = - input_params.attention.device.ring_cur_seqlen.cpu().contiguous(); + attention.device.ring_cur_seqlen.cpu().contiguous(); torch::Tensor ring_cache_seqlen_host = - input_params.attention.device.ring_cache_seqlen.cpu().contiguous(); - input_params.attention.host.ring_cur_seqlen = std::vector( + attention.device.ring_cache_seqlen.cpu().contiguous(); + attention.host.ring_cur_seqlen = std::vector( ring_cur_seqlen_host.data_ptr(), ring_cur_seqlen_host.data_ptr() + ring_cur_seqlen_host.numel()); - input_params.attention.host.ring_cache_seqlen = std::vector( + attention.host.ring_cache_seqlen = std::vector( ring_cache_seqlen_host.data_ptr(), ring_cache_seqlen_host.data_ptr() + ring_cache_seqlen_host.numel()); } diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index 1686ec980b..21a2ff1452 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -118,7 +118,7 @@ class WorkerImpl { ForwardInput& processed_inputs); // Internal helper shared by worker pipelines before model execution. - virtual void apply_kv_block_swaps(const ModelInputParams& input_params); + virtual void apply_kv_block_swaps(const BlockCopyInput& block_copy); virtual std::optional step(const ForwardInput& inputs) = 0; @@ -199,7 +199,7 @@ class WorkerImpl { protected: void update_last_step_output(const std::optional& output); // Only used for deepseek chunked prefill ops on npu device - void prepare_mla_prefixcache_inputs(ModelInputParams& input_params); + void prepare_mla_prefixcache_inputs(AttentionInput& attention); void init_hierarchy_kv_cache_transfer(); @@ -211,9 +211,8 @@ class WorkerImpl { #if defined(USE_CUDA) void refresh_cuda_block_copy_runtime_state(); - bool can_use_cuda_block_copy_kernel( - const ModelInputParams& input_params) const; - void execute_cuda_block_copy_kernel(const ModelInputParams& input_params); + bool can_use_cuda_block_copy_kernel(const BlockCopyInput& block_copy) const; + void execute_cuda_block_copy_kernel(const BlockCopyInput& block_copy); struct CudaBlockCopyRuntimeState { torch::Tensor k_cache_ptrs_device; diff --git a/xllm/models/dit/pipeline_longcat_image.h b/xllm/models/dit/pipeline_longcat_image.h index 51497269ab..a301632760 100644 --- a/xllm/models/dit/pipeline_longcat_image.h +++ b/xllm/models/dit/pipeline_longcat_image.h @@ -548,10 +548,9 @@ class LongCatImagePipelineImpl : public torch::nn::Module { const auto& text_encoder_args = context_.get_model_args("text_encoder"); std::vector kv_caches(text_encoder_args.n_layers()); - ModelInputParams input_params = + ForwardInput input = build_longcat_input_params(tokens_flat, positions_2d, attention_mask); - auto model_output = text_encoder_->forward( - tokens_flat, positions_2d, kv_caches, input_params); + auto model_output = text_encoder_->forward(input, kv_caches); torch::Tensor hidden_states_flat = model_output.hidden_states; int64_t hidden_size = hidden_states_flat.size(-1); @@ -628,12 +627,13 @@ class LongCatImagePipelineImpl : public torch::nn::Module { std::string prompt_template_encode_prefix_; std::string prompt_template_encode_suffix_; - // Build ModelInputParams for LongCat-Image text encoding - ModelInputParams build_longcat_input_params( - const torch::Tensor& tokens, - const torch::Tensor& positions, - const torch::Tensor& attention_mask) { - ModelInputParams params; + // Build ForwardInput for LongCat-Image text encoding + ForwardInput build_longcat_input_params(const torch::Tensor& tokens, + const torch::Tensor& positions, + const torch::Tensor& attention_mask) { + ForwardInput params; + params.token_ids = tokens; + params.positions = positions; int64_t actual_seq_len; if (positions.dim() == 2) { @@ -657,7 +657,7 @@ class LongCatImagePipelineImpl : public torch::nn::Module { // Let Qwen2_5_VL build multimodal-aware embeddings from tokens and params. params.embedding.input_embedding = - text_encoder_->get_input_embeddings(tokens, params); + text_encoder_->get_input_embeddings(tokens, params.multimodal); if (attention_mask.defined() && attention_mask.size(0) > 0) { params.graph.attn_mask = attention_mask.view({-1}).to(torch::kFloat32); diff --git a/xllm/models/dit/pipeline_longcat_image_edit.h b/xllm/models/dit/pipeline_longcat_image_edit.h index e56c9049c3..4ae868a0af 100644 --- a/xllm/models/dit/pipeline_longcat_image_edit.h +++ b/xllm/models/dit/pipeline_longcat_image_edit.h @@ -543,10 +543,9 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { std::vector kv_caches(text_encoder_args.n_layers()); std::vector mm_data_list(static_cast(batch_size), mm_data); MMBatchData mm_batch(std::move(mm_data_list)); - ModelInputParams input_params = build_longcat_input_params( + ForwardInput input = build_longcat_input_params( tokens_flat, positions_2d, attention_mask, mm_batch); - auto model_output = text_encoder_->forward( - tokens_flat, positions_2d, kv_caches, input_params); + auto model_output = text_encoder_->forward(input, kv_caches); torch::Tensor hidden_states_flat = model_output.hidden_states; int64_t hidden_size = hidden_states_flat.size(-1); @@ -672,7 +671,7 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { return {latents, image_latents, latents_ids, image_latents_ids}; } - // Build ModelInputParams for LongCat-Image text encoding (reuse from + // Build ForwardInput for LongCat-Image text encoding (reuse from // T2I pipeline). When mm_data_opt is provided (e.g. pixel_values + // image_grid_thw), the text encoder will use vision embeddings. torch::Tensor build_qwen2_5_vl_mrope_positions( @@ -794,12 +793,14 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { return position_ids.reshape({3, -1}).contiguous(); } - ModelInputParams build_longcat_input_params( + ForwardInput build_longcat_input_params( const torch::Tensor& tokens, const torch::Tensor& positions, const torch::Tensor& attention_mask, std::optional mm_data_opt = std::nullopt) { - ModelInputParams params; + ForwardInput params; + params.token_ids = tokens; + params.positions = positions; int64_t actual_seq_len; if (positions.dim() == 2) { @@ -831,10 +832,15 @@ class LongCatImageEditPipelineImpl : public torch::nn::Module { // (may be needed by multimodal processing) params.attn_metadata = std::make_shared( layer::AttentionMetadataBuilder::build( - params, context_.get_model_args("text_encoder").enable_mla())); + params.meta, + params.attention, + params.graph, + params.llmrec_params(), + params.enable_cuda_graph, + context_.get_model_args("text_encoder").enable_mla())); params.attn_metadata->is_causal = true; params.embedding.input_embedding = - text_encoder_->get_input_embeddings(tokens, params); + text_encoder_->get_input_embeddings(tokens, params.multimodal); if (attention_mask.defined() && attention_mask.size(0) > 0) { params.graph.attn_mask = attention_mask.view({-1}).to(torch::kFloat32); } diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 092fbb1b4e..0007e8a2c2 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -69,9 +69,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module { ModelOutput forward_native(torch::Tensor tokens, torch::Tensor positions, std::vector& kv_caches, - const ModelInputParams& input_params) { + const ForwardInput& input_params) { // for dp, if tokens is empty, set tokens to 1 and positions to 0 - ModelInputParams modified_input_params = input_params; + ForwardInput modified_input_params = input_params; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -83,8 +83,13 @@ class DeepseekV2ModelImpl : public torch::nn::Module { if (!modified_input_params.attn_metadata) { modified_input_params.attn_metadata = std::make_shared( - layer::AttentionMetadataBuilder::build(modified_input_params, - model_args_.enable_mla())); + layer::AttentionMetadataBuilder::build( + modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + model_args_.enable_mla())); } auto& attn_metadata = *(modified_input_params.attn_metadata); torch::Tensor hidden_states = embed_tokens_(tokens); @@ -113,11 +118,11 @@ class DeepseekV2ModelImpl : public torch::nn::Module { } // Provide batched signature to satisfy callers that pass vectors - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return forward_native(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; + return forward_native(tokens, positions, kv_caches, input); } // load the weight from the checkpoint diff --git a/xllm/models/llm/deepseek_v32.h b/xllm/models/llm/deepseek_v32.h index 2b93f1b951..e35bcc7ac8 100644 --- a/xllm/models/llm/deepseek_v32.h +++ b/xllm/models/llm/deepseek_v32.h @@ -52,22 +52,26 @@ class DeepseekV32ModelImpl : public DeepseekV2ModelImpl { CHECK(!sp_config_error.has_value()) << sp_config_error.value(); } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - ModelInputParams modified_input_params = input_params; + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; + ForwardInput modified_input_params = input; if (!modified_input_params.attn_metadata) { modified_input_params.attn_metadata = std::make_shared( - layer::AttentionMetadataBuilder::build(modified_input_params, - model_args_.enable_mla())); + layer::AttentionMetadataBuilder::build( + modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + model_args_.enable_mla())); } auto& attn_metadata = *modified_input_params.attn_metadata; std::optional sp_ctx; const bool requested_sequence_parallel = - FLAGS_enable_prefill_sp && - input_params.meta.batch_forward_type.no_decode(); + FLAGS_enable_prefill_sp && input.meta.batch_forward_type.no_decode(); if (requested_sequence_parallel) { if (sequence_parallel_group_ == nullptr) { CHECK_EQ(parallel_world_size_, 1) @@ -75,7 +79,7 @@ class DeepseekV32ModelImpl : public DeepseekV2ModelImpl { } else if (sequence_parallel_group_->world_size() > 1) { sp_ctx = layer::v32_sp::build_deepseek_v32_sp_context( attn_metadata, - input_params.meta.batch_forward_type, + input.meta.batch_forward_type, tokens, sequence_parallel_group_, sequence_parallel_group_->rank(), @@ -86,7 +90,7 @@ class DeepseekV32ModelImpl : public DeepseekV2ModelImpl { // Fallback to the normal TP path when SP is disabled or the current // prefill batch cannot be split across all SP ranks. active_sequence_parallel_context_ = nullptr; - return DeepseekV2ModelImpl::forward( + return DeepseekV2ModelImpl::forward_native( tokens, positions, kv_caches, modified_input_params); } diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index fbb4681145..89213a50ed 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -29,6 +29,7 @@ limitations under the License. #include "core/layers/common/attention_metadata_builder.h" #include "core/layers/common/lm_head.h" #include "core/layers/common/rms_norm.h" +#include "core/runtime/forward_params.h" #include "core/util/rec_model_utils.h" #include "models/model_registry.h" @@ -57,15 +58,15 @@ class LlmModelImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; // test torch::Tensor h; if (inputs_embeds.defined()) { @@ -74,14 +75,19 @@ class LlmModelImplBase : public torch::nn::Module { h = embed_tokens_(tokens); } - auto modified_input_params = input_params; + ForwardInput modified_input_params = input; auto& dp_token_nums = modified_input_params.parallel.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); if (!modified_input_params.attn_metadata) { modified_input_params.attn_metadata = std::make_shared( - layer::AttentionMetadataBuilder::build(modified_input_params, - model_args_.enable_mla())); + layer::AttentionMetadataBuilder::build( + modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + model_args_.enable_mla())); } auto& attn_metadata = *(modified_input_params.attn_metadata); if (positions.dim() == 2) { @@ -183,11 +189,9 @@ class LlmForCausalLMImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_(tokens, positions, kv_caches, input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return model_(input, kv_caches); } // hidden_states: [num_tokens, hidden_size] diff --git a/xllm/models/llm/mtp_model_base.h b/xllm/models/llm/mtp_model_base.h index abf911a07b..23ea73eca4 100644 --- a/xllm/models/llm/mtp_model_base.h +++ b/xllm/models/llm/mtp_model_base.h @@ -53,11 +53,11 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { torch::Tensor positions, const layer::AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params) { + const ForwardInput& input) { // Layer norm on token inputs auto enorm_out = std::get<0>(enorm_(embed)); - torch::Tensor embedding_data = input_params.embedding.input_embedding; + torch::Tensor embedding_data = input.embedding.input_embedding; // for dummy data parallel run, we set a empty embedding if (attn_metadata.is_dummy) { embedding_data = torch::zeros({embed.size(0), model_args_.hidden_size()}, @@ -73,12 +73,8 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { auto hidden_states = eh_proj_(concat_emb); // Pass through mtp block - hidden_states = mtp_block_(hidden_states, - residual, - positions, - attn_metadata, - kv_cache, - input_params); + hidden_states = mtp_block_( + hidden_states, residual, positions, attn_metadata, kv_cache, input); return hidden_states; } @@ -154,12 +150,12 @@ class MtpModelImplBase : public torch::nn::Module { } // Provide batched signature to satisfy callers that pass vectors - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; // for dp, if tokens is empty, set tokens to 1 and positions to 0 - ModelInputParams modified_input_params = input_params; + ForwardInput modified_input_params = input; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -171,8 +167,13 @@ class MtpModelImplBase : public torch::nn::Module { if (!modified_input_params.attn_metadata) { modified_input_params.attn_metadata = std::make_shared( - layer::AttentionMetadataBuilder::build(modified_input_params, - model_args_.enable_mla())); + layer::AttentionMetadataBuilder::build( + modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + model_args_.enable_mla())); } auto& attn_metadata = *(modified_input_params.attn_metadata); torch::Tensor hidden_states = embed_tokens_(tokens); diff --git a/xllm/models/llm/musa/qwen3.h b/xllm/models/llm/musa/qwen3.h index 934c59f564..a1718a60c6 100644 --- a/xllm/models/llm/musa/qwen3.h +++ b/xllm/models/llm/musa/qwen3.h @@ -97,20 +97,19 @@ class QWen3ModelImpl : public LlmModelImplBase { return std::make_pair(cos_pos, sin_pos); } - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - bool use_deepstack = input_params.multimodal.deep_stacks.size() > 0; - ModelInputParams& input_params_new = - const_cast(input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + const bool use_deepstack = input.multimodal.deep_stacks.size() > 0; + ForwardInput modified_input_params = input; std::vector deep_stacks; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -119,21 +118,26 @@ class QWen3ModelImpl : public LlmModelImplBase { } if (use_deepstack) { deep_stacks = - input_params.multimodal.deep_stacks; // [num_deepstack, hidden_size] + input.multimodal.deep_stacks; // [num_deepstack, hidden_size] } - auto& dp_token_nums = input_params_new.parallel.dp_global_token_nums; + auto& dp_token_nums = modified_input_params.parallel.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); - if (!input_params_new.attn_metadata) { - input_params_new.attn_metadata = + if (!modified_input_params.attn_metadata) { + modified_input_params.attn_metadata = std::make_shared( - get_attention_metadata(input_params_new, h)); + get_attention_metadata(modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + h)); } - auto& attn_metadata = *(input_params_new.attn_metadata); + auto& attn_metadata = *(modified_input_params.attn_metadata); torch::Tensor& new_cache_slots = - input_params_new.attention.device.new_cache_slots; + modified_input_params.attention.device.new_cache_slots; // musa cache slots should be (block_id, id_in_block) // todo: add this as an optional change to build input_params phase? new_cache_slots = torch::stack( @@ -143,8 +147,9 @@ class QWen3ModelImpl : public LlmModelImplBase { std::optional residual; for (size_t i = 0; i < layers_.size(); i++) { - if (is_rec_multi_round_mode() && input_params_new.has_llmrec_params()) { - const auto& llmrec_params = input_params_new.llmrec_params(); + if (is_rec_multi_round_mode() && + modified_input_params.has_llmrec_params()) { + const auto& llmrec_params = modified_input_params.llmrec_params(); attn_metadata.full_k_cache = llmrec_params->full_k_caches[i]; attn_metadata.full_v_cache = llmrec_params->full_v_caches[i]; attn_metadata.unshared_k_cache = llmrec_params->unshared_k_caches[i]; @@ -159,12 +164,12 @@ class QWen3ModelImpl : public LlmModelImplBase { positions, attn_metadata, kv_caches[i], - input_params_new); + modified_input_params); if (use_deepstack) { if (deep_stacks.size() > 0 && i < deep_stacks.size()) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } } @@ -174,9 +179,17 @@ class QWen3ModelImpl : public LlmModelImplBase { private: layer::AttentionMetadata get_attention_metadata( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, const torch::Tensor& h) { - return layer::AttentionMetadataBuilder::build(params, + return layer::AttentionMetadataBuilder::build(meta, + attention, + graph, + llmrec_params, + enable_cuda_graph, model_args_.enable_mla()); } }; diff --git a/xllm/models/llm/npu/deepseek_v2.h b/xllm/models/llm/npu/deepseek_v2.h index 810bd7d5a2..664ba1ddea 100644 --- a/xllm/models/llm/npu/deepseek_v2.h +++ b/xllm/models/llm/npu/deepseek_v2.h @@ -41,17 +41,11 @@ class DeepseekV2DecoderLayerImpl : public torch::nn::Module { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag) { - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - event, - event_flag); + return decoder_layer_( + x, cos_pos, sin_pos, attn_mask, kv_cache, input, event, event_flag); } void load_state_dict(const StateDict& state_dict) { @@ -147,10 +141,10 @@ class DeepseekV2ModelImpl : public torch::nn::Module { num_experts_per_tok_ = model_args.num_experts_per_tok(); } - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -166,9 +160,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (FLAGS_enable_prefix_cache && - !input_params.meta.batch_forward_type.is_decode()) { + !input.meta.batch_forward_type.is_decode()) { attn_mask = attn_mask_.get_attn_mask(512, dtype_, device_); - } else if (input_params.meta.batch_forward_type.is_prefill()) { + } else if (input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else if (num_speculative_tokens_ > 0) { // TODO :the judgement of gen_free_mask need more check @@ -180,12 +174,11 @@ class DeepseekV2ModelImpl : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); - event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + if (input.parallel.layer_synchronizer != nullptr) { + event = input.parallel.layer_synchronizer->get_event(i); + event_flag = input.parallel.layer_synchronizer->get_event_flag(i); } - if (!input_params.synchronize_layer(i)) { + if (!input.synchronize_layer(i)) { return ModelOutput(); } @@ -197,7 +190,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params, + input, event, event_flag); rolling_guard.after_layer(layer_index); diff --git a/xllm/models/llm/npu/deepseek_v32.h b/xllm/models/llm/npu/deepseek_v32.h index 3270173a8e..daf9be7b09 100644 --- a/xllm/models/llm/npu/deepseek_v32.h +++ b/xllm/models/llm/npu/deepseek_v32.h @@ -41,17 +41,11 @@ class DeepseekV32DecoderLayerImpl : public torch::nn::Module { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag) { - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - event, - event_flag); + return decoder_layer_( + x, cos_pos, sin_pos, attn_mask, kv_cache, input, event, event_flag); } void load_state_dict(const StateDict& state_dict) { @@ -151,10 +145,10 @@ class DeepseekV32ModelImpl : public torch::nn::Module { } } - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -170,7 +164,7 @@ class DeepseekV32ModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (num_speculative_tokens_ == 0 || - input_params.meta.batch_forward_type.is_prefill()) { + input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( @@ -181,12 +175,11 @@ class DeepseekV32ModelImpl : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); - event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + if (input.parallel.layer_synchronizer != nullptr) { + event = input.parallel.layer_synchronizer->get_event(i); + event_flag = input.parallel.layer_synchronizer->get_event_flag(i); } - if (!input_params.synchronize_layer(i)) { + if (!input.synchronize_layer(i)) { return ModelOutput(); } @@ -198,7 +191,7 @@ class DeepseekV32ModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params, + input, event, event_flag); rolling_guard.after_layer(layer_index); diff --git a/xllm/models/llm/npu/glm4.h b/xllm/models/llm/npu/glm4.h index f4a865f54c..522b54f7d1 100644 --- a/xllm/models/llm/npu/glm4.h +++ b/xllm/models/llm/npu/glm4.h @@ -71,18 +71,17 @@ class Glm4ModelImpl } } - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - ModelInputParams& input_params_new = - const_cast(input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -119,24 +118,24 @@ class Glm4ModelImpl sin_pos = sin_pos.reshape({-1, sin_pos.sizes().back() / 2, 2}); torch::Tensor attn_mask; if (FLAGS_enable_chunked_prefill) { - int max_kv_seq = input_params.meta.kv_max_seq_len; - int num_sequences = input_params.meta.num_sequences; + int max_kv_seq = input.meta.kv_max_seq_len; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - max_kv_seq, - cos_pos.dtype().toScalarType(), - cos_pos.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + max_kv_seq, + cos_pos.dtype().toScalarType(), + cos_pos.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); } - } else if (input_params.meta.batch_forward_type.is_prefill()) { + } else if (input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask( 128, cos_pos.dtype().toScalarType(), cos_pos.device()); } @@ -145,12 +144,13 @@ class Glm4ModelImpl aclrtEvent* event{nullptr}; std::atomic* event_flag{nullptr}; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -160,7 +160,7 @@ class Glm4ModelImpl sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); } diff --git a/xllm/models/llm/npu/glm4_moe.h b/xllm/models/llm/npu/glm4_moe.h index aac03244fc..0230e68a13 100644 --- a/xllm/models/llm/npu/glm4_moe.h +++ b/xllm/models/llm/npu/glm4_moe.h @@ -39,17 +39,11 @@ class Glm4MoeDecoderLayerImpl : public torch::nn::Module { torch::Tensor sin_pos, torch::Tensor attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag) { - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - event, - event_flag); + return decoder_layer_( + x, cos_pos, sin_pos, attn_mask, kv_cache, input, event, event_flag); } void load_state_dict(const StateDict& state_dict) { @@ -138,10 +132,11 @@ class Glm4MoeModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -149,7 +144,7 @@ class Glm4MoeModelImpl : public torch::nn::Module { } } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -189,41 +184,40 @@ class Glm4MoeModelImpl : public torch::nn::Module { sin_pos = sin_pos.view(at::IntArrayRef{-1, 2, sin_pos.size(-1) / 2}); torch::Tensor attn_mask; if (FLAGS_enable_chunked_prefill) { - int max_kv_seq = input_params.meta.kv_max_seq_len; - int num_sequences = input_params.meta.num_sequences; + int max_kv_seq = input.meta.kv_max_seq_len; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - max_kv_seq, - cos_pos.dtype().toScalarType(), - cos_pos.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + max_kv_seq, + cos_pos.dtype().toScalarType(), + cos_pos.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); } - } else if (input_params.meta.batch_forward_type.is_prefill()) { + } else if (input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } - ModelInputParams& input_params_new = - const_cast(input_params); - input_params_new.expert.expert_array = expert_array; + modified_input_params.expert.expert_array = expert_array; RollingLayerGuard rolling_guard(rolling_mgr_); for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -235,7 +229,7 @@ class Glm4MoeModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); rolling_guard.after_layer(layer_index); diff --git a/xllm/models/llm/npu/glm4_moe_lite.h b/xllm/models/llm/npu/glm4_moe_lite.h index 61baab5f5b..cc5dacf6f2 100644 --- a/xllm/models/llm/npu/glm4_moe_lite.h +++ b/xllm/models/llm/npu/glm4_moe_lite.h @@ -40,17 +40,11 @@ class Glm4MoeDecoderLiteLayerImpl : public torch::nn::Module { torch::Tensor sin_pos, torch::Tensor attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event, std::atomic* event_flag) { - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - event, - event_flag); + return decoder_layer_( + x, cos_pos, sin_pos, attn_mask, kv_cache, input, event, event_flag); } void load_state_dict(const StateDict& state_dict) { @@ -157,10 +151,11 @@ class Glm4MoeLiteModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -168,7 +163,7 @@ class Glm4MoeLiteModelImpl : public torch::nn::Module { } } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -188,9 +183,9 @@ class Glm4MoeLiteModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (FLAGS_enable_prefix_cache && - !input_params.meta.batch_forward_type.is_decode()) { + !input.meta.batch_forward_type.is_decode()) { attn_mask = attn_mask_.get_attn_mask(512, dtype_, device_); - } else if (input_params.meta.batch_forward_type.is_prefill()) { + } else if (input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else if (num_speculative_tokens_ > 0) { // TODO :the judgement of gen_free_mask need more check @@ -198,22 +193,21 @@ class Glm4MoeLiteModelImpl : public torch::nn::Module { num_speculative_tokens_ + 1, dtype_, device_); } - ModelInputParams& input_params_new = - const_cast(input_params); - input_params_new.expert.expert_array = expert_array; + modified_input_params.expert.expert_array = expert_array; RollingLayerGuard rolling_guard(rolling_mgr_); for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -225,7 +219,7 @@ class Glm4MoeLiteModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); rolling_guard.after_layer(layer_index); diff --git a/xllm/models/llm/npu/glm5_moe.h b/xllm/models/llm/npu/glm5_moe.h index 6ee2da5a33..18ba1e665f 100644 --- a/xllm/models/llm/npu/glm5_moe.h +++ b/xllm/models/llm/npu/glm5_moe.h @@ -72,10 +72,10 @@ class GlmMoeDsaModelImpl : public torch::nn::Module { } } - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -91,7 +91,7 @@ class GlmMoeDsaModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (num_speculative_tokens_ == 0 || - input_params.meta.batch_forward_type.is_prefill()) { + input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else { attn_mask = attn_mask_.gen_free_mask( @@ -102,12 +102,11 @@ class GlmMoeDsaModelImpl : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); - event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + if (input.parallel.layer_synchronizer != nullptr) { + event = input.parallel.layer_synchronizer->get_event(i); + event_flag = input.parallel.layer_synchronizer->get_event_flag(i); } - if (!input_params.synchronize_layer(i)) { + if (!input.synchronize_layer(i)) { return ModelOutput(); } @@ -119,7 +118,7 @@ class GlmMoeDsaModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params, + input, event, event_flag); rolling_guard.after_layer(layer_index); diff --git a/xllm/models/llm/npu/joyai_llm_flash.h b/xllm/models/llm/npu/joyai_llm_flash.h index e1f0a57161..37f28147e2 100644 --- a/xllm/models/llm/npu/joyai_llm_flash.h +++ b/xllm/models/llm/npu/joyai_llm_flash.h @@ -68,10 +68,10 @@ class JoyAILLMFlashModelImpl : public torch::nn::Module { num_experts_per_tok_ = model_args.num_experts_per_tok(); } - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -87,9 +87,9 @@ class JoyAILLMFlashModelImpl : public torch::nn::Module { torch::Tensor attn_mask; if (FLAGS_enable_prefix_cache && - !input_params.meta.batch_forward_type.is_decode()) { + !input.meta.batch_forward_type.is_decode()) { attn_mask = attn_mask_.get_attn_mask(512, dtype_, device_); - } else if (input_params.meta.batch_forward_type.is_prefill()) { + } else if (input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); } else if (num_speculative_tokens_ > 0) { // TODO :the judgement of gen_free_mask need more check @@ -101,12 +101,11 @@ class JoyAILLMFlashModelImpl : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); - event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + if (input.parallel.layer_synchronizer != nullptr) { + event = input.parallel.layer_synchronizer->get_event(i); + event_flag = input.parallel.layer_synchronizer->get_event_flag(i); } - if (!input_params.synchronize_layer(i)) { + if (!input.synchronize_layer(i)) { return ModelOutput(); } @@ -118,7 +117,7 @@ class JoyAILLMFlashModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params, + input, event, event_flag); rolling_guard.after_layer(layer_index); diff --git a/xllm/models/llm/npu/llama.h b/xllm/models/llm/npu/llama.h index 06813da690..d7637563a7 100644 --- a/xllm/models/llm/npu/llama.h +++ b/xllm/models/llm/npu/llama.h @@ -36,10 +36,10 @@ class LlamaDecoderLayerImpl : public torch::nn::Module { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, int node_id) { return decoder_layer_( - x, cos_pos, sin_pos, attn_mask, kv_cache, input_params, node_id); + x, cos_pos, sin_pos, attn_mask, kv_cache, input, node_id); } // load the weight from the checkpoint @@ -142,20 +142,18 @@ class LlamaModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + ForwardInput layer_input = input; + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; torch::Tensor h = npu_embed_tokens_(tokens, 0); auto cos_pos = cos_pos_.index_select(0, positions); auto sin_pos = sin_pos_.index_select(0, positions); - ModelInputParams& input_params_new = - const_cast(input_params); // torch::Tensor max_of_seq = // torch::max(input_params.attention.device.kv_seq_lens); max_seq_len_ = // std::max(max_of_seq.item(), max_seq_len_); - torch::Tensor max_of_seq = - torch::max(input_params.attention.device.kv_seq_lens); + torch::Tensor max_of_seq = torch::max(input.attention.device.kv_seq_lens); max_seq_len_ = FLAGS_enable_chunked_prefill ? std::max(max_of_seq.item(), max_seq_len_) : 128; @@ -163,14 +161,15 @@ class LlamaModelImpl : public torch::nn::Module { max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); if (FLAGS_enable_chunked_prefill) { - int batch_size = input_params.attention.host.q_seq_lens.size(); + const int32_t batch_size = + static_cast(input.attention.host.q_seq_lens.size()); std::vector req_mask_vec; req_mask_vec.reserve(batch_size); - for (int i = 0; i < batch_size; i++) { - int start = input_params.attention.host.kv_seq_lens[i] - - input_params.attention.host.q_seq_lens[i]; - int end = input_params.attention.host.kv_seq_lens[i]; + for (int32_t i = 0; i < batch_size; i++) { + const int32_t start = input.attention.host.kv_seq_lens[i] - + input.attention.host.q_seq_lens[i]; + const int32_t end = input.attention.host.kv_seq_lens[i]; auto req_mask_slice = attn_mask.slice(0, start, end); req_mask_vec.emplace_back(req_mask_slice); @@ -180,9 +179,15 @@ class LlamaModelImpl : public torch::nn::Module { RollingLayerGuard rolling_guard(rolling_mgr_); for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; - const int32_t layer_index = i; + const int32_t layer_index = static_cast(i); rolling_guard.before_layer(layer_index); - layer(h, cos_pos, sin_pos, attn_mask, kv_caches[i], input_params_new, i); + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + layer_input, + layer_index); rolling_guard.after_layer(layer_index); } auto hidden_states = norm_(h, 0); diff --git a/xllm/models/llm/npu/llm_model_base.h b/xllm/models/llm/npu/llm_model_base.h index 1fc178c27f..7ddbc2a48f 100644 --- a/xllm/models/llm/npu/llm_model_base.h +++ b/xllm/models/llm/npu/llm_model_base.h @@ -41,6 +41,7 @@ limitations under the License. #include "core/layers/npu/npu_pos_embedding_impl.h" #include "core/layers/npu/npu_rms_norm_impl.h" #include "core/layers/npu/npu_word_embedding_impl.h" +#include "core/runtime/forward_params.h" #include "models/model_registry.h" #include "xllm_atb_layers/core/include/atb_speed/log.h" @@ -63,15 +64,15 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event, std::atomic* event_flag) { - if (input_params.block_copy.src_block_indices.numel() > 0) { + if (input.block_copy.src_block_indices.numel() > 0) { block_copy_(kv_cache.get_k_cache(), kv_cache.get_v_cache(), - input_params.block_copy.src_block_indices, - input_params.block_copy.dst_block_indices, - input_params.block_copy.cum_sum, + input.block_copy.src_block_indices, + input.block_copy.dst_block_indices, + input.block_copy.cum_sum, 0); } @@ -80,7 +81,7 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { sin_pos, attn_mask, kv_cache, - input_params, + input, event, event_flag, layer_id_); @@ -145,15 +146,16 @@ class LlmModelImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; // test torch::Tensor h; if (inputs_embeds.defined()) { @@ -188,30 +190,27 @@ class LlmModelImplBase : public torch::nn::Module { {positions.sizes().front(), -1, sin_pos.sizes().back()})); } - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor attn_mask; - max_seq_len_ = - FLAGS_enable_chunked_prefill - ? std::max(input_params.meta.kv_max_seq_len, max_seq_len_) - : 128; + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input.meta.kv_max_seq_len, max_seq_len_) + : 128; if (model_type_ == "qwen2") { attn_mask = attn_mask_.get_attn_mask( max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); } else { if (FLAGS_enable_chunked_prefill) { - int num_sequences = input_params.meta.num_sequences; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - max_seq_len_, - cos_pos.dtype().toScalarType(), - cos_pos.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + max_seq_len_, + cos_pos.dtype().toScalarType(), + cos_pos.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -227,12 +226,13 @@ class LlmModelImplBase : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -250,7 +250,7 @@ class LlmModelImplBase : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); @@ -407,11 +407,9 @@ class LlmForCausalLMImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_(tokens, positions, kv_caches, input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return model_(input, kv_caches); } // hidden_states: [num_tokens, hidden_size] diff --git a/xllm/models/llm/npu/mtp_model_base.h b/xllm/models/llm/npu/mtp_model_base.h index bf07668cce..b1589f1958 100644 --- a/xllm/models/llm/npu/mtp_model_base.h +++ b/xllm/models/llm/npu/mtp_model_base.h @@ -88,10 +88,11 @@ class MtpModelImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (dp_size_ > 1 && (!tokens.defined() || tokens.numel() == 0)) { auto options = torch::TensorOptions().dtype(torch::kInt32).device(device_); @@ -101,7 +102,7 @@ class MtpModelImplBase : public torch::nn::Module { torch::Tensor h = embed_tokens_(tokens, 0); torch::Tensor enorm = enorm_(h, 0); - torch::Tensor input_embedding = input_params.embedding.input_embedding; + torch::Tensor input_embedding = input.embedding.input_embedding; if (input_embedding.defined()) { h = input_embedding; } @@ -127,18 +128,18 @@ class MtpModelImplBase : public torch::nn::Module { torch::Tensor attn_mask; // TODO(liangzhiwei20): support prefix cache for deepseek . if (FLAGS_enable_chunked_prefill) { - int num_sequences = input_params.meta.num_sequences; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - input_params.meta.kv_max_seq_len, - h.dtype().toScalarType(), - h.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + input.meta.kv_max_seq_len, + h.dtype().toScalarType(), + h.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -148,7 +149,7 @@ class MtpModelImplBase : public torch::nn::Module { attn_mask_.get_attn_mask(128, h.dtype().toScalarType(), h.device()); } } else if (model_type_ == "deepseek_v3" && FLAGS_enable_prefix_cache && - !input_params.meta.batch_forward_type.is_decode()) { + !input.meta.batch_forward_type.is_decode()) { attn_mask = attn_mask_.get_attn_mask(512, h.dtype().toScalarType(), h.device()); } else { @@ -163,23 +164,23 @@ class MtpModelImplBase : public torch::nn::Module { torch::TensorOptions().dtype(torch::kInt32).device(tokens.device())); // TODO(liangzhiwei20): MTP need more support for layer wise copy. - if (input_params.parallel.layer_wise_load_synchronizer != nullptr) { + if (modified_input_params.parallel.layer_wise_load_synchronizer != + nullptr) { LOG(FATAL) << "MTP not support layer wise copy!"; } - ModelInputParams& input_params_new = - const_cast(input_params); - input_params_new.expert.expert_array = expert_array; + modified_input_params.expert.expert_array = expert_array; for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -195,7 +196,7 @@ class MtpModelImplBase : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); } @@ -306,11 +307,9 @@ class MtpForCausalLMImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_(tokens, positions, kv_caches, input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return model_(input, kv_caches); } // hidden_states: [num_tokens, hidden_size] diff --git a/xllm/models/llm/npu/oxygen.h b/xllm/models/llm/npu/oxygen.h index 299ae108dd..b59fa9e713 100644 --- a/xllm/models/llm/npu/oxygen.h +++ b/xllm/models/llm/npu/oxygen.h @@ -25,18 +25,19 @@ class OxygenModelImpl : public QWen3ModelImpl { public: OxygenModelImpl(const ModelContext& context) : QWen3ModelImpl(context) {} - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - bool use_deepstack = input_params.multimodal.deep_stacks.size() > 0; + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + const bool use_deepstack = input.multimodal.deep_stacks.size() > 0; + ForwardInput modified_input_params = input; std::vector deep_stacks; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -45,7 +46,7 @@ class OxygenModelImpl : public QWen3ModelImpl { } if (use_deepstack) { deep_stacks = - input_params.multimodal.deep_stacks; // [num_deepstack, hidden_size] + input.multimodal.deep_stacks; // [num_deepstack, hidden_size] } auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); @@ -75,21 +76,21 @@ class OxygenModelImpl : public QWen3ModelImpl { torch::Tensor attn_mask; // for chunked prefill, generate the attn mask. - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { if (FLAGS_enable_chunked_prefill) { - int max_kv_seq = input_params.meta.kv_max_seq_len; - int num_sequences = input_params.meta.num_sequences; + int max_kv_seq = input.meta.kv_max_seq_len; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - max_kv_seq, - cos_pos.dtype().toScalarType(), - cos_pos.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + max_kv_seq, + cos_pos.dtype().toScalarType(), + cos_pos.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -100,18 +101,17 @@ class OxygenModelImpl : public QWen3ModelImpl { } } - ModelInputParams& input_params_new = - const_cast(input_params); for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event{nullptr}; std::atomic* event_flag{nullptr}; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -122,13 +122,13 @@ class OxygenModelImpl : public QWen3ModelImpl { sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); if (use_deepstack) { if (deep_stacks.size() > 0 && i < deep_stacks.size()) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } } diff --git a/xllm/models/llm/npu/qwen3.h b/xllm/models/llm/npu/qwen3.h index c8800d7162..85617f9944 100755 --- a/xllm/models/llm/npu/qwen3.h +++ b/xllm/models/llm/npu/qwen3.h @@ -123,18 +123,19 @@ class QWen3ModelImpl : public LlmModelImplBase { << layers_to_capture_set_.size(); } - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - bool use_deepstack = input_params.multimodal.deep_stacks.size() > 0; + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + const bool use_deepstack = input.multimodal.deep_stacks.size() > 0; + ForwardInput modified_input_params = input; std::vector deep_stacks; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -143,7 +144,7 @@ class QWen3ModelImpl : public LlmModelImplBase { } if (use_deepstack) { deep_stacks = - input_params.multimodal.deep_stacks; // [num_deepstack, hidden_size] + input.multimodal.deep_stacks; // [num_deepstack, hidden_size] } auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); @@ -185,21 +186,21 @@ class QWen3ModelImpl : public LlmModelImplBase { torch::Tensor attn_mask; // for chunked prefill, generate the attn mask. - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { if (FLAGS_enable_chunked_prefill) { - int max_kv_seq = input_params.meta.kv_max_seq_len; - int num_sequences = input_params.meta.num_sequences; + int max_kv_seq = input.meta.kv_max_seq_len; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - max_kv_seq, - cos_pos.dtype().toScalarType(), - cos_pos.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + max_kv_seq, + cos_pos.dtype().toScalarType(), + cos_pos.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -210,8 +211,6 @@ class QWen3ModelImpl : public LlmModelImplBase { } } - ModelInputParams& input_params_new = - const_cast(input_params); const int64_t num_tokens = h.size(0); const int64_t hidden_size = h.size(-1); int64_t capture_idx = 0; @@ -220,12 +219,13 @@ class QWen3ModelImpl : public LlmModelImplBase { aclrtEvent* event{nullptr}; std::atomic* event_flag{nullptr}; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -251,7 +251,7 @@ class QWen3ModelImpl : public LlmModelImplBase { sin_pos, attn_mask, kv_caches[i], - input_params_new, + modified_input_params, event, event_flag); @@ -259,7 +259,7 @@ class QWen3ModelImpl : public LlmModelImplBase { if (use_deepstack) { if (deep_stacks.size() > 0 && i < deep_stacks.size()) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } } diff --git a/xllm/models/llm/npu/qwen3_eagle3.h b/xllm/models/llm/npu/qwen3_eagle3.h index f5fc4b51c5..0d2ef841f9 100644 --- a/xllm/models/llm/npu/qwen3_eagle3.h +++ b/xllm/models/llm/npu/qwen3_eagle3.h @@ -62,7 +62,7 @@ class QWen3Eagle3DecoderLayerImpl : public torch::nn::Module { torch::Tensor& sin_pos, torch::Tensor& attn_mask, KVCache& kv_cache, - ModelInputParams& input_params, + ForwardInput& input, aclrtEvent* event, std::atomic* event_flag) { return decoder_layer_(hidden_states, @@ -71,7 +71,7 @@ class QWen3Eagle3DecoderLayerImpl : public torch::nn::Module { sin_pos, attn_mask, kv_cache, - input_params, + input, event, event_flag, layer_id_); @@ -141,12 +141,11 @@ class QWen3Eagle3ModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - ModelInputParams& input_params_new = - const_cast(input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; // Handle empty tokens case for dp if (dp_size_ > 1 && tokens.numel() == 0) { @@ -158,7 +157,7 @@ class QWen3Eagle3ModelImpl : public torch::nn::Module { // Get hidden_states_extra from input_params.embedding.input_embedding // In EAGLE-3, hidden_states_extra comes from verifier layers // (3 layers concatenated) - torch::Tensor hidden_states_extra = input_params.embedding.input_embedding; + torch::Tensor hidden_states_extra = input.embedding.input_embedding; if (!hidden_states_extra.defined() || hidden_states_extra.size(0) == 0) { LOG(WARNING) << "hidden_states_extra use embedding from tokens."; hidden_states_extra = hidden_states; @@ -179,20 +178,20 @@ class QWen3Eagle3ModelImpl : public torch::nn::Module { // Generate attention mask torch::Tensor attn_mask; - if (!input_params.meta.batch_forward_type.is_decode()) { + if (!input.meta.batch_forward_type.is_decode()) { if (FLAGS_enable_chunked_prefill) { - int num_sequences = input_params.meta.num_sequences; + int num_sequences = input.meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], - input_params.meta.kv_max_seq_len, - cos_pos.dtype().toScalarType(), - cos_pos.device()); + auto mask = + attn_mask_.gen_append_mask(input.attention.host.q_seq_lens[j], + input.attention.host.kv_seq_lens[j], + input.meta.kv_max_seq_len, + cos_pos.dtype().toScalarType(), + cos_pos.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -207,11 +206,12 @@ class QWen3Eagle3ModelImpl : public torch::nn::Module { aclrtEvent* event{nullptr}; std::atomic* event_flag{nullptr}; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(0); - event_flag = input_params.parallel.layer_synchronizer->get_event_flag(0); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(0); + event_flag = + modified_input_params.parallel.layer_synchronizer->get_event_flag(0); } - if (!input_params.synchronize_layer(0)) { + if (!modified_input_params.synchronize_layer(0)) { return ModelOutput(); } @@ -221,7 +221,7 @@ class QWen3Eagle3ModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[0], - input_params_new, + modified_input_params, event, event_flag); auto aux_hidden_states = hidden_states.clone(); @@ -319,11 +319,9 @@ class QWen3Eagle3ForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: ModelOutput with hidden_states [num_tokens, hidden_size] - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_(tokens, positions, kv_caches, input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return model_(input, kv_caches); } // hidden_states: [num_tokens, hidden_size] diff --git a/xllm/models/llm/npu/qwen3_moe.h b/xllm/models/llm/npu/qwen3_moe.h index 89520a1c11..e8bd308bdd 100644 --- a/xllm/models/llm/npu/qwen3_moe.h +++ b/xllm/models/llm/npu/qwen3_moe.h @@ -40,7 +40,7 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { torch::Tensor sin_pos, torch::Tensor attn_mask, KVCache& kv_cache, - const ModelInputParams& input_params, + const ForwardInput& input, aclrtEvent* event = nullptr, std::atomic* event_flag = nullptr) { return decoder_layer_(x, @@ -49,7 +49,7 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { sin_pos, attn_mask, kv_cache, - input_params, + input, event, event_flag); } @@ -215,17 +215,18 @@ class Qwen3MoeModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); positions = torch::tensor({0}).to(torch::kInt32).to(device_); } } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -273,35 +274,34 @@ class Qwen3MoeModelImpl : public torch::nn::Module { torch::Tensor attn_mask; // for chunked prefill, generate the attn mask. - if (!input_params.meta.batch_forward_type.is_decode()) { - max_seq_len_ = - FLAGS_enable_chunked_prefill - ? std::max(input_params.meta.kv_max_seq_len, max_seq_len_) - : 128; + if (!input.meta.batch_forward_type.is_decode()) { + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input.meta.kv_max_seq_len, max_seq_len_) + : 128; if (FLAGS_enable_chunked_prefill) { attn_mask = attn_mask_.get_attn_mask( max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - int batch_size = input_params.attention.host.q_seq_lens.size(); + int batch_size = input.attention.host.q_seq_lens.size(); if (batch_size > 0) { std::vector req_mask_vec; req_mask_vec.reserve(batch_size); for (int j = 0; j < batch_size; j++) { - int start = input_params.attention.host.kv_seq_lens[j] - - input_params.attention.host.q_seq_lens[j]; - int end = input_params.attention.host.kv_seq_lens[j]; + int start = input.attention.host.kv_seq_lens[j] - + input.attention.host.q_seq_lens[j]; + int end = input.attention.host.kv_seq_lens[j]; auto req_mask_slice = attn_mask.slice(0, start, end); req_mask_vec.emplace_back(req_mask_slice); } attn_mask = torch::cat(req_mask_vec, 0); } - } else if (input_params.meta.batch_forward_type.is_prefill()) { + } else if (input.meta.batch_forward_type.is_prefill()) { attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); } } - auto deep_stacks = input_params.multimodal.deep_stacks; + auto deep_stacks = input.multimodal.deep_stacks; int deep_stack_size = deep_stacks.size(); int64_t input_length = h.size(0); @@ -309,9 +309,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { 0, input_length * num_experts_per_tok_, torch::TensorOptions().dtype(torch::kInt32).device(tokens.device())); - ModelInputParams& input_params_new = - const_cast(input_params); - input_params_new.expert.expert_array = expert_array; + modified_input_params.expert.expert_array = expert_array; RollingLayerGuard rolling_guard(rolling_mgr_); @@ -325,12 +323,13 @@ class Qwen3MoeModelImpl : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer != nullptr) { - event = input_params.parallel.layer_synchronizer->get_event(i); + if (modified_input_params.parallel.layer_synchronizer != nullptr) { + event = modified_input_params.parallel.layer_synchronizer->get_event(i); event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + modified_input_params.parallel.layer_synchronizer->get_event_flag( + i); } - if (!input_params.synchronize_layer(i)) { + if (!modified_input_params.synchronize_layer(i)) { return ModelOutput(); } @@ -356,14 +355,14 @@ class Qwen3MoeModelImpl : public torch::nn::Module { sin_pos, attn_mask, kv_caches[i], - input_params, + modified_input_params, event, event_flag); rolling_guard.after_layer(layer_index); if (deep_stack_size && i < deep_stack_size) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } diff --git a/xllm/models/llm/oxygen.h b/xllm/models/llm/oxygen.h index e88dcbae59..b729ba659b 100644 --- a/xllm/models/llm/oxygen.h +++ b/xllm/models/llm/oxygen.h @@ -25,20 +25,19 @@ class OxygenModelImpl : public QWen3ModelImpl { public: OxygenModelImpl(const ModelContext& context) : QWen3ModelImpl(context) {} - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - bool use_deepstack = input_params.multimodal.deep_stacks.size() > 0; - ModelInputParams& input_params_new = - const_cast(input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + const bool use_deepstack = input.multimodal.deep_stacks.size() > 0; + ForwardInput modified_input_params = input; std::vector deep_stacks; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -47,18 +46,23 @@ class OxygenModelImpl : public QWen3ModelImpl { } if (use_deepstack) { deep_stacks = - input_params.multimodal.deep_stacks; // [num_deepstack, hidden_size] + input.multimodal.deep_stacks; // [num_deepstack, hidden_size] } - auto& dp_token_nums = input_params_new.parallel.dp_global_token_nums; + auto& dp_token_nums = modified_input_params.parallel.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); - if (!input_params_new.attn_metadata) { - input_params_new.attn_metadata = + if (!modified_input_params.attn_metadata) { + modified_input_params.attn_metadata = std::make_shared( - get_attention_metadata(input_params_new, h)); + get_attention_metadata(modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + h)); } - auto& attn_metadata = *(input_params_new.attn_metadata); + auto& attn_metadata = *(modified_input_params.attn_metadata); bool only_prefill = (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill); if (positions.dim() == 2 && only_prefill && !mrope_section_.empty()) { @@ -68,8 +72,9 @@ class OxygenModelImpl : public QWen3ModelImpl { std::optional residual; for (size_t i = 0; i < layers_.size(); i++) { - if (is_rec_multi_round_mode() && input_params_new.has_llmrec_params()) { - const auto& llmrec_params = input_params_new.llmrec_params(); + if (is_rec_multi_round_mode() && + modified_input_params.has_llmrec_params()) { + const auto& llmrec_params = modified_input_params.llmrec_params(); attn_metadata.full_k_cache = llmrec_params->full_k_caches[i]; attn_metadata.full_v_cache = llmrec_params->full_v_caches[i]; attn_metadata.unshared_k_cache = llmrec_params->unshared_k_caches[i]; @@ -81,16 +86,16 @@ class OxygenModelImpl : public QWen3ModelImpl { positions, attn_metadata, kv_caches[i], - input_params_new); - if (!input_params_new.record_layer(static_cast(i), - h.device())) { + modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + h.device())) { return ModelOutput(); } if (use_deepstack) { if (deep_stacks.size() > 0 && i < deep_stacks.size()) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } } diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index ec7c58faec..c41f0c1f40 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -107,20 +107,19 @@ class QWen3ModelImpl : public LlmModelImplBase { return std::make_pair(cos_pos, sin_pos); } - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - bool use_deepstack = input_params.multimodal.deep_stacks.size() > 0; - ModelInputParams& input_params_new = - const_cast(input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + const bool use_deepstack = input.multimodal.deep_stacks.size() > 0; + ForwardInput modified_input_params = input; std::vector deep_stacks; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -129,18 +128,23 @@ class QWen3ModelImpl : public LlmModelImplBase { } if (use_deepstack) { deep_stacks = - input_params.multimodal.deep_stacks; // [num_deepstack, hidden_size] + input.multimodal.deep_stacks; // [num_deepstack, hidden_size] } - auto& dp_token_nums = input_params_new.parallel.dp_global_token_nums; + auto& dp_token_nums = modified_input_params.parallel.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); - if (!input_params_new.attn_metadata) { - input_params_new.attn_metadata = + if (!modified_input_params.attn_metadata) { + modified_input_params.attn_metadata = std::make_shared( - get_attention_metadata(input_params_new, h)); + get_attention_metadata(modified_input_params.meta, + modified_input_params.attention, + modified_input_params.graph, + modified_input_params.llmrec_params(), + modified_input_params.enable_cuda_graph, + h)); } - auto& attn_metadata = *(input_params_new.attn_metadata); + auto& attn_metadata = *(modified_input_params.attn_metadata); bool only_prefill = (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill); if (positions.dim() == 2 && only_prefill && !mrope_section_.empty()) { @@ -150,8 +154,9 @@ class QWen3ModelImpl : public LlmModelImplBase { std::optional residual; for (size_t i = 0; i < layers_.size(); i++) { - if (is_rec_multi_round_mode() && input_params_new.has_llmrec_params()) { - const auto& llmrec_params = input_params_new.llmrec_params(); + if (is_rec_multi_round_mode() && + modified_input_params.has_llmrec_params()) { + const auto& llmrec_params = modified_input_params.llmrec_params(); attn_metadata.full_k_cache = llmrec_params->full_k_caches[i]; attn_metadata.full_v_cache = llmrec_params->full_v_caches[i]; attn_metadata.unshared_k_cache = llmrec_params->unshared_k_caches[i]; @@ -172,16 +177,16 @@ class QWen3ModelImpl : public LlmModelImplBase { positions, attn_metadata, kv_caches[i], - input_params_new); - if (!input_params_new.record_layer(static_cast(i), - h.device())) { + modified_input_params); + if (!modified_input_params.record_layer(static_cast(i), + h.device())) { return ModelOutput(); } if (use_deepstack) { if (deep_stacks.size() > 0 && i < deep_stacks.size()) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } } @@ -191,27 +196,30 @@ class QWen3ModelImpl : public LlmModelImplBase { protected: layer::AttentionMetadata get_attention_metadata( - const ModelInputParams& params, + const BatchInputMeta& meta, + const AttentionInput& attention, + const GraphInput& graph, + const LlmRecMultiRoundParams* llmrec_params, + bool enable_cuda_graph, const torch::Tensor& h) { #if defined(USE_NPU) - max_seq_len_ = std::max(params.meta.kv_max_seq_len, max_seq_len_); + max_seq_len_ = std::max(meta.kv_max_seq_len, max_seq_len_); // NOTE: Enabling chunked prefill here is known to cause garbled output in // this model. TODO: investigate and fix the output corruption. torch::Tensor attn_mask; if (FLAGS_enable_chunked_prefill) { - const int32_t max_kv_seq = params.meta.kv_max_seq_len; - const int32_t num_sequences = params.meta.num_sequences; + const int32_t max_kv_seq = meta.kv_max_seq_len; + const int32_t num_sequences = meta.num_sequences; if (num_sequences > 0) { std::vector req_mask_vec; req_mask_vec.reserve(num_sequences); for (int32_t j = 0; j < num_sequences; ++j) { - auto mask = - attn_mask_.gen_append_mask(params.attention.host.q_seq_lens[j], - params.attention.host.kv_seq_lens[j], - max_kv_seq, - h.dtype().toScalarType(), - h.device()); + auto mask = attn_mask_.gen_append_mask(attention.host.q_seq_lens[j], + attention.host.kv_seq_lens[j], + max_kv_seq, + h.dtype().toScalarType(), + h.device()); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -223,10 +231,19 @@ class QWen3ModelImpl : public LlmModelImplBase { attn_mask = attn_mask_.get_attn_mask( max_seq_len_, h.dtype().toScalarType(), h.device()); } - return layer::AttentionMetadataBuilder::build( - params, model_args_.enable_mla(), attn_mask); + return layer::AttentionMetadataBuilder::build(meta, + attention, + graph, + llmrec_params, + enable_cuda_graph, + model_args_.enable_mla(), + attn_mask); #else - return layer::AttentionMetadataBuilder::build(params, + return layer::AttentionMetadataBuilder::build(meta, + attention, + graph, + llmrec_params, + enable_cuda_graph, model_args_.enable_mla()); #endif } diff --git a/xllm/models/llm/qwen3_5_mtp.h b/xllm/models/llm/qwen3_5_mtp.h index eec7706885..115d392428 100644 --- a/xllm/models/llm/qwen3_5_mtp.h +++ b/xllm/models/llm/qwen3_5_mtp.h @@ -100,10 +100,10 @@ class Qwen3_5MtpModelImpl : public Qwen3HybridModelImplBase { } } - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) override { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; torch::NoGradGuard no_grad; if (dp_size_ > 1 && tokens.sizes() == 0) { @@ -112,12 +112,16 @@ class Qwen3_5MtpModelImpl : public Qwen3HybridModelImplBase { } auto attn_metadata = layer::AttentionMetadataBuilder::build( - input_params, + input.meta, + input.attention, + input.graph, + input.llmrec_params(), + input.enable_cuda_graph, model_args_.enable_mla(), - build_attention_mask(input_params)); + build_attention_mask(input.meta, input.attention)); torch::Tensor embedding = embed_tokens_(tokens); - torch::Tensor hidden = input_params.embedding.input_embedding; + torch::Tensor hidden = input.embedding.input_embedding; if (hidden.defined() == false) { hidden = embedding; } @@ -129,12 +133,8 @@ class Qwen3_5MtpModelImpl : public Qwen3HybridModelImplBase { CHECK_EQ(kv_caches.size(), layers_.size()); std::optional residual = std::nullopt; for (size_t i = 0; i < layers_.size(); ++i) { - mtp_hidden = layers_[i]->forward(mtp_hidden, - residual, - positions, - attn_metadata, - kv_caches[i], - input_params); + mtp_hidden = layers_[i]->forward( + mtp_hidden, residual, positions, attn_metadata, kv_caches[i], input); } auto [new_mtp_hidden, new_res] = norm_->forward(mtp_hidden, residual); mtp_hidden = new_mtp_hidden; diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 5c92a3c250..bbe4b867a1 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -106,18 +106,18 @@ class Qwen3MoeModelImpl : public LlmModelImplBase { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) override { - ModelInputParams modified_input_params = input_params; + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; + ForwardInput modified_input_params = input; if (tokens.numel() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); positions = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); } auto& dp_token_nums = modified_input_params.parallel.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; @@ -125,7 +125,7 @@ class Qwen3MoeModelImpl : public LlmModelImplBase { h = embed_tokens_(tokens); } - auto deep_stacks = input_params.multimodal.deep_stacks; + auto deep_stacks = input.multimodal.deep_stacks; int deep_stack_size = deep_stacks.size(); if (!modified_input_params.attn_metadata) { modified_input_params.attn_metadata = @@ -179,7 +179,7 @@ class Qwen3MoeModelImpl : public LlmModelImplBase { if (deep_stack_size && i < deep_stack_size) { h = deepstack_process( - h, input_params.multimodal.visual_pos_masks, deep_stacks[i]); + h, input.multimodal.visual_pos_masks, deep_stacks[i]); } } auto [hidden_states, residual_out] = norm_(h, residual); @@ -232,9 +232,8 @@ class Qwen3MoeModelImpl : public LlmModelImplBase { } private: - layer::AttentionMetadata get_attention_metadata( - const ModelInputParams& params, - const torch::Tensor& h) { + layer::AttentionMetadata get_attention_metadata(const ForwardInput& params, + const torch::Tensor& h) { #if defined(USE_NPU) max_seq_len_ = std::max(params.meta.kv_max_seq_len, max_seq_len_); torch::Tensor attn_mask; @@ -263,10 +262,19 @@ class Qwen3MoeModelImpl : public LlmModelImplBase { attn_mask = attn_mask_.get_attn_mask( max_seq_len_, h.dtype().toScalarType(), h.device()); } - return layer::AttentionMetadataBuilder::build( - params, model_args_.enable_mla(), attn_mask); + return layer::AttentionMetadataBuilder::build(params.meta, + params.attention, + params.graph, + params.llmrec_params(), + params.enable_cuda_graph, + model_args_.enable_mla(), + attn_mask); #else - return layer::AttentionMetadataBuilder::build(params, + return layer::AttentionMetadataBuilder::build(params.meta, + params.attention, + params.graph, + params.llmrec_params(), + params.enable_cuda_graph, model_args_.enable_mla()); #endif } diff --git a/xllm/models/llm/qwen3_next_hybrid_base.h b/xllm/models/llm/qwen3_next_hybrid_base.h index 516843727e..14475a0a94 100644 --- a/xllm/models/llm/qwen3_next_hybrid_base.h +++ b/xllm/models/llm/qwen3_next_hybrid_base.h @@ -34,15 +34,14 @@ limitations under the License. #include "core/layers/common/qwen3_next_rms_norm.h" #include "core/layers/common/word_embedding.h" #include "core/layers/npu_torch/qwen3_next_hybrid_decoder_layer_base.h" +#include "core/runtime/forward_params.h" namespace xllm { class Qwen3HybridModelModule : public torch::nn::Module { public: - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) = 0; + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) = 0; virtual void load_state_dict(const StateDict& state_dict) = 0; virtual void verify_loaded_weights(const std::string& prefix) const = 0; virtual layer::WordEmbedding get_word_embedding() = 0; @@ -78,10 +77,10 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) override { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) override { + torch::Tensor tokens = input.token_ids; + torch::Tensor positions = input.positions; // Disable gradient computation to reduce memory usage during inference torch::NoGradGuard no_grad; if (dp_size_ > 1) { @@ -93,9 +92,13 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { layer::AttentionMetadata attn_metadata = layer::AttentionMetadataBuilder::build( - input_params, + input.meta, + input.attention, + input.graph, + input.llmrec_params(), + input.enable_cuda_graph, model_args_.enable_mla(), - build_attention_mask(input_params)); + build_attention_mask(input.meta, input.attention)); torch::Tensor h = embed_tokens_(tokens); torch::Tensor mrope_cos_sin; @@ -112,7 +115,7 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { positions, attn_metadata, kv_caches[i], - input_params, + input, mrope_cos_sin); } auto [hidden_states, residual_out] = norm_->forward(h, residual); @@ -154,13 +157,14 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { } protected: - torch::Tensor build_attention_mask(const ModelInputParams& input_params) { - max_seq_len_ = std::max(input_params.meta.kv_max_seq_len, max_seq_len_); + torch::Tensor build_attention_mask(const BatchInputMeta& meta, + const AttentionInput& attention) { + max_seq_len_ = std::max(meta.kv_max_seq_len, max_seq_len_); if (!FLAGS_enable_chunked_prefill) { return attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); } - const int32_t num_sequences = input_params.meta.num_sequences; + const int32_t num_sequences = meta.num_sequences; if (num_sequences <= 0) { return attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); } @@ -169,8 +173,8 @@ class Qwen3HybridModelImplBase : public Qwen3HybridModelModule { req_mask_vec.reserve(num_sequences); for (int32_t j = 0; j < num_sequences; ++j) { req_mask_vec.emplace_back( - attn_mask_.gen_append_mask(input_params.attention.host.q_seq_lens[j], - input_params.attention.host.kv_seq_lens[j], + attn_mask_.gen_append_mask(attention.host.q_seq_lens[j], + attention.host.kv_seq_lens[j], max_seq_len_, dtype_, device_)); @@ -200,11 +204,9 @@ class Qwen3HybridForCausalLMImplBase : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_->forward(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return model_->forward(input, kv_caches); } // hidden_states: [num_tokens, hidden_size] diff --git a/xllm/models/rec/npu/onerec.h b/xllm/models/rec/npu/onerec.h index 848e2fb495..ddcf3da3ce 100644 --- a/xllm/models/rec/npu/onerec.h +++ b/xllm/models/rec/npu/onerec.h @@ -45,30 +45,29 @@ class OneRecModelImpl final : public torch::nn::Module { "decoder", OneRecStack(context, /*is_decode=*/true, shared_)); } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; if (!tokens.defined()) { return ModelOutput(); } (void)positions; (void)kv_caches; - const auto* onerec_params = input_params.onerec_params(); + const auto* onerec_params = input.onerec_params(); if (onerec_params != nullptr) { if (onerec_params->is_encoder_forward) { std::vector encoder_kv_caches; auto encoder_output = - encoder_(tokens, positions, encoder_kv_caches, input_params); + encoder_(tokens, positions, encoder_kv_caches, input); torch::Tensor cached_encoder_output; if (encoder_output.defined() && onerec_params->encoder_max_seq_len > 0 && !onerec_params->encoder_seq_lens.empty()) { - cached_encoder_output = - pad_encoder_output(encoder_output, input_params); + cached_encoder_output = pad_encoder_output(encoder_output, input); } else { cached_encoder_output = encoder_output; } @@ -95,8 +94,8 @@ class OneRecModelImpl final : public torch::nn::Module { return ModelOutput(); } - auto decoder_output = decoder_( - tokens, positions, kv_caches, input_params, cached_encoder_output); + auto decoder_output = + decoder_(tokens, positions, kv_caches, input, cached_encoder_output); return ModelOutput(decoder_output); } diff --git a/xllm/models/rec/npu/onerec_npu_impl.h b/xllm/models/rec/npu/onerec_npu_impl.h index 3f43484671..11ac94383b 100644 --- a/xllm/models/rec/npu/onerec_npu_impl.h +++ b/xllm/models/rec/npu/onerec_npu_impl.h @@ -26,8 +26,8 @@ limitations under the License. namespace xllm { inline torch::Tensor pad_encoder_output(const torch::Tensor& encoder_output, - const ModelInputParams& input_params) { - const auto* onerec_params = input_params.onerec_params(); + const ForwardInput& input) { + const auto* onerec_params = input.onerec_params(); CHECK(onerec_params != nullptr) << "OneRec requires onerec_params()."; const int64_t bs = onerec_params->bs; @@ -71,7 +71,7 @@ inline torch::Tensor compute_onerec_position_bias( int64_t max_distance = 128, const torch::TensorOptions& options = torch::kFloat32, bool is_decode_stage = false, - const ModelInputParams* input_params = nullptr) { + const ForwardInput* input = nullptr) { auto device = options.device(); auto dtype = options.dtype(); @@ -138,9 +138,9 @@ inline torch::Tensor compute_onerec_position_bias( values = values.permute({2, 0, 1}); } - if (is_decode_stage && input_params != nullptr && - !input_params->attention.host.kv_seq_lens.empty()) { - const int32_t seq_kv_len = input_params->attention.host.kv_seq_lens[0]; + if (is_decode_stage && input != nullptr && + !input->attention.host.kv_seq_lens.empty()) { + const int32_t seq_kv_len = input->attention.host.kv_seq_lens[0]; values = values.slice(1, -1, values.size(1)).slice(2, 0, seq_kv_len); } else if (is_decode_stage) { values = values.slice(1, -1, values.size(1)); @@ -190,11 +190,11 @@ class OneRecStackImpl : public torch::nn::Module { torch::Tensor forward(const torch::Tensor& tokens, const torch::Tensor& positions, std::vector& kv_caches, - const ModelInputParams& input_params, + const ForwardInput& input, const torch::Tensor& encoder_output = torch::Tensor()) { (void)positions; - const auto* onerec_params = input_params.onerec_params(); + const auto* onerec_params = input.onerec_params(); CHECK(onerec_params != nullptr) << "OneRec requires onerec_params()."; torch::Tensor h; @@ -244,12 +244,12 @@ class OneRecStackImpl : public torch::nn::Module { const bool is_prefill = onerec_params->rec_stage == OneRecModelInputParams::RecStage::PREFILL; - auto [query_length, key_length] = compute_sequence_lengths( - input_params.meta.q_max_seq_len, is_prefill, input_params); + auto [query_length, key_length] = + compute_sequence_lengths(input.meta.q_max_seq_len, is_prefill, input); - ModelInputParams input_params_local = input_params; + ForwardInput input_params_local = input; auto& mutable_onerec_params = input_params_local.mutable_onerec_params(); - const auto* onerec_xattn_params = input_params.onerec_xattention_params(); + const auto* onerec_xattn_params = input.onerec_xattention_params(); auto validate_selected_token_idxes_stage = [&](const char* stage_name) { if (!is_decoder_ || onerec_xattn_params == nullptr || @@ -278,13 +278,12 @@ class OneRecStackImpl : public torch::nn::Module { const bool is_decode_stage = is_decoder_ && !is_prefill; torch::Tensor effective_attn_mask; if (use_absolute_position_embedding_) { - const int64_t batch_size = - std::max(1, input_params.meta.num_sequences); + const int64_t batch_size = std::max(1, input.meta.num_sequences); effective_attn_mask = create_moe_attention_mask(query_length, h, is_decoder_, batch_size); } else { effective_attn_mask = compute_position_bias_mask( - query_length, key_length, h, is_decode_stage, input_params); + query_length, key_length, h, is_decode_stage, input); } auto preprocessed_attn_mask = @@ -309,12 +308,11 @@ class OneRecStackImpl : public torch::nn::Module { for (size_t i = 0; i < layers_.size(); ++i) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; - if (input_params.parallel.layer_synchronizer) { - event = input_params.parallel.layer_synchronizer->get_event(i); - event_flag = - input_params.parallel.layer_synchronizer->get_event_flag(i); + if (input.parallel.layer_synchronizer) { + event = input.parallel.layer_synchronizer->get_event(i); + event_flag = input.parallel.layer_synchronizer->get_event_flag(i); } - if (!input_params.synchronize_layer(i)) { + if (!input.synchronize_layer(i)) { return torch::Tensor(); } @@ -336,8 +334,8 @@ class OneRecStackImpl : public torch::nn::Module { event_flag, expert_array); - if (input_params.parallel.layer_synchronizer != nullptr && - !input_params.parallel.layer_synchronizer->synchronize_layer(i)) { + if (input.parallel.layer_synchronizer != nullptr && + !input.parallel.layer_synchronizer->synchronize_layer(i)) { return torch::Tensor(); } validate_selected_token_idxes_stage( @@ -397,11 +395,11 @@ class OneRecStackImpl : public torch::nn::Module { std::pair compute_sequence_lengths( int64_t seq_length, bool is_prefill, - const ModelInputParams& input_params) const { + const ForwardInput& input) const { int64_t query_length = seq_length; int64_t key_length = seq_length; - const auto* onerec_params = input_params.onerec_params(); + const auto* onerec_params = input.onerec_params(); CHECK(onerec_params != nullptr) << "OneRec requires onerec_params()."; if (is_decoder_) { @@ -410,10 +408,10 @@ class OneRecStackImpl : public torch::nn::Module { key_length = seq_length; } else { query_length = 1; - if (!input_params.attention.host.kv_seq_lens.empty()) { + if (!input.attention.host.kv_seq_lens.empty()) { key_length = - *std::max_element(input_params.attention.host.kv_seq_lens.begin(), - input_params.attention.host.kv_seq_lens.end()); + *std::max_element(input.attention.host.kv_seq_lens.begin(), + input.attention.host.kv_seq_lens.end()); } // Decode keeps a square bias/mask shape expected by OneRec NPU block. query_length = key_length; @@ -460,12 +458,11 @@ class OneRecStackImpl : public torch::nn::Module { return effective_attn_mask; } - torch::Tensor compute_position_bias_mask( - int64_t query_length, - int64_t key_length, - const torch::Tensor& h, - bool is_decode_stage, - const ModelInputParams& input_params) { + torch::Tensor compute_position_bias_mask(int64_t query_length, + int64_t key_length, + const torch::Tensor& h, + bool is_decode_stage, + const ForwardInput& input) { CHECK(!position_bias_embedding_.is_empty()) << "position_bias_embedding is required for relative attention."; @@ -479,7 +476,7 @@ class OneRecStackImpl : public torch::nn::Module { relative_attention_max_distance_, torch::dtype(h.dtype()).device(h.device()), is_decode_stage, - &input_params); + &input); auto effective_attn_mask = layer_position_bias.is_contiguous() ? layer_position_bias diff --git a/xllm/models/rec/onerec.h b/xllm/models/rec/onerec.h index d8d2fabe65..96ebc452f0 100644 --- a/xllm/models/rec/onerec.h +++ b/xllm/models/rec/onerec.h @@ -38,17 +38,17 @@ class OneRecModelImpl : public torch::nn::Module { shared_ = register_module("shared", layer::WordEmbedding(context)); } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + const torch::Tensor& tokens = input.token_ids; + const torch::Tensor& positions = input.positions; if (!tokens.defined()) { return ModelOutput(); } (void)positions; (void)kv_caches; - const auto* onerec_params = input_params.onerec_params(); + const auto* onerec_params = input.onerec_params(); const bool is_encoder_forward = (onerec_params != nullptr) && onerec_params->is_encoder_forward; diff --git a/xllm/models/rec/rec_model_base.h b/xllm/models/rec/rec_model_base.h index 8827e8388a..f885411ac5 100644 --- a/xllm/models/rec/rec_model_base.h +++ b/xllm/models/rec/rec_model_base.h @@ -29,6 +29,7 @@ limitations under the License. #include "core/framework/model_loader.h" #include "core/layers/common/lm_head.h" #include "core/layers/common/word_embedding.h" +#include "core/runtime/forward_params.h" namespace xllm { @@ -46,11 +47,9 @@ class RecForCausalLMImplBase : public torch::nn::Module { lm_head_ = register_module("lm_head", layer::LmHead(context)); } - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_->forward(tokens, positions, kv_caches, input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return model_->forward(input, kv_caches); } virtual torch::Tensor logits(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/glm4v.h b/xllm/models/vlm/npu/glm4v.h index 2ae60a7a65..8a4b690c87 100644 --- a/xllm/models/vlm/npu/glm4v.h +++ b/xllm/models/vlm/npu/glm4v.h @@ -27,6 +27,7 @@ limitations under the License. #include "core/framework/model/model_input_params.h" #include "core/framework/model/model_output.h" #include "core/layers/npu/npu_lm_head_impl.h" +#include "core/runtime/forward_params.h" #include "models/llm/npu/glm4.h" #include "models/model_registry.h" #include "processors/glm4v_image_processor.h" @@ -125,15 +126,9 @@ class Glm4_VisionBlockImpl : public torch::nn::Module { torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id) { - return encoder_layer_(x, - m_cos_pos, - m_sin_pos, - cu_seq_len, - cu_seq_len_vec, - input_params, - node_id); + return encoder_layer_( + x, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec, node_id); } // load the weight from the checkpoint @@ -509,9 +504,7 @@ class Glm4VisionTransformerImpl : public torch::nn::Module { return std::make_tuple(rotary_pos_emb, pos_ids); } - torch::Tensor forward(torch::Tensor hidden_states, - torch::Tensor grid_thw, - const ModelInputParams& input_params) { + torch::Tensor forward(torch::Tensor hidden_states, torch::Tensor grid_thw) { hidden_states = patch_embed_(hidden_states); hidden_states = post_conv_layernorm_(hidden_states); @@ -553,21 +546,14 @@ class Glm4VisionTransformerImpl : public torch::nn::Module { grid_thw, image_type_ids.select(1, 0), image_type_ids.select(1, 1)); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); std::vector cu_seqlens_vec( cu_seqlens_cpu.data_ptr(), cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); cu_seqlens = cu_seqlens.to(hidden_states.device()); for (int idx = 0; idx < blocks_->size(); ++idx) { - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens, - cu_seqlens_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, idx); } hidden_states = post_layernorm_(hidden_states); hidden_states = hidden_states.view( @@ -690,10 +676,10 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", Glm4ForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -717,18 +703,17 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module { video_inputs = Glm4VVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -754,8 +739,7 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module { auto flatten_video_grid_thw = torch::cat(temp_frames_hw, 0); // visual auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), - flatten_video_grid_thw, - input_params); + flatten_video_grid_thw); auto t = video_input->video_grid_thw.index({torch::indexing::Slice(), 0}); auto video_tokens = ((video_input->video_grid_thw.prod(-1) / merge_size / merge_size) / t) @@ -789,8 +773,8 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -805,11 +789,9 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor logits(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/glm4v_moe.h b/xllm/models/vlm/npu/glm4v_moe.h index 38abfa8339..49cd5f3830 100644 --- a/xllm/models/vlm/npu/glm4v_moe.h +++ b/xllm/models/vlm/npu/glm4v_moe.h @@ -29,6 +29,7 @@ limitations under the License. #include "core/framework/model_context.h" #include "core/layers/npu/npu_lm_head_impl.h" #include "core/layers/npu/npu_rms_norm_impl.h" +#include "core/runtime/forward_params.h" #include "glm4v.h" #include "models/llm/npu/glm4_moe.h" #include "models/model_registry.h" @@ -52,10 +53,10 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", Glm4MoeForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -79,18 +80,17 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module { video_inputs = Glm4VVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -116,8 +116,7 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module { auto flatten_video_grid_thw = torch::cat(temp_frames_hw, 0); // visual auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), - flatten_video_grid_thw, - input_params); + flatten_video_grid_thw); auto t = video_input->video_grid_thw.index({torch::indexing::Slice(), 0}); auto video_tokens = ((video_input->video_grid_thw.prod(-1) / merge_size / merge_size) / t) @@ -151,8 +150,8 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -167,11 +166,9 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor logits(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/minicpmv.h b/xllm/models/vlm/npu/minicpmv.h index eb63c1af80..ba282f8a84 100644 --- a/xllm/models/vlm/npu/minicpmv.h +++ b/xllm/models/vlm/npu/minicpmv.h @@ -28,6 +28,7 @@ limitations under the License. #include "core/framework/model_context.h" #include "core/layers/npu/multi_head_attention.h" #include "core/layers/npu/npu_siglip_encoder_layer_impl.h" +#include "core/runtime/forward_params.h" #include "models/llm/npu/qwen2.h" #include "models/model_registry.h" #include "processors/minicpmv_image_processor.h" @@ -821,9 +822,9 @@ class MiniCPMV2_6Impl : public torch::nn::Module { mlp_ = register_module("mlp", VisionAdapterMLP(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; std::vector pixel_values; if (const auto& res = @@ -949,9 +950,9 @@ class MiniCPMV2_6Impl : public torch::nn::Module { return llm_embedding; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_inputs; - prepare_encoder_input(input_params, image_inputs); + prepare_encoder_input(input.multimodal, image_inputs); MMDict multimodal_embeds; if (!image_inputs.has_value()) { return multimodal_embeds; @@ -1013,8 +1014,8 @@ class MiniCPMV2_6Impl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -1032,11 +1033,9 @@ class MiniCPMV2_6Impl : public torch::nn::Module { inputs_embeds, multimodal_embeds, image_bounds); return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor logits(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/oxygen_vlm.h b/xllm/models/vlm/npu/oxygen_vlm.h index c420138cbc..b3837fcc93 100644 --- a/xllm/models/vlm/npu/oxygen_vlm.h +++ b/xllm/models/vlm/npu/oxygen_vlm.h @@ -27,6 +27,7 @@ limitations under the License. #include "core/framework/model/model_input_params.h" #include "core/framework/model/model_output.h" #include "core/layers/npu/npu_lm_head_impl.h" +#include "core/runtime/forward_params.h" #include "glm4v.h" #include "models/llm/npu/oxygen.h" #include "models/model_registry.h" @@ -47,10 +48,10 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", OxygenForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -74,18 +75,17 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { video_inputs = Glm4VVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -111,8 +111,7 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { auto flatten_video_grid_thw = torch::cat(temp_frames_hw, 0); // visual auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), - flatten_video_grid_thw, - input_params); + flatten_video_grid_thw); auto t = video_input->video_grid_thw.index({torch::indexing::Slice(), 0}); auto video_tokens = ((video_input->video_grid_thw.prod(-1) / merge_size / merge_size) / t) @@ -146,8 +145,8 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -162,11 +161,9 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - auto emb = language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + auto emb = language_model_(input, kv_caches); return emb; } diff --git a/xllm/models/vlm/npu/qwen2_5_vl.h b/xllm/models/vlm/npu/qwen2_5_vl.h index 75c999d05f..a17aafe680 100644 --- a/xllm/models/vlm/npu/qwen2_5_vl.h +++ b/xllm/models/vlm/npu/qwen2_5_vl.h @@ -28,6 +28,7 @@ limitations under the License. #include "core/layers/npu/npu_qwen2_decoder_layer_impl.h" #include "core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h" #include "core/layers/npu/npu_rms_norm_impl.h" +#include "core/runtime/forward_params.h" #include "models/llm/npu/qwen2.h" #include "models/model_registry.h" #include "processors/qwen2_vl_image_processor.h" @@ -50,15 +51,9 @@ class Qwen2_5_VisionBlockImpl : public torch::nn::Module { torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id) { - return encoder_layer_(x, - m_cos_pos, - m_sin_pos, - cu_seq_len, - cu_seq_len_vec, - input_params, - node_id); + return encoder_layer_( + x, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec, node_id); } // load the weight from the checkpoint @@ -418,8 +413,7 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { } torch::Tensor forward(torch::Tensor hidden_states, - torch::Tensor grid_thw, // [batch,thw] - const ModelInputParams& input_params) { + torch::Tensor grid_thw) { // [batch,thw] // patchify // hidden_states = x.to(device=self.device, dtype=self.dtype); hidden_states = patch_embed_(hidden_states); @@ -472,8 +466,6 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { m_cos = m_cos.repeat({1, 2}); m_sin = m_sin.repeat({1, 2}); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); torch::Tensor cu_window_seqlens_cpu = cu_window_seqlens.cpu(); std::vector cu_seqlens_vec( @@ -493,13 +485,8 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { cu_seqlens_now = cu_window_seqlens; cu_seqlens_now_vec = cu_w_seqlens_vec; } - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens_now, - cu_seqlens_now_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens_now, cu_seqlens_now_vec, idx); } // adapter hidden_states = merger_(hidden_states); @@ -581,10 +568,10 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { } void prepare_encoder_input( - const ModelInputParams& input_params, + const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -614,17 +601,16 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { pixel_values_videos, video_grid_thw, second_per_grid_ts}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -639,8 +625,7 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { if (video_input) { // visual auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), - video_input->video_grid_thw, - input_params); + video_input->video_grid_thw); auto video_tokens = (video_input->video_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -673,8 +658,8 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -689,11 +674,9 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/qwen2_5_vl_mm_embedding.h b/xllm/models/vlm/npu/qwen2_5_vl_mm_embedding.h index 8e898fee95..aa79c5437e 100644 --- a/xllm/models/vlm/npu/qwen2_5_vl_mm_embedding.h +++ b/xllm/models/vlm/npu/qwen2_5_vl_mm_embedding.h @@ -47,9 +47,9 @@ class Qwen2_5_VLForMMEmbeddingImpl : public torch::nn::Module { return images_size; } - MMDict encode(const ModelInputParams& input_params) { + MMDict encode(const ForwardInput& input) { torch::NoGradGuard no_grad; - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = input.multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) @@ -66,8 +66,7 @@ class Qwen2_5_VLForMMEmbeddingImpl : public torch::nn::Module { CHECK(image_inputs.has_value()); auto image_embeds = visual_(image_inputs->pixel_values.to(options_), - image_inputs->image_grid_thw, - input_params); + image_inputs->image_grid_thw); std::vector mm_embeddings; @@ -118,12 +117,12 @@ class MMEmbeddingVLMImpl const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - MMDict encode(const ModelInputParams& input_params) override { - return model_->encode(input_params); + MMDict encode(const ForwardInput& input) override { + return model_->encode(input); }; torch::Tensor get_input_embeddings(const torch::Tensor& input_ids, - const ModelInputParams& input_params) { + const MultiModalInput& multimodal) { return torch::Tensor{}; } @@ -132,10 +131,8 @@ class MMEmbeddingVLMImpl return torch::Tensor(); } - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { return ModelOutput(); } virtual void prepare_expert_weight(int32_t layer_id, diff --git a/xllm/models/vlm/npu/qwen2_vl.h b/xllm/models/vlm/npu/qwen2_vl.h index 54d1947ead..f1d6166abb 100644 --- a/xllm/models/vlm/npu/qwen2_vl.h +++ b/xllm/models/vlm/npu/qwen2_vl.h @@ -27,6 +27,7 @@ limitations under the License. #include "core/layers/npu/npu_lm_head_impl.h" #include "core/layers/npu/npu_qwen2_vision_encoder_layer_impl.h" #include "core/layers/npu/npu_rms_norm_impl.h" +#include "core/runtime/forward_params.h" #include "models/llm/npu/qwen2.h" #include "models/model_registry.h" #include "processors/qwen2_vl_image_processor.h" @@ -48,15 +49,9 @@ class Qwen2_VisionBlockImpl : public torch::nn::Module { torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id) { - return encoder_layer_(x, - m_cos_pos, - m_sin_pos, - cu_seq_len, - cu_seq_len_vec, - input_params, - node_id); + return encoder_layer_( + x, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec, node_id); } // load the weight from the checkpoint @@ -363,8 +358,7 @@ class Qwen2_VisionTransformerImpl : public torch::nn::Module { } torch::Tensor forward(torch::Tensor hidden_states, - torch::Tensor grid_thw, // [batch,thw] - const ModelInputParams& input_params) { + torch::Tensor grid_thw) { // [batch,thw] // patchify // hidden_states = x.to(device=self.device, dtype=self.dtype); hidden_states = patch_embed_(hidden_states); @@ -388,20 +382,13 @@ class Qwen2_VisionTransformerImpl : public torch::nn::Module { m_cos = m_cos.repeat({1, 2}); m_sin = rotary_pos_emb.sin().type_as(hidden_states); m_sin = m_sin.repeat({1, 2}); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); std::vector cu_seqlens_vec( cu_seqlens_cpu.data_ptr(), // full seqlen vec cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); for (int idx = 0; idx < blocks_->size(); ++idx) { - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens, - cu_seqlens_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, idx); } // adapter hidden_states = merger_(hidden_states); @@ -478,10 +465,10 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", QWen2ForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -511,17 +498,16 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { pixel_values_videos, video_grid_thw, second_per_grid_ts}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -554,8 +540,8 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -570,11 +556,9 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/qwen3_vl.h b/xllm/models/vlm/npu/qwen3_vl.h index 92bdd7a9df..2aa30cbc6d 100644 --- a/xllm/models/vlm/npu/qwen3_vl.h +++ b/xllm/models/vlm/npu/qwen3_vl.h @@ -23,6 +23,7 @@ limitations under the License. #include "core/layers/npu/npu_lm_head_impl.h" #include "core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h" #include "core/layers/npu/npu_rms_norm_impl.h" +#include "core/runtime/forward_params.h" #include "models/llm/npu/qwen3.h" #include "models/model_registry.h" #include "processors/qwen3_vl_image_processor.h" @@ -100,15 +101,9 @@ class Qwen3_VisionBlockImpl : public torch::nn::Module { torch::Tensor& m_sin_pos, torch::Tensor& cu_seq_len, std::vector& cu_seq_len_vec, - ModelInputParams& input_params, int node_id) { - return encoder_layer_(x, - m_cos_pos, - m_sin_pos, - cu_seq_len, - cu_seq_len_vec, - input_params, - node_id); + return encoder_layer_( + x, m_cos_pos, m_sin_pos, cu_seq_len, cu_seq_len_vec, node_id); } // load the weight from the checkpoint @@ -483,8 +478,7 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { std::tuple> forward( torch::Tensor hidden_states, - torch::Tensor grid_thw, // [batch,thw] - const ModelInputParams& input_params) { + torch::Tensor grid_thw) { // [batch,thw] hidden_states = patch_embed_(hidden_states); auto pos_embeds = fast_pos_embed_interpolate(grid_thw); hidden_states = hidden_states + pos_embeds; @@ -508,8 +502,6 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { m_sin = rotary_pos_emb.sin().type_as(hidden_states); m_sin = m_sin.repeat({1, 2}); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); std::vector cu_seqlens_vec( cu_seqlens_cpu.data_ptr(), // full seqlen vec @@ -517,13 +509,8 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { std::vector deepstack_feature_lists; deepstack_feature_lists.reserve(deepstack_visual_indexes_.size()); for (int idx = 0; idx < blocks_->size(); ++idx) { - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens, - cu_seqlens_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, idx); auto it = std::find(deepstack_visual_indexes_.begin(), deepstack_visual_indexes_.end(), idx); @@ -637,10 +624,10 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", QWen3ForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -664,18 +651,17 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); MMDict multimodal_embeds; auto merge_size = model_args_.mm_image_merge_size(); if (image_input) { auto [image_embeds, deep_stacks] = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw.to(options_.device()), - input_params); + image_input->image_grid_thw.to(options_.device())); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) @@ -698,8 +684,7 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { if (video_input) { auto [video_embeds, deep_stacks] = visual_(video_input->pixel_values_videos.to(options_), - video_input->video_grid_thw.to(options_.device()), - input_params); + video_input->video_grid_thw.to(options_.device())); auto video_tokens = (video_input->video_grid_thw.prod(-1) / merge_size / merge_size) @@ -730,9 +715,8 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { return is_multimodal; } - std::vector get_deep_stacks( - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + std::vector get_deep_stacks(const ForwardInput& input) { + const auto& mm_data = input.multimodal.mm_data; if (!mm_data.has("embedding|deepstack_0")) { return {}; } @@ -753,8 +737,8 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -764,19 +748,17 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } auto is_multimodal = generate_multimodal_mask(input_ids); - input_params.multimodal.visual_pos_masks = is_multimodal; + multimodal.visual_pos_masks = is_multimodal; inputs_embeds = merge_multimodal_embeddings( inputs_embeds, multimodal_embeds, is_multimodal); return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - input_params.multimodal.deep_stacks = - std::move(get_deep_stacks(input_params)); - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + ForwardInput model_input = input; + model_input.multimodal.deep_stacks = get_deep_stacks(model_input); + return language_model_(model_input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/npu/qwen3_vl_mm_embedding.h b/xllm/models/vlm/npu/qwen3_vl_mm_embedding.h index 634ade0a17..376aba71ef 100644 --- a/xllm/models/vlm/npu/qwen3_vl_mm_embedding.h +++ b/xllm/models/vlm/npu/qwen3_vl_mm_embedding.h @@ -50,9 +50,9 @@ class Qwen3_VLForMMEmbeddingImpl : public torch::nn::Module { return torch::cat(tensors, 0); } - MMDict encode(const ModelInputParams& input_params) { + MMDict encode(const ForwardInput& input) { torch::NoGradGuard no_grad; - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = input.multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -66,11 +66,8 @@ class Qwen3_VLForMMEmbeddingImpl : public torch::nn::Module { image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; CHECK(image_inputs.has_value()); - auto [image_embeds, deep_stacks] = - visual_(image_inputs->pixel_values.to(options_), - image_inputs->image_grid_thw, - input_params); - input_params.multimodal.deep_stacks = deep_stacks; + auto [image_embeds, deep_stacks] = visual_( + image_inputs->pixel_values.to(options_), image_inputs->image_grid_thw); std::vector image_sizes = get_images_size(image_grid_thw); @@ -131,12 +128,12 @@ class MMEmbeddingVLMImpl const torch::TensorOptions& options) : model_(std::move(model)), options_(options) {} - MMDict encode(const ModelInputParams& input_params) override { - return model_->encode(input_params); + MMDict encode(const ForwardInput& input) override { + return model_->encode(input); }; torch::Tensor get_input_embeddings(const torch::Tensor& input_ids, - const ModelInputParams& input_params) { + const MultiModalInput& multimodal) { return torch::Tensor{}; } @@ -145,10 +142,8 @@ class MMEmbeddingVLMImpl return torch::Tensor(); } - virtual ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { return ModelOutput(); } virtual void prepare_expert_weight(int32_t layer_id, diff --git a/xllm/models/vlm/npu/qwen3_vl_moe.h b/xllm/models/vlm/npu/qwen3_vl_moe.h index 2d60f9d37a..d82ebaf79b 100644 --- a/xllm/models/vlm/npu/qwen3_vl_moe.h +++ b/xllm/models/vlm/npu/qwen3_vl_moe.h @@ -24,6 +24,7 @@ limitations under the License. #include "core/layers/npu/npu_lm_head_impl.h" #include "core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h" #include "core/layers/npu/npu_rms_norm_impl.h" +#include "core/runtime/forward_params.h" #include "models/llm/npu/qwen3_moe.h" #include "models/model_registry.h" #include "processors/qwen2_vl_image_processor.h" @@ -47,10 +48,10 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", Qwen3MoeForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -74,18 +75,17 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); MMDict multimodal_embeds; auto merge_size = model_args_.mm_image_merge_size(); if (image_input) { auto [image_embeds, deep_stacks] = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw.to(options_.device()), - input_params); + image_input->image_grid_thw.to(options_.device())); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) @@ -116,9 +116,8 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { return is_multimodal; } - std::vector get_deep_stacks( - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + std::vector get_deep_stacks(const ForwardInput& input) { + const auto& mm_data = input.multimodal.mm_data; if (!mm_data.has("embedding|deepstack_0")) { return {}; } @@ -139,8 +138,8 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -150,19 +149,17 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } auto is_multimodal = generate_multimodal_mask(input_ids); - input_params.multimodal.visual_pos_masks = is_multimodal; + multimodal.visual_pos_masks = is_multimodal; inputs_embeds = merge_multimodal_embeddings( inputs_embeds, multimodal_embeds, is_multimodal); return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - input_params.multimodal.deep_stacks = - std::move(get_deep_stacks(input_params)); - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + ForwardInput model_input = input; + model_input.multimodal.deep_stacks = get_deep_stacks(model_input); + return language_model_(model_input, kv_caches); } torch::Tensor logits(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/oxygen_vlm.h b/xllm/models/vlm/oxygen_vlm.h index 0592fb39fe..171a46b3dd 100644 --- a/xllm/models/vlm/oxygen_vlm.h +++ b/xllm/models/vlm/oxygen_vlm.h @@ -20,6 +20,7 @@ limitations under the License. #include "core/layers/oxygen_vision_layer.h" #include "core/layers/qwen2_5_vision_layer.h" #include "core/layers/qwen2_decoder_layer.h" +#include "core/runtime/forward_params.h" #include "models/llm/oxygen.h" #include "models/model_registry.h" #include "processors/input_processor.h" @@ -408,9 +409,7 @@ class OxygenVisionTransformerImpl : public torch::nn::Module { return std::make_tuple(rotary_pos_emb, pos_ids); } - torch::Tensor forward(torch::Tensor hidden_states, - torch::Tensor grid_thw, - const ModelInputParams& input_params) { + torch::Tensor forward(torch::Tensor hidden_states, torch::Tensor grid_thw) { hidden_states = patch_embed_(hidden_states); hidden_states = std::get<0>(post_conv_layernorm_(hidden_states)); @@ -452,21 +451,14 @@ class OxygenVisionTransformerImpl : public torch::nn::Module { grid_thw, image_type_ids.select(1, 0), image_type_ids.select(1, 1)); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); std::vector cu_seqlens_vec( cu_seqlens_cpu.data_ptr(), cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); cu_seqlens = cu_seqlens.to(hidden_states.device()); for (int idx = 0; idx < blocks_->size(); ++idx) { - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens, - cu_seqlens_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, idx); } hidden_states = std::get<0>(post_layernorm_(hidden_states)); hidden_states = hidden_states.view( @@ -569,10 +561,10 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", OxygenForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -596,18 +588,17 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { video_inputs = OxygenVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -633,8 +624,7 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { auto flatten_video_grid_thw = torch::cat(temp_frames_hw, 0); // visual auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), - flatten_video_grid_thw, - input_params); + flatten_video_grid_thw); // Split based on original video count, not frame count // video_grid_thw has shape [num_videos, 3], video_embeds is flattened // We need to split video_embeds back to match num_videos @@ -669,8 +659,8 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -685,11 +675,9 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor logits(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/qwen2_5_vl.h b/xllm/models/vlm/qwen2_5_vl.h index 09d96598bd..a597ea4f5b 100644 --- a/xllm/models/vlm/qwen2_5_vl.h +++ b/xllm/models/vlm/qwen2_5_vl.h @@ -19,6 +19,7 @@ limitations under the License. #include "core/layers/common/lm_head.h" #include "core/layers/qwen2_5_vision_layer.h" #include "core/layers/qwen2_decoder_layer.h" +#include "core/runtime/forward_params.h" #include "models/llm/qwen2.h" #include "models/model_registry.h" #include "processors/qwen2_vl_image_processor.h" @@ -355,8 +356,7 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { } torch::Tensor forward(torch::Tensor hidden_states, - torch::Tensor grid_thw, // [batch,thw] - const ModelInputParams& input_params) { + torch::Tensor grid_thw) { // [batch,thw] // patchify // hidden_states = x.to(device=self.device, dtype=self.dtype); hidden_states = patch_embed_(hidden_states); @@ -400,8 +400,6 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { m_cos = m_cos.repeat({1, 2}); m_sin = m_sin.repeat({1, 2}); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); torch::Tensor cu_window_seqlens_cpu = cu_window_seqlens.cpu(); std::vector cu_seqlens_vec( @@ -420,13 +418,8 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { cu_seqlens_now = cu_window_seqlens; cu_seqlens_now_vec = cu_w_seqlens_vec; } - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens_now, - cu_seqlens_now_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens_now, cu_seqlens_now_vec, idx); } // adapter hidden_states = merger_(hidden_states); @@ -491,10 +484,10 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", QWen2ForCausalLM(context)); } void prepare_encoder_input( - const ModelInputParams& input_params, + const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -524,17 +517,16 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { pixel_values_videos, video_grid_thw, second_per_grid_ts}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const MultiModalInput& multimodal) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -550,8 +542,7 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { if (video_input) { // visual auto video_embeds = visual_(video_input->pixel_values_videos.to(options_), - video_input->video_grid_thw, - input_params); + video_input->video_grid_thw); auto video_tokens = (video_input->video_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -567,6 +558,10 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { return multimodal_embeds; } + MMDict get_multimodal_embeddings(const ForwardInput& input) { + return get_multimodal_embeddings(input.multimodal); + } + torch::Tensor generate_multimodal_mask(torch::Tensor input_ids) { auto special_token_ids = torch::tensor( {model_args_.image_token_id(), model_args_.video_token_id()}, @@ -584,15 +579,15 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); } else if (mm_data.get("pixel_values").has_value() && mm_data.get("image_grid_thw").has_value()) { // Compute vision embeddings from pixel_values and merge with text - auto mm_dict = get_multimodal_embeddings(input_params); + auto mm_dict = get_multimodal_embeddings(multimodal); if (mm_dict.count("image|embedding")) { const auto& image_embeds_list = std::get>(mm_dict["image|embedding"]); @@ -604,7 +599,9 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { auto video_embeds = torch::cat(video_embeds_list, 0); multimodal_embeds = multimodal_embeds.defined() - ? torch::cat({multimodal_embeds, video_embeds}, 0) + ? torch::cat(std::vector{multimodal_embeds, + video_embeds}, + 0) : video_embeds; } } @@ -618,11 +615,9 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/qwen2_vl.h b/xllm/models/vlm/qwen2_vl.h index 94c360b9a1..f996471098 100644 --- a/xllm/models/vlm/qwen2_vl.h +++ b/xllm/models/vlm/qwen2_vl.h @@ -19,6 +19,7 @@ limitations under the License. #include "core/layers/common/lm_head.h" #include "core/layers/qwen2_decoder_layer.h" #include "core/layers/qwen2_vision_layer.h" +#include "core/runtime/forward_params.h" #include "models/llm/qwen2.h" #include "models/model_registry.h" #include "processors/qwen2_vl_image_processor.h" @@ -315,8 +316,7 @@ class Qwen2_VisionTransformerImpl : public torch::nn::Module { } torch::Tensor forward(torch::Tensor hidden_states, - torch::Tensor grid_thw, // [batch,thw] - const ModelInputParams& input_params) { + torch::Tensor grid_thw) { // [batch,thw] // patchify // hidden_states = x.to(device=self.device, dtype=self.dtype); hidden_states = patch_embed_(hidden_states); @@ -338,20 +338,13 @@ class Qwen2_VisionTransformerImpl : public torch::nn::Module { m_cos = m_cos.repeat({1, 2}); m_sin = rotary_pos_emb.sin().type_as(hidden_states); m_sin = m_sin.repeat({1, 2}); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); std::vector cu_seqlens_vec( cu_seqlens_cpu.data_ptr(), // full seqlen vec cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); for (int idx = 0; idx < blocks_->size(); ++idx) { - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens, - cu_seqlens_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, idx); } // adapter hidden_states = merger_(hidden_states); @@ -413,10 +406,10 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", QWen2ForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -446,17 +439,16 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { pixel_values_videos, video_grid_thw, second_per_grid_ts}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); auto merge_size = model_args_.mm_image_merge_size(); MMDict multimodal_embeds; if (image_input) { // visual auto image_embeds = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw, - input_params); + image_input->image_grid_thw); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) .cpu() @@ -489,8 +481,8 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -505,11 +497,9 @@ class Qwen2_VLForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + return language_model_(input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/qwen3_5.h b/xllm/models/vlm/qwen3_5.h index ea8d40ff93..f31f4e5cc9 100644 --- a/xllm/models/vlm/qwen3_5.h +++ b/xllm/models/vlm/qwen3_5.h @@ -107,43 +107,43 @@ class Qwen3_5ModelImpl final return std::make_pair(cos_pos, sin_pos); } - virtual ModelOutput forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - ModelInputParams& input_params_new = - const_cast(input_params); + virtual ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + ForwardInput input_params_new = input; std::vector deep_stacks; if (dp_size_ > 1) { - if (tokens.numel() == 0) { - tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); - positions = torch::tensor({1}).to(torch::kInt32).to(positions.device()); + if (input_params_new.token_ids.numel() == 0) { + input_params_new.token_ids = + torch::tensor({1}).to(torch::kInt32).to(input.token_ids.device()); + input_params_new.positions = + torch::tensor({1}).to(torch::kInt32).to(input.positions.device()); } auto& dp_token_nums = input_params_new.parallel.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); } - auto inputs_embeds = input_params.embedding.input_embedding; + auto inputs_embeds = input_params_new.embedding.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { h = inputs_embeds; } else { - h = embed_tokens_(tokens); + h = embed_tokens_(input_params_new.token_ids); } if (!input_params_new.attn_metadata) { input_params_new.attn_metadata = std::make_shared( - get_attention_metadata(input_params_new, h)); + get_attention_metadata(input_params_new)); } auto& attn_metadata = *(input_params_new.attn_metadata); bool only_prefill = (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill); - if (positions.dim() == 2 && only_prefill && !mrope_section_.empty()) { + if (input_params_new.positions.dim() == 2 && only_prefill && + !mrope_section_.empty()) { std::tie(attn_metadata.mrope_cos, attn_metadata.mrope_sin) = - apply_mrope(positions); + apply_mrope(input_params_new.positions); } std::optional residual; @@ -151,7 +151,7 @@ class Qwen3_5ModelImpl final auto& layer = layers_[i]; h = layer(h, residual, - positions, + input_params_new.positions, attn_metadata, kv_caches[i], input_params_new); @@ -166,10 +166,14 @@ class Qwen3_5ModelImpl final private: int32_t dp_size_ = 1; layer::Qwen3NextRMSNorm rms_norm_{nullptr}; - layer::AttentionMetadata get_attention_metadata( - const ModelInputParams& params, - const torch::Tensor& h) { - auto attn_metadata = layer::AttentionMetadataBuilder::build(params, false); + layer::AttentionMetadata get_attention_metadata(const ForwardInput& params) { + auto attn_metadata = + layer::AttentionMetadataBuilder::build(params.meta, + params.attention, + params.graph, + params.llmrec_params(), + params.enable_cuda_graph, + /*enable_mla=*/false); // TODO: support linear attention return attn_metadata; } diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h index 01b6b666d7..4109aaef31 100644 --- a/xllm/models/vlm/qwen3_vl.h +++ b/xllm/models/vlm/qwen3_vl.h @@ -18,6 +18,7 @@ limitations under the License. #include "core/framework/model/model_output.h" #include "core/layers/common/lm_head.h" #include "core/layers/qwen3_vision_layer.h" +#include "core/runtime/forward_params.h" #include "models/llm/qwen3.h" #include "models/model_registry.h" #include "models/vlm/qwen3_vl_base.h" @@ -437,8 +438,7 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { std::tuple> forward( torch::Tensor hidden_states, - torch::Tensor grid_thw, // [batch,thw] - const ModelInputParams& input_params) { + torch::Tensor grid_thw) { // [batch,thw] hidden_states = patch_embed_(hidden_states); auto pos_embeds = fast_pos_embed_interpolate(grid_thw); hidden_states = hidden_states + pos_embeds; @@ -459,8 +459,6 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { m_sin = rotary_pos_emb.sin().type_as(hidden_states); m_sin = m_sin.repeat({1, 2}); - ModelInputParams& input_params_new = - const_cast(input_params); torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); std::vector cu_seqlens_vec( cu_seqlens_cpu.data_ptr(), // full seqlen vec @@ -468,13 +466,8 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { std::vector deepstack_feature_lists; deepstack_feature_lists.reserve(deepstack_visual_indexes_.size()); for (int idx = 0; idx < layers_.size(); ++idx) { - hidden_states = layers_[idx](hidden_states, - m_cos, - m_sin, - cu_seqlens, - cu_seqlens_vec, - input_params_new, - idx); + hidden_states = layers_[idx]( + hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, idx); auto it = std::find(deepstack_visual_indexes_.begin(), deepstack_visual_indexes_.end(), idx); diff --git a/xllm/models/vlm/qwen3_vl_base.h b/xllm/models/vlm/qwen3_vl_base.h index a159c4d276..4772428c2b 100644 --- a/xllm/models/vlm/qwen3_vl_base.h +++ b/xllm/models/vlm/qwen3_vl_base.h @@ -18,6 +18,7 @@ limitations under the License. #include "core/framework/model/model_output.h" #include "core/framework/request/mm_data_item.h" #include "core/layers/common/lm_head.h" +#include "core/runtime/forward_params.h" #include "models/model_registry.h" namespace xllm { @@ -43,10 +44,10 @@ class Qwen3VLForConditionalGenerationBase : public torch::nn::Module { language_model_ = register_module("language_model", LanguageModel(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -70,18 +71,17 @@ class Qwen3VLForConditionalGenerationBase : public torch::nn::Module { video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); MMDict multimodal_embeds; auto merge_size = model_args_.mm_image_merge_size(); if (image_input) { auto [image_embeds, deep_stacks] = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw.to(options_.device()), - input_params); + image_input->image_grid_thw.to(options_.device())); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) @@ -104,8 +104,7 @@ class Qwen3VLForConditionalGenerationBase : public torch::nn::Module { if (video_input) { auto [video_embeds, deep_stacks] = visual_(video_input->pixel_values_videos.to(options_), - video_input->video_grid_thw.to(options_.device()), - input_params); + video_input->video_grid_thw.to(options_.device())); auto video_tokens = (video_input->video_grid_thw.prod(-1) / merge_size / merge_size) @@ -137,8 +136,8 @@ class Qwen3VLForConditionalGenerationBase : public torch::nn::Module { } std::vector get_deep_stacks( - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; if (!mm_data.has("embedding|deepstack_0")) { return {}; } @@ -159,8 +158,8 @@ class Qwen3VLForConditionalGenerationBase : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -170,19 +169,18 @@ class Qwen3VLForConditionalGenerationBase : public torch::nn::Module { return inputs_embeds; } auto is_multimodal = generate_multimodal_mask(input_ids); - input_params.multimodal.visual_pos_masks = is_multimodal; + multimodal.visual_pos_masks = is_multimodal; inputs_embeds = merge_multimodal_embeddings( inputs_embeds, multimodal_embeds, is_multimodal); return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - input_params.multimodal.deep_stacks = - std::move(get_deep_stacks(input_params)); - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + ForwardInput model_input = input; + model_input.multimodal.deep_stacks = + std::move(get_deep_stacks(model_input.multimodal)); + return language_model_(model_input, kv_caches); } torch::Tensor pooler(const torch::Tensor& hidden_states, diff --git a/xllm/models/vlm/qwen3_vl_moe.h b/xllm/models/vlm/qwen3_vl_moe.h index d6ae502777..acdf03e10d 100644 --- a/xllm/models/vlm/qwen3_vl_moe.h +++ b/xllm/models/vlm/qwen3_vl_moe.h @@ -18,6 +18,7 @@ limitations under the License. #include "core/framework/model/model_output.h" #include "core/layers/common/lm_head.h" #include "core/layers/qwen3_vision_layer.h" +#include "core/runtime/forward_params.h" #include "models/llm/qwen3_moe.h" #include "models/model_registry.h" #include "processors/qwen2_vl_image_processor.h" @@ -41,10 +42,10 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { register_module("language_model", Qwen3MoeForCausalLM(context)); } - void prepare_encoder_input(const ModelInputParams& input_params, + void prepare_encoder_input(const MultiModalInput& multimodal, std::optional& image_inputs, std::optional& video_inputs) { - const auto& mm_data = input_params.multimodal.mm_data; + const auto& mm_data = multimodal.mm_data; torch::Tensor pixel_values; if (const auto& res = mm_data.get("pixel_values")) pixel_values = res.value(); @@ -68,18 +69,17 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { video_inputs = Qwen3_VLVideoInputs{pixel_values_videos, video_grid_thw}; } - MMDict get_multimodal_embeddings(const ModelInputParams& input_params) { + MMDict get_multimodal_embeddings(const ForwardInput& input) { std::optional image_input; std::optional video_input; - prepare_encoder_input(input_params, image_input, video_input); + prepare_encoder_input(input.multimodal, image_input, video_input); MMDict multimodal_embeds; auto merge_size = model_args_.mm_image_merge_size(); if (image_input) { auto [image_embeds, deep_stacks] = visual_(image_input->pixel_values.to(options_), - image_input->image_grid_thw.to(options_.device()), - input_params); + image_input->image_grid_thw.to(options_.device())); auto image_tokens = (image_input->image_grid_thw.prod(-1) / merge_size / merge_size) @@ -111,8 +111,8 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { } std::vector get_deep_stacks( - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; if (!mm_data.has("embedding|deepstack_0")) { return {}; } @@ -133,8 +133,8 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { } torch::Tensor get_input_embeddings(const torch::Tensor input_ids, - const ModelInputParams& input_params) { - const auto& mm_data = input_params.multimodal.mm_data; + const MultiModalInput& multimodal) { + const auto& mm_data = multimodal.mm_data; torch::Tensor multimodal_embeds; if (const auto& emb = mm_data.get("embedding")) { multimodal_embeds = emb.value(); @@ -144,19 +144,18 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { return inputs_embeds; } auto is_multimodal = generate_multimodal_mask(input_ids); - input_params.multimodal.visual_pos_masks = is_multimodal; + multimodal.visual_pos_masks = is_multimodal; inputs_embeds = merge_multimodal_embeddings( inputs_embeds, multimodal_embeds, is_multimodal); return inputs_embeds; } - ModelOutput forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - input_params.multimodal.deep_stacks = - std::move(get_deep_stacks(input_params)); - return language_model_(tokens, positions, kv_caches, input_params); + ModelOutput forward(const ForwardInput& input, + std::vector& kv_caches) { + ForwardInput model_input = input; + model_input.multimodal.deep_stacks = + get_deep_stacks(model_input.multimodal); + return language_model_(model_input, kv_caches); } torch::Tensor logits(const torch::Tensor& hidden_states,