fused_sigmoid_gating_tilelang tilelang adapt in qwen3.x#1465
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the fused_sigmoid_gating_delta_rule kernel for Ascend NPUs, encompassing its TileLang implementation, C++ wrapper, and integration into the ops_api. It also optimizes the fused_gdn_gating and split_qkv_rmsnorm_mrope kernels by removing unnecessary temporary buffers. The review feedback primarily addresses precision loss by suggesting that the SSM state be stored in float32 throughout the computation. Additionally, the reviewer identified several style guide violations concerning Python type annotations, logging, and the improper use of auto for primitive types in C++.
| ssm_state_indices: T.Tensor([max_num_seqs], "int32"), | ||
| cu_seqlens: T.Tensor([max_num_seqs + 1], "int32"), | ||
| out: T.Tensor([total_tokens_padded, nv, dv], input_dtype), | ||
| final_state: T.Tensor([max_num_seqs, nv, dk, dv], input_dtype), |
There was a problem hiding this comment.
The SSM state (final_state) should be saved in high precision (accum_dtype) to avoid precision loss during the recurrent update. Currently, it is being cast to input_dtype (bf16), which can lead to significant errors in SSM scans over long sequences.
| final_state: T.Tensor([max_num_seqs, nv, dk, dv], input_dtype), | |
| final_state: T.Tensor([max_num_seqs, nv, dk, dv], accum_dtype), |
| T.tile.cast(h_store_vec, h_vec, "CAST_RINT", vec_block_v * dk) | ||
| T.set_flag("v", "mte3", 5) | ||
| T.wait_flag("v", "mte3", 5) | ||
| T.copy(h_store_vec, final_state[seq_idx, v_head_idx, :, v_offset : v_offset + vec_block_v]) |
There was a problem hiding this comment.
If final_state is updated to accum_dtype, the cast to input_dtype is unnecessary and should be removed to maintain precision.
| T.tile.cast(h_store_vec, h_vec, "CAST_RINT", vec_block_v * dk) | |
| T.set_flag("v", "mte3", 5) | |
| T.wait_flag("v", "mte3", 5) | |
| T.copy(h_store_vec, final_state[seq_idx, v_head_idx, :, v_offset : v_offset + vec_block_v]) | |
| T.set_flag("v", "mte3", 5) | |
| T.wait_flag("v", "mte3", 5) | |
| T.copy(h_vec, final_state[seq_idx, v_head_idx, :, v_offset : v_offset + vec_block_v]) |
| CHECK(init_state.scalar_type() == torch::kFloat32 || | ||
| init_state.scalar_type() == torch::kBFloat16) | ||
| << "TileLang fused_sigmoid_gating_delta_rule: init_state must be float32 " | ||
| "or bf16"; |
There was a problem hiding this comment.
The TileLang kernel is compiled with accum_dtype = "float" (float32) for init_state. Allowing kBFloat16 here will result in a type mismatch and incorrect memory access in the kernel. This check should strictly enforce kFloat32.
CHECK_EQ(init_state.scalar_type(), torch::kFloat32)
<< "TileLang fused_sigmoid_gating_delta_rule: init_state must be float32";| def golden( | ||
| A_log, | ||
| a, | ||
| dt_bias, | ||
| query, | ||
| key, | ||
| value, | ||
| beta, | ||
| init_state, | ||
| ssm_state_indices, | ||
| cu_seqlens, | ||
| scale=None, | ||
| use_qk_l2norm=True, | ||
| softplus_beta=1.0, | ||
| ): |
There was a problem hiding this comment.
Python function signatures must include type annotations for all parameters and return types per the repository style guide.
References
- Type annotations are required on all function signatures (parameters and return types). (link)
|
|
||
| torch.testing.assert_close(out.cpu(), out_golden, rtol=2e-2, atol=2e-2) | ||
| torch.testing.assert_close(final_state.cpu(), final_state_golden, rtol=2e-2, atol=2e-2) | ||
| print("Kernel Output Match!") |
There was a problem hiding this comment.
Do not use print() for logging. All diagnostic output must go through the shared logger at scripts/logger.py.
| print("Kernel Output Match!") | |
| logger.info("Kernel Output Match!") |
References
- All Python diagnostic output MUST go through the shared logger at scripts/logger.py. Do not use print() for logging. (link)
| const auto nv = value.size(1); | ||
| const auto nk = query.size(1); | ||
| const auto dk = query.size(2); | ||
| const auto dv = value.size(2); |
There was a problem hiding this comment.
Do not use auto for simple/primitive types like int64_t.
| const auto nv = value.size(1); | |
| const auto nk = query.size(1); | |
| const auto dk = query.size(2); | |
| const auto dv = value.size(2); | |
| const int64_t nv = value.size(1); | |
| const int64_t nk = query.size(1); | |
| const int64_t dk = query.size(2); | |
| const int64_t dv = value.size(2); |
References
- Do not use auto for simple/primitive types (int32_t, float, bool, std::string, etc.). (link)
| const auto options = query.options(); | ||
|
|
||
| auto out = torch::empty({query.size(0), nv, dv}, options); | ||
| auto final_state = torch::empty({compiled_n, nv, dk, dv}, options); |
There was a problem hiding this comment.
To maintain precision for the SSM state, final_state should be allocated as float32 if the kernel is updated to return accum_dtype.
| auto final_state = torch::empty({compiled_n, nv, dk, dv}, options); | |
| torch::Tensor final_state = torch::empty({compiled_n, nv, dk, dv}, options.dtype(torch::kFloat32)); |
| auto q = params.q; | ||
| auto k = params.k; | ||
| auto v = params.v; |
There was a problem hiding this comment.
Avoid using auto for torch::Tensor types to improve readability and adhere to the style guide's preference for explicit types over auto for non-complex types.
| auto q = params.q; | |
| auto k = params.k; | |
| auto v = params.v; | |
| torch::Tensor q = params.q; | |
| torch::Tensor k = params.k; | |
| torch::Tensor v = params.v; |
References
- Do not use auto for simple/primitive types. auto is acceptable for complex types (iterators, lambdas, template-deduced types) but not for int32_t, float, bool, std::string, etc. (link)
| auto init_state_small = torch::index_select( | ||
| params.initial_state_source, 0, indices); |
There was a problem hiding this comment.
The TileLang kernel expects init_state to be float32. Ensure the input is cast to kFloat32 before passing it to the wrapper.
| auto init_state_small = torch::index_select( | |
| params.initial_state_source, 0, indices); | |
| torch::Tensor init_state_small = torch::index_select( | |
| params.initial_state_source, 0, indices).to(torch::kFloat32); |
adapt fused_sigmoid_gating_tilelang in tilelang