feat: add TileLang chunk_gated_delta_rule_fwd_h kernel.#1498
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the chunk_gated_delta_rule_fwd_h kernel for NPU, including the Python builder, C++ wrapper, and unit tests. The review identified critical dimension and shape handling issues in the C++ wrapper for multi-batch scenarios, requested missing type annotations in the Python kernel builder, addressed a hardcoded dtype issue, and recommended renaming an inverted boolean variable for better maintainability.
| auto cu_prepared = | ||
| params.cu_seqlens.has_value() | ||
| ? std::optional<torch::Tensor>( | ||
| params.cu_seqlens.value().to(torch::kInt32).contiguous()) | ||
| : std::nullopt; | ||
| auto g_cumsum = | ||
| npu::npu_chunk_local_cumsum(params.g, chunk_size, cu_prepared); | ||
| const float scale_value = params.scale.has_value() | ||
| ? params.scale.value() | ||
| : std::pow(static_cast<float>(head_dim), -0.5f); | ||
| auto matrix_a = npu::npu_chunk_scaled_dot_kkt_fwd( | ||
| k_prepared, params.beta, g_cumsum, chunk_size, cu_prepared); | ||
| auto matrix_a_inv = npu::npu_solve_tril( | ||
| matrix_a, chunk_size, cu_prepared, params.k.scalar_type()); | ||
| auto [w, u] = npu::npu_recompute_w_u_fwd( | ||
| k_prepared, params.v, params.beta, g_cumsum, matrix_a_inv, cu_prepared); | ||
| auto init_state_prepared = | ||
| params.initial_state.has_value() | ||
| ? std::optional<torch::Tensor>( | ||
| params.initial_state.value().to(torch::kBFloat16).contiguous()) | ||
| : std::nullopt; | ||
| auto [h, v_new, final_state] = npu::tilelang::chunk_gated_delta_rule_fwd_h( | ||
| k_prepared.squeeze(0), | ||
| w.squeeze(0), | ||
| u.squeeze(0), | ||
| g_cumsum.squeeze(0), | ||
| init_state_prepared, | ||
| params.output_final_state, | ||
| chunk_size, | ||
| /*save_new_value=*/true, | ||
| cu_prepared, | ||
| /*chunk_offsets=*/std::nullopt); | ||
| auto out = npu::npu_chunk_fwd_o(q_prepared, | ||
| k_prepared, | ||
| v_new.unsqueeze(0), | ||
| h.unsqueeze(0), | ||
| g_cumsum, | ||
| scale_value, | ||
| chunk_size, | ||
| cu_prepared); |
There was a problem hiding this comment.
This block contains several critical issues for batch_size > 1:
- Dimension Mismatch:
squeeze(0)(lines 1219-1222) only works whenbatch_size == 1. ForB > 1, the tensors remain 4D, which will cause aCHECKfailure in the C++ wrapper (CHECK_EQ(k.dim(), 3)). - Output Shape Corruption:
unsqueeze(0)(lines 1231-1232) incorrectly wraps the outputs.his already[B, NT, H, K, V], sounsqueeze(0)makes it[1, B, NT, H, K, V], which is incompatible withnpu_chunk_fwd_oexpectations. - Missing cu_seqlens: The TileLang C++ wrapper requires
cu_seqlensto be defined (it performs aCHECK(cu_seqlens.has_value())). Ifparams.cu_seqlensis not provided, this will crash. A default linearcu_seqlensmust be generated for the fixed-batch case.
auto cu_prepared = params.cu_seqlens.has_value()
? std::optional<torch::Tensor>(
params.cu_seqlens.value().to(torch::kInt32).contiguous())
: std::nullopt;
if (!cu_prepared.has_value()) {
cu_prepared = torch::arange(0, (batch_size + 1) * seq_len, seq_len,
torch::TensorOptions().dtype(torch::kInt32).device(params.q.device()));
}
auto g_cumsum =
npu::npu_chunk_local_cumsum(params.g, chunk_size, cu_prepared);
const float scale_value = params.scale.has_value()
? params.scale.value()
: std::pow(static_cast<float>(head_dim), -0.5f);
auto matrix_a = npu::npu_chunk_scaled_dot_kkt_fwd(
k_prepared, params.beta, g_cumsum, chunk_size, cu_prepared);
auto matrix_a_inv = npu::npu_solve_tril(
matrix_a, chunk_size, cu_prepared, params.k.scalar_type());
auto [w, u] = npu::npu_recompute_w_u_fwd(
k_prepared, params.v, params.beta, g_cumsum, matrix_a_inv, cu_prepared);
auto init_state_prepared =
params.initial_state.has_value()
? std::optional<torch::Tensor>(
params.initial_state.value().to(torch::kBFloat16).contiguous())
: std::nullopt;
auto [h, v_new, final_state] = npu::tilelang::chunk_gated_delta_rule_fwd_h(
k_prepared.flatten(0, 1),
w.flatten(0, 1),
u.flatten(0, 1),
g_cumsum.flatten(0, 1),
init_state_prepared,
params.output_final_state,
chunk_size,
/*save_new_value=*/true,
cu_prepared,
/*chunk_offsets=*/std::nullopt);
auto out = npu::npu_chunk_fwd_o(q_prepared,
k_prepared,
v_new.view({batch_size, seq_len, num_heads_v, -1}),
h,
g_cumsum,
scale_value,
chunk_size,
cu_prepared);| def _build_chunk_gated_delta_rule_fwd_h_kernel( | ||
| H: int, | ||
| Hg: int, | ||
| K: int, | ||
| V: int, | ||
| dtype: str = DEFAULT_DTYPE, | ||
| accum_dtype: str = DEFAULT_ACCUM_DTYPE, | ||
| bt: int = COMPILE_BT, | ||
| use_g: bool = True, | ||
| store_final_state: bool = True, | ||
| save_new_value: bool = True, | ||
| ): |
There was a problem hiding this comment.
The function signature is missing a return type annotation, which is required by the Python style guide (Section 11).
| def _build_chunk_gated_delta_rule_fwd_h_kernel( | |
| H: int, | |
| Hg: int, | |
| K: int, | |
| V: int, | |
| dtype: str = DEFAULT_DTYPE, | |
| accum_dtype: str = DEFAULT_ACCUM_DTYPE, | |
| bt: int = COMPILE_BT, | |
| use_g: bool = True, | |
| store_final_state: bool = True, | |
| save_new_value: bool = True, | |
| ): | |
| def _build_chunk_gated_delta_rule_fwd_h_kernel( | |
| H: int, | |
| Hg: int, | |
| K: int, | |
| V: int, | |
| dtype: str = DEFAULT_DTYPE, | |
| accum_dtype: str = DEFAULT_ACCUM_DTYPE, | |
| bt: int = COMPILE_BT, | |
| use_g: bool = True, | |
| store_final_state: bool = True, | |
| save_new_value: bool = True, | |
| ) -> tilelang.language.prim_func: |
References
- Type annotations are required on all function signatures (parameters and return types). (link)
| save_new_value: bool = True, | ||
| ): | ||
| V_half = V // 2 | ||
| input_dtype = "bfloat16" |
There was a problem hiding this comment.
The input_dtype is hardcoded to "bfloat16", which ignores the dtype argument passed to the builder. This makes the kernel builder inflexible and potentially incorrect if a different dtype (e.g., float16) is requested in the future.
| input_dtype = "bfloat16" | |
| input_dtype = "bfloat16" if dtype == "bf16" else dtype |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: | ||
| BT = chunk_size | ||
| USE_G = g is not None | ||
| IS_VARLEN = cu_seqlens is None |
There was a problem hiding this comment.
The logic for IS_VARLEN is inverted. If cu_seqlens is None, it indicates a standard batch (fixed length), not a variable-length sequence. The subsequent code block (lines 550-563) correctly handles the fixed-length batch case by generating a linear cu_seqlens, but the variable name is misleading and could lead to maintenance errors. Please rename it to IS_FIXED_BATCH and update its usage throughout the function (lines 550, 587, 616).
| IS_VARLEN = cu_seqlens is None | |
| IS_FIXED_BATCH = cu_seqlens is None |
No description provided.