Skip to content

fused_sigmoid_gating_tilelang tilelang adapt in qwen3.x#1465

Open
BikingNow wants to merge 6 commits into
jd-opensource:preview/qwen3.5-qwen3.6from
BikingNow:qwen_tl
Open

fused_sigmoid_gating_tilelang tilelang adapt in qwen3.x#1465
BikingNow wants to merge 6 commits into
jd-opensource:preview/qwen3.5-qwen3.6from
BikingNow:qwen_tl

Conversation

@BikingNow
Copy link
Copy Markdown

adapt fused_sigmoid_gating_tilelang in tilelang

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the fused_sigmoid_gating_delta_rule kernel for Ascend NPUs, encompassing its TileLang implementation, C++ wrapper, and integration into the ops_api. It also optimizes the fused_gdn_gating and split_qkv_rmsnorm_mrope kernels by removing unnecessary temporary buffers. The review feedback primarily addresses precision loss by suggesting that the SSM state be stored in float32 throughout the computation. Additionally, the reviewer identified several style guide violations concerning Python type annotations, logging, and the improper use of auto for primitive types in C++.

ssm_state_indices: T.Tensor([max_num_seqs], "int32"),
cu_seqlens: T.Tensor([max_num_seqs + 1], "int32"),
out: T.Tensor([total_tokens_padded, nv, dv], input_dtype),
final_state: T.Tensor([max_num_seqs, nv, dk, dv], input_dtype),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The SSM state (final_state) should be saved in high precision (accum_dtype) to avoid precision loss during the recurrent update. Currently, it is being cast to input_dtype (bf16), which can lead to significant errors in SSM scans over long sequences.

Suggested change
final_state: T.Tensor([max_num_seqs, nv, dk, dv], input_dtype),
final_state: T.Tensor([max_num_seqs, nv, dk, dv], accum_dtype),

Comment on lines +282 to +285
T.tile.cast(h_store_vec, h_vec, "CAST_RINT", vec_block_v * dk)
T.set_flag("v", "mte3", 5)
T.wait_flag("v", "mte3", 5)
T.copy(h_store_vec, final_state[seq_idx, v_head_idx, :, v_offset : v_offset + vec_block_v])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If final_state is updated to accum_dtype, the cast to input_dtype is unnecessary and should be removed to maintain precision.

Suggested change
T.tile.cast(h_store_vec, h_vec, "CAST_RINT", vec_block_v * dk)
T.set_flag("v", "mte3", 5)
T.wait_flag("v", "mte3", 5)
T.copy(h_store_vec, final_state[seq_idx, v_head_idx, :, v_offset : v_offset + vec_block_v])
T.set_flag("v", "mte3", 5)
T.wait_flag("v", "mte3", 5)
T.copy(h_vec, final_state[seq_idx, v_head_idx, :, v_offset : v_offset + vec_block_v])

Comment on lines +234 to +237
CHECK(init_state.scalar_type() == torch::kFloat32 ||
init_state.scalar_type() == torch::kBFloat16)
<< "TileLang fused_sigmoid_gating_delta_rule: init_state must be float32 "
"or bf16";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TileLang kernel is compiled with accum_dtype = "float" (float32) for init_state. Allowing kBFloat16 here will result in a type mismatch and incorrect memory access in the kernel. This check should strictly enforce kFloat32.

  CHECK_EQ(init_state.scalar_type(), torch::kFloat32)
      << "TileLang fused_sigmoid_gating_delta_rule: init_state must be float32";

Comment on lines +390 to +404
def golden(
A_log,
a,
dt_bias,
query,
key,
value,
beta,
init_state,
ssm_state_indices,
cu_seqlens,
scale=None,
use_qk_l2norm=True,
softplus_beta=1.0,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Python function signatures must include type annotations for all parameters and return types per the repository style guide.

References
  1. Type annotations are required on all function signatures (parameters and return types). (link)


torch.testing.assert_close(out.cpu(), out_golden, rtol=2e-2, atol=2e-2)
torch.testing.assert_close(final_state.cpu(), final_state_golden, rtol=2e-2, atol=2e-2)
print("Kernel Output Match!")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Do not use print() for logging. All diagnostic output must go through the shared logger at scripts/logger.py.

Suggested change
print("Kernel Output Match!")
logger.info("Kernel Output Match!")
References
  1. All Python diagnostic output MUST go through the shared logger at scripts/logger.py. Do not use print() for logging. (link)

Comment on lines +189 to +192
const auto nv = value.size(1);
const auto nk = query.size(1);
const auto dk = query.size(2);
const auto dv = value.size(2);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Do not use auto for simple/primitive types like int64_t.

Suggested change
const auto nv = value.size(1);
const auto nk = query.size(1);
const auto dk = query.size(2);
const auto dv = value.size(2);
const int64_t nv = value.size(1);
const int64_t nk = query.size(1);
const int64_t dk = query.size(2);
const int64_t dv = value.size(2);
References
  1. Do not use auto for simple/primitive types (int32_t, float, bool, std::string, etc.). (link)

const auto options = query.options();

auto out = torch::empty({query.size(0), nv, dv}, options);
auto final_state = torch::empty({compiled_n, nv, dk, dv}, options);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To maintain precision for the SSM state, final_state should be allocated as float32 if the kernel is updated to return accum_dtype.

Suggested change
auto final_state = torch::empty({compiled_n, nv, dk, dv}, options);
torch::Tensor final_state = torch::empty({compiled_n, nv, dk, dv}, options.dtype(torch::kFloat32));

Comment on lines +850 to +852
auto q = params.q;
auto k = params.k;
auto v = params.v;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Avoid using auto for torch::Tensor types to improve readability and adhere to the style guide's preference for explicit types over auto for non-complex types.

Suggested change
auto q = params.q;
auto k = params.k;
auto v = params.v;
torch::Tensor q = params.q;
torch::Tensor k = params.k;
torch::Tensor v = params.v;
References
  1. Do not use auto for simple/primitive types. auto is acceptable for complex types (iterators, lambdas, template-deduced types) but not for int32_t, float, bool, std::string, etc. (link)

Comment on lines +873 to +874
auto init_state_small = torch::index_select(
params.initial_state_source, 0, indices);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TileLang kernel expects init_state to be float32. Ensure the input is cast to kFloat32 before passing it to the wrapper.

Suggested change
auto init_state_small = torch::index_select(
params.initial_state_source, 0, indices);
torch::Tensor init_state_small = torch::index_select(
params.initial_state_source, 0, indices).to(torch::kFloat32);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant