Skip to content

feat: add TileLang chunk_gated_delta_rule_fwd_h kernel.#1498

Open
fengz72 wants to merge 2 commits into
jd-opensource:mainfrom
fengz72:main
Open

feat: add TileLang chunk_gated_delta_rule_fwd_h kernel.#1498
fengz72 wants to merge 2 commits into
jd-opensource:mainfrom
fengz72:main

Conversation

@fengz72
Copy link
Copy Markdown

@fengz72 fengz72 commented May 20, 2026

No description provided.

@yingxudeng yingxudeng changed the title add TileLang chunk_gated_delta_rule_fwd_h kernel feat: add TileLang chunk_gated_delta_rule_fwd_h kernel. May 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the chunk_gated_delta_rule_fwd_h kernel for NPU, including the Python builder, C++ wrapper, and unit tests. The review identified critical dimension and shape handling issues in the C++ wrapper for multi-batch scenarios, requested missing type annotations in the Python kernel builder, addressed a hardcoded dtype issue, and recommended renaming an inverted boolean variable for better maintainability.

Comment on lines +1197 to +1236
auto cu_prepared =
params.cu_seqlens.has_value()
? std::optional<torch::Tensor>(
params.cu_seqlens.value().to(torch::kInt32).contiguous())
: std::nullopt;
auto g_cumsum =
npu::npu_chunk_local_cumsum(params.g, chunk_size, cu_prepared);
const float scale_value = params.scale.has_value()
? params.scale.value()
: std::pow(static_cast<float>(head_dim), -0.5f);
auto matrix_a = npu::npu_chunk_scaled_dot_kkt_fwd(
k_prepared, params.beta, g_cumsum, chunk_size, cu_prepared);
auto matrix_a_inv = npu::npu_solve_tril(
matrix_a, chunk_size, cu_prepared, params.k.scalar_type());
auto [w, u] = npu::npu_recompute_w_u_fwd(
k_prepared, params.v, params.beta, g_cumsum, matrix_a_inv, cu_prepared);
auto init_state_prepared =
params.initial_state.has_value()
? std::optional<torch::Tensor>(
params.initial_state.value().to(torch::kBFloat16).contiguous())
: std::nullopt;
auto [h, v_new, final_state] = npu::tilelang::chunk_gated_delta_rule_fwd_h(
k_prepared.squeeze(0),
w.squeeze(0),
u.squeeze(0),
g_cumsum.squeeze(0),
init_state_prepared,
params.output_final_state,
chunk_size,
/*save_new_value=*/true,
cu_prepared,
/*chunk_offsets=*/std::nullopt);
auto out = npu::npu_chunk_fwd_o(q_prepared,
k_prepared,
v_new.unsqueeze(0),
h.unsqueeze(0),
g_cumsum,
scale_value,
chunk_size,
cu_prepared);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block contains several critical issues for batch_size > 1:

  1. Dimension Mismatch: squeeze(0) (lines 1219-1222) only works when batch_size == 1. For B > 1, the tensors remain 4D, which will cause a CHECK failure in the C++ wrapper (CHECK_EQ(k.dim(), 3)).
  2. Output Shape Corruption: unsqueeze(0) (lines 1231-1232) incorrectly wraps the outputs. h is already [B, NT, H, K, V], so unsqueeze(0) makes it [1, B, NT, H, K, V], which is incompatible with npu_chunk_fwd_o expectations.
  3. Missing cu_seqlens: The TileLang C++ wrapper requires cu_seqlens to be defined (it performs a CHECK(cu_seqlens.has_value())). If params.cu_seqlens is not provided, this will crash. A default linear cu_seqlens must be generated for the fixed-batch case.
  auto cu_prepared = params.cu_seqlens.has_value()
                         ? std::optional<torch::Tensor>(
                               params.cu_seqlens.value().to(torch::kInt32).contiguous())
                         : std::nullopt;
  if (!cu_prepared.has_value()) {
    cu_prepared = torch::arange(0, (batch_size + 1) * seq_len, seq_len,
                                torch::TensorOptions().dtype(torch::kInt32).device(params.q.device()));
  }
  auto g_cumsum =
      npu::npu_chunk_local_cumsum(params.g, chunk_size, cu_prepared);
  const float scale_value = params.scale.has_value()
                                ? params.scale.value()
                                : std::pow(static_cast<float>(head_dim), -0.5f);
  auto matrix_a = npu::npu_chunk_scaled_dot_kkt_fwd(
      k_prepared, params.beta, g_cumsum, chunk_size, cu_prepared);
  auto matrix_a_inv = npu::npu_solve_tril(
      matrix_a, chunk_size, cu_prepared, params.k.scalar_type());
  auto [w, u] = npu::npu_recompute_w_u_fwd(
      k_prepared, params.v, params.beta, g_cumsum, matrix_a_inv, cu_prepared);
  auto init_state_prepared =
      params.initial_state.has_value()
          ? std::optional<torch::Tensor>(
                params.initial_state.value().to(torch::kBFloat16).contiguous())
          : std::nullopt;
  auto [h, v_new, final_state] = npu::tilelang::chunk_gated_delta_rule_fwd_h(
      k_prepared.flatten(0, 1),
      w.flatten(0, 1),
      u.flatten(0, 1),
      g_cumsum.flatten(0, 1),
      init_state_prepared,
      params.output_final_state,
      chunk_size,
      /*save_new_value=*/true,
      cu_prepared,
      /*chunk_offsets=*/std::nullopt);
  auto out = npu::npu_chunk_fwd_o(q_prepared,
                                  k_prepared,
                                  v_new.view({batch_size, seq_len, num_heads_v, -1}),
                                  h,
                                  g_cumsum,
                                  scale_value,
                                  chunk_size,
                                  cu_prepared);

Comment on lines +63 to +74
def _build_chunk_gated_delta_rule_fwd_h_kernel(
H: int,
Hg: int,
K: int,
V: int,
dtype: str = DEFAULT_DTYPE,
accum_dtype: str = DEFAULT_ACCUM_DTYPE,
bt: int = COMPILE_BT,
use_g: bool = True,
store_final_state: bool = True,
save_new_value: bool = True,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function signature is missing a return type annotation, which is required by the Python style guide (Section 11).

Suggested change
def _build_chunk_gated_delta_rule_fwd_h_kernel(
H: int,
Hg: int,
K: int,
V: int,
dtype: str = DEFAULT_DTYPE,
accum_dtype: str = DEFAULT_ACCUM_DTYPE,
bt: int = COMPILE_BT,
use_g: bool = True,
store_final_state: bool = True,
save_new_value: bool = True,
):
def _build_chunk_gated_delta_rule_fwd_h_kernel(
H: int,
Hg: int,
K: int,
V: int,
dtype: str = DEFAULT_DTYPE,
accum_dtype: str = DEFAULT_ACCUM_DTYPE,
bt: int = COMPILE_BT,
use_g: bool = True,
store_final_state: bool = True,
save_new_value: bool = True,
) -> tilelang.language.prim_func:
References
  1. Type annotations are required on all function signatures (parameters and return types). (link)

save_new_value: bool = True,
):
V_half = V // 2
input_dtype = "bfloat16"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input_dtype is hardcoded to "bfloat16", which ignores the dtype argument passed to the builder. This makes the kernel builder inflexible and potentially incorrect if a different dtype (e.g., float16) is requested in the future.

Suggested change
input_dtype = "bfloat16"
input_dtype = "bfloat16" if dtype == "bf16" else dtype

) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
BT = chunk_size
USE_G = g is not None
IS_VARLEN = cu_seqlens is None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for IS_VARLEN is inverted. If cu_seqlens is None, it indicates a standard batch (fixed length), not a variable-length sequence. The subsequent code block (lines 550-563) correctly handles the fixed-length batch case by generating a linear cu_seqlens, but the variable name is misleading and could lead to maintenance errors. Please rename it to IS_FIXED_BATCH and update its usage throughout the function (lines 550, 587, 616).

Suggested change
IS_VARLEN = cu_seqlens is None
IS_FIXED_BATCH = cu_seqlens is None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant