From da09c8eedc57a067a1b488944b8d3000cd6c09ef Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Thu, 26 Feb 2026 09:38:21 +0000 Subject: [PATCH 01/13] [FEAT](tle): WIP - add tle features --- CMakeLists.txt | 2 + include/triton/Dialect/Triton/IR/TritonOps.td | 211 +++++++++++++++++- lib/Dialect/Triton/IR/Ops.cpp | 44 ++++ python/setup.py | 6 + python/setup_tools/utils/__init__.py | 2 +- python/triton/experimental/__init__.py | 1 + python/triton/experimental/tle/__init__.py | 8 + .../experimental/tle/language/__init__.py | 4 + .../tle/language/ascend/__init__.py | 21 ++ .../experimental/tle/language/ascend/core.py | 61 +++++ .../tle/language/ascend/semantic.py | 1 + .../experimental/tle/language/builder.py | 49 ++++ .../experimental/tle/language/dsa/README.md | 62 +++++ .../experimental/tle/language/dsa/__init__.py | 31 +++ .../experimental/tle/language/dsa/core.py | 168 ++++++++++++++ .../experimental/tle/language/dsa/semantic.py | 78 +++++++ python/triton/experimental/tle/src/tle_ir.cc | 111 +++++++++ .../spec/triton/compiler/code_generator.py | 12 +- .../backend/spec/triton/compiler/compiler.py | 3 +- .../backend/spec/triton/language/core.py | 4 +- .../backend/spec/triton/language/semantic.py | 8 +- third_party/ascend/python/src/ir.cc | 4 +- third_party/ascend/python/src/main.cc | 2 + 23 files changed, 880 insertions(+), 13 deletions(-) create mode 100644 python/triton/experimental/__init__.py create mode 100644 python/triton/experimental/tle/__init__.py create mode 100644 python/triton/experimental/tle/language/__init__.py create mode 100644 python/triton/experimental/tle/language/ascend/__init__.py create mode 100644 python/triton/experimental/tle/language/ascend/core.py create mode 100644 python/triton/experimental/tle/language/ascend/semantic.py create mode 100644 python/triton/experimental/tle/language/builder.py create mode 100644 python/triton/experimental/tle/language/dsa/README.md create mode 100644 python/triton/experimental/tle/language/dsa/__init__.py create mode 100644 python/triton/experimental/tle/language/dsa/core.py create mode 100644 python/triton/experimental/tle/language/dsa/semantic.py create mode 100644 python/triton/experimental/tle/src/tle_ir.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 534647a6d..e61b0f566 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -422,10 +422,12 @@ if(TRITON_BUILD_PYTHON_MODULE) elseif(FLAGTREE_BACKEND STREQUAL "ascend") set(PYTHON_ROOT_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) set(BUFFER_IR_SRC_PATH ${FLAGTREE_BACKEND_DIR}/python/triton/extension/buffer/src) + set(TLE_IR_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/experimental/tle/src) include_directories(${PYTHON_ROOT_SRC_PATH}) add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc ${BUFFER_IR_SRC_PATH}/buffer_ir.cc + ${TLE_IR_SRC_PATH}/tle_ir.cc ${PYTHON_ROOT_SRC_PATH}/passes.cc ${PYTHON_ROOT_SRC_PATH}/interpreter.cc ${PYTHON_ROOT_SRC_PATH}/llvm.cc) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 283dd9165..e50304619 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -252,12 +252,32 @@ def TT_LoadOp : TT_Op<"load", [ OptionalAttr:$padding, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict, - DefaultValuedAttr:$isVolatile + DefaultValuedAttr:$isVolatile, + DefaultValuedAttr:$optMask ); let results = (outs TT_Type:$result); let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor of pointers or a pointer to a scalar OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, @@ -1256,5 +1276,194 @@ def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< }]; } +///////////// Definitions for DSA +// +// Alloc Op +// +def TT_DSAAllocOp : TT_Op<"dsa_alloc", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined alloc operation"; + let description = [{ + `tt.dsa_alloc` triton alloc op is designed to performs memory allocation. + }]; + let arguments = ( + ins + I64ArrayAttr:$shape, + StrAttr:$layout, + StrAttr:$scope + ); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$shape `,` $layout `,` $scope attr-dict `:` `->` type($result)"; +} + +// +// Copy OP +// +def TT_DSACopyOp : TT_Op<"dsa_copy", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined copy operation"; + let description = [{ + 'tt.dsa_copy' triton copy op is designed to copy data between memory regions. + Example: + ```mlir + tt.dsa_copy %src, %dst, %shape : tensor<128xf32> + ``` + }]; + let arguments = (ins + AnyType:$src, + AnyType:$dst, + Variadic:$shape + ); + + // let builders = [ + // OpBuilder<(ins "Value":$src, "Value":$dst, "ValueRange": $shape)> + // ]; + + //assemble + let assemblyFormat = "$src `,` $dst `,` $shape attr-dict `:` type($src) `,` type($dst) `,` `[`type($shape)`]`"; +} + +// +// Add Op +// +def TT_DSAAddOp : TT_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined add operation"; + let description = [{ + `tt.dsa_add` triton dsa_add op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Sub op +// +def TT_DSASubOp : TT_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined sub operation"; + let description = [{ + `tt.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Mul op +// +def TT_DSAMulOp : TT_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined mul operation"; + let description = [{ + `tt.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Div op +// +def TT_DSADivOp : TT_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined div operation"; + let description = [{ + `tt.dsa_div` triton dsa_div op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Max op +// +def TT_DSAMaxOp : TT_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Max operation"; + let description = [{ + `tt.dsa_max` triton dsa_max op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Min op +// +def TT_DSAMinOp : TT_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Min operation"; + let description = [{ + `tt.dsa_min` triton dsa_min op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Dot op +// +def TT_DSADotOp : TT_Op<"dsa_dot", [Pure, + MemoryEffects<[MemWrite]>, + SameOperandsElementType, + DotLike]> { + let summary = "self-defined Dot operation"; + let description = [{ + $d = matrix_multiply($a, $b) + $c. + }]; + + let arguments = ( + ins + TT_PtrLike:$inA, + TT_PtrLike:$inB, + TT_PtrLike:$res, + I64ArrayAttr:$size, + DefaultValuedAttr:$initC, + DefaultValuedAttr:$traA, + DefaultValuedAttr:$traB, + DefaultValuedAttr:$enableHf32 + ); + + let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; +} + +// +// ToTensor op +// +def TT_ToTensorOp : TT_Op<"to_tensor", [Pure, MemoryEffects<[MemWrite]>, + TypesMatchWith<"result matches ptr type", "src", "result", "getPointeeType($_self)">]> { + let summary = "self-defined to_tensor operation"; + let description = [{ + `tt.to_tensor` triton to_tensor op is designed to performs the conversion from buffer to tensor. + }]; + let arguments = (ins TT_PtrLike:$src); + + let results = (outs TT_Tensor:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +// +// ToBuffer op +// +def TT_ToBufferOp : TT_Op<"to_buffer", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined to_buffer operation"; + let description = [{ + `tt.to_buffer` triton to_buffer op is designed to performs the conversion from tensor to buffer. + }]; + let arguments = (ins TT_Tensor:$src); + + let results = (outs TT_PtrLike:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} +///////////// Definition for DSA end #endif // Triton_OPS diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 294683d33..14f0e3aa7 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -46,6 +46,50 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, cache, evict, isVolatile); } +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile, optMask); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile, optMask); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile, optMask); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile, optMask); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile, bool optMask) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile, optMask); +} + void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, ArrayRef boundaryCheck, std::optional padding, CacheModifier cache, diff --git a/python/setup.py b/python/setup.py index 9880184d3..1796615ea 100644 --- a/python/setup.py +++ b/python/setup.py @@ -701,6 +701,12 @@ def get_packages(): "triton/runtime", "triton/backends", "triton/tools", + + # for tle + "triton/experimental", + "triton/experimental/tle", + "triton/experimental/tle/language", + "triton/experimental/tle/language/dsa", ] if helper.flagtree_backend and helper.flagtree_backend in helper.configs.language_extra_backends: packages.append(f"triton/language/extra/{helper.get_device_name()}") diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py index 4f93cd8cc..a0fda077b 100644 --- a/python/setup_tools/utils/__init__.py +++ b/python/setup_tools/utils/__init__.py @@ -10,7 +10,7 @@ commit_id="380b87122c88af131530903a702d5318ec59bb33", dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")), "flir": - tools.Module(name="flir", url="https://github.com/FlagTree/flir.git", + tools.Module(name="flir", url="https://github.com/flagos-ai/flir.git", dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")), } diff --git a/python/triton/experimental/__init__.py b/python/triton/experimental/__init__.py new file mode 100644 index 000000000..ef36c171c --- /dev/null +++ b/python/triton/experimental/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. \ No newline at end of file diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py new file mode 100644 index 000000000..c00f9b0e8 --- /dev/null +++ b/python/triton/experimental/tle/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from .language import dsa, ascend + +__all__ = [ + "dsa", + "ascend", +] \ No newline at end of file diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py new file mode 100644 index 000000000..a3e8b3ab8 --- /dev/null +++ b/python/triton/experimental/tle/language/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from . import dsa +from . import ascend diff --git a/python/triton/experimental/tle/language/ascend/__init__.py b/python/triton/experimental/tle/language/ascend/__init__.py new file mode 100644 index 000000000..4f5dffe80 --- /dev/null +++ b/python/triton/experimental/tle/language/ascend/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from .core import ( + ND, + NZ, + UB, + L1, + L0A, + L0B, + L0C, +) + +__all__ = [ + "ND", + "NZ", + "UB", + "L1", + "L0A", + "L0B", + "L0C", +] diff --git a/python/triton/experimental/tle/language/ascend/core.py b/python/triton/experimental/tle/language/ascend/core.py new file mode 100644 index 000000000..c649a0b52 --- /dev/null +++ b/python/triton/experimental/tle/language/ascend/core.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from triton.language.core import ( + _unwrap_if_constexpr, +) + +class layout: + ASCEND = ['ND', 'NZ'] + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in layout.ASCEND, name + + def __str__(self): + return self.name + + def codegen_name(self): + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +ND = layout('ND') +NZ = layout('NZ') + +class scope: + ASCEND = ['UB', 'L1', 'L0A', 'L0B', 'L0C'] + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in scope.ASCEND, name + + def __str__(self): + return self.name + + def codegen_name(self): + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + +UB = scope('UB') +L1 = scope('L1') +L0A = scope('L0A') +L0B = scope('L0B') +L0C = scope('L0C') diff --git a/python/triton/experimental/tle/language/ascend/semantic.py b/python/triton/experimental/tle/language/ascend/semantic.py new file mode 100644 index 000000000..36c08a4d7 --- /dev/null +++ b/python/triton/experimental/tle/language/ascend/semantic.py @@ -0,0 +1 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. \ No newline at end of file diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py new file mode 100644 index 000000000..7e8856eb8 --- /dev/null +++ b/python/triton/experimental/tle/language/builder.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +def create_dsa_method_wrapper_with_tle_builder(main_builder, delegate_builder, method_name): + delegate_method = getattr(delegate_builder, method_name) + + def wrapper(*args, **kwargs): + saved_ip = main_builder.get_insertion_point() + saved_loc = main_builder.get_loc() + delegate_builder.restore_insertion_point(saved_ip) + if saved_loc: + delegate_builder.set_loc(saved_loc) + result = delegate_method(*args, **kwargs) + main_builder.restore_insertion_point(saved_ip) + if saved_loc: + main_builder.set_loc(saved_loc) + return result + + wrapper.__name__ = method_name + wrapper.__doc__ = getattr(delegate_method, '__doc__', None) + return wrapper + + +def attach_builder_methods_with_tle_builder(main_builder, delegate_builder, method_names): + """Attach multiple methods from a delegate builder to the main builder.""" + for method_name in method_names: + wrapper = create_dsa_method_wrapper_with_tle_builder(main_builder, delegate_builder, method_name) + + if hasattr(main_builder, method_name): + raise AttributeError(f"Method '{method_name}' already exists in the main builder.") + setattr(main_builder, method_name, wrapper) + + +def setup_unified_builder_with_tle_builder(main_builder, buffer_builder): + """Set up a unified builder interface by attaching methods from specialized builders.""" + main_builder._buffer_builder = buffer_builder + buffer_methods = [ + 'create_dsa_alloc', + 'create_dsa_copy', + 'create_dsa_add', + 'create_dsa_sub', + 'create_dsa_mul', + 'create_dsa_div', + 'create_dsa_max', + 'create_dsa_min', + 'create_dsa_dot', + 'dsa_to_buffer', + 'dsa_to_tensor', + ] + attach_builder_methods_with_tle_builder(main_builder, buffer_builder, buffer_methods) \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/README.md b/python/triton/experimental/tle/language/dsa/README.md new file mode 100644 index 000000000..00cce40f3 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/README.md @@ -0,0 +1,62 @@ +# TLE (Triton Language Extension) + +TLE is a language extension for Triton that exposes on-chip memory, pipeline compile hints and the accompanying calculation operations for high-performance computing. This extension is specifically optimized for Ascend 910B devices. + +## Features + +- **On-chip Memory Management**: `tle.alloc()` - Allocate memory on UB/L1/L0C +- **Data Movement**: `tle.copy()` - Efficient bidirectional copying between memory spaces +- **compute Operations**: `tle.npu_add()` - Addition on UB +- **Pipeline Optimization**: `tle.pipeline()` - Hardware-aware pipeline iteration + +## Memory Scopes & Layouts + +- **Scopes**: `tle.UB` (UB memory), `tle.L1` (L1 memory), `tle.L0C` (L0C memory) +- **Layouts**: `tle.ND`, `tle.NZ` + +## Quick Example + +```python +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Allocate UB memory + a_ub = tle.alloc([BLOCK_SIZE], dtype=tl.float32, layout=tle.ND, scope=tle.UB) + b_ub = tle.alloc([BLOCK_SIZE], dtype=tl.float32, layout=tle.ND, scope=tle.UB) + c_ub = tle.alloc([BLOCK_SIZE], dtype=tl.float32, layout=tle.ND, scope=tle.UB) + + # Tail block processing + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + # Copy data from GM to UB + tle.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.copy(y_ptr + offsets, b_ub, [tail_size]) + + # Addition + tle.npu_add(a_ub, b_ub, c_ub) + + # Copy result back to GM + tle.copy(c_ub, output_ptr + offsets, [tail_size]) + +``` + +## Testing + +```bash +cd ascend/examples/tle/pytest_ut/ +python3 test_vec_add.py +``` + +## Learn More + +See other examples in `ascend/examples/tle/pytest_ut/`: +- `test_matmul.py` - GEMM implementation and pipeline usage \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py new file mode 100644 index 000000000..1fdd09bc7 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from .core import ( + alloc, + copy, + pipeline, + to_tensor, + to_buffer, + add, + sub, + mul, + div, + max, + min, + dot, +) + +__all__ = [ + "alloc", + "copy", + "pipeline", + "to_tensor", + "to_buffer", + "add", + "sub", + "mul", + "div", + "max", + "min", + "dot", +] diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py new file mode 100644 index 000000000..e70627e8c --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from triton.language import semantic as tl_semantic +from triton.language.core import ( + _shape_check_impl, + _constexpr_to_value, + _unwrap_if_constexpr, + builtin, + constexpr +) +from . import semantic as tle_semantic + +class range(): + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. + :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot + operation in the loop to be multi-buffered, if applicable. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. + :param warp_specialize: Enable automatic warp specialization on the loop. + The compiler will attempt to partition memory, MMA, and vector + operations in the loop into separate async partitions. This will + increase the total number of warps required by the kernel. + :param disable_licm: Tells the compiler it shouldn't hoist loop invariant + code outside the loop. This is often useful to avoid creating long liveranges + within a loop. + + Note that warp specialization is only supported on Blackwell GPUs and + only works on simple matmul loops. Support for arbitrary loops will be + expanded over time. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor + self.disallow_acc_multi_buffer = disallow_acc_multi_buffer + self.flatten = flatten + self.warp_specialize = warp_specialize + self.disable_licm = disable_licm + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + +class pipeline(range): + """ + Iterator that counts upward forever, with software pipeline semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + +@builtin +def alloc(shape, dtype, layout=None, scope=None, _builder=None): + """ + Returns a pointer for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data type of the new array, e.g., :code:`tl.float16` + :type dtype: tl.dtype + """ + shape = _shape_check_impl(shape) + dtype = _constexpr_to_value(dtype) + layout = _constexpr_to_value(layout) + scope = _constexpr_to_value(scope) + return tle_semantic.alloc(shape, dtype, layout, scope, _builder) + + +@builtin +def copy(src, dst, shape, _builder=None): + assert len(shape) != 0, f"Can't deduce copy extents from args" + + shape = _constexpr_to_value(shape) + tle_semantic.copy(src, dst, shape, _builder) + + +@builtin +def to_tensor(buffer, _builder=None): + """ + Create a tensor-like type from a buffer-like type. + + :param buffer: the input buffer-like object. + """ + return tle_semantic.to_tensor(buffer, _builder) + +@builtin +def to_buffer(src, _builder=None): + """ + Create a buffer-like type from a tensor-like type. + + :param src: the input tensor-like object. + """ + return tle_semantic.to_buffer(src, _builder) + + +@builtin +def add(input, other, result, _builder=None): + tle_semantic.add(input, other, result, _builder) + +@builtin +def sub(input, other, result, _builder=None): + tle_semantic.sub(input, other, result, _builder) + +@builtin +def mul(input, other, result, _builder=None): + tle_semantic.mul(input, other, result, _builder) + +@builtin +def div(input, other, result, _builder=None): + tle_semantic.div(input, other, result, _builder) + +@builtin +def max(input, other, result, _builder=None): + # elementwise binary vector maximum op + tle_semantic.max(input, other, result, _builder) + +@builtin +def min(input, other, result, _builder=None): + # elementwise binary vector minimum op + tle_semantic.min(input, other, result, _builder) + +@builtin +def dot(inputA, inputB, result, size, initC, a_transpose=False, b_transpose=False, enable_hf32=False, _builder=None): + initC = _constexpr_to_value(initC) + a_transpose = _constexpr_to_value(a_transpose) + b_transpose = _constexpr_to_value(b_transpose) + enable_hf32 = _constexpr_to_value(enable_hf32) + tle_semantic.dot(inputA, inputB, result, size, initC, a_transpose, b_transpose, enable_hf32, _builder) \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py new file mode 100644 index 000000000..be09a4727 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from typing import List, Optional, Union, Tuple +from triton.language import core as tl +from triton.language.semantic import ( + binary_op_type_checking_impl, +) +from triton._C.libtriton import ir + +def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + # assert value.numel.value == 1, "only accepts size-1 tensor" + if isinstance(value, tl.constexpr): + value = builder.get_int32(value) + return tl.tensor(value, dtype) + + if value.dtype.is_int(): + return tl.tensor(value.handle, dtype) + +def alloc(shape: List[tl.tensor], dtype: tl.dtype, layout, scope, builder: ir.builder) -> tl.tensor: + ret_ty = tl.block_type(dtype, shape) + return tl.tensor(builder.create_dsa_alloc(shape, str(layout), str(scope), + dtype.to_ir(builder)), ret_ty) + + +def copy(src, dst, shape: List[Union[tl.constexpr, int]], builder: ir.builder): + """ + Generate tt.copy(src, dst, shape) and return dst-like tensor. + Lowering to hivm.load/hivm.store is done in MLIR pass. + """ + shape = [scalar_constant(x, tl.int32, builder) for x in shape] + builder.create_dsa_copy(src.handle, dst.handle, [s.handle for s in shape]) + + +def to_tensor(buffer: tl.tensor, builder: ir.builder) -> tl.tensor: + if not isinstance(buffer, tl.tensor): + raise TypeError("buffer must be tensor of pointers") + + tensor_ty = buffer.type + element_ty = tensor_ty.element_ty + if not element_ty.is_ptr: + raise TypeError("The basic elements of a buffer must be pointers") + + return tl.tensor(builder.dsa_to_tensor(buffer.handle), tensor_ty) + +def to_buffer(src: tl.tensor, builder: ir.builder) -> tl.tensor: + if not isinstance(src, tl.tensor): + raise TypeError("src of to_buffer must be tensor") + + return tl.tensor(builder.dsa_to_buffer(src.handle), src.type) + +def add(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_add(input.handle, other.handle, result.handle) + +def sub(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_sub(input.handle, other.handle, result.handle) + +def mul(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_mul(input.handle, other.handle, result.handle) + +def div(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_div(input.handle, other.handle, result.handle) + +def max(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_max(input.handle, result.handle) + +def min(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_min(input.handle, other.handle, result.handle) + +def dot(inputA: tl.tensor, inputB: tl.tensor, result: tl.tensor, size: List[int], initC: bool, a_transpose: bool, b_transpose: bool, enable_hf32: bool, builder: ir.builder): + assert len(size) == 3, f"Please set the M、N、K value." + + builder.create_dsa_dot(inputA.handle, inputB.handle, result.handle, size, initC, a_transpose, b_transpose, enable_hf32) diff --git a/python/triton/experimental/tle/src/tle_ir.cc b/python/triton/experimental/tle/src/tle_ir.cc new file mode 100644 index 000000000..b1dee7c7a --- /dev/null +++ b/python/triton/experimental/tle/src/tle_ir.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#include +#include + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "ir.h" + +using namespace mlir; +namespace py = pybind11; + +struct DSAOpBuilder : public TritonOpBuilder {}; + +void init_tle_ir(py::module &&m) +{ + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "tle_builder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + // Add alloc op + .def("create_dsa_alloc", + [](TritonOpBuilder &self, std::vector &shape, + std::string &layout, std::string &scope, Type type)-> Value { + auto shapeAttr = self.getBuilder().getI64ArrayAttr(shape); + auto layoutAttr = self.getBuilder().getStringAttr(layout); + auto scopeAttr = self.getBuilder().getStringAttr(scope); + + auto ptrType = triton::PointerType::get(type, 1); + auto tensorPtrType = RankedTensorType::get(shape, ptrType); + return self.create(tensorPtrType, shapeAttr, + layoutAttr, scopeAttr); + }) + // Add copy op + .def("create_dsa_copy", + [](TritonOpBuilder &self, Value &src, Value &dst, std::vector &shape)-> void { + self.create(src, dst, shape); + }) + // Add op + .def("create_dsa_add", + [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Sub op + .def("create_dsa_sub", + [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Mul op + .def("create_dsa_mul", + [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Div op + .def("create_dsa_div", + [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Max op + .def("create_dsa_max", + [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Min op + .def("create_dsa_min", + [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Dot op + .def("create_dsa_dot", + [](TritonOpBuilder &self, Value &inA, Value &inB, Value &res, + std::vector &size, bool &initC, bool &traA, bool &traB, + bool &enable_hf32) -> void { + auto &builder = self.getBuilder(); + auto sizeAttr = builder.getI64ArrayAttr(size); + + // convert bool to boolattr. + auto initC_attr = builder.getBoolAttr(initC); + auto traA_attr = builder.getBoolAttr(traA); + auto traB_attr = builder.getBoolAttr(traB); + auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); + + self.create(inA, inB, res, sizeAttr, initC_attr, + traA_attr, traB_attr, enable_hf32_attr); + }) + // ToTensor op + .def("dsa_to_tensor", + [](TritonOpBuilder &self, Value &src) -> Value { + return self.create(src); + }) + // ToBuffer op + .def("dsa_to_buffer", + [](TritonOpBuilder &self, Value &src) -> Value { + auto srcType = src.getType(); + auto tensorTy = cast(srcType); + Type elementType = tensorTy.getElementType(); + auto ptrType = triton::PointerType::get(elementType, 1); + auto shape = tensorTy.getShape(); + auto tensorPtrType = RankedTensorType::get(shape, ptrType); + return self.create(tensorPtrType, src); + }); + +} \ No newline at end of file diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 172ba90b4..2f64a50ed 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -11,12 +11,14 @@ import triton.language.extra.cann.extension as extension from triton.extension.buffer.language import core as bl from triton.extension.buffer.language.builder import setup_unified_builder_with_buffer_builder +from triton.experimental.tle.language.builder import setup_unified_builder_with_tle_builder from .. import language -from .._C.libtriton import ir, buffer_ir +from .._C.libtriton import ir, buffer_ir, tle_ir from .._C.libtriton.ascend import ir as ascend_ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value +from ..experimental.tle import dsa from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction @@ -230,7 +232,10 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n setup_unified_builder(self.builder, self.ascend_builder) self.buffer_builder = buffer_ir.buffer_builder(context) self.buffer_builder.set_loc(file_name, begin_line, 0) + self.tle_builder = tle_ir.tle_builder(context) + self.tle_builder.set_loc(file_name, begin_line, 0) setup_unified_builder_with_buffer_builder(self.builder, self.buffer_builder) + setup_unified_builder_with_tle_builder(self.builder, self.tle_builder) # dict of functions provided by the backend. Below are the list of possible functions: # Convert custom types not natively supported on HW. @@ -967,7 +972,7 @@ def visit_For(self, node): warp_specialize = False disable_licm = False bind_sub_block = None - if IteratorClass in [language.range, extension.parallel]: + if IteratorClass in [language.range, extension.parallel, dsa.pipeline]: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -976,6 +981,9 @@ def visit_For(self, node): ub = iterator.end step = iterator.step num_stages = iterator.num_stages + if num_stages is not None and num_stages > 2: + raise AssertionError('Only `range` iterator supports num_stages <= 2') + loop_unroll_factor = iterator.loop_unroll_factor disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer flatten = iterator.flatten diff --git a/third_party/ascend/backend/spec/triton/compiler/compiler.py b/third_party/ascend/backend/spec/triton/compiler/compiler.py index cc7ba30e7..8dedbd21f 100644 --- a/third_party/ascend/backend/spec/triton/compiler/compiler.py +++ b/third_party/ascend/backend/spec/triton/compiler/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib import json -from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir +from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir, tle_ir from .._C.libtriton.ascend import ir as ascend_ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor @@ -270,6 +270,7 @@ def compile(src, target=None, options=None): context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) + tle_ir.load_dialects(context) ascend_ir.load_dialects(context) backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() diff --git a/third_party/ascend/backend/spec/triton/language/core.py b/third_party/ascend/backend/spec/triton/language/core.py index 9ed34a899..fae1912ac 100644 --- a/third_party/ascend/backend/spec/triton/language/core.py +++ b/third_party/ascend/backend/spec/triton/language/core.py @@ -1589,7 +1589,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, @builtin def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", - volatile=False, care_padding=True, _builder=None): + volatile=False, care_padding=True, mask_opt = False, _builder=None): """ Return a tensor of data whose values are loaded from memory at location defined by `pointer`: @@ -1650,7 +1650,7 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c volatile = _constexpr_to_value(volatile) care_padding = _constexpr_to_value(care_padding) return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, - volatile, care_padding, _builder) + volatile, care_padding, mask_opt, _builder) @builtin diff --git a/third_party/ascend/backend/spec/triton/language/semantic.py b/third_party/ascend/backend/spec/triton/language/semantic.py index 7be3b70db..31693177e 100644 --- a/third_party/ascend/backend/spec/triton/language/semantic.py +++ b/third_party/ascend/backend/spec/triton/language/semantic.py @@ -1119,7 +1119,7 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) -def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, builder): +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, mask_opt, builder): # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` if not ptr.type.scalar.is_ptr(): raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") @@ -1181,7 +1181,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ else: ret = tl.tensor( builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, - is_volatile), dst_ty) + is_volatile, mask_opt), dst_ty) # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load if is_bool: ret.was_bool_to_int8 = True @@ -1191,7 +1191,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, care_padding: bool, - builder: ir.builder) -> tl.tensor: + mask_opt:bool, builder: ir.builder) -> tl.tensor: # Cache, eviction and padding options cache = _str_to_load_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) @@ -1203,7 +1203,7 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], else: # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, - builder) + mask_opt, builder) def tensormap_create( diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index d52868ea6..56843a2bc 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -1309,10 +1309,10 @@ void init_triton_ir(py::module &&m) { .def("create_masked_load", [](TritonOpBuilder &self, Value &ptrs, Value &mask, std::optional &other, CacheModifier cacheModifier, - EvictionPolicy evictionPolicy, bool isVolatile) -> Value { + EvictionPolicy evictionPolicy, bool isVolatile, bool optMask) -> Value { return self.create(ptrs, mask, other.value_or(Value()), cacheModifier, evictionPolicy, - isVolatile); + isVolatile, optMask); }) .def("create_masked_store", [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, diff --git a/third_party/ascend/python/src/main.cc b/third_party/ascend/python/src/main.cc index 7664c6bda..525d7c7c8 100644 --- a/third_party/ascend/python/src/main.cc +++ b/third_party/ascend/python/src/main.cc @@ -38,6 +38,7 @@ namespace py = pybind11; void init_triton_env_vars(pybind11::module &m); void init_triton_ir(pybind11::module &&m); void init_buffer_ir(pybind11::module &&m); +void init_tle_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); @@ -50,6 +51,7 @@ PYBIND11_MODULE(libtriton, m) { init_triton_env_vars(m); init_triton_ir(m.def_submodule("ir")); init_buffer_ir(m.def_submodule("buffer_ir")); + init_tle_ir(m.def_submodule("tle_ir")); init_triton_passes(m.def_submodule("passes")); init_triton_interpreter(m.def_submodule("interpreter")); init_triton_llvm(m.def_submodule("llvm")); From c0c2da970d4977bf1e9f1fd925ddcec8ea2b81f0 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Sat, 28 Feb 2026 02:28:35 +0000 Subject: [PATCH 02/13] [FEAT]: WIP - refactor tle * move tle.ascend to tle.dsa.ascend * move tle_ir to third_party/tle/dsa * reimplement alloc/to_tensor/to_buffer reference to buffer_ir in third_party/ascend * reimplement tle.dsa.ascend scope with address_space in ascend --- CMakeLists.txt | 2 +- include/triton/Dialect/Triton/IR/TritonOps.td | 212 +----------------- lib/Dialect/Triton/IR/Ops.cpp | 44 ---- python/setup.py | 1 + python/triton/experimental/tle/__init__.py | 3 +- .../experimental/tle/language/__init__.py | 1 - .../experimental/tle/language/ascend/core.py | 61 ----- .../experimental/tle/language/builder.py | 2 +- .../experimental/tle/language/dsa/__init__.py | 4 +- .../tle/language/{ => dsa}/ascend/__init__.py | 4 - .../tle/language/dsa/ascend/core.py | 70 ++++++ .../tle/language/{ => dsa}/ascend/semantic.py | 0 .../experimental/tle/language/dsa/core.py | 133 +++++++---- .../experimental/tle/language/dsa/semantic.py | 109 +++++++-- .../experimental/tle/language/dsa/types.py | 103 +++++++++ python/triton/experimental/tle/src/tle_ir.cc | 111 --------- third_party/ascend/AscendNPU-IR | 1 + .../triton/Dialect/Triton/IR/TritonOps.td | 190 ++++++++++++++++ .../backend/spec/triton/language/core.py | 4 +- .../backend/spec/triton/language/semantic.py | 8 +- third_party/ascend/python/src/ir.cc | 4 +- third_party/tle/dsa/tle_ir.cc | 152 +++++++++++++ 22 files changed, 710 insertions(+), 509 deletions(-) delete mode 100644 python/triton/experimental/tle/language/ascend/core.py rename python/triton/experimental/tle/language/{ => dsa}/ascend/__init__.py (84%) create mode 100644 python/triton/experimental/tle/language/dsa/ascend/core.py rename python/triton/experimental/tle/language/{ => dsa}/ascend/semantic.py (100%) create mode 100644 python/triton/experimental/tle/language/dsa/types.py delete mode 100644 python/triton/experimental/tle/src/tle_ir.cc create mode 160000 third_party/ascend/AscendNPU-IR create mode 100644 third_party/tle/dsa/tle_ir.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index e61b0f566..fe7398134 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -422,7 +422,7 @@ if(TRITON_BUILD_PYTHON_MODULE) elseif(FLAGTREE_BACKEND STREQUAL "ascend") set(PYTHON_ROOT_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) set(BUFFER_IR_SRC_PATH ${FLAGTREE_BACKEND_DIR}/python/triton/extension/buffer/src) - set(TLE_IR_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/triton/experimental/tle/src) + set(TLE_IR_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/third_party/tle/dsa) include_directories(${PYTHON_ROOT_SRC_PATH}) add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index e50304619..0d843b358 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -252,32 +252,12 @@ def TT_LoadOp : TT_Op<"load", [ OptionalAttr:$padding, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict, - DefaultValuedAttr:$isVolatile, - DefaultValuedAttr:$optMask + DefaultValuedAttr:$isVolatile ); let results = (outs TT_Type:$result); let builders = [ - // A tensor of pointers or a pointer to a scalar - OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, - // A tensor pointer with boundary check and padding - OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, - "std::optional":$padding, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, - // A tensor of pointers or a pointer to a scalar with mask - OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, - // A tensor of pointers or a pointer to a scalar with mask and other - OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, - // A utility function to build the operation with all attributes - OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, - "ArrayRef":$boundaryCheck, - "std::optional":$padding, "triton::CacheModifier":$cache, - "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, - // A tensor of pointers or a pointer to a scalar OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, @@ -1276,194 +1256,4 @@ def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< }]; } -///////////// Definitions for DSA -// -// Alloc Op -// -def TT_DSAAllocOp : TT_Op<"dsa_alloc", [Pure, MemoryEffects<[MemWrite]>]> { - let summary = "self-defined alloc operation"; - let description = [{ - `tt.dsa_alloc` triton alloc op is designed to performs memory allocation. - }]; - let arguments = ( - ins - I64ArrayAttr:$shape, - StrAttr:$layout, - StrAttr:$scope - ); - - let results = (outs TT_PtrLike:$result); - - let assemblyFormat = "$shape `,` $layout `,` $scope attr-dict `:` `->` type($result)"; -} - -// -// Copy OP -// -def TT_DSACopyOp : TT_Op<"dsa_copy", [Pure, MemoryEffects<[MemWrite]>]> { - let summary = "self-defined copy operation"; - let description = [{ - 'tt.dsa_copy' triton copy op is designed to copy data between memory regions. - Example: - ```mlir - tt.dsa_copy %src, %dst, %shape : tensor<128xf32> - ``` - }]; - let arguments = (ins - AnyType:$src, - AnyType:$dst, - Variadic:$shape - ); - - // let builders = [ - // OpBuilder<(ins "Value":$src, "Value":$dst, "ValueRange": $shape)> - // ]; - - //assemble - let assemblyFormat = "$src `,` $dst `,` $shape attr-dict `:` type($src) `,` type($dst) `,` `[`type($shape)`]`"; -} - -// -// Add Op -// -def TT_DSAAddOp : TT_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined add operation"; - let description = [{ - `tt.dsa_add` triton dsa_add op is designed to performs element-wise addition. - }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Sub op -// -def TT_DSASubOp : TT_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined sub operation"; - let description = [{ - `tt.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. - }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Mul op -// -def TT_DSAMulOp : TT_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined mul operation"; - let description = [{ - `tt.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. - }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Div op -// -def TT_DSADivOp : TT_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined div operation"; - let description = [{ - `tt.dsa_div` triton dsa_div op is designed to performs element-wise addition. - }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Max op -// -def TT_DSAMaxOp : TT_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined Max operation"; - let description = [{ - `tt.dsa_max` triton dsa_max op is designed to performs element-wise addition. - }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Min op -// -def TT_DSAMinOp : TT_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined Min operation"; - let description = [{ - `tt.dsa_min` triton dsa_min op is designed to performs element-wise addition. - }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Dot op -// -def TT_DSADotOp : TT_Op<"dsa_dot", [Pure, - MemoryEffects<[MemWrite]>, - SameOperandsElementType, - DotLike]> { - let summary = "self-defined Dot operation"; - let description = [{ - $d = matrix_multiply($a, $b) + $c. - }]; - - let arguments = ( - ins - TT_PtrLike:$inA, - TT_PtrLike:$inB, - TT_PtrLike:$res, - I64ArrayAttr:$size, - DefaultValuedAttr:$initC, - DefaultValuedAttr:$traA, - DefaultValuedAttr:$traB, - DefaultValuedAttr:$enableHf32 - ); - - let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; -} - -// -// ToTensor op -// -def TT_ToTensorOp : TT_Op<"to_tensor", [Pure, MemoryEffects<[MemWrite]>, - TypesMatchWith<"result matches ptr type", "src", "result", "getPointeeType($_self)">]> { - let summary = "self-defined to_tensor operation"; - let description = [{ - `tt.to_tensor` triton to_tensor op is designed to performs the conversion from buffer to tensor. - }]; - let arguments = (ins TT_PtrLike:$src); - - let results = (outs TT_Tensor:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} - -// -// ToBuffer op -// -def TT_ToBufferOp : TT_Op<"to_buffer", [Pure, MemoryEffects<[MemWrite]>]> { - let summary = "self-defined to_buffer operation"; - let description = [{ - `tt.to_buffer` triton to_buffer op is designed to performs the conversion from tensor to buffer. - }]; - let arguments = (ins TT_Tensor:$src); - - let results = (outs TT_PtrLike:$result); - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} -///////////// Definition for DSA end - #endif // Triton_OPS diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 14f0e3aa7..294683d33 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -46,50 +46,6 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, cache, evict, isVolatile); } -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - CacheModifier cache, EvictionPolicy evict, bool isVolatile, bool optMask) { - LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, - /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, - cache, evict, isVolatile, optMask); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - ArrayRef boundaryCheck, - std::optional padding, CacheModifier cache, - EvictionPolicy evict, bool isVolatile, bool optMask) { - LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, - padding, cache, evict, isVolatile, optMask); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, CacheModifier cache, EvictionPolicy evict, - bool isVolatile, bool optMask) { - LoadOp::build(builder, state, ptr, mask, /*other=*/{}, - /*boundaryCheck=*/ArrayRef{}, - /*padding=*/std::nullopt, cache, evict, isVolatile, optMask); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, Value other, CacheModifier cache, - EvictionPolicy evict, bool isVolatile, bool optMask) { - LoadOp::build(builder, state, ptr, mask, other, - /*boundaryCheck=*/ArrayRef{}, - /*padding=*/std::nullopt, cache, evict, isVolatile, optMask); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, Value other, ArrayRef boundaryCheck, - std::optional padding, CacheModifier cache, - EvictionPolicy evict, bool isVolatile, bool optMask) { - auto paddingAttr = - padding.has_value() - ? PaddingOptionAttr::get(builder.getContext(), padding.value()) - : PaddingOptionAttr(); - LoadOp::build(builder, state, ptr, mask, other, - builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, - evict, isVolatile, optMask); -} - void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, ArrayRef boundaryCheck, std::optional padding, CacheModifier cache, diff --git a/python/setup.py b/python/setup.py index 1796615ea..dcfa04ced 100644 --- a/python/setup.py +++ b/python/setup.py @@ -710,6 +710,7 @@ def get_packages(): ] if helper.flagtree_backend and helper.flagtree_backend in helper.configs.language_extra_backends: packages.append(f"triton/language/extra/{helper.get_device_name()}") + packages.append("triton/experimental/tle/language/dsa/{helper.get_device_name()}") packages += helper.get_extra_packages() packages += get_language_extra_packages() packages += [f'triton/backends/{backend.name}' for backend in backends] diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py index c00f9b0e8..16f30b856 100644 --- a/python/triton/experimental/tle/__init__.py +++ b/python/triton/experimental/tle/__init__.py @@ -1,8 +1,7 @@ # Copyright (c) 2025 XCoreSigma Inc. All rights reserved. -from .language import dsa, ascend +from .language import dsa __all__ = [ "dsa", - "ascend", ] \ No newline at end of file diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py index a3e8b3ab8..f2cd21743 100644 --- a/python/triton/experimental/tle/language/__init__.py +++ b/python/triton/experimental/tle/language/__init__.py @@ -1,4 +1,3 @@ # Copyright (c) 2025 XCoreSigma Inc. All rights reserved. from . import dsa -from . import ascend diff --git a/python/triton/experimental/tle/language/ascend/core.py b/python/triton/experimental/tle/language/ascend/core.py deleted file mode 100644 index c649a0b52..000000000 --- a/python/triton/experimental/tle/language/ascend/core.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. - -from triton.language.core import ( - _unwrap_if_constexpr, -) - -class layout: - ASCEND = ['ND', 'NZ'] - - def __init__(self, name): - name = _unwrap_if_constexpr(name) - self.name = name - assert name in layout.ASCEND, name - - def __str__(self): - return self.name - - def codegen_name(self): - return self.name - - @property - def cache_key_part(self) -> str: - """See cache_key_part() in triton.cc.""" - return self.name - - def __repr__(self): - """Output of repr needs to be an evaluatable expression""" - return f'triton.language.{self.codegen_name()}' - - -ND = layout('ND') -NZ = layout('NZ') - -class scope: - ASCEND = ['UB', 'L1', 'L0A', 'L0B', 'L0C'] - - def __init__(self, name): - name = _unwrap_if_constexpr(name) - self.name = name - assert name in scope.ASCEND, name - - def __str__(self): - return self.name - - def codegen_name(self): - return self.name - - @property - def cache_key_part(self) -> str: - """See cache_key_part() in triton.cc.""" - return self.name - - def __repr__(self): - """Output of repr needs to be an evaluatable expression""" - return f'triton.language.{self.codegen_name()}' - -UB = scope('UB') -L1 = scope('L1') -L0A = scope('L0A') -L0B = scope('L0B') -L0C = scope('L0C') diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py index 7e8856eb8..82581c1eb 100644 --- a/python/triton/experimental/tle/language/builder.py +++ b/python/triton/experimental/tle/language/builder.py @@ -42,7 +42,7 @@ def setup_unified_builder_with_tle_builder(main_builder, buffer_builder): 'create_dsa_div', 'create_dsa_max', 'create_dsa_min', - 'create_dsa_dot', + # 'create_dsa_dot', 'dsa_to_buffer', 'dsa_to_tensor', ] diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py index 1fdd09bc7..286b1e991 100644 --- a/python/triton/experimental/tle/language/dsa/__init__.py +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -12,9 +12,10 @@ div, max, min, - dot, ) +from . import ascend + __all__ = [ "alloc", "copy", @@ -27,5 +28,4 @@ "div", "max", "min", - "dot", ] diff --git a/python/triton/experimental/tle/language/ascend/__init__.py b/python/triton/experimental/tle/language/dsa/ascend/__init__.py similarity index 84% rename from python/triton/experimental/tle/language/ascend/__init__.py rename to python/triton/experimental/tle/language/dsa/ascend/__init__.py index 4f5dffe80..e3348e629 100644 --- a/python/triton/experimental/tle/language/ascend/__init__.py +++ b/python/triton/experimental/tle/language/dsa/ascend/__init__.py @@ -1,8 +1,6 @@ # Copyright (c) 2025 XCoreSigma Inc. All rights reserved. from .core import ( - ND, - NZ, UB, L1, L0A, @@ -11,8 +9,6 @@ ) __all__ = [ - "ND", - "NZ", "UB", "L1", "L0A", diff --git a/python/triton/experimental/tle/language/dsa/ascend/core.py b/python/triton/experimental/tle/language/dsa/ascend/core.py new file mode 100644 index 000000000..c06e627b8 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/ascend/core.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from triton.language.extra.cann.extension.core import ascend_address_space + +UB = ascend_address_space.UB +L1 = ascend_address_space.L1 +L0A = ascend_address_space.L0A +L0B = ascend_address_space.L0B +L0C = ascend_address_space.L0C + + +### from triton.language.core import ( +### _unwrap_if_constexpr, +### ) +### +### class layout: +### ASCEND = ['ND', 'NZ'] +### +### def __init__(self, name): +### name = _unwrap_if_constexpr(name) +### self.name = name +### assert name in layout.ASCEND, name +### +### def __str__(self): +### return self.name +### +### def codegen_name(self): +### return self.name +### +### @property +### def cache_key_part(self) -> str: +### """See cache_key_part() in triton.cc.""" +### return self.name +### +### def __repr__(self): +### """Output of repr needs to be an evaluatable expression""" +### return f'triton.language.{self.codegen_name()}' +### +### +### ND = layout('ND') +### NZ = layout('NZ') +### +### class scope: +### ASCEND = ['UB', 'L1', 'L0A', 'L0B', 'L0C'] +### +### def __init__(self, name): +### name = _unwrap_if_constexpr(name) +### self.name = name +### assert name in scope.ASCEND, name +### +### def __str__(self): +### return self.name +### +### def codegen_name(self): +### return self.name +### +### @property +### def cache_key_part(self) -> str: +### """See cache_key_part() in triton.cc.""" +### return self.name +### +### def __repr__(self): +### """Output of repr needs to be an evaluatable expression""" +### return f'triton.language.{self.codegen_name()}' +### +### UB = scope('UB') +### L1 = scope('L1') +### L0A = scope('L0A') +### L0B = scope('L0B') +### L0C = scope('L0C') diff --git a/python/triton/experimental/tle/language/ascend/semantic.py b/python/triton/experimental/tle/language/dsa/ascend/semantic.py similarity index 100% rename from python/triton/experimental/tle/language/ascend/semantic.py rename to python/triton/experimental/tle/language/dsa/ascend/semantic.py diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py index e70627e8c..d9e683493 100644 --- a/python/triton/experimental/tle/language/dsa/core.py +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -1,6 +1,6 @@ # Copyright (c) 2025 XCoreSigma Inc. All rights reserved. -from triton.language import semantic as tl_semantic +import triton.language.core as tl from triton.language.core import ( _shape_check_impl, _constexpr_to_value, @@ -8,7 +8,15 @@ builtin, constexpr ) +from triton.language import semantic as real_semantic +from triton._C.libtriton import ir + +import importlib +from typing import List + from . import semantic as tle_semantic +from .types import address_space, buffer + class range(): """ @@ -89,21 +97,21 @@ class pipeline(range): def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) -@builtin -def alloc(shape, dtype, layout=None, scope=None, _builder=None): - """ - Returns a pointer for the given :code:`shape` and :code:`dtype`. - - :param shape: Shape of the new array, e.g., (8, 16) or (8, ) - :type shape: tuple of ints - :param dtype: Data type of the new array, e.g., :code:`tl.float16` - :type dtype: tl.dtype - """ - shape = _shape_check_impl(shape) - dtype = _constexpr_to_value(dtype) - layout = _constexpr_to_value(layout) - scope = _constexpr_to_value(scope) - return tle_semantic.alloc(shape, dtype, layout, scope, _builder) +### @builtin +### def alloc(shape, dtype, layout=None, scope=None, _builder=None): +### """ +### Returns a pointer for the given :code:`shape` and :code:`dtype`. +### +### :param shape: Shape of the new array, e.g., (8, 16) or (8, ) +### :type shape: tuple of ints +### :param dtype: Data type of the new array, e.g., :code:`tl.float16` +### :type dtype: tl.dtype +### """ +### shape = _shape_check_impl(shape) +### dtype = _constexpr_to_value(dtype) +### layout = _constexpr_to_value(layout) +### scope = _constexpr_to_value(scope) +### return tle_semantic.alloc(shape, dtype, layout, scope, _builder) @builtin @@ -114,23 +122,23 @@ def copy(src, dst, shape, _builder=None): tle_semantic.copy(src, dst, shape, _builder) -@builtin -def to_tensor(buffer, _builder=None): - """ - Create a tensor-like type from a buffer-like type. - - :param buffer: the input buffer-like object. - """ - return tle_semantic.to_tensor(buffer, _builder) - -@builtin -def to_buffer(src, _builder=None): - """ - Create a buffer-like type from a tensor-like type. - - :param src: the input tensor-like object. - """ - return tle_semantic.to_buffer(src, _builder) +### @builtin +### def to_tensor(buffer, _builder=None): +### """ +### Create a tensor-like type from a buffer-like type. +### +### :param buffer: the input buffer-like object. +### """ +### return tle_semantic.to_tensor(buffer, _builder) +### +### @builtin +### def to_buffer(src, _builder=None): +### """ +### Create a buffer-like type from a tensor-like type. +### +### :param src: the input tensor-like object. +### """ +### return tle_semantic.to_buffer(src, _builder) @builtin @@ -159,10 +167,57 @@ def min(input, other, result, _builder=None): # elementwise binary vector minimum op tle_semantic.min(input, other, result, _builder) +### @builtin +### def dot(inputA, inputB, result, size, initC, a_transpose=False, b_transpose=False, enable_hf32=False, _builder=None): +### initC = _constexpr_to_value(initC) +### a_transpose = _constexpr_to_value(a_transpose) +### b_transpose = _constexpr_to_value(b_transpose) +### enable_hf32 = _constexpr_to_value(enable_hf32) +### tle_semantic.dot(inputA, inputB, result, size, initC, a_transpose, b_transpose, enable_hf32, _builder) + + + +@builtin +def alloc(shape: List[tl.constexpr], dtype: tl.dtype, mem_addr_space: address_space = None, _builder=None) -> buffer: + """ + Allocates a region of local memory with the specified shape and type. + + :param etype: the element type of the buffer. + :type etype: tl.dtype + :param shape: A list of non-negative integers representing the shape of the buffer. + :type shape: List[tl.constexpr] + :param _address_space: (Optional) backend-specific local memory address space + :type _address_space: bl.address_space + """ + return tle_semantic.alloc(dtype, shape, mem_addr_space, _builder) + + +@builtin +def to_buffer(tensor: tl.tensor, space: address_space = None, bind_buffer: buffer = None, _builder=None) -> buffer: + """ + Convert a tensor to a buffer. + + :param tensor: the tensor to convert. + :type tensor: tl.tensor + :param space: the address space for the buffer (optional). + :type space: address_space + """ + return tle_semantic.to_buffer(tensor, space, bind_buffer, _builder) + + +@builtin +def to_tensor(memref: buffer, writable: bool = True, target_shape=None, _builder=None) -> tl.tensor: + """ + Create a tl.tensor from a bl.buffer. + + :param memref: the input bl.buffer object. + :memref type: bl.buffer + :param writable: If set true, the resultant tensor is considered "writable" during bufferization. + :type writable: bool + """ + return tle_semantic.to_tensor(memref, writable, _builder, target_shape=target_shape) + @builtin -def dot(inputA, inputB, result, size, initC, a_transpose=False, b_transpose=False, enable_hf32=False, _builder=None): - initC = _constexpr_to_value(initC) - a_transpose = _constexpr_to_value(a_transpose) - b_transpose = _constexpr_to_value(b_transpose) - enable_hf32 = _constexpr_to_value(enable_hf32) - tle_semantic.dot(inputA, inputB, result, size, initC, a_transpose, b_transpose, enable_hf32, _builder) \ No newline at end of file +def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], + _builder=None) -> buffer: + pass \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py index be09a4727..c5cfc5ca3 100644 --- a/python/triton/experimental/tle/language/dsa/semantic.py +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -6,6 +6,7 @@ binary_op_type_checking_impl, ) from triton._C.libtriton import ir +from .types import buffer, buffer_type, address_space def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: # assert value.numel.value == 1, "only accepts size-1 tensor" @@ -16,10 +17,10 @@ def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if value.dtype.is_int(): return tl.tensor(value.handle, dtype) -def alloc(shape: List[tl.tensor], dtype: tl.dtype, layout, scope, builder: ir.builder) -> tl.tensor: - ret_ty = tl.block_type(dtype, shape) - return tl.tensor(builder.create_dsa_alloc(shape, str(layout), str(scope), - dtype.to_ir(builder)), ret_ty) +### def alloc(shape: List[tl.tensor], dtype: tl.dtype, layout, scope, builder: ir.builder) -> tl.tensor: +### ret_ty = tl.block_type(dtype, shape) +### return tl.tensor(builder.create_dsa_alloc(shape, str(layout), str(scope), +### dtype.to_ir(builder)), ret_ty) def copy(src, dst, shape: List[Union[tl.constexpr, int]], builder: ir.builder): @@ -31,22 +32,22 @@ def copy(src, dst, shape: List[Union[tl.constexpr, int]], builder: ir.builder): builder.create_dsa_copy(src.handle, dst.handle, [s.handle for s in shape]) -def to_tensor(buffer: tl.tensor, builder: ir.builder) -> tl.tensor: - if not isinstance(buffer, tl.tensor): - raise TypeError("buffer must be tensor of pointers") - - tensor_ty = buffer.type - element_ty = tensor_ty.element_ty - if not element_ty.is_ptr: - raise TypeError("The basic elements of a buffer must be pointers") - - return tl.tensor(builder.dsa_to_tensor(buffer.handle), tensor_ty) - -def to_buffer(src: tl.tensor, builder: ir.builder) -> tl.tensor: - if not isinstance(src, tl.tensor): - raise TypeError("src of to_buffer must be tensor") - - return tl.tensor(builder.dsa_to_buffer(src.handle), src.type) +### def to_tensor(buffer: tl.tensor, builder: ir.builder) -> tl.tensor: +### if not isinstance(buffer, tl.tensor): +### raise TypeError("buffer must be tensor of pointers") +### +### tensor_ty = buffer.type +### element_ty = tensor_ty.element_ty +### if not element_ty.is_ptr: +### raise TypeError("The basic elements of a buffer must be pointers") +### +### return tl.tensor(builder.dsa_to_tensor(buffer.handle), tensor_ty) +### +### def to_buffer(src: tl.tensor, builder: ir.builder) -> tl.tensor: +### if not isinstance(src, tl.tensor): +### raise TypeError("src of to_buffer must be tensor") +### +### return tl.tensor(builder.dsa_to_buffer(src.handle), src.type) def add(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) @@ -72,7 +73,67 @@ def min(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.build input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_min(input.handle, other.handle, result.handle) -def dot(inputA: tl.tensor, inputB: tl.tensor, result: tl.tensor, size: List[int], initC: bool, a_transpose: bool, b_transpose: bool, enable_hf32: bool, builder: ir.builder): - assert len(size) == 3, f"Please set the M、N、K value." - - builder.create_dsa_dot(inputA.handle, inputB.handle, result.handle, size, initC, a_transpose, b_transpose, enable_hf32) +### def dot(inputA: tl.tensor, inputB: tl.tensor, result: tl.tensor, size: List[int], initC: bool, a_transpose: bool, b_transpose: bool, enable_hf32: bool, builder: ir.builder): +### assert len(size) == 3, f"Please set the M、N、K value." +### +### builder.create_dsa_dot(inputA.handle, inputB.handle, result.handle, size, initC, a_transpose, b_transpose, enable_hf32) + +def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: address_space, + builder: ir.builder) -> buffer: + shape = tl._unwrap_shape(shape) + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + etype = tl._constexpr_to_value(etype) + address_space = tl._constexpr_to_value(address_space) + element_ty_ir = etype.to_ir(builder) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) + memref_ty = builder.get_buffer_ty(shape, element_ty_ir, addr_space_attr) + handle = builder.create_dsa_alloc(memref_ty) + buffer_ty = buffer_type(element_ty=etype, shape=shape, space=address_space) + return buffer(handle, buffer_ty) + + +def to_buffer( + tensor: tl.tensor, + address_space: address_space, + bind_buffer: buffer, + builder: ir.builder, +) -> buffer: + if not isinstance(tensor.shape, (tuple, list)) or not tensor.shape: + raise TypeError("scalar type cannot be converted to buffer") + if isinstance(bind_buffer, buffer): + builder.create_bind_buffer(tensor.handle, bind_buffer.handle) + return bind_buffer + if not (bind_buffer is None): + raise ValueError("bind_buffer must be a buffer or None") + address_space = tl._constexpr_to_value(address_space) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) + handle = builder.dsa_to_buffer(tensor.handle, addr_space_attr) + buffer_ty = buffer_type(element_ty=tensor.dtype, shape=tensor.shape, space=address_space) + return buffer(handle, buffer_ty) + + +def to_tensor(memref: buffer, writable: bool, builder: ir.builder, target_shape=None) -> tl.tensor: + if not isinstance(memref, buffer): + raise TypeError("memref must be buffer") + + need_convert_layout = False + shape = memref.shape + if target_shape: + need_convert_layout = True + shape = tl._unwrap_shape(target_shape) + assert shape != memref.shape, "target shape is the same as source shape" + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + tensor_type = tl.block_type(memref.dtype, shape) + + memref_value = memref.handle + if need_convert_layout: + buffer_ty = buffer_type( + element_ty=memref.dtype, + shape=shape, + space=memref.space, + ) + memref_value = builder.create_convert_layout(memref_value, buffer_ty.to_ir(builder)) + + return tl.tensor(builder.dsa_to_tensor(memref_value, writable), tensor_type) \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/types.py b/python/triton/experimental/tle/language/dsa/types.py new file mode 100644 index 000000000..60fe3c3bb --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/types.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +from triton._C.libtriton import ir + +from typing import List +import triton.language.core as tl +from triton.language.core import builtin + +class address_space: + """Represents a buffer's address space. + + The :code:`address_space` of a buffer is a target-specific attribute. + """ + + def to_ir(self, builder: ir.builder) -> ir.type: + raise NotImplementedError("Abstract address_space cannot be converted to ir") + + +class buffer_type(tl.dtype): + + def __init__(self, element_ty: tl.dtype, shape: List, space: address_space = None, strides: List = None): + self.element_ty = element_ty + self.shape = shape if isinstance(shape, list) else list(shape) + self.space = space + self.strides = strides if strides is not None else [] + self.name = self._make_name() + + def _make_name(self): + res = '' + + def to_ir(self, builder: ir.builder) -> ir.type: + element_ty_ir = self.element_ty.to_ir(builder) + addr_space_attr = self.space.to_ir(builder) if self.space else builder.get_null_attr() + + # use the method with strides if strides is not empty + if self.strides: + return builder.get_buffer_ty_with_strides(self.shape, element_ty_ir, self.strides, addr_space_attr) + else: + return builder.get_buffer_ty(self.shape, element_ty_ir, addr_space_attr) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def __eq__(self, other) -> bool: + if not isinstance(other, buffer_type): + return False + return (self.element_ty == other.element_ty and self.shape == other.shape and self.space == other.space + and self.strides == other.strides) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +# ----------------------- +# buffer +# ----------------------- + + +class buffer(tl._value): + """Represents a region of memory. + + :code:`buffer` is the fundamental data structure for Triton programs using + the buffer language extension. Most functions in + :py:mod:`triton.extension.buffer.language` operate on and return buffers. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, buffer_ty: buffer_type): + """Not called by user code.""" + super().__init__(handle) + self.type = buffer_ty + self.dtype = buffer_ty.element_ty.scalar + self.shape = buffer_ty.shape + self.space = buffer_ty.space + self.strides = buffer_ty.strides + + def __str__(self) -> str: + # ex. "<16x32xfloat32, address_space>" + res = '<' + 'x'.join(str(s) for s in self.shape) + 'x' + str(self.dtype) + if self.space: + res += ', ' + str(self.space) + return res + '>' diff --git a/python/triton/experimental/tle/src/tle_ir.cc b/python/triton/experimental/tle/src/tle_ir.cc deleted file mode 100644 index b1dee7c7a..000000000 --- a/python/triton/experimental/tle/src/tle_ir.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. - -#include -#include - -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" - -#include "ir.h" - -using namespace mlir; -namespace py = pybind11; - -struct DSAOpBuilder : public TritonOpBuilder {}; - -void init_tle_ir(py::module &&m) -{ - m.def("load_dialects", [](MLIRContext &context) { - DialectRegistry registry; - registry.insert(); - context.appendDialectRegistry(registry); - context.loadAllAvailableDialects(); - }); - - py::class_(m, "tle_builder", py::module_local(), py::dynamic_attr()) - .def(py::init()) - // Add alloc op - .def("create_dsa_alloc", - [](TritonOpBuilder &self, std::vector &shape, - std::string &layout, std::string &scope, Type type)-> Value { - auto shapeAttr = self.getBuilder().getI64ArrayAttr(shape); - auto layoutAttr = self.getBuilder().getStringAttr(layout); - auto scopeAttr = self.getBuilder().getStringAttr(scope); - - auto ptrType = triton::PointerType::get(type, 1); - auto tensorPtrType = RankedTensorType::get(shape, ptrType); - return self.create(tensorPtrType, shapeAttr, - layoutAttr, scopeAttr); - }) - // Add copy op - .def("create_dsa_copy", - [](TritonOpBuilder &self, Value &src, Value &dst, std::vector &shape)-> void { - self.create(src, dst, shape); - }) - // Add op - .def("create_dsa_add", - [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Sub op - .def("create_dsa_sub", - [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Mul op - .def("create_dsa_mul", - [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Div op - .def("create_dsa_div", - [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Max op - .def("create_dsa_max", - [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Min op - .def("create_dsa_min", - [](TritonOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Dot op - .def("create_dsa_dot", - [](TritonOpBuilder &self, Value &inA, Value &inB, Value &res, - std::vector &size, bool &initC, bool &traA, bool &traB, - bool &enable_hf32) -> void { - auto &builder = self.getBuilder(); - auto sizeAttr = builder.getI64ArrayAttr(size); - - // convert bool to boolattr. - auto initC_attr = builder.getBoolAttr(initC); - auto traA_attr = builder.getBoolAttr(traA); - auto traB_attr = builder.getBoolAttr(traB); - auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); - - self.create(inA, inB, res, sizeAttr, initC_attr, - traA_attr, traB_attr, enable_hf32_attr); - }) - // ToTensor op - .def("dsa_to_tensor", - [](TritonOpBuilder &self, Value &src) -> Value { - return self.create(src); - }) - // ToBuffer op - .def("dsa_to_buffer", - [](TritonOpBuilder &self, Value &src) -> Value { - auto srcType = src.getType(); - auto tensorTy = cast(srcType); - Type elementType = tensorTy.getElementType(); - auto ptrType = triton::PointerType::get(elementType, 1); - auto shape = tensorTy.getShape(); - auto tensorPtrType = RankedTensorType::get(shape, ptrType); - return self.create(tensorPtrType, src); - }); - -} \ No newline at end of file diff --git a/third_party/ascend/AscendNPU-IR b/third_party/ascend/AscendNPU-IR new file mode 160000 index 000000000..5a3921f87 --- /dev/null +++ b/third_party/ascend/AscendNPU-IR @@ -0,0 +1 @@ +Subproject commit 5a3921f87197bad7f4c8037648c9935f205fae35 diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td index f024cfa81..dd2301621 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1380,5 +1380,195 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpIn let hasVerifier = 1; } +///////////// Definitions for DSA +/// // +/// // Alloc Op +/// // +/// def TT_DSAAllocOp : TT_Op<"dsa_alloc", [Pure, MemoryEffects<[MemWrite]>]> { +/// let summary = "self-defined alloc operation"; +/// let description = [{ +/// `tt.dsa_alloc` triton alloc op is designed to performs memory allocation. +/// }]; +/// let arguments = ( +/// ins +/// I64ArrayAttr:$shape, +/// StrAttr:$layout, +/// StrAttr:$scope +/// ); +/// +/// let results = (outs TT_PtrLike:$result); +/// +/// let assemblyFormat = "$shape `,` $layout `,` $scope attr-dict `:` `->` type($result)"; +/// } + +// +// Copy OP +// +def TT_DSACopyOp : TT_Op<"dsa_copy", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined copy operation"; + let description = [{ + 'tt.dsa_copy' triton copy op is designed to copy data between memory regions. + Example: + ```mlir + tt.dsa_copy %src, %dst, %shape : tensor<128xf32> + ``` + }]; + let arguments = (ins + AnyType:$src, + AnyType:$dst, + Variadic:$shape + ); + + // let builders = [ + // OpBuilder<(ins "Value":$src, "Value":$dst, "ValueRange": $shape)> + // ]; + + //assemble + let assemblyFormat = "$src `,` $dst `,` $shape attr-dict `:` type($src) `,` type($dst) `,` `[`type($shape)`]`"; +} + +// +// Add Op +// +def TT_DSAAddOp : TT_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined add operation"; + let description = [{ + `tt.dsa_add` triton dsa_add op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Sub op +// +def TT_DSASubOp : TT_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined sub operation"; + let description = [{ + `tt.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Mul op +// +def TT_DSAMulOp : TT_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined mul operation"; + let description = [{ + `tt.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Div op +// +def TT_DSADivOp : TT_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined div operation"; + let description = [{ + `tt.dsa_div` triton dsa_div op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Max op +// +def TT_DSAMaxOp : TT_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Max operation"; + let description = [{ + `tt.dsa_max` triton dsa_max op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Min op +// +def TT_DSAMinOp : TT_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Min operation"; + let description = [{ + `tt.dsa_min` triton dsa_min op is designed to performs element-wise addition. + }]; + let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +/// // +/// // Dot op +/// // +/// def TT_DSADotOp : TT_Op<"dsa_dot", [Pure, +/// MemoryEffects<[MemWrite]>, +/// SameOperandsElementType, +/// DotLike]> { +/// let summary = "self-defined Dot operation"; +/// let description = [{ +/// $d = matrix_multiply($a, $b) + $c. +/// }]; +/// +/// let arguments = ( +/// ins +/// TT_PtrLike:$inA, +/// TT_PtrLike:$inB, +/// TT_PtrLike:$res, +/// I64ArrayAttr:$size, +/// DefaultValuedAttr:$initC, +/// DefaultValuedAttr:$traA, +/// DefaultValuedAttr:$traB, +/// DefaultValuedAttr:$enableHf32 +/// ); +/// +/// let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; +/// } +/// +/// // +/// // ToTensor op +/// // +/// def TT_ToTensorOp : TT_Op<"to_tensor", [Pure, MemoryEffects<[MemWrite]>, +/// TypesMatchWith<"result matches ptr type", "src", "result", "getPointeeType($_self)">]> { +/// let summary = "self-defined to_tensor operation"; +/// let description = [{ +/// `tt.to_tensor` triton to_tensor op is designed to performs the conversion from buffer to tensor. +/// }]; +/// let arguments = (ins TT_PtrLike:$src); +/// +/// let results = (outs TT_Tensor:$result); +/// +/// let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +/// } +/// +/// // +/// // ToBuffer op +/// // +/// def TT_ToBufferOp : TT_Op<"to_buffer", [Pure, MemoryEffects<[MemWrite]>]> { +/// let summary = "self-defined to_buffer operation"; +/// let description = [{ +/// `tt.to_buffer` triton to_buffer op is designed to performs the conversion from tensor to buffer. +/// }]; +/// let arguments = (ins TT_Tensor:$src); +/// +/// let results = (outs TT_PtrLike:$result); +/// +/// let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +/// } +///////////// Definition for DSA ends + #endif // Triton_OPS diff --git a/third_party/ascend/backend/spec/triton/language/core.py b/third_party/ascend/backend/spec/triton/language/core.py index fae1912ac..9ed34a899 100644 --- a/third_party/ascend/backend/spec/triton/language/core.py +++ b/third_party/ascend/backend/spec/triton/language/core.py @@ -1589,7 +1589,7 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, @builtin def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", - volatile=False, care_padding=True, mask_opt = False, _builder=None): + volatile=False, care_padding=True, _builder=None): """ Return a tensor of data whose values are loaded from memory at location defined by `pointer`: @@ -1650,7 +1650,7 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c volatile = _constexpr_to_value(volatile) care_padding = _constexpr_to_value(care_padding) return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, - volatile, care_padding, mask_opt, _builder) + volatile, care_padding, _builder) @builtin diff --git a/third_party/ascend/backend/spec/triton/language/semantic.py b/third_party/ascend/backend/spec/triton/language/semantic.py index 31693177e..7be3b70db 100644 --- a/third_party/ascend/backend/spec/triton/language/semantic.py +++ b/third_party/ascend/backend/spec/triton/language/semantic.py @@ -1119,7 +1119,7 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) -def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, mask_opt, builder): +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, builder): # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` if not ptr.type.scalar.is_ptr(): raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") @@ -1181,7 +1181,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ else: ret = tl.tensor( builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, - is_volatile, mask_opt), dst_ty) + is_volatile), dst_ty) # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load if is_bool: ret.was_bool_to_int8 = True @@ -1191,7 +1191,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, care_padding: bool, - mask_opt:bool, builder: ir.builder) -> tl.tensor: + builder: ir.builder) -> tl.tensor: # Cache, eviction and padding options cache = _str_to_load_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) @@ -1203,7 +1203,7 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], else: # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, - mask_opt, builder) + builder) def tensormap_create( diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index 56843a2bc..d52868ea6 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -1309,10 +1309,10 @@ void init_triton_ir(py::module &&m) { .def("create_masked_load", [](TritonOpBuilder &self, Value &ptrs, Value &mask, std::optional &other, CacheModifier cacheModifier, - EvictionPolicy evictionPolicy, bool isVolatile, bool optMask) -> Value { + EvictionPolicy evictionPolicy, bool isVolatile) -> Value { return self.create(ptrs, mask, other.value_or(Value()), cacheModifier, evictionPolicy, - isVolatile, optMask); + isVolatile); }) .def("create_masked_store", [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc new file mode 100644 index 000000000..5cb4f5a06 --- /dev/null +++ b/third_party/tle/dsa/tle_ir.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#include +#include + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "ir.h" + +using namespace mlir; +namespace py = pybind11; + +constexpr unsigned kIntegerAttrBitWidth = 64; + +struct DSAOpBuilder : public TritonOpBuilder {}; + +void init_tle_ir(py::module &&m) +{ + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_(m, "tle_builder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + // Add alloc op + /// .def("create_dsa_alloc", + /// [](DSAOpBuilder &self, std::vector &shape, + /// std::string &layout, std::string &scope, Type type)-> Value { + /// auto shapeAttr = self.getBuilder().getI64ArrayAttr(shape); + /// auto layoutAttr = self.getBuilder().getStringAttr(layout); + /// auto scopeAttr = self.getBuilder().getStringAttr(scope); + + /// auto ptrType = triton::PointerType::get(type, 1); + /// auto tensorPtrType = RankedTensorType::get(shape, ptrType); + /// return self.create(tensorPtrType, shapeAttr, + /// layoutAttr, scopeAttr); + /// }) + // Add copy op + .def("create_dsa_alloc", + [](DSAOpBuilder &self, Type memrefType) -> Value { + return self.create(mlir::cast(memrefType)); + }) + .def("create_dsa_copy", + [](DSAOpBuilder &self, Value &src, Value &dst, std::vector &shape)-> void { + self.create(src, dst, shape); + }) + // Add op + .def("create_dsa_add", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Sub op + .def("create_dsa_sub", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Mul op + .def("create_dsa_mul", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Div op + .def("create_dsa_div", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Max op + .def("create_dsa_max", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Min op + .def("create_dsa_min", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Dot op + /// .def("create_dsa_dot", + /// [](DSAOpBuilder &self, Value &inA, Value &inB, Value &res, + /// std::vector &size, bool &initC, bool &traA, bool &traB, + /// bool &enable_hf32) -> void { + /// auto &builder = self.getBuilder(); + /// auto sizeAttr = builder.getI64ArrayAttr(size); + + /// // convert bool to boolattr. + /// auto initC_attr = builder.getBoolAttr(initC); + /// auto traA_attr = builder.getBoolAttr(traA); + /// auto traB_attr = builder.getBoolAttr(traB); + /// auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); + + /// self.create(inA, inB, res, sizeAttr, initC_attr, + /// traA_attr, traB_attr, enable_hf32_attr); + /// }) + /// // ToTensor op + /// .def("dsa_to_tensor", + /// [](DSAOpBuilder &self, Value &src) -> Value { + /// return self.create(src); + /// }) + /// // ToBuffer op + /// .def("dsa_to_buffer", + /// [](DSAOpBuilder &self, Value &src) -> Value { + /// auto srcType = src.getType(); + /// auto tensorTy = cast(srcType); + /// Type elementType = tensorTy.getElementType(); + /// auto ptrType = triton::PointerType::get(elementType, 1); + /// auto shape = tensorTy.getShape(); + /// auto tensorPtrType = RankedTensorType::get(shape, ptrType); + /// return self.create(tensorPtrType, src); + /// }) + .def("dsa_to_buffer", + [](DSAOpBuilder &self, Value &src, + const Attribute &addressSpace) -> Value { + auto tensorType = dyn_cast(src.getType()); + if (!tensorType) { + llvm::report_fatal_error("to_buffer: src must be tensor type"); + } + auto memrefType = MemRefType::get( + tensorType.getShape(), tensorType.getElementType(), + MemRefLayoutAttrInterface{}, addressSpace); + return self.create(memrefType, src); + }) + .def("dsa_to_tensor", + [](DSAOpBuilder &self, Value &src, bool writable) -> Value { + const auto &memrefType = mlir::cast(src.getType()); + auto hasAddressSpace = memrefType.getMemorySpace(); + if (hasAddressSpace) { + return self.create( + self.create( + MemRefType::get(memrefType.getShape(), + memrefType.getElementType(), + memrefType.getLayout()), + src), + true, writable); + } + return self.create(src, true, writable); + }) + ; + +} \ No newline at end of file From 7e339c29eadb1e62e4c6f2561fd64db4f7c6b70f Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Sat, 28 Feb 2026 10:31:18 +0000 Subject: [PATCH 03/13] [FEAT](tle): support add, sub, mul, div, max, min in tle.dsaf --- python/test/tle/test_bind_buffer.py | 38 ++++++++ python/test/tle/test_vec_add.py | 52 ++++++++++ python/test/tle/test_vec_add_2d.py | 74 ++++++++++++++ python/test/tle/test_vec_mathOps.py | 77 +++++++++++++++ .../experimental/tle/language/builder.py | 3 + .../experimental/tle/language/dsa/core.py | 96 +++++++++++-------- .../experimental/tle/language/dsa/semantic.py | 12 +-- .../experimental/tle/language/dsa/types.py | 4 +- .../triton/Dialect/Triton/IR/TritonOps.td | 44 ++------- third_party/tle/dsa/tle_ir.cc | 16 ++++ 10 files changed, 331 insertions(+), 85 deletions(-) create mode 100644 python/test/tle/test_bind_buffer.py create mode 100755 python/test/tle/test_vec_add.py create mode 100755 python/test/tle/test_vec_add_2d.py create mode 100755 python/test/tle/test_vec_mathOps.py diff --git a/python/test/tle/test_bind_buffer.py b/python/test/tle/test_bind_buffer.py new file mode 100644 index 000000000..f1c12c7b3 --- /dev/null +++ b/python/test/tle/test_bind_buffer.py @@ -0,0 +1,38 @@ +import triton +import triton.experimental.tle as tle +import triton.language as tl + +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir, tle_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + tle_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + +@triton.jit +def bind_buffer(): + # tle.dsa.ascend.UB is triton.language.extra.extension.cann.core.ascend_address_space.UB + buffer1 = tle.dsa.alloc(shape=[32, 32], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + tle.dsa.to_tensor(buffer1, writable=True) + +if __name__ == "__main__": + print("=" * 60) + mlir = compile_kernel(bind_buffer, {}, {}) + print(mlir) diff --git a/python/test/tle/test_vec_add.py b/python/test/tle/test_vec_add.py new file mode 100755 index 000000000..1960dc95a --- /dev/null +++ b/python/test/tle/test_vec_add.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +import torch +import triton +import triton.language as tl +# import triton.language.extra.tle.ascend as tle +import triton.experimental.tle as tle + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # mem_addr_space is language.extra.cann.core.ascend_address_space + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + tle.dsa.add(a_ub, b_ub, c_ub) + tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + +def custom_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) + return output + +def test_add(): + torch.manual_seed(0) + size = 1024 + x = torch.rand(size, device='npu', dtype=torch.float) + y = torch.rand(size, device='npu', dtype=torch.float) + output_torch = x + y + output_triton = custom_func(x, y) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + +if __name__ == "__main__": + test_add() diff --git a/python/test/tle/test_vec_add_2d.py b/python/test/tle/test_vec_add_2d.py new file mode 100755 index 000000000..debdb2780 --- /dev/null +++ b/python/test/tle/test_vec_add_2d.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +import torch +import triton +import triton.language as tl +import triton.experimental.tle as tle + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + n_cols, n_rows, + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # 2D offsets + block_start_m = pid_m * BLOCK_SIZE + block_start_n = pid_n * BLOCK_SIZE + offs_m = block_start_m + tl.arange(0, BLOCK_SIZE) + offs_n = block_start_n + tl.arange(0, BLOCK_SIZE) + + # 计算线性地址(row-major) + x_ptrs = x_ptr + offs_m[:, None] * n_cols + offs_n[None, :] + y_ptrs = y_ptr + offs_m[:, None] * n_cols + offs_n[None, :] + out_ptrs = output_ptr + offs_m[:, None] * n_cols + offs_n[None, :] + + a_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start_m + t1 = n_elements - block_start_n + tail_size_m = tl.minimum(t0, BLOCK_SIZE) + tail_size_n = tl.minimum(t1, BLOCK_SIZE) + + tle.dsa.copy(x_ptrs, a_ub, [tail_size_m, tail_size_n]) + tle.dsa.copy(y_ptrs, b_ub, [tail_size_m, tail_size_n]) + + tle.dsa.add(a_ub, b_ub, c_ub) + + tle.dsa.copy(c_ub, out_ptrs, [tail_size_m, tail_size_n]) + +def custom_func(x: torch.Tensor, y: torch.Tensor, size: int): + output = torch.empty_like(x) + n_elements = output.numel() + BLOCK_SIZE = 16 + # grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE)) + add_kernel[grid](x, y, output, n_elements, size, size-1, BLOCK_SIZE) + return output + +def test_add(): + torch.manual_seed(0) + size = 128 + x = torch.rand((size,size-1), device='npu', dtype=torch.float) + y = torch.rand((size,size-1), device='npu', dtype=torch.float) + output_torch = x + y + output_triton = custom_func(x, y, size) + print(f"============X===========") + print(x) + print(f"============Y===========") + print(y) + print(f"============outTorch===========") + print(output_torch) + print(f"============outTriton===========") + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + +if __name__ == "__main__": + test_add() diff --git a/python/test/tle/test_vec_mathOps.py b/python/test/tle/test_vec_mathOps.py new file mode 100755 index 000000000..2d161d433 --- /dev/null +++ b/python/test/tle/test_vec_mathOps.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +from typing import Callable, Tuple +import torch +import triton +import triton.language as tl +import triton.experimental.tle as tle + +@triton.jit +def run_test( + x_ptr, y_ptr, output_ptr, n_elements, + OP_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + tle.dsa.copy(x_ptr + offsets, a_ub, [BLOCK_SIZE]) + tle.dsa.copy(y_ptr + offsets, b_ub, [BLOCK_SIZE]) + + if OP_ID == 0: # add + tle.dsa.add(a_ub, b_ub, c_ub) + elif OP_ID == 1: # sub + tle.dsa.sub(a_ub, b_ub, c_ub) + elif OP_ID == 2: # mul + tle.dsa.mul(a_ub, b_ub, c_ub) + elif OP_ID == 3: # div + tle.dsa.div(a_ub, b_ub, c_ub) + + tle.dsa.copy(c_ub, output_ptr + offsets, [BLOCK_SIZE]) + +OP_REGISTRY = { + 'add': (0, torch.add), + 'sub': (1, torch.sub), + 'mul': (2, torch.mul), + 'div': (3, torch.div), +} + +def common_test(op_name: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if op_name not in OP_REGISTRY: + raise ValueError(f"Unsupported op: {op_name}") + + op_id, _ = OP_REGISTRY[op_name] + output = torch.empty_like(x) + n_elements = output.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + run_test[grid]( + x, y, output, n_elements, + OP_ID=op_id, + BLOCK_SIZE=128, + ) + return output + +def test_binary_op(size: int = 1024, dtype=torch.float32, device='npu'): + x = torch.rand(size, device=device, dtype=dtype) + y = torch.rand(size, device=device, dtype=dtype) + y = y + 0.1 + + print(f"Testing {len(OP_REGISTRY)} operators with size={size}, dtype={dtype}") + + for op_name in OP_REGISTRY: + torch_fn = OP_REGISTRY[op_name][1] + triton_out = common_test(op_name, x, y) + torch_out = torch_fn(x, y) + + max_diff = torch.max(torch.abs(torch_out - triton_out)).item() + status = "SUCCESS" if max_diff < 1e-5 else "FAIL" + print(f"{status} {op_name:8}: max diff = {max_diff:.2e}") + +if __name__ == "__main__": + test_binary_op(size=1024, dtype=torch.float32) diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py index 82581c1eb..f9a1f2e9e 100644 --- a/python/triton/experimental/tle/language/builder.py +++ b/python/triton/experimental/tle/language/builder.py @@ -45,5 +45,8 @@ def setup_unified_builder_with_tle_builder(main_builder, buffer_builder): # 'create_dsa_dot', 'dsa_to_buffer', 'dsa_to_tensor', + 'dsa_get_null_attr', + 'dsa_get_buffer_type', + 'dsa_get_buffer_type_with_strides', ] attach_builder_methods_with_tle_builder(main_builder, buffer_builder, buffer_methods) \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py index d9e683493..60e9eb07d 100644 --- a/python/triton/experimental/tle/language/dsa/core.py +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -2,21 +2,46 @@ import triton.language.core as tl from triton.language.core import ( - _shape_check_impl, _constexpr_to_value, - _unwrap_if_constexpr, - builtin, constexpr ) -from triton.language import semantic as real_semantic from triton._C.libtriton import ir -import importlib -from typing import List +from typing import List, TypeVar +from functools import wraps from . import semantic as tle_semantic from .types import address_space, buffer +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +TLE_BUILTIN = "__tle_builtin__" + +def builtin(fn: T) -> T: + """ + Decorator for builtin functions to mark a function as a tle language builtin function. + """ + assert callable + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, TLE_BUILTIN, True) + + return wrapper + +def is_builtin(fn) -> bool: + """ + Returns whether a function is a builtin function. + """ + return getattr(fn, TLE_BUILTIN, False) + class range(): """ @@ -97,22 +122,14 @@ class pipeline(range): def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) -### @builtin -### def alloc(shape, dtype, layout=None, scope=None, _builder=None): -### """ -### Returns a pointer for the given :code:`shape` and :code:`dtype`. -### -### :param shape: Shape of the new array, e.g., (8, 16) or (8, ) -### :type shape: tuple of ints -### :param dtype: Data type of the new array, e.g., :code:`tl.float16` -### :type dtype: tl.dtype -### """ -### shape = _shape_check_impl(shape) -### dtype = _constexpr_to_value(dtype) -### layout = _constexpr_to_value(layout) -### scope = _constexpr_to_value(scope) -### return tle_semantic.alloc(shape, dtype, layout, scope, _builder) +@builtin +def from_buffer_to_tensor_pointer(src: buffer, _builder=None) -> tl.tensor: + buffer_ty = src.type + ele_type = buffer_ty.element_ty + shape = buffer_ty.shape + block_type = tl.block_type(ele_type, shape) + return tl.tensor(src.handle, block_type) @builtin def copy(src, dst, shape, _builder=None): @@ -122,49 +139,48 @@ def copy(src, dst, shape, _builder=None): tle_semantic.copy(src, dst, shape, _builder) -### @builtin -### def to_tensor(buffer, _builder=None): -### """ -### Create a tensor-like type from a buffer-like type. -### -### :param buffer: the input buffer-like object. -### """ -### return tle_semantic.to_tensor(buffer, _builder) -### -### @builtin -### def to_buffer(src, _builder=None): -### """ -### Create a buffer-like type from a tensor-like type. -### -### :param src: the input tensor-like object. -### """ -### return tle_semantic.to_buffer(src, _builder) - - @builtin def add(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.add(input, other, result, _builder) @builtin def sub(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.sub(input, other, result, _builder) @builtin def mul(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.mul(input, other, result, _builder) @builtin def div(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.div(input, other, result, _builder) @builtin def max(input, other, result, _builder=None): # elementwise binary vector maximum op + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.max(input, other, result, _builder) @builtin def min(input, other, result, _builder=None): # elementwise binary vector minimum op + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.min(input, other, result, _builder) ### @builtin diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py index c5cfc5ca3..d34561185 100644 --- a/python/triton/experimental/tle/language/dsa/semantic.py +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -86,8 +86,8 @@ def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: address_spa etype = tl._constexpr_to_value(etype) address_space = tl._constexpr_to_value(address_space) element_ty_ir = etype.to_ir(builder) - addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) - memref_ty = builder.get_buffer_ty(shape, element_ty_ir, addr_space_attr) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.dsa_get_null_attr()) + memref_ty = builder.dsa_get_buffer_type(shape, element_ty_ir, addr_space_attr) handle = builder.create_dsa_alloc(memref_ty) buffer_ty = buffer_type(element_ty=etype, shape=shape, space=address_space) return buffer(handle, buffer_ty) @@ -101,13 +101,13 @@ def to_buffer( ) -> buffer: if not isinstance(tensor.shape, (tuple, list)) or not tensor.shape: raise TypeError("scalar type cannot be converted to buffer") - if isinstance(bind_buffer, buffer): - builder.create_bind_buffer(tensor.handle, bind_buffer.handle) - return bind_buffer + # if isinstance(bind_buffer, buffer): + # builder.create_bind_buffer(tensor.handle, bind_buffer.handle) + # return bind_buffer if not (bind_buffer is None): raise ValueError("bind_buffer must be a buffer or None") address_space = tl._constexpr_to_value(address_space) - addr_space_attr = (address_space.to_ir(builder) if address_space else builder.get_null_attr()) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.dsa_get_null_attr()) handle = builder.dsa_to_buffer(tensor.handle, addr_space_attr) buffer_ty = buffer_type(element_ty=tensor.dtype, shape=tensor.shape, space=address_space) return buffer(handle, buffer_ty) diff --git a/python/triton/experimental/tle/language/dsa/types.py b/python/triton/experimental/tle/language/dsa/types.py index 60fe3c3bb..7c2c11a53 100644 --- a/python/triton/experimental/tle/language/dsa/types.py +++ b/python/triton/experimental/tle/language/dsa/types.py @@ -39,9 +39,9 @@ def to_ir(self, builder: ir.builder) -> ir.type: # use the method with strides if strides is not empty if self.strides: - return builder.get_buffer_ty_with_strides(self.shape, element_ty_ir, self.strides, addr_space_attr) + return builder.dsa_get_buffer_type_with_strides(self.shape, element_ty_ir, self.strides, addr_space_attr) else: - return builder.get_buffer_ty(self.shape, element_ty_ir, addr_space_attr) + return builder.dsa_get_buffer_ty(self.shape, element_ty_ir, addr_space_attr) def __str__(self): return self.name diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td index dd2301621..2aa6a6045 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1436,7 +1436,7 @@ def TT_DSAAddOp : TT_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, let description = [{ `tt.dsa_add` triton dsa_add op is designed to performs element-wise addition. }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; } @@ -1450,7 +1450,7 @@ def TT_DSASubOp : TT_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, let description = [{ `tt.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; } @@ -1464,7 +1464,7 @@ def TT_DSAMulOp : TT_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, let description = [{ `tt.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; } @@ -1478,7 +1478,7 @@ def TT_DSADivOp : TT_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, let description = [{ `tt.dsa_div` triton dsa_div op is designed to performs element-wise addition. }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; } @@ -1492,7 +1492,7 @@ def TT_DSAMaxOp : TT_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, let description = [{ `tt.dsa_max` triton dsa_max op is designed to performs element-wise addition. }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; } @@ -1506,7 +1506,7 @@ def TT_DSAMinOp : TT_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, let description = [{ `tt.dsa_min` triton dsa_min op is designed to performs element-wise addition. }]; - let arguments = (ins TT_PtrLike:$lhs, TT_PtrLike:$rhs, TT_PtrLike:$res); + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; } @@ -1537,37 +1537,7 @@ def TT_DSAMinOp : TT_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, /// /// let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; /// } -/// -/// // -/// // ToTensor op -/// // -/// def TT_ToTensorOp : TT_Op<"to_tensor", [Pure, MemoryEffects<[MemWrite]>, -/// TypesMatchWith<"result matches ptr type", "src", "result", "getPointeeType($_self)">]> { -/// let summary = "self-defined to_tensor operation"; -/// let description = [{ -/// `tt.to_tensor` triton to_tensor op is designed to performs the conversion from buffer to tensor. -/// }]; -/// let arguments = (ins TT_PtrLike:$src); -/// -/// let results = (outs TT_Tensor:$result); -/// -/// let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -/// } -/// -/// // -/// // ToBuffer op -/// // -/// def TT_ToBufferOp : TT_Op<"to_buffer", [Pure, MemoryEffects<[MemWrite]>]> { -/// let summary = "self-defined to_buffer operation"; -/// let description = [{ -/// `tt.to_buffer` triton to_buffer op is designed to performs the conversion from tensor to buffer. -/// }]; -/// let arguments = (ins TT_Tensor:$src); -/// -/// let results = (outs TT_PtrLike:$result); -/// -/// let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -/// } + ///////////// Definition for DSA ends diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc index 5cb4f5a06..b63ad1a18 100644 --- a/third_party/tle/dsa/tle_ir.cc +++ b/third_party/tle/dsa/tle_ir.cc @@ -49,6 +49,22 @@ void init_tle_ir(py::module &&m) /// layoutAttr, scopeAttr); /// }) // Add copy op + .def("dsa_get_null_attr", [](DSAOpBuilder &self) { return Attribute(); }) + .def("dsa_get_buffer_type", + [](DSAOpBuilder &self, std::vector &shape, + Type &elementType, const Attribute &memorySpace) -> Type { + return MemRefType::get(shape, elementType, + MemRefLayoutAttrInterface{}, memorySpace); + }) + .def("dsa_get_buffer_type_with_strides", + [](TritonOpBuilder &self, std::vector &shape, + Type &elementType, const std::vector &strides, + const Attribute &memorySpace) -> Type { + // create a layout with strides, using dynamic offset + auto layout = StridedLayoutAttr::get( + self.getBuilder().getContext(), ShapedType::kDynamic, strides); + return MemRefType::get(shape, elementType, layout, memorySpace); + }) .def("create_dsa_alloc", [](DSAOpBuilder &self, Type memrefType) -> Value { return self.create(mlir::cast(memrefType)); From 3e4241a0fb251cc2886521ae35294b71bcb14778 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Tue, 3 Mar 2026 03:24:34 +0000 Subject: [PATCH 04/13] [FIX](tle): fix to_tensor in test_add_vec_mix.py * [FIX] remove memory_space_cast in dsa_to_tensor because the op removes the memory space attribute and result in compiling errors * [TESTING] add collect_single method in ascend/testing.py to preserve the original benchmark statistics --- python/test/tle/test_vec_add.py | 7 ++++ python/test/tle/test_vec_add_mix.py | 60 +++++++++++++++++++++++++++ third_party/ascend/backend/testing.py | 45 +++++++++++++++++++- third_party/tle/dsa/tle_ir.cc | 7 +--- 4 files changed, 111 insertions(+), 8 deletions(-) create mode 100755 python/test/tle/test_vec_add_mix.py diff --git a/python/test/tle/test_vec_add.py b/python/test/tle/test_vec_add.py index 1960dc95a..44a2f0af8 100755 --- a/python/test/tle/test_vec_add.py +++ b/python/test/tle/test_vec_add.py @@ -48,5 +48,12 @@ def test_add(): print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(output_torch - output_triton))}') + from triton.backends.ascend.testing import do_bench_npu + bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False) + bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False) + # 保留两位小数输出 + print(f"torch time : {bench_torch:.2f}") + print(f"triton time: {bench_triton:.2f}") + if __name__ == "__main__": test_add() diff --git a/python/test/tle/test_vec_add_mix.py b/python/test/tle/test_vec_add_mix.py new file mode 100755 index 000000000..2a5512cca --- /dev/null +++ b/python/test/tle/test_vec_add_mix.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +import torch +import triton +import triton.language as tl +import triton.experimental.tle as tle + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + tle.dsa.add(a_ub, b_ub, c_ub) + + c_val = tle.dsa.to_tensor(c_ub) + b_val = tle.dsa.to_tensor(b_ub) + + result = c_val - b_val + + #tl.store(output_ptr + offsets, result) + + d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB) + tle.dsa.copy(d_ub, output_ptr + offsets, [tail_size]) + + +def custom_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) + return output + +def test_add(): + torch.manual_seed(0) + size = 1024 + x = torch.rand(size, device='npu', dtype=torch.float) + y = torch.rand(size, device='npu', dtype=torch.float) + output_torch = x + output_triton = custom_func(x, y) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + +if __name__ == "__main__": + test_add() diff --git a/third_party/ascend/backend/testing.py b/third_party/ascend/backend/testing.py index 97e36ca2f..f763be03c 100644 --- a/third_party/ascend/backend/testing.py +++ b/third_party/ascend/backend/testing.py @@ -26,7 +26,7 @@ import triton.runtime as runtime -def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None, keep_res=False): +def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None, keep_res=False, collect_prof=True): import torch import torch_npu @@ -80,10 +80,51 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None if clear_l2_cache: del buffer - time_cost = _collect_prof_result(torch_path, funcs, warmup, active) + if collect_prof: + time_cost = _collect_prof_result(torch_path, funcs, warmup, active) + else: + time_cost = _collect_single(torch_path) _rm_dic(keep_res, torch_path) return time_cost +# keep the original behavior to get the statistics for the specified kernel func +def _collect_single(base_dir: str, key: str = None) -> float: + if not os.path.exists(base_dir): + return float("inf") + + import pandas as pd + + for root, _, files in os.walk(base_dir): + for file in files: + if file != "op_statistic.csv": + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + print(df) + if key is not None: + key_rows = df[df["OP Type"].str.startswith(key, na=False)] + if not key_rows.empty: + return key_rows["Avg Time(us)"].values[0] + return float("inf") + else: + # default: read the first row except header + # return df.loc[0, "Avg Time(us)"] + # default: extract ZerosLike time (L2 cache clear operation) + filter_cond = df["OP Type"].str.contains(r"^ZerosLike$", case=False, na=False) + filter_df = df[filter_cond] + if not filter_df.empty: + zeroslike_time = filter_df.iloc[0]['Avg Time(us)'] + print("Clear L2 cache time:", zeroslike_time) + + # Calculate total time of all operators excluding ZerosLike + non_zeroslike_df = df[~df["OP Type"].str.contains(r"^ZerosLike$", case=False, na=False)] + all_ops_total_time = non_zeroslike_df['Avg Time(us)'].sum() + all_ops_total_time = round(all_ops_total_time, 3) + print("All ops total time:", all_ops_total_time) + + return all_ops_total_time + + return float("inf") def _rm_dic(keep_res, torch_path): if keep_res: diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc index b63ad1a18..4aaeb3f71 100644 --- a/third_party/tle/dsa/tle_ir.cc +++ b/third_party/tle/dsa/tle_ir.cc @@ -154,12 +154,7 @@ void init_tle_ir(py::module &&m) auto hasAddressSpace = memrefType.getMemorySpace(); if (hasAddressSpace) { return self.create( - self.create( - MemRefType::get(memrefType.getShape(), - memrefType.getElementType(), - memrefType.getLayout()), - src), - true, writable); + src, true, writable); } return self.create(src, true, writable); }) From 34120fcc007267fe66118874163e3fc4865a7774 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Wed, 4 Mar 2026 02:27:05 +0000 Subject: [PATCH 05/13] [FEAT](tle) add hint, subview, extract_slice, extrace_element in tle.dsa --- .../experimental/tle/language/dsa/__init__.py | 12 ++ .../experimental/tle/language/dsa/core.py | 154 ++++++++++++++- .../experimental/tle/language/dsa/semantic.py | 82 +++++++- .../spec/triton/compiler/code_generator.py | 33 +++- third_party/tle/dsa/tle_ir.cc | 177 ++++++++++++++++-- 5 files changed, 424 insertions(+), 34 deletions(-) diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py index 286b1e991..0e8c7b0ee 100644 --- a/python/triton/experimental/tle/language/dsa/__init__.py +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -4,6 +4,7 @@ alloc, copy, pipeline, + parallel, to_tensor, to_buffer, add, @@ -12,6 +13,11 @@ div, max, min, + hint, + extract_slice, + insert_slice, + extract_element, + subview, ) from . import ascend @@ -20,6 +26,7 @@ "alloc", "copy", "pipeline", + "parallel", "to_tensor", "to_buffer", "add", @@ -28,4 +35,9 @@ "div", "max", "min", + "hint", + "extract_slice", + "insert_slice", + "extract_element", + "subview", ] diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py index 60e9eb07d..ebc2a5ff2 100644 --- a/python/triton/experimental/tle/language/dsa/core.py +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -1,8 +1,10 @@ # Copyright (c) 2025 XCoreSigma Inc. All rights reserved. import triton.language.core as tl +from triton.language import semantic as tl_semantic from triton.language.core import ( _constexpr_to_value, + tensor, constexpr ) from triton._C.libtriton import ir @@ -123,6 +125,18 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) +class parallel(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it indicates that there are no dependencies between loop iterations, + allowing them to be executed in parallel. + """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + + @builtin def from_buffer_to_tensor_pointer(src: buffer, _builder=None) -> tl.tensor: buffer_ty = src.type @@ -132,11 +146,16 @@ def from_buffer_to_tensor_pointer(src: buffer, _builder=None) -> tl.tensor: return tl.tensor(src.handle, block_type) @builtin -def copy(src, dst, shape, _builder=None): +def copy(src, dst, shape, inter_no_alias=False, _builder=None): + """Copy data from `src` to `dst` shaped by `shape`. + + :param inter_no_alias: If True, the copy is annotated as no aliasing between different iterations. + """ assert len(shape) != 0, f"Can't deduce copy extents from args" shape = _constexpr_to_value(shape) - tle_semantic.copy(src, dst, shape, _builder) + inter_no_alias = _constexpr_to_value(inter_no_alias) + tle_semantic.copy(src, dst, shape, inter_no_alias, _builder) @builtin @@ -194,7 +213,7 @@ def min(input, other, result, _builder=None): @builtin -def alloc(shape: List[tl.constexpr], dtype: tl.dtype, mem_addr_space: address_space = None, _builder=None) -> buffer: +def alloc(shape: List[tl.constexpr], dtype: tl.dtype, mem_addr_space: address_space, _builder=None) -> buffer: """ Allocates a region of local memory with the specified shape and type. @@ -205,6 +224,7 @@ def alloc(shape: List[tl.constexpr], dtype: tl.dtype, mem_addr_space: address_sp :param _address_space: (Optional) backend-specific local memory address space :type _address_space: bl.address_space """ + assert (mem_addr_space is not None) return tle_semantic.alloc(dtype, shape, mem_addr_space, _builder) @@ -236,4 +256,130 @@ def to_tensor(memref: buffer, writable: bool = True, target_shape=None, _builder @builtin def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], _builder=None) -> buffer: - pass \ No newline at end of file + ''' + Creates a subview of the source buffer with the specified offsets, sizes, and strides. + + :param src: The source buffer to create a subview from. + :type src: buffer + :param offsets: A list of non-negative integers representing the offsets in each dimension. + :type offsets: List[tl.constexpr] + :param sizes: A list of non-negative integers representing the sizes in each dimension. + :type sizes: List[tl.constexpr] + :param strides: A list of non-negative integers representing the strides in each dimension. + :type strides: List[tl.constexpr] + :return: A new buffer representing the subview of the source buffer. + :rtype: buffer + ''' + # Validate that sizes and strides contain only constexpr values + new_sizes = [] + for i, size in enumerate(sizes): + if isinstance(size, int): + # Convert regular integers to constexpr + new_sizes.append(tl.constexpr(size)) + elif isinstance(size, tl.constexpr): + new_sizes.append(size) + else: + raise TypeError(f"sizes[{i}] must be constexpr, got {type(size).__name__}") + + new_strides = [] + for i, stride in enumerate(strides): + if isinstance(stride, int): + # Convert regular integers to constexpr + new_strides.append(tl.constexpr(stride)) + elif isinstance(stride, tl.constexpr): + new_strides.append(stride) + else: + raise TypeError(f"strides[{i}] must be constexpr, got {type(stride).__name__}") + + new_offsets = [] + for offset in offsets: + if isinstance(offset, tl.constexpr): + # Check that constexpr offset values cannot be negative + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(tl_semantic.to_tensor(offset, _builder)) + elif isinstance(offset, int): + # Convert regular integers to constexpr and then to tensor + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(tl_semantic.to_tensor(tl.constexpr(offset), _builder)) + else: + # Assume it's already a tensor + new_offsets.append(offset) + + return tle_semantic.subview(src, new_offsets, new_sizes, new_strides, _builder) + +def hint(**kwargs): + """Dummy function for AST parsing. Not executed during JIT compilation.""" + raise RuntimeError("tle.hint() cannot be called directly.") + + +@builtin +def insert_slice(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], _builder=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + new_offsets = [ + tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + out = tle_semantic.insert_slice(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + new_offsets = [ + tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o + for o in offsets + ] + sub = tle_semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) + return sub + + +@builtin +def extract_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + assert len(src.shape) > 0 + new_indice = [ + tl_semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i + for i in indice + ] + return tle_semantic.extract_element(src, new_indice, _builder) \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py index d34561185..47f85f42f 100644 --- a/python/triton/experimental/tle/language/dsa/semantic.py +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -8,6 +8,14 @@ from triton._C.libtriton import ir from .types import buffer, buffer_type, address_space +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: # assert value.numel.value == 1, "only accepts size-1 tensor" if isinstance(value, tl.constexpr): @@ -17,19 +25,14 @@ def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if value.dtype.is_int(): return tl.tensor(value.handle, dtype) -### def alloc(shape: List[tl.tensor], dtype: tl.dtype, layout, scope, builder: ir.builder) -> tl.tensor: -### ret_ty = tl.block_type(dtype, shape) -### return tl.tensor(builder.create_dsa_alloc(shape, str(layout), str(scope), -### dtype.to_ir(builder)), ret_ty) - -def copy(src, dst, shape: List[Union[tl.constexpr, int]], builder: ir.builder): +def copy(src, dst, shape: List[Union[tl.constexpr, int]], inter_no_alias: bool, builder: ir.builder): """ Generate tt.copy(src, dst, shape) and return dst-like tensor. Lowering to hivm.load/hivm.store is done in MLIR pass. """ shape = [scalar_constant(x, tl.int32, builder) for x in shape] - builder.create_dsa_copy(src.handle, dst.handle, [s.handle for s in shape]) + builder.create_dsa_copy(src.handle, dst.handle, [s.handle for s in shape], inter_no_alias) ### def to_tensor(buffer: tl.tensor, builder: ir.builder) -> tl.tensor: @@ -136,4 +139,67 @@ def to_tensor(memref: buffer, writable: bool, builder: ir.builder, target_shape= ) memref_value = builder.create_convert_layout(memref_value, buffer_ty.to_ir(builder)) - return tl.tensor(builder.dsa_to_tensor(memref_value, writable), tensor_type) \ No newline at end of file + return tl.tensor(builder.dsa_to_tensor(memref_value, writable), tensor_type) + + +def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_dsa_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: + assert(len(ful.shape) == len(offsets)) + assert(len(ful.shape) == len(sizes)) + assert(len(ful.shape) == len(strides)) + assert(all([s>=1 for s in sizes])) + assert(all([s>=0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_dsa_extract_slice(ful.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + +def extract_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + new_indice = [i.handle for i in indice] + result = builder.create_dsa_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + + +def subview(src: buffer, offsets: List[tl.tensor], sizes: List[tl.constexpr], strides: List[tl.constexpr], + builder: ir.builder) -> buffer: + + new_offsets = [offset.handle for offset in offsets] + sizes_int = tl._unwrap_shape(sizes) + strides_int = tl._unwrap_shape(strides) + + result_handle = builder.create_dsa_subview(src.handle, new_offsets, sizes_int, strides_int) + + # calculate the memory layout strides of the source buffer + if src.strides: + # use the strides of the source buffer + src_memory_strides = src.strides + else: + # calculate the default row-major strides + src_memory_strides = [] + stride = 1 + for dim_size in reversed(src.shape): + if dim_size < 0: + raise ValueError("Cannot compute strides for buffer with dynamic dimensions") + src_memory_strides.insert(0, stride) + stride *= dim_size + + result_memory_strides = [] + for src_stride, subview_stride in zip(src_memory_strides, strides_int): + result_memory_strides.append(src_stride * subview_stride) + + # create buffer_type with strides + buffer_ty = buffer_type(element_ty=src.dtype, shape=sizes_int, space=src.space, strides=result_memory_strides) + return buffer(result_handle, buffer_ty) \ No newline at end of file diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 2f64a50ed..19daa1c0d 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -276,6 +276,9 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # Are we currently visiting an ast.arg's default value? These have some # special handling. self.visiting_arg_default_value = False + # Stack to keep track of active `with`-hints (e.g., tle.hint(...)) + # Each entry is a dict mapping hint names to literal values. + self._with_hints = [] builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} builtin_namespace.update(( @@ -812,13 +815,27 @@ def visit_With(self, node): # Check if context is a Call and dispatch to registered handler if isinstance(context, ast.Call): + # TODO[FIXME]: This is a hack to support `with hint(...)`, maybe should be handled in a better way like scope handler + if isinstance(context.func, ast.Attribute) and context.func.attr == "hint": + hints = {} + for kw in context.keywords: + if not isinstance(kw.value, ast.Constant): + raise self._unsupported(node, "keyword arguments to hint() are only supported for constant values") + hints[kw.arg] = kw.value.value + self._with_hints.append(hints) + withitemClass = self.visit(context.func) handler = WITH_DISPATCH.get(withitemClass) if handler: return handler(self, node) # Fall back to visiting body for unhandled cases - return self.visit_compound_statement(node.body) + try: + self.visit_compound_statement(node.body) + finally: + self._with_hints.pop() + + return def visit_Compare(self, node): if not (len(node.comparators) == 1 and len(node.ops) == 1): @@ -972,7 +989,7 @@ def visit_For(self, node): warp_specialize = False disable_licm = False bind_sub_block = None - if IteratorClass in [language.range, extension.parallel, dsa.pipeline]: + if IteratorClass in [language.range, extension.parallel, dsa.pipeline, dsa.parallel]: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -1071,7 +1088,7 @@ def visit_For(self, node): for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr()) if disable_licm: for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) - if (IteratorClass is extension.parallel): + if (IteratorClass is extension.parallel or IteratorClass is dsa.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) self.scf_stack.append(node) @@ -1197,6 +1214,16 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: + # Honor hints coming from an enclosing `with ... hint(...)` block. + # For example, `with tle.hint(inter_no_alias=True): tle.copy(...)` + # should behave like `tle.copy(..., inter_no_alias=True)` when the + # keyword isn't explicitly provided on the call site. + if self._with_hints: + # Only apply to some builtins; currently, 'copy' is relevant. + if fn.__name__ == 'copy': + top_hints = self._with_hints[-1] + if 'inter_no_alias' in top_hints and 'inter_no_alias' not in kws: + kws['inter_no_alias'] = top_hints['inter_no_alias'] ret = fn(*args, **extra_kwargs, **kws) # Sync the builder's location before return. ip, last_loc = self._get_insertion_point_and_loc(_builder) diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc index 4aaeb3f71..81fe42ab5 100644 --- a/third_party/tle/dsa/tle_ir.cc +++ b/third_party/tle/dsa/tle_ir.cc @@ -57,7 +57,7 @@ void init_tle_ir(py::module &&m) MemRefLayoutAttrInterface{}, memorySpace); }) .def("dsa_get_buffer_type_with_strides", - [](TritonOpBuilder &self, std::vector &shape, + [](DSAOpBuilder &self, std::vector &shape, Type &elementType, const std::vector &strides, const Attribute &memorySpace) -> Type { // create a layout with strides, using dynamic offset @@ -70,8 +70,11 @@ void init_tle_ir(py::module &&m) return self.create(mlir::cast(memrefType)); }) .def("create_dsa_copy", - [](DSAOpBuilder &self, Value &src, Value &dst, std::vector &shape)-> void { - self.create(src, dst, shape); + [](DSAOpBuilder &self, Value &src, Value &dst, std::vector &shape, bool inter_no_alias)-> void { + auto copyOp = self.create(src, dst, shape); + if (inter_no_alias) { + copyOp->setAttr("inter_no_alias", self.getBuilder().getBoolAttr(true)); + } }) // Add op .def("create_dsa_add", @@ -120,22 +123,6 @@ void init_tle_ir(py::module &&m) /// self.create(inA, inB, res, sizeAttr, initC_attr, /// traA_attr, traB_attr, enable_hf32_attr); /// }) - /// // ToTensor op - /// .def("dsa_to_tensor", - /// [](DSAOpBuilder &self, Value &src) -> Value { - /// return self.create(src); - /// }) - /// // ToBuffer op - /// .def("dsa_to_buffer", - /// [](DSAOpBuilder &self, Value &src) -> Value { - /// auto srcType = src.getType(); - /// auto tensorTy = cast(srcType); - /// Type elementType = tensorTy.getElementType(); - /// auto ptrType = triton::PointerType::get(elementType, 1); - /// auto shape = tensorTy.getShape(); - /// auto tensorPtrType = RankedTensorType::get(shape, ptrType); - /// return self.create(tensorPtrType, src); - /// }) .def("dsa_to_buffer", [](DSAOpBuilder &self, Value &src, const Attribute &addressSpace) -> Value { @@ -158,6 +145,158 @@ void init_tle_ir(py::module &&m) } return self.create(src, true, writable); }) + .def("create_extract_scalar", + [](DSAOpBuilder &self, Value &src, std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) + .def("create_extract_slice", + [](DSAOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get(retSizes, + cast(ful.getType()).getElementType()); + + return self.create(retTy, ful, offsets, sizes, strides); + }) + .def("create_insert_slice", + [](DSAOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + .def("create_dsa_subview", + [](DSAOpBuilder &self, Value source, std::vector &offsets, + const std::vector &sizes, + const std::vector &strides) -> Value { + SmallVector mixedOffsets; + auto *context = self.getBuilder().getContext(); + auto &builder = self.getBuilder(); + + // Get source memref type for validation + auto sourceType = mlir::cast(source.getType()); + int64_t rank = sourceType.getRank(); + // Verify the number of parameters + if (offsets.size() != rank || sizes.size() != rank || + strides.size() != rank) { + throw std::runtime_error("Number of offsets, sizes, and strides " + "must match memref rank"); + } + + for (const auto &offset : offsets) { + auto indexType = builder.getIndexType(); + if (offset.getType() != indexType) { + Value offset_val = + self.create(indexType, offset); + mixedOffsets.push_back(offset_val); + } else { + mixedOffsets.push_back(offset); + } + } + + SmallVector mixedSizes; + SmallVector mixedStrides; + for (int64_t i = 0; i < rank; ++i) { + int64_t size = sizes[i]; + int64_t stride = strides[i]; + int64_t srcDim = sourceType.getDimSize(i); + + // verify sizes cannot be negative or zero + if (size <= 0) { + throw std::runtime_error("Expected sizes to be positive"); + } + + // verify strides cannot be negative or zero + if (stride <= 0) { + throw std::runtime_error("Expected strides to be positive"); + } + + // getDimSize() returns -1 (ShapedType::kDynamic) for dynamic + // dimensions + if (!ShapedType::isDynamic(srcDim)) { + // verify the subview size does not exceed the source dimension + if (size > srcDim) { + throw std::runtime_error( + "Subview size cannot exceed source dimension size"); + } + + // verify strides cannot exceed the source dimension size + if (stride > srcDim) { + throw std::runtime_error( + "Stride cannot exceed source dimension size"); + } + } + + mixedSizes.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), size)); + mixedStrides.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), stride)); + } + + return self.create(source, mixedOffsets, + mixedSizes, mixedStrides); + }) ; } \ No newline at end of file From 6f8678bb4fd3c1349b139a4593277dcb131b192b Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Thu, 5 Mar 2026 07:21:10 +0000 Subject: [PATCH 06/13] [REFACT](tle): decouple tle from TritonOps.td * decouple TleOps from TritonOps and mov to third_party/tle/dsa/dialect * implement the TleOp conversion in third_party/tle/dsa rather than in flir directly, flir just call the conversion in its pass --- CMakeLists.txt | 6 +- python/test/tle/test_bind_buffer.py | 2 +- .../experimental/tle/language/builder.py | 4 + .../triton/Dialect/Triton/IR/TritonOps.td | 161 ------------------ .../spec/triton/compiler/code_generator.py | 2 +- .../backend/spec/triton/compiler/compiler.py | 2 +- third_party/ascend/python/src/main.cc | 2 - third_party/tle/dsa/CMakeLists.txt | 23 +++ third_party/tle/dsa/dialect/CMakeLists.txt | 7 + .../dialect/include/Analysis/CMakeLists.txt | 1 + .../tle/dsa/dialect/include/CMakeLists.txt | 5 + .../dialect/include/Conversion/CMakeLists.txt | 1 + .../Conversion/TleToLinalg/DSACopyConverter.h | 40 +++++ .../Conversion/TleToLinalg/MathConverter.h | 117 +++++++++++++ .../tle/dsa/dialect/include/IR/CMakeLists.txt | 19 +++ .../tle/dsa/dialect/include/IR/Dialect.h | 18 ++ .../tle/dsa/dialect/include/IR/TleAttrDefs.td | 11 ++ .../tle/dsa/dialect/include/IR/TleDialect.td | 23 +++ .../tle/dsa/dialect/include/IR/TleOps.td | 154 +++++++++++++++++ .../dsa/dialect/lib/Analysis/CMakeLists.txt | 1 + .../tle/dsa/dialect/lib/CMakeLists.txt | 5 + .../dsa/dialect/lib/Conversion/CMakeLists.txt | 3 + .../lib/Conversion/TleToLinalg/CMakeLists.txt | 10 ++ .../TleToLinalg/DSACopyConverter.cpp | 117 +++++++++++++ .../Conversion/TleToLinalg/MathConverter.cpp | 24 +++ .../tle/dsa/dialect/lib/IR/CMakeLists.txt | 13 ++ .../tle/dsa/dialect/lib/IR/Dialect.cpp | 25 +++ third_party/tle/dsa/dialect/lib/IR/TleOps.cpp | 8 + third_party/tle/dsa/tle_ir.cc | 45 ++--- 29 files changed, 653 insertions(+), 196 deletions(-) create mode 100644 third_party/tle/dsa/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/include/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h create mode 100644 third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h create mode 100644 third_party/tle/dsa/dialect/include/IR/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/include/IR/Dialect.h create mode 100644 third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td create mode 100644 third_party/tle/dsa/dialect/include/IR/TleDialect.td create mode 100644 third_party/tle/dsa/dialect/include/IR/TleOps.td create mode 100644 third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/lib/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp create mode 100644 third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp create mode 100644 third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt create mode 100644 third_party/tle/dsa/dialect/lib/IR/Dialect.cpp create mode 100644 third_party/tle/dsa/dialect/lib/IR/TleOps.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index fe7398134..8518361ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -291,6 +291,10 @@ if(TRITON_BUILD_PYTHON_MODULE) add_subdirectory(third_party/proton) endif() + # add TLE plugin + list(APPEND TRITON_PLUGIN_NAMES "tle") + add_subdirectory(third_party/tle/dsa) + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) set(TRITON_LIBRARIES @@ -422,12 +426,10 @@ if(TRITON_BUILD_PYTHON_MODULE) elseif(FLAGTREE_BACKEND STREQUAL "ascend") set(PYTHON_ROOT_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src) set(BUFFER_IR_SRC_PATH ${FLAGTREE_BACKEND_DIR}/python/triton/extension/buffer/src) - set(TLE_IR_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/third_party/tle/dsa) include_directories(${PYTHON_ROOT_SRC_PATH}) add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/ir.cc ${BUFFER_IR_SRC_PATH}/buffer_ir.cc - ${TLE_IR_SRC_PATH}/tle_ir.cc ${PYTHON_ROOT_SRC_PATH}/passes.cc ${PYTHON_ROOT_SRC_PATH}/interpreter.cc ${PYTHON_ROOT_SRC_PATH}/llvm.cc) diff --git a/python/test/tle/test_bind_buffer.py b/python/test/tle/test_bind_buffer.py index f1c12c7b3..c1cafdbf1 100644 --- a/python/test/tle/test_bind_buffer.py +++ b/python/test/tle/test_bind_buffer.py @@ -4,7 +4,7 @@ from triton.compiler.compiler import ASTSource from triton.compiler.code_generator import ast_to_ttir -from triton._C.libtriton import ir, tle_ir +from triton._C.libtriton import ir, tle as tle_ir from triton._C.libtriton.ascend import ir as ascend_ir class Options: diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py index f9a1f2e9e..e824f6553 100644 --- a/python/triton/experimental/tle/language/builder.py +++ b/python/triton/experimental/tle/language/builder.py @@ -48,5 +48,9 @@ def setup_unified_builder_with_tle_builder(main_builder, buffer_builder): 'dsa_get_null_attr', 'dsa_get_buffer_type', 'dsa_get_buffer_type_with_strides', + "create_dsa_extract_scalar", + "create_dsa_extract_slice", + "create_dsa_insert_slice", + "create_dsa_subview", ] attach_builder_methods_with_tle_builder(main_builder, buffer_builder, buffer_methods) \ No newline at end of file diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td index 2aa6a6045..8cad0f1e5 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1380,165 +1380,4 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpIn let hasVerifier = 1; } -///////////// Definitions for DSA -/// // -/// // Alloc Op -/// // -/// def TT_DSAAllocOp : TT_Op<"dsa_alloc", [Pure, MemoryEffects<[MemWrite]>]> { -/// let summary = "self-defined alloc operation"; -/// let description = [{ -/// `tt.dsa_alloc` triton alloc op is designed to performs memory allocation. -/// }]; -/// let arguments = ( -/// ins -/// I64ArrayAttr:$shape, -/// StrAttr:$layout, -/// StrAttr:$scope -/// ); -/// -/// let results = (outs TT_PtrLike:$result); -/// -/// let assemblyFormat = "$shape `,` $layout `,` $scope attr-dict `:` `->` type($result)"; -/// } - -// -// Copy OP -// -def TT_DSACopyOp : TT_Op<"dsa_copy", [Pure, MemoryEffects<[MemWrite]>]> { - let summary = "self-defined copy operation"; - let description = [{ - 'tt.dsa_copy' triton copy op is designed to copy data between memory regions. - Example: - ```mlir - tt.dsa_copy %src, %dst, %shape : tensor<128xf32> - ``` - }]; - let arguments = (ins - AnyType:$src, - AnyType:$dst, - Variadic:$shape - ); - - // let builders = [ - // OpBuilder<(ins "Value":$src, "Value":$dst, "ValueRange": $shape)> - // ]; - - //assemble - let assemblyFormat = "$src `,` $dst `,` $shape attr-dict `:` type($src) `,` type($dst) `,` `[`type($shape)`]`"; -} - -// -// Add Op -// -def TT_DSAAddOp : TT_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined add operation"; - let description = [{ - `tt.dsa_add` triton dsa_add op is designed to performs element-wise addition. - }]; - let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Sub op -// -def TT_DSASubOp : TT_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined sub operation"; - let description = [{ - `tt.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. - }]; - let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Mul op -// -def TT_DSAMulOp : TT_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined mul operation"; - let description = [{ - `tt.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. - }]; - let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Div op -// -def TT_DSADivOp : TT_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined div operation"; - let description = [{ - `tt.dsa_div` triton dsa_div op is designed to performs element-wise addition. - }]; - let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Max op -// -def TT_DSAMaxOp : TT_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined Max operation"; - let description = [{ - `tt.dsa_max` triton dsa_max op is designed to performs element-wise addition. - }]; - let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -// -// Min op -// -def TT_DSAMinOp : TT_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, - SameOperandsShape, SameOperandsElementType]> { - let summary = "self-defined Min operation"; - let description = [{ - `tt.dsa_min` triton dsa_min op is designed to performs element-wise addition. - }]; - let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); - - let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; -} - -/// // -/// // Dot op -/// // -/// def TT_DSADotOp : TT_Op<"dsa_dot", [Pure, -/// MemoryEffects<[MemWrite]>, -/// SameOperandsElementType, -/// DotLike]> { -/// let summary = "self-defined Dot operation"; -/// let description = [{ -/// $d = matrix_multiply($a, $b) + $c. -/// }]; -/// -/// let arguments = ( -/// ins -/// TT_PtrLike:$inA, -/// TT_PtrLike:$inB, -/// TT_PtrLike:$res, -/// I64ArrayAttr:$size, -/// DefaultValuedAttr:$initC, -/// DefaultValuedAttr:$traA, -/// DefaultValuedAttr:$traB, -/// DefaultValuedAttr:$enableHf32 -/// ); -/// -/// let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; -/// } - -///////////// Definition for DSA ends - - #endif // Triton_OPS diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 19daa1c0d..b691e2839 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -14,7 +14,7 @@ from triton.experimental.tle.language.builder import setup_unified_builder_with_tle_builder from .. import language -from .._C.libtriton import ir, buffer_ir, tle_ir +from .._C.libtriton import ir, buffer_ir, tle as tle_ir from .._C.libtriton.ascend import ir as ascend_ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value diff --git a/third_party/ascend/backend/spec/triton/compiler/compiler.py b/third_party/ascend/backend/spec/triton/compiler/compiler.py index 8dedbd21f..d352f9233 100644 --- a/third_party/ascend/backend/spec/triton/compiler/compiler.py +++ b/third_party/ascend/backend/spec/triton/compiler/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib import json -from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir, tle_ir +from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir, tle as tle_ir from .._C.libtriton.ascend import ir as ascend_ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor diff --git a/third_party/ascend/python/src/main.cc b/third_party/ascend/python/src/main.cc index 525d7c7c8..7664c6bda 100644 --- a/third_party/ascend/python/src/main.cc +++ b/third_party/ascend/python/src/main.cc @@ -38,7 +38,6 @@ namespace py = pybind11; void init_triton_env_vars(pybind11::module &m); void init_triton_ir(pybind11::module &&m); void init_buffer_ir(pybind11::module &&m); -void init_tle_ir(pybind11::module &&m); void init_triton_llvm(pybind11::module &&m); void init_triton_interpreter(pybind11::module &&m); void init_triton_passes(pybind11::module &&m); @@ -51,7 +50,6 @@ PYBIND11_MODULE(libtriton, m) { init_triton_env_vars(m); init_triton_ir(m.def_submodule("ir")); init_buffer_ir(m.def_submodule("buffer_ir")); - init_tle_ir(m.def_submodule("tle_ir")); init_triton_passes(m.def_submodule("passes")); init_triton_interpreter(m.def_submodule("interpreter")); init_triton_llvm(m.def_submodule("llvm")); diff --git a/third_party/tle/dsa/CMakeLists.txt b/third_party/tle/dsa/CMakeLists.txt new file mode 100644 index 000000000..19b28cffe --- /dev/null +++ b/third_party/tle/dsa/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/dialect/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/dialect/include) +add_subdirectory(dialect) + +if (TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonTLE + ${CMAKE_CURRENT_SOURCE_DIR}/tle_ir.cc + + LINK_LIBS + TleIR + TritonIR + ) + + find_package(Python3 REQUIRED COMPONENTS Development Interpreter) + find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") + include_directories(${Python3_INCLUDE_DIRS}) + include_directories(${pybind11_INCLUDE_DIR}) + link_directories(${Python3_LIBRARY_DIRS}) + link_libraries(${Python3_LIBRARIES}) + add_link_options(${Python3_LINK_OPTIONS}) +endif() \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/CMakeLists.txt b/third_party/tle/dsa/dialect/CMakeLists.txt new file mode 100644 index 000000000..799695a7d --- /dev/null +++ b/third_party/tle/dsa/dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${PROJECT_SOURCE_DIR}/python/src) +add_subdirectory(include) +add_subdirectory(lib) diff --git a/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt b/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt new file mode 100644 index 000000000..87913e1bd --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt @@ -0,0 +1 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. diff --git a/third_party/tle/dsa/dialect/include/CMakeLists.txt b/third_party/tle/dsa/dialect/include/CMakeLists.txt new file mode 100644 index 000000000..181d7332c --- /dev/null +++ b/third_party/tle/dsa/dialect/include/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(IR) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt new file mode 100644 index 000000000..87913e1bd --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h new file mode 100644 index 000000000..0c6388b4a --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h @@ -0,0 +1,40 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#ifndef TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ +#define TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" + +#include "tle/dsa/dialect/include/IR/Dialect.h" + +namespace TleCopyConverter { + +using namespace mlir; + +class CopyConverter : public OpConversionPattern { + +public: + explicit CopyConverter(MLIRContext *context); + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} + +namespace mlir::triton::tle { +void populateTleCopyOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); +} + +#endif \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h new file mode 100644 index 000000000..9b950c2ef --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h @@ -0,0 +1,117 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +#ifndef TRITON_TLE_CONVERSION_MATH_CONVERTER_H +#define TRITON_TLE_CONVERSION_MATH_CONVERTER_H + +#if __has_include("bishengir/Dialect/HIVM/IR/HIVM.h") +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#endif + +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace TleMathConverter { + +using namespace mlir; + +template +class BinaryMathConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = adaptor.getRes(); + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (result.getType() != lhs.getType() || + result.getType() != rhs.getType()) { + op->emitError("Unexpected binary calculation type!"); + return failure(); + } + + auto binOp = rewriter.create( + loc, + TypeRange{}, + ValueRange{lhs, rhs}, + ValueRange{result} + ); + + rewriter.replaceOp(op, binOp); + return success(); + } +}; + +template +class UnaryMathConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(MathOp op, PatternRewriter &rewriter) const override { + } +}; + +template +class MatMulConverter : public OpConversionPattern { +public: + static constexpr llvm::StringLiteral fixpipeAlreadyInserted = + "fixpipe_already_inserted"; + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto inA = adaptor.getInA(); + auto inB = adaptor.getInB(); + auto res = adaptor.getRes(); + + auto sizeAttr = adaptor.getSize(); + if (sizeAttr.size() > 3) { + op->emitError("Unexpected matmul calculation size!"); + return failure(); + } + + auto mAttr = dyn_cast(sizeAttr[0]); + auto nAttr = dyn_cast(sizeAttr[1]); + auto kAttr = dyn_cast(sizeAttr[2]); + Value M = rewriter.create(loc, mAttr.getInt()); + Value N = rewriter.create(loc, nAttr.getInt()); + Value K = rewriter.create(loc, kAttr.getInt()); + + bool initC = adaptor.getInitC(); + auto initCValue = rewriter.create(loc, + /*value*/ initC, /*width*/ 1); + auto newOp = rewriter.create( + loc, + TypeRange{}, // result types + inA, // Matrix A + inB, // Matrix B + initCValue, // init condition + M, // M + K, // K + N, // N + res, // Matrix C + Value{}, // per channel bias + adaptor.getTraA() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose A + adaptor.getTraB() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose B + adaptor.getEnableHf32() ? rewriter.getUnitAttr() : UnitAttr{}// enable hf32 mode + ); + + newOp->setAttr(fixpipeAlreadyInserted, rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +} // namespace TleMathConverter + +namespace mlir::triton::tle { +void populateTleMathOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); +} +#endif \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt new file mode 100644 index 000000000..c1ec982de --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TleOps.td) +mlir_tablegen(TleOps.h.inc -gen-op-decls) +mlir_tablegen(TleOps.cpp.inc -gen-op-defs) +add_mlir_doc(TleOps TleOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TleDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=tle) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=tle) +add_mlir_doc(TleDialect TleDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TleAttrDefs.td) +mlir_tablegen(TleAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TleAttrDefs.cpp.inc -gen-attrdef-defs) + +add_public_tablegen_target(TleTableGen) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/include/IR/Dialect.h b/third_party/tle/dsa/dialect/include/IR/Dialect.h new file mode 100644 index 000000000..9d0f3ce0f --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/Dialect.h @@ -0,0 +1,18 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#ifndef TRITON_TLE_IR_DIALECT_H_ +#define TRITON_TLE_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "tle/dsa/dialect/include/IR/Dialect.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "tle/dsa/dialect/include/IR/TleAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "tle/dsa/dialect/include/IR/TleOps.h.inc" + +#endif \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td b/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td new file mode 100644 index 000000000..616ab0e82 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td @@ -0,0 +1,11 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#ifndef TRITON_TLE_ATTR_DEFS +#define TRITON_TLE_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "tle/dsa/dialect/include/IR/TleDialect.td" + + +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/TleDialect.td b/third_party/tle/dsa/dialect/include/IR/TleDialect.td new file mode 100644 index 000000000..8f46ab6c1 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/TleDialect.td @@ -0,0 +1,23 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#ifndef TRITON_TLE_DIALECT +#define TRITON_TLE_DIALECT + +include "mlir/IR/OpBase.td" + +def Tle_Dialect : Dialect { + let name = "tle"; + let cppNamespace = "::mlir::triton::tle"; + let description = [{ + Triton Language Extension Dialect. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "triton::TritonDialect", + ]; + + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/TleOps.td b/third_party/tle/dsa/dialect/include/IR/TleOps.td new file mode 100644 index 000000000..73d827c35 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/TleOps.td @@ -0,0 +1,154 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#ifndef TRITON_TLE_OPS +#define TRITON_TLE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/CommonTypeConstraints.td" +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "tle/dsa/dialect/include/IR/TleDialect.td" + +// +// Op Base +// +class TLE_Op traits = []>: + Op { +} + +// +// Copy OP +// +def TLE_DSACopyOp : TLE_Op<"dsa_copy", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined copy operation"; + let description = [{ + 'tle.dsa_copy' triton copy op is designed to copy data between memory regions. + Example: + ```mlir + tle.dsa_copy %src, %dst, %shape : tensor<128xf32> + ``` + }]; + let arguments = (ins + AnyType:$src, + AnyType:$dst, + Variadic:$shape + ); + + //assemble + let assemblyFormat = "$src `,` $dst `,` $shape attr-dict `:` type($src) `,` type($dst) `,` `[`type($shape)`]`"; +} + +// +// Add Op +// +def TLE_DSAAddOp : TLE_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined add operation"; + let description = [{ + `tle.dsa_add` triton dsa_add op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Sub op +// +def TLE_DSASubOp : TLE_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined sub operation"; + let description = [{ + `tle.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Mul op +// +def TLE_DSAMulOp : TLE_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined mul operation"; + let description = [{ + `tle.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Div op +// +def TLE_DSADivOp : TLE_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined div operation"; + let description = [{ + `tle.dsa_div` triton dsa_div op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Max op +// +def TLE_DSAMaxOp : TLE_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Max operation"; + let description = [{ + `tle.dsa_max` triton dsa_max op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Min op +// +def TLE_DSAMinOp : TLE_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Min operation"; + let description = [{ + `tle.dsa_min` triton dsa_min op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +/// // +/// // Dot op +/// // +/// def TLE_DSADotOp : TLE_Op<"dsa_dot", [Pure, +/// MemoryEffects<[MemWrite]>, +/// SameOperandsElementType, +/// DotLike]> { +/// let summary = "self-defined Dot operation"; +/// let description = [{ +/// $d = matrix_multiply($a, $b) + $c. +/// }]; +/// +/// let arguments = ( +/// ins +/// AnyMemRef:$inA, +/// AnyMemRef:$inB, +/// AnyMemRef:$res, +/// I64ArrayAttr:$size, +/// DefaultValuedAttr:$initC, +/// DefaultValuedAttr:$traA, +/// DefaultValuedAttr:$traB, +/// DefaultValuedAttr:$enableHf32 +/// ); +/// +/// let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; +/// } + +#endif diff --git a/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000..87913e1bd --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt @@ -0,0 +1 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. diff --git a/third_party/tle/dsa/dialect/lib/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/CMakeLists.txt new file mode 100644 index 000000000..181d7332c --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(IR) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt new file mode 100644 index 000000000..efba34782 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt @@ -0,0 +1,3 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +add_subdirectory(TleToLinalg) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt new file mode 100644 index 000000000..57675b403 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +add_triton_library(TleToLinalg + DSACopyConverter.cpp + MathConverter.cpp + + DEPENDS + TritonIR + TleTableGen +) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp new file mode 100644 index 000000000..feb3a7c23 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp @@ -0,0 +1,117 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#include "tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h" +#if __has_include("bishengir/Dialect/HIVM/IR/HIVM.h") +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#endif + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace TleCopyConverter { + +using namespace mlir; + +memref::SubViewOp makeSubViewOp(Value src, + const llvm::SmallVector &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter) { + auto srcType = cast(src.getType()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, dyn_cast(dstType), + src, offsets, sizes, strides); +} + +CopyConverter::CopyConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +LogicalResult +CopyConverter::matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto dst = adaptor.getDst(); + auto loc = op.getLoc(); + + if (!dyn_cast(src.getType()) || + !dyn_cast(dst.getType())) { + op.emitError("Unexpected copy type!"); + return failure(); + } + + llvm::SmallVector shapeValues; + for (auto shape: adaptor.getShape()) { + Value indexShape = rewriter.create(loc, rewriter.getIndexType(), shape); + shapeValues.push_back(indexShape); + } + + // create copyOp + auto srcSubView = makeSubViewOp(src, shapeValues, loc, rewriter); + auto dstSubView = makeSubViewOp(dst, shapeValues, loc, rewriter); + + MemRefType srcMemRefTy = cast(srcSubView.getType()); + MemRefType dstMemRefTy = cast(dstSubView.getType()); + + // Extract AddressSpace from MemRefType. + auto getAddressSpace = [](MemRefType ty) -> hivm::AddressSpace { + auto attr = ty.getMemorySpace(); + if (!attr) { + // The default memory attribute is GM. + return hivm::AddressSpace::GM; + } + auto addrSpaceAttr = dyn_cast(attr); + if (!addrSpaceAttr) { + return hivm::AddressSpace::GM; + } + return addrSpaceAttr.getAddressSpace(); + }; + + hivm::AddressSpace srcAddrSpace = getAddressSpace(srcMemRefTy); + hivm::AddressSpace dstAddrSpace = getAddressSpace(dstMemRefTy); + + Operation *copyOp = nullptr; + if (srcAddrSpace == hivm::AddressSpace::GM && + dstAddrSpace == hivm::AddressSpace::UB || + srcAddrSpace == hivm::AddressSpace::UB && + dstAddrSpace == hivm::AddressSpace::GM) { + copyOp = rewriter.create(loc, srcSubView, dstSubView); + } else if (srcAddrSpace == hivm::AddressSpace::GM && + dstAddrSpace == hivm::AddressSpace::L1) { + copyOp = rewriter.create(loc, /*result_tensor=*/TypeRange{}, + /*src=*/srcSubView, /*dst=*/dstSubView, + /*dst_continuous=*/UnitAttr::get(rewriter.getContext())); + } + /// else if (srcAddrSpace == hivm::AddressSpace::L0C && + /// dstAddrSpace == hivm::AddressSpace::GM) { + /// copyOp = rewriter.create(loc, + /// /*result_tensor=*/TypeRange{}, /*src=*/srcSubView, /*dst=*/dstSubView, + /// /*enable_nz2nd=*/UnitAttr::get(rewriter.getContext()) + + /// // #ifdef BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES + /// /*nullptr, + /// hivm::FixpipeDMAModeAttr::get(rewriter.getContext(), hivm::FixpipeDMAMode::NZ2ND), + /// nullptr, nullptr, nullptr, nullptr, nullptr*/ + /// ); + /// } + else { + op.emitError("Not implemented!"); + return failure(); + } + + copyOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, copyOp); + return success(); +} + +} // namespace TleCopyConverter + +namespace mlir::triton::tle { +void populateTleCopyOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} +} \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp new file mode 100644 index 000000000..6937e1571 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp @@ -0,0 +1,24 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +#include "tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h" +#include "tle/dsa/dialect/include/IR/Dialect.h" + +namespace TleMathConverter { + +using namespace mlir; +using namespace triton::tle; + +} + +namespace mlir::triton::tle { +void populateTleMathOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns) { + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + + /// patterns.add>(patterns.getContext()); +} +} \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt new file mode 100644 index 000000000..f4127590e --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +add_triton_library(TleIR + Dialect.cpp + TleOps.cpp + + DEPENDS + TleTableGen + + LINK_LIBS PUBLIC + TritonIR + MLIRIR +) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp new file mode 100644 index 000000000..e9aeab5ef --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#include "mlir/Support/LLVM.h" +#include "tle/dsa/dialect/include/IR/Dialect.h" +#include "tle/dsa/dialect/include/IR/Dialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "tle/dsa/dialect/include/IR/TleAttrDefs.cpp.inc" + +#define GET_OP_CLASSES +#include "tle/dsa/dialect/include/IR/TleOps.cpp.inc" + +namespace mlir::triton::tle { +void TleDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "tle/dsa/dialect/include/IR/TleAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "tle/dsa/dialect/include/IR/TleOps.cpp.inc" + >(); +} +} \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp new file mode 100644 index 000000000..64c964ecb --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp @@ -0,0 +1,8 @@ +// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "tle/dsa/dialect/include/IR/Dialect.h" + +namespace mlir::triton::tle { +} \ No newline at end of file diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc index 81fe42ab5..b4fb51db7 100644 --- a/third_party/tle/dsa/tle_ir.cc +++ b/third_party/tle/dsa/tle_ir.cc @@ -7,6 +7,8 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "tle/dsa/dialect/include/IR/Dialect.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -22,33 +24,19 @@ constexpr unsigned kIntegerAttrBitWidth = 64; struct DSAOpBuilder : public TritonOpBuilder {}; -void init_tle_ir(py::module &&m) -{ +void init_triton_tle(py::module &&m) +{ m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); py::class_(m, "tle_builder", py::module_local(), py::dynamic_attr()) .def(py::init()) - // Add alloc op - /// .def("create_dsa_alloc", - /// [](DSAOpBuilder &self, std::vector &shape, - /// std::string &layout, std::string &scope, Type type)-> Value { - /// auto shapeAttr = self.getBuilder().getI64ArrayAttr(shape); - /// auto layoutAttr = self.getBuilder().getStringAttr(layout); - /// auto scopeAttr = self.getBuilder().getStringAttr(scope); - - /// auto ptrType = triton::PointerType::get(type, 1); - /// auto tensorPtrType = RankedTensorType::get(shape, ptrType); - /// return self.create(tensorPtrType, shapeAttr, - /// layoutAttr, scopeAttr); - /// }) - // Add copy op .def("dsa_get_null_attr", [](DSAOpBuilder &self) { return Attribute(); }) .def("dsa_get_buffer_type", [](DSAOpBuilder &self, std::vector &shape, @@ -69,9 +57,10 @@ void init_tle_ir(py::module &&m) [](DSAOpBuilder &self, Type memrefType) -> Value { return self.create(mlir::cast(memrefType)); }) + // Add copy op .def("create_dsa_copy", [](DSAOpBuilder &self, Value &src, Value &dst, std::vector &shape, bool inter_no_alias)-> void { - auto copyOp = self.create(src, dst, shape); + auto copyOp = self.create(src, dst, shape); if (inter_no_alias) { copyOp->setAttr("inter_no_alias", self.getBuilder().getBoolAttr(true)); } @@ -79,32 +68,32 @@ void init_tle_ir(py::module &&m) // Add op .def("create_dsa_add", [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); + self.create(lhs, rhs, res); }) // Sub op .def("create_dsa_sub", [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); + self.create(lhs, rhs, res); }) // Mul op .def("create_dsa_mul", [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); + self.create(lhs, rhs, res); }) // Div op .def("create_dsa_div", [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); + self.create(lhs, rhs, res); }) // Max op .def("create_dsa_max", [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); + self.create(lhs, rhs, res); }) // Min op .def("create_dsa_min", [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); + self.create(lhs, rhs, res); }) // Dot op /// .def("create_dsa_dot", @@ -120,7 +109,7 @@ void init_tle_ir(py::module &&m) /// auto traB_attr = builder.getBoolAttr(traB); /// auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); - /// self.create(inA, inB, res, sizeAttr, initC_attr, + /// self.create(inA, inB, res, sizeAttr, initC_attr, /// traA_attr, traB_attr, enable_hf32_attr); /// }) .def("dsa_to_buffer", @@ -145,7 +134,7 @@ void init_tle_ir(py::module &&m) } return self.create(src, true, writable); }) - .def("create_extract_scalar", + .def("create_dsa_extract_scalar", [](DSAOpBuilder &self, Value &src, std::vector &indices) -> Value { llvm::SmallVector arg_indices; for (const auto &i : indices) { @@ -161,7 +150,7 @@ void init_tle_ir(py::module &&m) auto ret = self.create(src, arg_indices); return ret; }) - .def("create_extract_slice", + .def("create_dsa_extract_slice", [](DSAOpBuilder &self, Value &ful, std::vector &offs_vec, std::vector &sizs_vec, std::vector &strd_vec) -> Value { llvm::SmallVector offsets; @@ -192,7 +181,7 @@ void init_tle_ir(py::module &&m) return self.create(retTy, ful, offsets, sizes, strides); }) - .def("create_insert_slice", + .def("create_dsa_insert_slice", [](DSAOpBuilder &self, Value &ful, Value &sub, std::vector &offs_vec, std::vector &sizs_vec, std::vector &strd_vec) -> Value { From c7c91dfb0f1d0491cdad20d037f72861464d6bc2 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Thu, 5 Mar 2026 07:52:43 +0000 Subject: [PATCH 07/13] [CHORE]: update doc in tle --- .../experimental/tle/language/dsa/README.md | 34 +++++------ .../tle/language/dsa/ascend/core.py | 61 ------------------- 2 files changed, 17 insertions(+), 78 deletions(-) diff --git a/python/triton/experimental/tle/language/dsa/README.md b/python/triton/experimental/tle/language/dsa/README.md index 00cce40f3..6e50c79d4 100644 --- a/python/triton/experimental/tle/language/dsa/README.md +++ b/python/triton/experimental/tle/language/dsa/README.md @@ -4,15 +4,14 @@ TLE is a language extension for Triton that exposes on-chip memory, pipeline com ## Features -- **On-chip Memory Management**: `tle.alloc()` - Allocate memory on UB/L1/L0C -- **Data Movement**: `tle.copy()` - Efficient bidirectional copying between memory spaces -- **compute Operations**: `tle.npu_add()` - Addition on UB -- **Pipeline Optimization**: `tle.pipeline()` - Hardware-aware pipeline iteration +- **On-chip Memory Management**: `tle.dsa.alloc()` - Allocate memory on UB/L1/L0C +- **Data Movement**: `tle.dsa.copy()` - Efficient bidirectional copying between memory spaces +- **compute Operations**: `tle.dsa.add()` - Addition on UB +- **Pipeline Optimization**: `tle.dsa.pipeline()` - Hardware-aware pipeline iteration -## Memory Scopes & Layouts +## Memory Scopes & Layouts for ascend -- **Scopes**: `tle.UB` (UB memory), `tle.L1` (L1 memory), `tle.L0C` (L0C memory) -- **Layouts**: `tle.ND`, `tle.NZ` +- **Scopes**: `tle.dsa.ascend.UB` (UB memory), `tle.dsa.ascend.L1` (L1 memory), `tle.dsa.ascend.L0C` (L0C memory) ## Quick Example @@ -29,34 +28,35 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. offsets = block_start + tl.arange(0, BLOCK_SIZE) # Allocate UB memory - a_ub = tle.alloc([BLOCK_SIZE], dtype=tl.float32, layout=tle.ND, scope=tle.UB) - b_ub = tle.alloc([BLOCK_SIZE], dtype=tl.float32, layout=tle.ND, scope=tle.UB) - c_ub = tle.alloc([BLOCK_SIZE], dtype=tl.float32, layout=tle.ND, scope=tle.UB) + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) # Tail block processing t0 = n_elements - block_start tail_size = tl.minimum(t0, BLOCK_SIZE) # Copy data from GM to UB - tle.copy(x_ptr + offsets, a_ub, [tail_size]) - tle.copy(y_ptr + offsets, b_ub, [tail_size]) + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) # Addition - tle.npu_add(a_ub, b_ub, c_ub) + tle.dsa.add(a_ub, b_ub, c_ub) # Copy result back to GM - tle.copy(c_ub, output_ptr + offsets, [tail_size]) + tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) ``` ## Testing ```bash -cd ascend/examples/tle/pytest_ut/ +cd python/test/tle python3 test_vec_add.py ``` ## Learn More -See other examples in `ascend/examples/tle/pytest_ut/`: -- `test_matmul.py` - GEMM implementation and pipeline usage \ No newline at end of file +See other examples in `python/test/tle`: +- `test_matmul.py` - GEMM implementation and pipeline usage +- `test_vec_mathOps.py` - Vector math operations, such as add, sub, mul, div \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/ascend/core.py b/python/triton/experimental/tle/language/dsa/ascend/core.py index c06e627b8..e1e70dd54 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/core.py +++ b/python/triton/experimental/tle/language/dsa/ascend/core.py @@ -7,64 +7,3 @@ L0A = ascend_address_space.L0A L0B = ascend_address_space.L0B L0C = ascend_address_space.L0C - - -### from triton.language.core import ( -### _unwrap_if_constexpr, -### ) -### -### class layout: -### ASCEND = ['ND', 'NZ'] -### -### def __init__(self, name): -### name = _unwrap_if_constexpr(name) -### self.name = name -### assert name in layout.ASCEND, name -### -### def __str__(self): -### return self.name -### -### def codegen_name(self): -### return self.name -### -### @property -### def cache_key_part(self) -> str: -### """See cache_key_part() in triton.cc.""" -### return self.name -### -### def __repr__(self): -### """Output of repr needs to be an evaluatable expression""" -### return f'triton.language.{self.codegen_name()}' -### -### -### ND = layout('ND') -### NZ = layout('NZ') -### -### class scope: -### ASCEND = ['UB', 'L1', 'L0A', 'L0B', 'L0C'] -### -### def __init__(self, name): -### name = _unwrap_if_constexpr(name) -### self.name = name -### assert name in scope.ASCEND, name -### -### def __str__(self): -### return self.name -### -### def codegen_name(self): -### return self.name -### -### @property -### def cache_key_part(self) -> str: -### """See cache_key_part() in triton.cc.""" -### return self.name -### -### def __repr__(self): -### """Output of repr needs to be an evaluatable expression""" -### return f'triton.language.{self.codegen_name()}' -### -### UB = scope('UB') -### L1 = scope('L1') -### L0A = scope('L0A') -### L0B = scope('L0B') -### L0C = scope('L0C') From c14180544662a67e62c40ab0f14461c0447c21ae Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Thu, 5 Mar 2026 10:09:08 +0000 Subject: [PATCH 08/13] [FEAT]: decouple tle.dsa in backend/ascend/spec * backend/ascend/spec/triton/compiler/code_generator.py still use tle.dsa in its visitor to visit python ast --- python/triton/experimental/tle/__init__.py | 48 +++++++++++++++++++ .../spec/triton/compiler/code_generator.py | 11 ++--- .../backend/spec/triton/compiler/compiler.py | 3 +- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py index 16f30b856..fe46bfdbd 100644 --- a/python/triton/experimental/tle/__init__.py +++ b/python/triton/experimental/tle/__init__.py @@ -1,5 +1,53 @@ # Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +from triton._C.libtriton import ir +from typing import Optional, Dict +from triton.runtime import JITFunction +from .language.builder import setup_unified_builder_with_tle_builder +import importlib + +try: + from triton._C.libtriton import tle as tle_ir +except ImportError: + raise RuntimeError("tle is not available") + +triton_compiler = importlib.import_module("triton.compiler", package=__package__) +def tle_patch_for_triton_compile(): + original_compile_fn = triton_compiler.compile + def tle_compile(src, target=None, options=None): + # ir.context() will return a new MLIRContext each time, here should keep the same context + cur_context = ir.context() + tle_ir.load_dialects(cur_context) + + original_context_fn = ir.context + def patched_context(): + return cur_context + ir.context = patched_context + + try: + compiled_kernel = original_compile_fn(src, target, options) + finally: + ir.context = original_context_fn + + return compiled_kernel + return tle_compile + +code_generator = importlib.import_module("triton.compiler.code_generator", package=__package__) + +class TleCodeGenerator(code_generator.CodeGenerator): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + super().__init__(context, prototype, gscope, attributes, constants, function_name, jit_fn, options, + codegen_fns, module_map, module, is_kernel, function_types, noinline, file_name, begin_line) + self.tle_builder = tle_ir.tle_builder(context) + self.tle_builder.set_loc(file_name, begin_line, 0) + setup_unified_builder_with_tle_builder(self.builder, self.tle_builder) + + +triton_compiler.compile = tle_patch_for_triton_compile() +code_generator.CodeGenerator = TleCodeGenerator + from .language import dsa __all__ = [ diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index b691e2839..c1d8594e0 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -7,18 +7,17 @@ import textwrap from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +import importlib import triton.language.extra.cann.extension as extension from triton.extension.buffer.language import core as bl from triton.extension.buffer.language.builder import setup_unified_builder_with_buffer_builder -from triton.experimental.tle.language.builder import setup_unified_builder_with_tle_builder from .. import language -from .._C.libtriton import ir, buffer_ir, tle as tle_ir +from .._C.libtriton import ir, buffer_ir from .._C.libtriton.ascend import ir as ascend_ir from ..language import constexpr, tensor, str_to_ty from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value -from ..experimental.tle import dsa from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction @@ -232,10 +231,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n setup_unified_builder(self.builder, self.ascend_builder) self.buffer_builder = buffer_ir.buffer_builder(context) self.buffer_builder.set_loc(file_name, begin_line, 0) - self.tle_builder = tle_ir.tle_builder(context) - self.tle_builder.set_loc(file_name, begin_line, 0) setup_unified_builder_with_buffer_builder(self.builder, self.buffer_builder) - setup_unified_builder_with_tle_builder(self.builder, self.tle_builder) # dict of functions provided by the backend. Below are the list of possible functions: # Convert custom types not natively supported on HW. @@ -989,6 +985,7 @@ def visit_For(self, node): warp_specialize = False disable_licm = False bind_sub_block = None + dsa = importlib.import_module("..experimental.tle.dsa", package=__package__) if IteratorClass in [language.range, extension.parallel, dsa.pipeline, dsa.parallel]: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -1088,6 +1085,8 @@ def visit_For(self, node): for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr()) if disable_licm: for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) + + dsa = importlib.import_module("..experimental.tle.dsa", package=__package__) if (IteratorClass is extension.parallel or IteratorClass is dsa.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) diff --git a/third_party/ascend/backend/spec/triton/compiler/compiler.py b/third_party/ascend/backend/spec/triton/compiler/compiler.py index d352f9233..cc7ba30e7 100644 --- a/third_party/ascend/backend/spec/triton/compiler/compiler.py +++ b/third_party/ascend/backend/spec/triton/compiler/compiler.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib import json -from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir, tle as tle_ir +from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir from .._C.libtriton.ascend import ir as ascend_ir from ..backends import backends from ..backends.compiler import GPUTarget, AttrsDescriptor @@ -270,7 +270,6 @@ def compile(src, target=None, options=None): context = ir.context() ir.load_dialects(context) buffer_ir.load_dialects(context) - tle_ir.load_dialects(context) ascend_ir.load_dialects(context) backend.load_dialects(context) codegen_fns = backend.get_codegen_implementation() From 53a55e7aa515788bf0fef245ee022ad8620839ea Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Fri, 6 Mar 2026 06:15:21 +0000 Subject: [PATCH 09/13] [FIX]: fix copyright declaration in tle --- python/test/tle/test_vec_add.py | 2 +- python/test/tle/test_vec_add_2d.py | 2 +- python/test/tle/test_vec_add_mix.py | 2 +- python/test/tle/test_vec_mathOps.py | 2 +- python/triton/experimental/__init__.py | 2 +- python/triton/experimental/tle/__init__.py | 2 +- python/triton/experimental/tle/language/__init__.py | 2 +- python/triton/experimental/tle/language/builder.py | 2 +- third_party/tle/dsa/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/include/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt | 2 +- .../dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h | 2 +- third_party/tle/dsa/dialect/include/IR/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/include/IR/Dialect.h | 2 +- third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td | 2 +- third_party/tle/dsa/dialect/include/IR/TleDialect.td | 2 +- third_party/tle/dsa/dialect/include/IR/TleOps.td | 2 +- third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/lib/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt | 2 +- .../tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt | 2 +- third_party/tle/dsa/dialect/lib/IR/Dialect.cpp | 2 +- third_party/tle/dsa/dialect/lib/IR/TleOps.cpp | 2 +- third_party/tle/dsa/tle_ir.cc | 2 +- 27 files changed, 27 insertions(+), 27 deletions(-) diff --git a/python/test/tle/test_vec_add.py b/python/test/tle/test_vec_add.py index 44a2f0af8..3a3fbeccf 100755 --- a/python/test/tle/test_vec_add.py +++ b/python/test/tle/test_vec_add.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd import torch import triton import triton.language as tl diff --git a/python/test/tle/test_vec_add_2d.py b/python/test/tle/test_vec_add_2d.py index debdb2780..6c0b273a4 100755 --- a/python/test/tle/test_vec_add_2d.py +++ b/python/test/tle/test_vec_add_2d.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd import torch import triton import triton.language as tl diff --git a/python/test/tle/test_vec_add_mix.py b/python/test/tle/test_vec_add_mix.py index 2a5512cca..4c18d94f4 100755 --- a/python/test/tle/test_vec_add_mix.py +++ b/python/test/tle/test_vec_add_mix.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd import torch import triton import triton.language as tl diff --git a/python/test/tle/test_vec_mathOps.py b/python/test/tle/test_vec_mathOps.py index 2d161d433..4252370a6 100755 --- a/python/test/tle/test_vec_mathOps.py +++ b/python/test/tle/test_vec_mathOps.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from typing import Callable, Tuple import torch import triton diff --git a/python/triton/experimental/__init__.py b/python/triton/experimental/__init__.py index ef36c171c..a6b9487a8 100644 --- a/python/triton/experimental/__init__.py +++ b/python/triton/experimental/__init__.py @@ -1 +1 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. \ No newline at end of file +# Copyright 2026- Xcoresigma Technology Co., Ltd \ No newline at end of file diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py index fe46bfdbd..9d3e80dc8 100644 --- a/python/triton/experimental/tle/__init__.py +++ b/python/triton/experimental/tle/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from triton._C.libtriton import ir from typing import Optional, Dict diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py index f2cd21743..32ed8c87b 100644 --- a/python/triton/experimental/tle/language/__init__.py +++ b/python/triton/experimental/tle/language/__init__.py @@ -1,3 +1,3 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from . import dsa diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py index e824f6553..9c71697fd 100644 --- a/python/triton/experimental/tle/language/builder.py +++ b/python/triton/experimental/tle/language/builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd def create_dsa_method_wrapper_with_tle_builder(main_builder, delegate_builder, method_name): delegate_method = getattr(delegate_builder, method_name) diff --git a/third_party/tle/dsa/CMakeLists.txt b/third_party/tle/dsa/CMakeLists.txt index 19b28cffe..f9768c71f 100644 --- a/third_party/tle/dsa/CMakeLists.txt +++ b/third_party/tle/dsa/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd include_directories(${CMAKE_CURRENT_SOURCE_DIR}/dialect/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/dialect/include) diff --git a/third_party/tle/dsa/dialect/CMakeLists.txt b/third_party/tle/dsa/dialect/CMakeLists.txt index 799695a7d..e918a4d15 100644 --- a/third_party/tle/dsa/dialect/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) diff --git a/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt b/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt index 87913e1bd..a111d0159 100644 --- a/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt @@ -1 +1 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/third_party/tle/dsa/dialect/include/CMakeLists.txt b/third_party/tle/dsa/dialect/include/CMakeLists.txt index 181d7332c..85a8512e3 100644 --- a/third_party/tle/dsa/dialect/include/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/include/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd add_subdirectory(Analysis) add_subdirectory(Conversion) diff --git a/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt index 87913e1bd..a111d0159 100644 --- a/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt @@ -1 +1 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h index 9b950c2ef..789231e6d 100644 --- a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_CONVERSION_MATH_CONVERTER_H #define TRITON_TLE_CONVERSION_MATH_CONVERTER_H diff --git a/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt index c1ec982de..2bbc0d99e 100644 --- a/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) diff --git a/third_party/tle/dsa/dialect/include/IR/Dialect.h b/third_party/tle/dsa/dialect/include/IR/Dialect.h index 9d0f3ce0f..d7a07c85f 100644 --- a/third_party/tle/dsa/dialect/include/IR/Dialect.h +++ b/third_party/tle/dsa/dialect/include/IR/Dialect.h @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_IR_DIALECT_H_ #define TRITON_TLE_IR_DIALECT_H_ diff --git a/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td b/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td index 616ab0e82..01e577718 100644 --- a/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td +++ b/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_ATTR_DEFS #define TRITON_TLE_ATTR_DEFS diff --git a/third_party/tle/dsa/dialect/include/IR/TleDialect.td b/third_party/tle/dsa/dialect/include/IR/TleDialect.td index 8f46ab6c1..ba85b281a 100644 --- a/third_party/tle/dsa/dialect/include/IR/TleDialect.td +++ b/third_party/tle/dsa/dialect/include/IR/TleDialect.td @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_DIALECT #define TRITON_TLE_DIALECT diff --git a/third_party/tle/dsa/dialect/include/IR/TleOps.td b/third_party/tle/dsa/dialect/include/IR/TleOps.td index 73d827c35..984e2f59c 100644 --- a/third_party/tle/dsa/dialect/include/IR/TleOps.td +++ b/third_party/tle/dsa/dialect/include/IR/TleOps.td @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_OPS #define TRITON_TLE_OPS diff --git a/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt index 87913e1bd..a111d0159 100644 --- a/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt @@ -1 +1 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/third_party/tle/dsa/dialect/lib/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/CMakeLists.txt index 181d7332c..85a8512e3 100644 --- a/third_party/tle/dsa/dialect/lib/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd add_subdirectory(Analysis) add_subdirectory(Conversion) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt index efba34782..21ea47f5d 100644 --- a/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt @@ -1,3 +1,3 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd add_subdirectory(TleToLinalg) \ No newline at end of file diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt index 57675b403..167620ff5 100644 --- a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd add_triton_library(TleToLinalg DSACopyConverter.cpp diff --git a/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt index f4127590e..d370c8313 100644 --- a/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd add_triton_library(TleIR Dialect.cpp diff --git a/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp index e9aeab5ef..5bf9562ee 100644 --- a/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp +++ b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #include "mlir/Support/LLVM.h" #include "tle/dsa/dialect/include/IR/Dialect.h" diff --git a/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp index 64c964ecb..5a3eeb16e 100644 --- a/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp +++ b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc index b4fb51db7..58dba49e5 100644 --- a/third_party/tle/dsa/tle_ir.cc +++ b/third_party/tle/dsa/tle_ir.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #include #include From 400e1d2f999c0043511e7c26a9bb58ddbd2e5356 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Fri, 6 Mar 2026 09:12:41 +0000 Subject: [PATCH 10/13] [FIX](tle): fix extract and apply tle.hints when visit ast * fix tle.dsa.hint for nested usage, see python/test/tle/test_tle_with_hints.py * implement extract_tle in experimental/tle --- python/test/tle/test_tle_with_hints.py | 62 +++++++++++++++++++ python/triton/compiler/code_generator.py | 10 +++ python/triton/experimental/tle/__init__.py | 58 +++++++++++++++++ .../spec/triton/compiler/code_generator.py | 35 +++-------- 4 files changed, 138 insertions(+), 27 deletions(-) create mode 100755 python/test/tle/test_tle_with_hints.py diff --git a/python/test/tle/test_tle_with_hints.py b/python/test/tle/test_tle_with_hints.py new file mode 100755 index 000000000..78e980052 --- /dev/null +++ b/python/test/tle/test_tle_with_hints.py @@ -0,0 +1,62 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd +import torch +import triton +import triton.language as tl +# import triton.language.extra.tle.ascend as tle +import triton.experimental.tle as tle + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # mem_addr_space is language.extra.cann.core.ascend_address_space + with tle.dsa.hint(inter_no_alias=True): + with tle.dsa.hint(test_k1=10): + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + with tle.dsa.hint(test_k2=20): + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + tle.dsa.add(a_ub, b_ub, c_ub) + tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + +def custom_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) + return output + +def test_add(): + torch.manual_seed(0) + size = 1024 + x = torch.rand(size, device='npu', dtype=torch.float) + y = torch.rand(size, device='npu', dtype=torch.float) + output_torch = x + y + output_triton = custom_func(x, y) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + from triton.backends.ascend.testing import do_bench_npu + bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False) + bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False) + # 保留两位小数输出 + print(f"torch time : {bench_torch:.2f}") + print(f"triton time: {bench_triton:.2f}") + +if __name__ == "__main__": + test_add() diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index d8ca58d8d..e0838aad9 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,6 +15,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +import importlib def mangle_ty(ty): @@ -1111,6 +1112,15 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: + if fn.__name__ == 'copy': + # extract tle hints from the generator to identify if node in the tle hints scope + tle = importlib.import_module("..experimental.tle", package=__package__) + top_hints = tle.extract_tle_hints_scope(self) + + # Only apply to some builtins; currently, 'copy' is relevant. + if 'inter_no_alias' in top_hints and 'inter_no_alias' not in kws: + kws['inter_no_alias'] = top_hints['inter_no_alias'] + return fn(*args, **extra_kwargs, **kws) except Exception as e: # Normally when we raise a CompilationError, we raise it as diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py index 9d3e80dc8..1eeda5be0 100644 --- a/python/triton/experimental/tle/__init__.py +++ b/python/triton/experimental/tle/__init__.py @@ -5,6 +5,8 @@ from triton.runtime import JITFunction from .language.builder import setup_unified_builder_with_tle_builder import importlib +import ast +from typing_extensions import override try: from triton._C.libtriton import tle as tle_ir @@ -42,8 +44,64 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n codegen_fns, module_map, module, is_kernel, function_types, noinline, file_name, begin_line) self.tle_builder = tle_ir.tle_builder(context) self.tle_builder.set_loc(file_name, begin_line, 0) + + # Stack to keep track of active `with`-hints (e.g., tle.hint(...)) + # Each entry is a dict mapping hint names to literal values. + self.with_hints = [] + setup_unified_builder_with_tle_builder(self.builder, self.tle_builder) + @override + def visit_With(self, node): + assert len(node.items) == 1 + context = node.items[0].context_expr + + # extract tle hints + hints = {} + if isinstance(context, ast.Call): + if isinstance(context.func, ast.Attribute) and context.func.attr == "hint": + for kw in context.keywords: + if not isinstance(kw.value, ast.Constant): + raise self._unsupported(node, "keyword arguments to hint() are only supported for constant values") + hints[kw.arg] = kw.value.value + + # append hints to with_hints anyway, to indicate that we're in the with scope + self.with_hints.append(hints) + + super().visit_With(node) + + # pop hints to indicate that we're out of the with scope + self.with_hints.pop() + +def extract_tle_hints_scope(generator: TleCodeGenerator): + """ + with tle.hints(inter_no_alias=True): + with xxxx: + with tle.hints(inter_no_alias=False): + ... + with xxx: + call_fn1(...) + call_fn(...) + + when visit_Call for call_fn1, we can get the hints scope as follows: + [{'inter_no_alias': True}, {xxx}, {'inter_no_alias': False}, {xxx}] + should get the parent scope hints 'inter_no_alias': False for call_fn1, after visit call_fn1, pop the scope + + when visit_Call for call_fn, we can get the hints scope as follows: + [{'inter_no_alias': True}, {xxx}, {'inter_no_alias': False}] + and now the hint scope is 'inter_no_alias': False' for call_fn, after visit call_fn, pop the scope + """ + if not generator.with_hints: + return {} + + # visit with_hints backward to find inter_no_alias hint + for i in range(len(generator.with_hints) - 1, -1, -1): + hints = generator.with_hints[i] + if "inter_no_alias" in hints: + return hints + + return {} + triton_compiler.compile = tle_patch_for_triton_compile() code_generator.CodeGenerator = TleCodeGenerator diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index c1d8594e0..18e17b6fa 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -272,9 +272,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # Are we currently visiting an ast.arg's default value? These have some # special handling. self.visiting_arg_default_value = False - # Stack to keep track of active `with`-hints (e.g., tle.hint(...)) - # Each entry is a dict mapping hint names to literal values. - self._with_hints = [] builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} builtin_namespace.update(( @@ -811,27 +808,13 @@ def visit_With(self, node): # Check if context is a Call and dispatch to registered handler if isinstance(context, ast.Call): - # TODO[FIXME]: This is a hack to support `with hint(...)`, maybe should be handled in a better way like scope handler - if isinstance(context.func, ast.Attribute) and context.func.attr == "hint": - hints = {} - for kw in context.keywords: - if not isinstance(kw.value, ast.Constant): - raise self._unsupported(node, "keyword arguments to hint() are only supported for constant values") - hints[kw.arg] = kw.value.value - self._with_hints.append(hints) - withitemClass = self.visit(context.func) handler = WITH_DISPATCH.get(withitemClass) if handler: return handler(self, node) # Fall back to visiting body for unhandled cases - try: - self.visit_compound_statement(node.body) - finally: - self._with_hints.pop() - - return + return self.visit_compound_statement(node.body) def visit_Compare(self, node): if not (len(node.comparators) == 1 and len(node.ops) == 1): @@ -1213,16 +1196,14 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: - # Honor hints coming from an enclosing `with ... hint(...)` block. - # For example, `with tle.hint(inter_no_alias=True): tle.copy(...)` - # should behave like `tle.copy(..., inter_no_alias=True)` when the - # keyword isn't explicitly provided on the call site. - if self._with_hints: + if fn.__name__ == 'copy': + # extract tle hints from the generator to identify if node in the tle hints scope + tle = importlib.import_module("..experimental.tle", package=__package__) + top_hints = tle.extract_tle_hints_scope(self) + # Only apply to some builtins; currently, 'copy' is relevant. - if fn.__name__ == 'copy': - top_hints = self._with_hints[-1] - if 'inter_no_alias' in top_hints and 'inter_no_alias' not in kws: - kws['inter_no_alias'] = top_hints['inter_no_alias'] + if 'inter_no_alias' in top_hints and 'inter_no_alias' not in kws: + kws['inter_no_alias'] = top_hints['inter_no_alias'] ret = fn(*args, **extra_kwargs, **kws) # Sync the builder's location before return. ip, last_loc = self._get_insertion_point_and_loc(_builder) From 268a5ce84ad42a658658842e2a2e5cc39e9085a5 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Mon, 9 Mar 2026 09:30:11 +0000 Subject: [PATCH 11/13] [FIX](tle): fix tle module importing in ascend/backend/spec/triton/compiler/code_generator.py and add sparse_flash_attn_tle.py --- python/tutorials/tle/sfa_tle_v1.py | 910 +++++++++++++++++ python/tutorials/tle/sparse_flash_attn_tle.py | 912 ++++++++++++++++++ .../spec/triton/compiler/code_generator.py | 10 +- 3 files changed, 1827 insertions(+), 5 deletions(-) create mode 100644 python/tutorials/tle/sfa_tle_v1.py create mode 100644 python/tutorials/tle/sparse_flash_attn_tle.py diff --git a/python/tutorials/tle/sfa_tle_v1.py b/python/tutorials/tle/sfa_tle_v1.py new file mode 100644 index 000000000..f59f6c54d --- /dev/null +++ b/python/tutorials/tle/sfa_tle_v1.py @@ -0,0 +1,910 @@ +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import numpy as np +from datetime import datetime +from triton.backends.ascend.testing import do_bench_npu +import triton.experimental.tle as tle +# import random + +np.random.seed(21) +DEVICE = "npu" +DEVICE_ID = 0 +torch.manual_seed(20) +torch_npu.npu.set_device(int(DEVICE_ID)) +torch.set_printoptions(sci_mode=False, precision=4, linewidth=300) + +ascend_aiv_core_nums = triton.language.constexpr(24) + +# ===== Fused PA + Rope Concat + BNSD + Gather Kernel ===== +@triton.jit +def fused_pa_rope_to_sparse_kernel( + k_pa_ptr, k_rope_pa_ptr, v_pa_ptr, # PA_BSND input [block_num, block_size, n, d] + block_table_ptr, # block_table [B, max_blocks] + sparse_indices_ptr, # sparse_indices [B, N, TOPK] + k_sparse_out_ptr, v_sparse_out_ptr, # BNSD output [B, N, TOPK, d] + stride_k_pa_bn, stride_k_pa_bs, stride_k_pa_n, stride_k_pa_d, # K PA strides + stride_k_rope_pa_bn, stride_k_rope_pa_bs, stride_k_rope_pa_n, stride_k_rope_pa_d, # K_rope PA strides + stride_v_pa_bn, stride_v_pa_bs, stride_v_pa_n, stride_v_pa_d, # V PA strides + stride_bt_b, stride_bt_blk, # block_table strides + stride_si_b, stride_si_n, stride_si_topk, # sparse_indices strides + stride_out_b, stride_out_n, stride_out_topk, stride_out_d, # output strides + stride_v_b, stride_v_n, stride_v_topk, stride_v_d, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_DK_ROPE: tl.constexpr, # 0 if no rope + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + B: tl.constexpr, +): + """ + Fused kernel: PA_BSND + Rope Concat -> BNSD Sparse + Input: K/V in PA_BSND format, K_rope in PA_BSND format + Output: K/V_sparse in BNSD format + """ + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # Process (b, n, topk) combinations + for b_idx in range(B): + b = b_idx # sparse_indices is [B, N, TOPK], assume B=1 for now + for idx in range(pid, TOPK, num_programs): + # Get batch and sparse index from sparse_indices + n = 0 # KV_N = 1 + + # Load sparse index + sparse_idx = tl.load(sparse_indices_ptr + b * stride_si_b + n * stride_si_n + idx * stride_si_topk) + + # Map sparse_idx to PA_BSND position + block_id = sparse_idx // BLOCK_SIZE # Which block + bs_offset = sparse_idx % BLOCK_SIZE # Offset within block + + # Get actual block ID from block_table + actual_block_id = tl.load(block_table_ptr + b * stride_bt_b + block_id * stride_bt_blk) + + # Compute PA_BSND offset for K + k_pa_offset = (actual_block_id * stride_k_pa_bn + + bs_offset * stride_k_pa_bs + + n * stride_k_pa_n) + + # Compute PA_BSND offset for K_rope + k_rope_pa_offset = (actual_block_id * stride_k_rope_pa_bn + + bs_offset * stride_k_rope_pa_bs + + n * stride_k_rope_pa_n) + + # Compute PA_BSND offset for V + v_pa_offset = (actual_block_id * stride_v_pa_bn + + bs_offset * stride_v_pa_bs + + n * stride_v_pa_n) + # Load K vector (no rope part) + k_vec = tl.load( + k_pa_ptr + k_pa_offset + + tl.arange(0, BLOCK_DK) * stride_k_pa_d + ) + + # Load V vector + v_vec = tl.load( + v_pa_ptr + v_pa_offset + + tl.arange(0, BLOCK_DV) * stride_v_pa_d + ) + # Output to BNSD format: [B, N, TOPK, D] + out_offset = b * stride_out_b + n * stride_out_n + idx * stride_out_topk + out_offset_v = b* stride_v_b + n *stride_v_n + idx*stride_v_topk + + if BLOCK_DK_ROPE > 0: + # Load K_rope vector + full_k = tl.full((BLOCK_DK + BLOCK_DK_ROPE,), 0.0, dtype=tl.float16) + k_rope_vec = tl.load( + k_rope_pa_ptr + k_rope_pa_offset + + tl.arange(0, BLOCK_DK_ROPE) * stride_k_rope_pa_d + ) + full_k = tle.dsa.insert_slice(full_k, k_vec, offsets=(0,), sizes=(BLOCK_DK,), strides=(1,)) + full_k = tle.dsa.insert_slice(full_k, k_rope_vec, offsets=(BLOCK_DK,), sizes=(BLOCK_DK_ROPE,), strides=(1,)) + tl.store( + k_sparse_out_ptr + out_offset + + tl.arange(0, BLOCK_DK + BLOCK_DK_ROPE) * stride_out_d, + full_k + ) + else: + # No rope, store K directly + tl.store( + k_sparse_out_ptr + out_offset + + tl.arange(0, BLOCK_DK) * stride_out_d, + k_vec + ) + + # Store V + tl.store( + v_sparse_out_ptr + out_offset_v + + tl.arange(0, BLOCK_DV) * stride_v_d, + v_vec + ) + + +def triton_fused_pa_rope_to_sparse(k_pa, k_rope_pa, v_pa, block_table, sparse_indices, block_size): + """ + Fused PA_BSND + Rope Concat -> BNSD Sparse conversion + + Args: + k_pa: Key in PA_BSND format [block_num, block_size, n, dk] + k_rope_pa: Key rope in PA_BSND format [block_num, block_size, n, d_rope], None if no rope + v_pa: Value in PA_BSND format [block_num, block_size, n, dv] + block_table: Block table [B, max_blocks] + sparse_indices: Sparse indices [B, N, TOPK] + block_size: Block size for PA format + + Returns: + k_sparse: Sparse key in BNSD format [B, N, TOPK, dk+d_rope] + v_sparse: Sparse value in BNSD format [B, N, TOPK, dv] + """ + block_num, _, n, dk = k_pa.shape + B = block_table.shape[0] + TOPK = sparse_indices.size(-1) + N = 1 # KV_N = 1 + _, _, _, dv = v_pa.shape + + has_rope = k_rope_pa is not None + dk_rope = k_rope_pa.shape[-1] if has_rope else 0 + dk_total = dk + dk_rope + + # Output BNSD format [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, dk_total), dtype=k_pa.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, dv), dtype=v_pa.dtype, device=DEVICE) + + # Grid: use 48 programs for parallelism + grid = (min(48, TOPK),) + + # sparse_indices input format: [T, N, TOPK] or [B, N, TOPK] + # No squeeze needed - kernel expects [B, N, TOPK] format + sparse_indices_input = sparse_indices + if sparse_indices.dim() == 2: + # If already 2D [B, TOPK], reshape to [B, 1, TOPK] + sparse_indices_input = sparse_indices.unsqueeze(1) + + # Set k_rope_pa to k_pa if no rope (dummy pointer, won't be accessed) + k_rope_pa_input = k_rope_pa if has_rope else k_pa + fused_pa_rope_to_sparse_kernel[grid]( + k_pa, k_rope_pa_input, v_pa, + block_table, + sparse_indices_input, + k_sparse, v_sparse, + k_pa.stride(0), k_pa.stride(1), k_pa.stride(2), k_pa.stride(3), + k_rope_pa_input.stride(0), k_rope_pa_input.stride(1), k_rope_pa_input.stride(2), k_rope_pa_input.stride(3), + v_pa.stride(0), v_pa.stride(1), v_pa.stride(2), v_pa.stride(3), + block_table.stride(0), block_table.stride(1), + sparse_indices_input.stride(0), sparse_indices_input.stride(1), sparse_indices_input.stride(2), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + BLOCK_DK=dk, + BLOCK_DV=dv, + BLOCK_DK_ROPE=dk_rope, + TOPK=TOPK, + BLOCK_SIZE=block_size, + B = B + ) + + return k_sparse, v_sparse + +@triton.jit +def gather_kv_bnsd_vec_kernel( + k_ptr, v_ptr, ind_ptr, + k_out_ptr, v_out_ptr, + stride_kb, stride_kn, stride_ks, stride_kd, + stride_vb, stride_vn, stride_vs, stride_vd, + stride_ob, stride_on, stride_os, stride_od, + stride_ovb, stride_ovn, stride_ovs, stride_ovd, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + TOPK: tl.constexpr, + B: tl.constexpr, +): + end = TOPK // 48 * 48 + for b_idx in range(B): + # 分批处理所有TOPK个索引,每次48个 + for batch_start in range(0, end, 48): + pid_k = tl.program_id(0) + batch_start + + # 读 index + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + + # 处理余数部分(end到TOPK) + for batch_start in range(end, TOPK, 48): + pid_k = tl.program_id(0) + batch_start + + # 必须在计算pid_k之后检查边界 + if pid_k < TOPK: + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + +def triton_gather_kv_bnsd_vec(k, v, indices): + B, N, SK, Dk = k.shape # N=1 + B, N, SK, Dv = v.shape # N=1 + TOPK = indices.size(-1) + + # 输出保持 bnsd [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, Dk), dtype=k.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, Dv), dtype=v.dtype, device=DEVICE) + + grid = (48,) # TOPK 个 program,每个搬 Dk/Dv 元素 + gather_kv_bnsd_vec_kernel[grid]( + k, v, indices.squeeze(0), # [B, N, SK, D] -> [N, SK, D] + k_sparse, v_sparse, + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + BLOCK_DK=Dk, + BLOCK_DV=Dv, + TOPK=TOPK, + B=B, + ) + return k_sparse, v_sparse + +@triton.jit +def _attn_fwd( + Q, K, V, O, scale_value, + stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, + stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, + stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, + stride_ob: tl.constexpr, stride_os: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, + KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N:tl.constexpr, O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, + ): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + # off_n = 0 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ob + off_s * stride_os + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base = Q + q_offset, + shape = (Q_N, Q_D), + strides = (stride_qn, stride_qd), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, Q_D), + order = (1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr(base = K + k_offset, + shape = (KV_S, K_D), + strides = (stride_ks, stride_kd), + offsets = (0, 0), + block_shape = (blk_size, K_D), + order = (1, 0),) + v_block_ptr = tl.make_block_ptr(base = V + v_offset, + shape = (KV_S, V_D), + strides = (stride_vs, stride_vd), + offsets = (0, 0), + block_shape = (blk_size, V_D), + order = (1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc*coeff[:,None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base = O + o_offset, + shape = (O_N, O_D), + strides = (stride_on, stride_od), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, O_D), + order = (1,0)) + # final normalize + acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + + +@triton.jit +def _attn_fwd_fused_bsnd_to_tnd( + Q, K, V, O, scale_value, + stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, + stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, + stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, + stride_ot: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, + KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N:tl.constexpr, O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, + ): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ot + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base = Q + q_offset, + shape = (Q_N, Q_D), + strides = (stride_qn, stride_qd), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, Q_D), + order = (1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr(base = K + k_offset, + shape = (KV_S, K_D), + strides = (stride_ks, stride_kd), + offsets = (0, 0), + block_shape = (blk_size, K_D), + order = (1, 0),) + v_block_ptr = tl.make_block_ptr(base = V + v_offset, + shape = (KV_S, V_D), + strides = (stride_vs, stride_vd), + offsets = (0, 0), + block_shape = (blk_size, V_D), + order = (1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc*coeff[:,None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base = O + o_offset, + shape = (O_N, O_D), + strides = (stride_on, stride_od), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, O_D), + order = (1,0)) + # final normalize + acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + + + +class _attention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + query, + key, + value, + sparse_indices, + scale_value, + sparse_block_size = 1, + actual_seq_lengths_query = None, + actual_seq_lengths_kv = None, + query_rope = None, + key_rope = None, + layout_query = 'BSND', + layout_kv = 'BSND', + sparse_mode = 0, + block_table = None): + # Save original sparse_indices for PA_BSND case + sparse_indices_orig = sparse_indices.clone() + total_len = 0 + # Handle query layout transformation (TND -> BSND) + if layout_query == 'TND': + actual_seq_lengths_query, total_len = trans_tnd_actseq(actual_seq_lengths_query) + # ✅ 融合版本:一次 kernel 调用处理所有 tensor + concat + query, sparse_indices = trans_tnd_to_bsnd_fused( + query, query_rope, sparse_indices, query.shape, actual_seq_lengths_query + ) + else: + if query_rope != None: + query = torch.cat([query, query_rope], dim = -1) + + # Handle KV layout and gather sparse K/V + if layout_kv == 'PA_BSND': + # Fused PA -> BNSD + rope concat + sparse gather + block_size = key.shape[1] # Get block_size from PA shape + # Use original sparse_indices [T, N, TOPK] for fused kernel + k_sparse, v_sparse = triton_fused_pa_rope_to_sparse( + key, key_rope, value, block_table, sparse_indices_orig, block_size + ) + # sparse_indices is already in BSND, needs permute to BNSD for downstream use + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + else: + # Original path for non-PA layouts + if key_rope != None: + key = torch.cat([key, key_rope], dim = -1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + + k_sparse, v_sparse = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + + k_sparse = k_sparse.contiguous() + v_sparse = v_sparse.contiguous() + enable_check_kv_sparse = 0 + if enable_check_kv_sparse: + key = pa_to_bsnd(key, block_table, actual_seq_lengths_kv) + key_rope = pa_to_bsnd(key_rope, block_table, actual_seq_lengths_kv) + value = pa_to_bsnd(value, block_table, actual_seq_lengths_kv) + if key_rope != None: + key = torch.cat([key, key_rope], dim = -1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + k_sparse_ref, v_sparse_ref = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + print(f"k_sparse={k_sparse}") + print(f"k_sparse_ref={k_sparse_ref}") + print(f"v_sparse={v_sparse}") + print(f"v_sparse_ref={v_sparse_ref}") + assert torch.allclose(k_sparse, k_sparse_ref, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + assert torch.allclose(v_sparse, v_sparse_ref, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + + # expected_k = key_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(k_sparse, expected_k, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + # expected_v = value_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(v_sparse, expected_v, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + num_cores = ascend_aiv_core_nums + sparse_size = sparse_indices_bnsd.shape[-1] # 4 + out_shape_bsnd = list(query.shape) + if query_rope != None: + out_shape_bsnd[-1] = out_shape_bsnd[-1] - query_rope.shape[-1] + B, Q_S, Q_N, Q_D = query.shape + _, _, KV_S, K_D = k_sparse.shape + + if layout_query == 'TND': + # t = B*act_q_s + output = torch.empty((total_len, out_shape_bsnd[2], out_shape_bsnd[3]), device=query.device, dtype=torch.float32) + _attn_fwd_fused_bsnd_to_tnd[(num_cores,)]( + query, k_sparse, v_sparse, output, scale_value, + query.stride(0), query.stride(1), query.stride(2), query.stride(3), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + output.stride(0), output.stride(1), output.stride(2), + B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, + KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], + sparse_mode = sparse_mode, O_N = output.shape[1], O_D = output.shape[2], + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + blk_size=128, Q_BLOCK_SIZE=16,multibuffer=False + ) + + else: + output = torch.empty(out_shape_bsnd, device=query.device, dtype=torch.float32) + _attn_fwd[(num_cores,)]( + query, k_sparse, v_sparse, output, scale_value, + query.stride(0), query.stride(1), query.stride(2), query.stride(3), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + output.stride(0), output.stride(1), output.stride(2), output.stride(3), + B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, + KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], + sparse_mode = sparse_mode, O_N = output.shape[2], O_D = output.shape[3], + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + blk_size=128, Q_BLOCK_SIZE=16 + ) + output = output.permute(0, 2, 1, 3).contiguous() + + ctx.save_for_backward(query, k_sparse, v_sparse, output) + ctx.scale_value = scale_value + return output + +def pa_to_bsnd(pa_in, block_table, actual_seq_lengths): + block_num, block_size, n, d = pa_in.shape + b = len(actual_seq_lengths) + output = torch.empty((b, block_num * block_size // b, 1, d), dtype = pa_in.dtype).to(DEVICE) + for i in range(b): + for j in range(20): + output[i, j * block_size: (j + 1) * block_size, 0, :] = \ + pa_in[block_table[i][j], :, 0, :].reshape(block_size, d) + return output + + +@triton.jit +def trans_tnd_to_bsnd_fused_kernel( + query_ptr, query_rope_ptr, sparse_ptr, + query_out_ptr, sparse_out_ptr, # query_out 已经拼接了 rope + act_s, + stride_q_t, stride_q_tn, stride_q_td, + stride_qr_t, stride_qr_tn, stride_qr_td, + stride_s_t, stride_s_tn, stride_s_td, + stride_qob, stride_qobs, stride_qon, stride_qod, # query_out strides + stride_sb, stride_sbs, stride_sbn, stride_sbd, + B: tl.constexpr, + N: tl.constexpr, + D_QUERY: tl.constexpr, + D_ROPE: tl.constexpr, + D_SPARSE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_QUERY: tl.constexpr, + BLOCK_D_ROPE: tl.constexpr, + BLOCK_D_SPARSE: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # 计算 head 的总块数 + num_head_blocks = (N + BLOCK_N - 1) // BLOCK_N + t_idx = tl.full((), 0, dtype=tl.int64) # TODO: 需要正确的 token 映射 + # 每个 pid 负责处理特定的 (batch, head_block) 组合 + for tn_id in range(B): + # sparse_indices 是单头的,只在第一个 head_block 处理一次 + if pid == 0: + sparse_block_ptr = tl.make_block_ptr(base = sparse_ptr + t_idx * stride_s_t, + shape = (1, D_SPARSE), + strides = (stride_s_tn, stride_s_td), + offsets = (0, 0), + block_shape = (1, D_SPARSE), + order = (1, 0)) + sparse = tl.load(sparse_block_ptr) + + sparse_out_block_ptr = tl.make_block_ptr(base = sparse_out_ptr + t_idx * stride_sb, + shape = (1, D_SPARSE), + strides = ( stride_sbn, stride_sbd), + offsets = (0, 0), + block_shape = (1, D_SPARSE), + order = (1, 0)) + tl.store(sparse_out_block_ptr, sparse) + + # query 和 query_rope 是多头的,需要在 head 维度上分块处理 + for head_block_id in range(pid, num_head_blocks, num_programs): + n_offset = head_block_id * BLOCK_N + + # Load q and q_ro + q_block_ptr = tl.make_block_ptr(base = query_ptr + t_idx * stride_q_t, + shape = (N, D_QUERY), + strides = (stride_q_tn, stride_q_td), + offsets = (n_offset, 0), + block_shape = (BLOCK_N, D_QUERY), + order = (1, 0)) + q_ro_block_ptr = tl.make_block_ptr(base = query_rope_ptr + t_idx * stride_qr_t, + shape = (N, D_ROPE), + strides = (stride_qr_tn, stride_qr_td), + offsets = (n_offset, 0), + block_shape = (BLOCK_N, D_ROPE), + order = (1, 0)) + q = tl.load(q_block_ptr) + q_ro = tl.load(q_ro_block_ptr) + + # Combine query and query_rope using insert_slice, then store in one operation + full_q = tl.zeros((BLOCK_N, D_QUERY + D_ROPE), dtype=query_out_ptr.dtype.element_ty) + full_q = tle.dsa.insert_slice(full_q, q, offsets=(0, 0), sizes=(BLOCK_N, D_QUERY), strides=(1, 1)) + full_q = tle.dsa.insert_slice(full_q, q_ro, offsets=(0, D_QUERY), sizes=(BLOCK_N, D_ROPE), strides=(1, 1)) + + q_out_block_ptr = tl.make_block_ptr(base = query_out_ptr + t_idx * stride_qob, + shape = (N, D_QUERY + D_ROPE), + strides = (stride_qon, stride_qod), + offsets = (n_offset, 0), + block_shape = (BLOCK_N, D_QUERY + D_ROPE), + order = (1, 0)) + tl.store(q_out_block_ptr, full_q) + t_idx = t_idx + tl.load(act_s + tn_id) + + +def trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, shape, act_seq, grid=(16,)): + """ + 融合版本的 TND -> BSND 转换(包含 concat) + 一次性处理 query, query_rope, sparse_indices,并拼接 query + query_rope + """ + t, n, d_query = shape + b = len(act_seq) + s = max(act_seq) + + # 获取各个 tensor 的维度 + d_rope = query_rope.shape[2] if query_rope is not None else 0 + d_sparse = sparse_indices.shape[2] + d_query_out = d_query + d_rope # 拼接后的维度 + + # 分配输出(query_out 已经包含 rope) + query_out = torch.empty((b, s, n, d_query_out), dtype=query.dtype, device=query.device) + sparse_out = torch.empty((b, s, 1, d_sparse), dtype=sparse_indices.dtype, device=sparse_indices.device) + assert sparse_indices.shape[1] == 1, "sparse_indices second dim must be 1 when MLA" + # 启动 fused kernel + # 使用较小的 BLOCK_N 避免内存溢出 + block_n = min(16, n) + # 计算需要的核心数:使用多核心并行处理不同的头 + num_head_blocks = (n + block_n - 1) // block_n + num_programs = min(ascend_aiv_core_nums, num_head_blocks) # 最多使用24个核心 + + trans_tnd_to_bsnd_fused_kernel[num_programs,]( + query, query_rope, sparse_indices, + query_out, sparse_out, + act_seq, + query.stride(0), query.stride(1), query.stride(2), + query_rope.stride(0), query_rope.stride(1), query_rope.stride(2), + sparse_indices.stride(0), sparse_indices.stride(1), sparse_indices.stride(2), + query_out.stride(0), query_out.stride(1), query_out.stride(2), query_out.stride(3), + sparse_out.stride(0), sparse_out.stride(1), sparse_out.stride(2), sparse_out.stride(3), + B=b, + N=n, + D_QUERY=d_query, + D_ROPE=d_rope, + D_SPARSE=d_sparse, + BLOCK_N=block_n, + BLOCK_D_QUERY=d_query, + BLOCK_D_ROPE=d_rope, + BLOCK_D_SPARSE=d_sparse, + ) + return query_out, sparse_out + + +def trans_tnd_actseq(seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().tolist() + list_len = len(seq) + output = [] + output = [seq[0]] + total_len = seq[0] + for i in range(list_len - 1): + new_item = seq[i + 1] - seq[i] + if new_item >= 0: + output.append(new_item) + total_len += new_item + else: + print(f"[ERROR]trans_tnd_actseq: Wrong input actseq:{seq}, in loop {i}, item {new_item} < 0") + return torch.tensor(output).to(DEVICE), total_len + +def sparse_attention(query, key, value, + sparse_indices, scale_value, sparse_block_size = 1, + actual_seq_lengths_query = None, actual_seq_lengths_kv = None, + query_rope = None, key_rope = None, + layout_query = 'BSND', layout_kv = 'BSND', + sparse_mode = 0, block_table = None): + return _attention.apply(query, key, value, + sparse_indices, scale_value, sparse_block_size, + actual_seq_lengths_query, actual_seq_lengths_kv, + query_rope, key_rope, + layout_query, layout_kv, + sparse_mode, block_table) + +def test_op(T, B, KV_S, Q_N, KV_N, D, D_rope, + sparse_size, scale_value, + sparse_block_size, sparse_mode, block_size, act_kv_s): + assert sparse_size <= KV_S + assert KV_N == 1 + assert sparse_mode == 0 or 3 + assert sparse_block_size == 1 + assert (B * KV_S) % block_size == 0 + assert D == 512 + assert D_rope == 0 or 64 + print("*batch_size=",B) + qkv_dtype = torch.float16 + #sparse_size = KV_S + query = torch.empty((T, Q_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + key = torch.empty((B * KV_S // block_size, block_size, KV_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + value = key.clone() + + act_q_s = T // B # step + # rand_vals = torch.rand(T, KV_N, act_kv_s, device=DEVICE) + # _, indices = torch.topk(rand_vals, sparse_size, dim=-1) #sparse_indices不重复 + # sparse_indices = indices.to(torch.int32) + sparse_indices = torch.arange(sparse_size, device=DEVICE, dtype=torch.int32).view(1, 1, -1).expand(T, KV_N, -1) + sparse_indices = sparse_indices.to(torch.int32) + # print("sparse_indices=", sparse_indices) + actual_seq_lengths_query = torch.arange(1, B + 1, dtype=torch.int32, device=DEVICE) + # actual_seq_lengths_query = torch.tensor([1]).reshape(B).to(torch.int32).to(DEVICE) + actual_seq_lengths_kv = torch.tensor([act_kv_s] * B, dtype=torch.int32, device=DEVICE) + print(actual_seq_lengths_kv) + block_table = torch.tensor([range(B * KV_S // block_size)], dtype=torch.int32, device=DEVICE).reshape(B, -1) + + if D_rope == 0: + query_rope = None + key_rope = None + else: + query_rope = torch.empty((T, Q_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + key_rope = torch.empty((B * KV_S // block_size, block_size, KV_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + + print("q.shape=",query.shape) + print("k.shape=",key.shape) + print("v.shape=",value.shape) + print("sparse_indices.shape=",sparse_indices.shape) + print("act_seq_query=",actual_seq_lengths_query) + print("act_seq_kv=", actual_seq_lengths_kv) + + + triton_out = sparse_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv= 'PA_BSND', + sparse_mode = sparse_mode, + block_table= block_table, + ) + npu_out = torch_npu.npu_sparse_flash_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv = 'PA_BSND', + sparse_mode = sparse_mode, + block_table = block_table, + # attention_mode = 2, + ) + triton_out = triton_out.to(npu_out.dtype) + torch.testing.assert_close(triton_out, npu_out, rtol=1e-2, atol=1e-2, equal_nan=True) + print("[PASSED]") + + # benchmarking + triton_time = do_bench_npu(lambda:sparse_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv= 'PA_BSND', + sparse_mode = sparse_mode, + block_table = block_table, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Triton SFA] Time: {triton_time:.4f} us") + + npu_time = do_bench_npu(lambda:torch_npu.npu_sparse_flash_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv = 'PA_BSND', + sparse_mode = sparse_mode, + block_table = block_table, + # attention_mode = 2, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Torch-NPU SFA] Time: {npu_time:.4f} us") + +if __name__ == "__main__": + print(torch_npu.__version__) + print("Test Real Case in DS-v3.2-Exp") + print(f"time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + i = 1 + print(f"====================第{i}次测试=================") + test_op(T = 1, B = 1, KV_S = 2560, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T = 4, B = 4, KV_S = 6400, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T = 8, B = 8, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T = 16, B = 16, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) diff --git a/python/tutorials/tle/sparse_flash_attn_tle.py b/python/tutorials/tle/sparse_flash_attn_tle.py new file mode 100644 index 000000000..fea4b1682 --- /dev/null +++ b/python/tutorials/tle/sparse_flash_attn_tle.py @@ -0,0 +1,912 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import numpy as np +from datetime import datetime +from triton.backends.ascend.testing import do_bench_npu +import triton.experimental.tle as tle +# import random + +np.random.seed(21) +DEVICE = "npu" +DEVICE_ID = 0 +torch.manual_seed(20) +torch_npu.npu.set_device(int(DEVICE_ID)) +torch.set_printoptions(sci_mode=False, precision=4, linewidth=300) + +ascend_aiv_core_nums = triton.language.constexpr(24) + +# ===== Fused PA + Rope Concat + BNSD + Gather Kernel ===== +@triton.jit +def fused_pa_rope_to_sparse_kernel( + k_pa_ptr, k_rope_pa_ptr, v_pa_ptr, # PA_BSND input [block_num, block_size, n, d] + block_table_ptr, # block_table [B, max_blocks] + sparse_indices_ptr, # sparse_indices [B, N, TOPK] + k_sparse_out_ptr, v_sparse_out_ptr, # BNSD output [B, N, TOPK, d] + stride_k_pa_bn, stride_k_pa_bs, stride_k_pa_n, stride_k_pa_d, # K PA strides + stride_k_rope_pa_bn, stride_k_rope_pa_bs, stride_k_rope_pa_n, stride_k_rope_pa_d, # K_rope PA strides + stride_v_pa_bn, stride_v_pa_bs, stride_v_pa_n, stride_v_pa_d, # V PA strides + stride_bt_b, stride_bt_blk, # block_table strides + stride_si_b, stride_si_n, stride_si_topk, # sparse_indices strides + stride_out_b, stride_out_n, stride_out_topk, stride_out_d, # output strides + stride_v_b, stride_v_n, stride_v_topk, stride_v_d, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_DK_ROPE: tl.constexpr, # 0 if no rope + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + B: tl.constexpr, +): + """ + Fused kernel: PA_BSND + Rope Concat -> BNSD Sparse + Input: K/V in PA_BSND format, K_rope in PA_BSND format + Output: K/V_sparse in BNSD format + """ + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # Process (b, n, topk) combinations + for b_idx in range(B): + b = b_idx # sparse_indices is [B, N, TOPK], assume B=1 for now + for idx in range(pid, TOPK, num_programs): + # Get batch and sparse index from sparse_indices + n = 0 # KV_N = 1 + + # Load sparse index + sparse_idx = tl.load(sparse_indices_ptr + b * stride_si_b + n * stride_si_n + idx * stride_si_topk) + + # Map sparse_idx to PA_BSND position + block_id = sparse_idx // BLOCK_SIZE # Which block + bs_offset = sparse_idx % BLOCK_SIZE # Offset within block + + # Get actual block ID from block_table + actual_block_id = tl.load(block_table_ptr + b * stride_bt_b + block_id * stride_bt_blk) + + # Compute PA_BSND offset for K + k_pa_offset = (actual_block_id * stride_k_pa_bn + + bs_offset * stride_k_pa_bs + + n * stride_k_pa_n) + + # Compute PA_BSND offset for K_rope + k_rope_pa_offset = (actual_block_id * stride_k_rope_pa_bn + + bs_offset * stride_k_rope_pa_bs + + n * stride_k_rope_pa_n) + + # Compute PA_BSND offset for V + v_pa_offset = (actual_block_id * stride_v_pa_bn + + bs_offset * stride_v_pa_bs + + n * stride_v_pa_n) + # Load K vector (no rope part) + k_vec = tl.load( + k_pa_ptr + k_pa_offset + + tl.arange(0, BLOCK_DK) * stride_k_pa_d + ) + + # Load V vector + v_vec = tl.load( + v_pa_ptr + v_pa_offset + + tl.arange(0, BLOCK_DV) * stride_v_pa_d + ) + # Output to BNSD format: [B, N, TOPK, D] + out_offset = b * stride_out_b + n * stride_out_n + idx * stride_out_topk + out_offset_v = b* stride_v_b + n *stride_v_n + idx*stride_v_topk + + if BLOCK_DK_ROPE > 0: + # Load K_rope vector + full_k = tl.full((BLOCK_DK + BLOCK_DK_ROPE,), 0.0, dtype=tl.float16) + k_rope_vec = tl.load( + k_rope_pa_ptr + k_rope_pa_offset + + tl.arange(0, BLOCK_DK_ROPE) * stride_k_rope_pa_d + ) + full_k = tle.dsa.insert_slice(full_k, k_vec, offsets=(0,), sizes=(BLOCK_DK,), strides=(1,)) + full_k = tle.dsa.insert_slice(full_k, k_rope_vec, offsets=(BLOCK_DK,), sizes=(BLOCK_DK_ROPE,), strides=(1,)) + tl.store( + k_sparse_out_ptr + out_offset + + tl.arange(0, BLOCK_DK + BLOCK_DK_ROPE) * stride_out_d, + full_k + ) + else: + # No rope, store K directly + tl.store( + k_sparse_out_ptr + out_offset + + tl.arange(0, BLOCK_DK) * stride_out_d, + k_vec + ) + + # Store V + tl.store( + v_sparse_out_ptr + out_offset_v + + tl.arange(0, BLOCK_DV) * stride_v_d, + v_vec + ) + + +def triton_fused_pa_rope_to_sparse(k_pa, k_rope_pa, v_pa, block_table, sparse_indices, block_size): + """ + Fused PA_BSND + Rope Concat -> BNSD Sparse conversion + + Args: + k_pa: Key in PA_BSND format [block_num, block_size, n, dk] + k_rope_pa: Key rope in PA_BSND format [block_num, block_size, n, d_rope], None if no rope + v_pa: Value in PA_BSND format [block_num, block_size, n, dv] + block_table: Block table [B, max_blocks] + sparse_indices: Sparse indices [B, N, TOPK] + block_size: Block size for PA format + + Returns: + k_sparse: Sparse key in BNSD format [B, N, TOPK, dk+d_rope] + v_sparse: Sparse value in BNSD format [B, N, TOPK, dv] + """ + block_num, _, n, dk = k_pa.shape + B = block_table.shape[0] + TOPK = sparse_indices.size(-1) + N = 1 # KV_N = 1 + _, _, _, dv = v_pa.shape + + has_rope = k_rope_pa is not None + dk_rope = k_rope_pa.shape[-1] if has_rope else 0 + dk_total = dk + dk_rope + + # Output BNSD format [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, dk_total), dtype=k_pa.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, dv), dtype=v_pa.dtype, device=DEVICE) + + # Grid: use 48 programs for parallelism + grid = (min(48, TOPK),) + + # sparse_indices input format: [T, N, TOPK] or [B, N, TOPK] + # No squeeze needed - kernel expects [B, N, TOPK] format + sparse_indices_input = sparse_indices + if sparse_indices.dim() == 2: + # If already 2D [B, TOPK], reshape to [B, 1, TOPK] + sparse_indices_input = sparse_indices.unsqueeze(1) + + # Set k_rope_pa to k_pa if no rope (dummy pointer, won't be accessed) + k_rope_pa_input = k_rope_pa if has_rope else k_pa + fused_pa_rope_to_sparse_kernel[grid]( + k_pa, k_rope_pa_input, v_pa, + block_table, + sparse_indices_input, + k_sparse, v_sparse, + k_pa.stride(0), k_pa.stride(1), k_pa.stride(2), k_pa.stride(3), + k_rope_pa_input.stride(0), k_rope_pa_input.stride(1), k_rope_pa_input.stride(2), k_rope_pa_input.stride(3), + v_pa.stride(0), v_pa.stride(1), v_pa.stride(2), v_pa.stride(3), + block_table.stride(0), block_table.stride(1), + sparse_indices_input.stride(0), sparse_indices_input.stride(1), sparse_indices_input.stride(2), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + BLOCK_DK=dk, + BLOCK_DV=dv, + BLOCK_DK_ROPE=dk_rope, + TOPK=TOPK, + BLOCK_SIZE=block_size, + B = B + ) + + return k_sparse, v_sparse + +@triton.jit +def gather_kv_bnsd_vec_kernel( + k_ptr, v_ptr, ind_ptr, + k_out_ptr, v_out_ptr, + stride_kb, stride_kn, stride_ks, stride_kd, + stride_vb, stride_vn, stride_vs, stride_vd, + stride_ob, stride_on, stride_os, stride_od, + stride_ovb, stride_ovn, stride_ovs, stride_ovd, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + TOPK: tl.constexpr, + B: tl.constexpr, +): + end = TOPK // 48 * 48 + for b_idx in range(B): + # 分批处理所有TOPK个索引,每次48个 + for batch_start in range(0, end, 48): + pid_k = tl.program_id(0) + batch_start + + # 读 index + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + + # 处理余数部分(end到TOPK) + for batch_start in range(end, TOPK, 48): + pid_k = tl.program_id(0) + batch_start + + # 必须在计算pid_k之后检查边界 + if pid_k < TOPK: + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + +def triton_gather_kv_bnsd_vec(k, v, indices): + B, N, SK, Dk = k.shape # N=1 + B, N, SK, Dv = v.shape # N=1 + TOPK = indices.size(-1) + + # 输出保持 bnsd [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, Dk), dtype=k.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, Dv), dtype=v.dtype, device=DEVICE) + + grid = (48,) # TOPK 个 program,每个搬 Dk/Dv 元素 + gather_kv_bnsd_vec_kernel[grid]( + k, v, indices.squeeze(0), # [B, N, SK, D] -> [N, SK, D] + k_sparse, v_sparse, + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + BLOCK_DK=Dk, + BLOCK_DV=Dv, + TOPK=TOPK, + B=B, + ) + return k_sparse, v_sparse + +@triton.jit +def _attn_fwd( + Q, K, V, O, scale_value, + stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, + stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, + stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, + stride_ob: tl.constexpr, stride_os: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, + KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N:tl.constexpr, O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, + ): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + # off_n = 0 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ob + off_s * stride_os + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base = Q + q_offset, + shape = (Q_N, Q_D), + strides = (stride_qn, stride_qd), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, Q_D), + order = (1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr(base = K + k_offset, + shape = (KV_S, K_D), + strides = (stride_ks, stride_kd), + offsets = (0, 0), + block_shape = (blk_size, K_D), + order = (1, 0),) + v_block_ptr = tl.make_block_ptr(base = V + v_offset, + shape = (KV_S, V_D), + strides = (stride_vs, stride_vd), + offsets = (0, 0), + block_shape = (blk_size, V_D), + order = (1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc*coeff[:,None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base = O + o_offset, + shape = (O_N, O_D), + strides = (stride_on, stride_od), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, O_D), + order = (1,0)) + # final normalize + acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + + +@triton.jit +def _attn_fwd_fused_bsnd_to_tnd( + Q, K, V, O, scale_value, + stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, + stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, + stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, + stride_ot: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, + KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N:tl.constexpr, O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, + ): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ot + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base = Q + q_offset, + shape = (Q_N, Q_D), + strides = (stride_qn, stride_qd), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, Q_D), + order = (1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr(base = K + k_offset, + shape = (KV_S, K_D), + strides = (stride_ks, stride_kd), + offsets = (0, 0), + block_shape = (blk_size, K_D), + order = (1, 0),) + v_block_ptr = tl.make_block_ptr(base = V + v_offset, + shape = (KV_S, V_D), + strides = (stride_vs, stride_vd), + offsets = (0, 0), + block_shape = (blk_size, V_D), + order = (1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc*coeff[:,None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base = O + o_offset, + shape = (O_N, O_D), + strides = (stride_on, stride_od), + offsets = (off_n * Q_BLOCK_SIZE, 0), + block_shape = (Q_BLOCK_SIZE, O_D), + order = (1,0)) + # final normalize + acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + + + +class _attention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + query, + key, + value, + sparse_indices, + scale_value, + sparse_block_size = 1, + actual_seq_lengths_query = None, + actual_seq_lengths_kv = None, + query_rope = None, + key_rope = None, + layout_query = 'BSND', + layout_kv = 'BSND', + sparse_mode = 0, + block_table = None): + # Save original sparse_indices for PA_BSND case + sparse_indices_orig = sparse_indices.clone() + total_len = 0 + # Handle query layout transformation (TND -> BSND) + if layout_query == 'TND': + actual_seq_lengths_query, total_len = trans_tnd_actseq(actual_seq_lengths_query) + # ✅ 融合版本:一次 kernel 调用处理所有 tensor + concat + query, sparse_indices = trans_tnd_to_bsnd_fused( + query, query_rope, sparse_indices, query.shape, actual_seq_lengths_query + ) + else: + if query_rope != None: + query = torch.cat([query, query_rope], dim = -1) + + # Handle KV layout and gather sparse K/V + if layout_kv == 'PA_BSND': + # Fused PA -> BNSD + rope concat + sparse gather + block_size = key.shape[1] # Get block_size from PA shape + # Use original sparse_indices [T, N, TOPK] for fused kernel + k_sparse, v_sparse = triton_fused_pa_rope_to_sparse( + key, key_rope, value, block_table, sparse_indices_orig, block_size + ) + # sparse_indices is already in BSND, needs permute to BNSD for downstream use + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + else: + # Original path for non-PA layouts + if key_rope != None: + key = torch.cat([key, key_rope], dim = -1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + + k_sparse, v_sparse = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + + k_sparse = k_sparse.contiguous() + v_sparse = v_sparse.contiguous() + enable_check_kv_sparse = 0 + if enable_check_kv_sparse: + key = pa_to_bsnd(key, block_table, actual_seq_lengths_kv) + key_rope = pa_to_bsnd(key_rope, block_table, actual_seq_lengths_kv) + value = pa_to_bsnd(value, block_table, actual_seq_lengths_kv) + if key_rope != None: + key = torch.cat([key, key_rope], dim = -1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + k_sparse_ref, v_sparse_ref = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + print(f"k_sparse={k_sparse}") + print(f"k_sparse_ref={k_sparse_ref}") + print(f"v_sparse={v_sparse}") + print(f"v_sparse_ref={v_sparse_ref}") + assert torch.allclose(k_sparse, k_sparse_ref, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + assert torch.allclose(v_sparse, v_sparse_ref, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + + # expected_k = key_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(k_sparse, expected_k, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + # expected_v = value_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(v_sparse, expected_v, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + num_cores = ascend_aiv_core_nums + sparse_size = sparse_indices_bnsd.shape[-1] # 4 + out_shape_bsnd = list(query.shape) + if query_rope != None: + out_shape_bsnd[-1] = out_shape_bsnd[-1] - query_rope.shape[-1] + B, Q_S, Q_N, Q_D = query.shape + _, _, KV_S, K_D = k_sparse.shape + + if layout_query == 'TND': + # t = B*act_q_s + output = torch.empty((total_len, out_shape_bsnd[2], out_shape_bsnd[3]), device=query.device, dtype=torch.float32) + _attn_fwd_fused_bsnd_to_tnd[(num_cores,)]( + query, k_sparse, v_sparse, output, scale_value, + query.stride(0), query.stride(1), query.stride(2), query.stride(3), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + output.stride(0), output.stride(1), output.stride(2), + B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, + KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], + sparse_mode = sparse_mode, O_N = output.shape[1], O_D = output.shape[2], + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + blk_size=128, Q_BLOCK_SIZE=16,multibuffer=False + ) + + else: + output = torch.empty(out_shape_bsnd, device=query.device, dtype=torch.float32) + _attn_fwd[(num_cores,)]( + query, k_sparse, v_sparse, output, scale_value, + query.stride(0), query.stride(1), query.stride(2), query.stride(3), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + output.stride(0), output.stride(1), output.stride(2), output.stride(3), + B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, + KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], + sparse_mode = sparse_mode, O_N = output.shape[2], O_D = output.shape[3], + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + blk_size=128, Q_BLOCK_SIZE=16 + ) + output = output.permute(0, 2, 1, 3).contiguous() + + ctx.save_for_backward(query, k_sparse, v_sparse, output) + ctx.scale_value = scale_value + return output + +def pa_to_bsnd(pa_in, block_table, actual_seq_lengths): + block_num, block_size, n, d = pa_in.shape + b = len(actual_seq_lengths) + output = torch.empty((b, block_num * block_size // b, 1, d), dtype = pa_in.dtype).to(DEVICE) + for i in range(b): + for j in range(20): + output[i, j * block_size: (j + 1) * block_size, 0, :] = \ + pa_in[block_table[i][j], :, 0, :].reshape(block_size, d) + return output + + +@triton.jit +def trans_tnd_to_bsnd_fused_kernel( + query_ptr, query_rope_ptr, sparse_ptr, + query_out_ptr, sparse_out_ptr, # query_out 已经拼接了 rope + act_s, + stride_q_t, stride_q_tn, stride_q_td, + stride_qr_t, stride_qr_tn, stride_qr_td, + stride_s_t, stride_s_tn, stride_s_td, + stride_qob, stride_qobs, stride_qon, stride_qod, # query_out strides + stride_sb, stride_sbs, stride_sbn, stride_sbd, + B: tl.constexpr, + N: tl.constexpr, + D_QUERY: tl.constexpr, + D_ROPE: tl.constexpr, + D_SPARSE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_QUERY: tl.constexpr, + BLOCK_D_ROPE: tl.constexpr, + BLOCK_D_SPARSE: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # 计算 head 的总块数 + num_head_blocks = (N + BLOCK_N - 1) // BLOCK_N + t_idx = tl.full((), 0, dtype=tl.int64) # TODO: 需要正确的 token 映射 + # 每个 pid 负责处理特定的 (batch, head_block) 组合 + for tn_id in range(B): + # sparse_indices 是单头的,只在第一个 head_block 处理一次 + if pid == 0: + sparse_block_ptr = tl.make_block_ptr(base = sparse_ptr + t_idx * stride_s_t, + shape = (1, D_SPARSE), + strides = (stride_s_tn, stride_s_td), + offsets = (0, 0), + block_shape = (1, D_SPARSE), + order = (1, 0)) + sparse = tl.load(sparse_block_ptr) + + sparse_out_block_ptr = tl.make_block_ptr(base = sparse_out_ptr + t_idx * stride_sb, + shape = (1, D_SPARSE), + strides = ( stride_sbn, stride_sbd), + offsets = (0, 0), + block_shape = (1, D_SPARSE), + order = (1, 0)) + tl.store(sparse_out_block_ptr, sparse) + + # query 和 query_rope 是多头的,需要在 head 维度上分块处理 + for head_block_id in range(pid, num_head_blocks, num_programs): + n_offset = head_block_id * BLOCK_N + + # Load q and q_ro + q_block_ptr = tl.make_block_ptr(base = query_ptr + t_idx * stride_q_t, + shape = (N, D_QUERY), + strides = (stride_q_tn, stride_q_td), + offsets = (n_offset, 0), + block_shape = (BLOCK_N, D_QUERY), + order = (1, 0)) + q_ro_block_ptr = tl.make_block_ptr(base = query_rope_ptr + t_idx * stride_qr_t, + shape = (N, D_ROPE), + strides = (stride_qr_tn, stride_qr_td), + offsets = (n_offset, 0), + block_shape = (BLOCK_N, D_ROPE), + order = (1, 0)) + q = tl.load(q_block_ptr) + q_ro = tl.load(q_ro_block_ptr) + + # Combine query and query_rope using insert_slice, then store in one operation + full_q = tl.zeros((BLOCK_N, D_QUERY + D_ROPE), dtype=query_out_ptr.dtype.element_ty) + full_q = tle.dsa.insert_slice(full_q, q, offsets=(0, 0), sizes=(BLOCK_N, D_QUERY), strides=(1, 1)) + full_q = tle.dsa.insert_slice(full_q, q_ro, offsets=(0, D_QUERY), sizes=(BLOCK_N, D_ROPE), strides=(1, 1)) + + q_out_block_ptr = tl.make_block_ptr(base = query_out_ptr + t_idx * stride_qob, + shape = (N, D_QUERY + D_ROPE), + strides = (stride_qon, stride_qod), + offsets = (n_offset, 0), + block_shape = (BLOCK_N, D_QUERY + D_ROPE), + order = (1, 0)) + tl.store(q_out_block_ptr, full_q) + t_idx = t_idx + tl.load(act_s + tn_id) + + +def trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, shape, act_seq, grid=(16,)): + """ + 融合版本的 TND -> BSND 转换(包含 concat) + 一次性处理 query, query_rope, sparse_indices,并拼接 query + query_rope + """ + t, n, d_query = shape + b = len(act_seq) + s = max(act_seq) + + # 获取各个 tensor 的维度 + d_rope = query_rope.shape[2] if query_rope is not None else 0 + d_sparse = sparse_indices.shape[2] + d_query_out = d_query + d_rope # 拼接后的维度 + + # 分配输出(query_out 已经包含 rope) + query_out = torch.empty((b, s, n, d_query_out), dtype=query.dtype, device=query.device) + sparse_out = torch.empty((b, s, 1, d_sparse), dtype=sparse_indices.dtype, device=sparse_indices.device) + assert sparse_indices.shape[1] == 1, "sparse_indices second dim must be 1 when MLA" + # 启动 fused kernel + # 使用较小的 BLOCK_N 避免内存溢出 + block_n = min(16, n) + # 计算需要的核心数:使用多核心并行处理不同的头 + num_head_blocks = (n + block_n - 1) // block_n + num_programs = min(ascend_aiv_core_nums, num_head_blocks) # 最多使用24个核心 + + trans_tnd_to_bsnd_fused_kernel[num_programs,]( + query, query_rope, sparse_indices, + query_out, sparse_out, + act_seq, + query.stride(0), query.stride(1), query.stride(2), + query_rope.stride(0), query_rope.stride(1), query_rope.stride(2), + sparse_indices.stride(0), sparse_indices.stride(1), sparse_indices.stride(2), + query_out.stride(0), query_out.stride(1), query_out.stride(2), query_out.stride(3), + sparse_out.stride(0), sparse_out.stride(1), sparse_out.stride(2), sparse_out.stride(3), + B=b, + N=n, + D_QUERY=d_query, + D_ROPE=d_rope, + D_SPARSE=d_sparse, + BLOCK_N=block_n, + BLOCK_D_QUERY=d_query, + BLOCK_D_ROPE=d_rope, + BLOCK_D_SPARSE=d_sparse, + ) + return query_out, sparse_out + + +def trans_tnd_actseq(seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().tolist() + list_len = len(seq) + output = [] + output = [seq[0]] + total_len = seq[0] + for i in range(list_len - 1): + new_item = seq[i + 1] - seq[i] + if new_item >= 0: + output.append(new_item) + total_len += new_item + else: + print(f"[ERROR]trans_tnd_actseq: Wrong input actseq:{seq}, in loop {i}, item {new_item} < 0") + return torch.tensor(output).to(DEVICE), total_len + +def sparse_attention(query, key, value, + sparse_indices, scale_value, sparse_block_size = 1, + actual_seq_lengths_query = None, actual_seq_lengths_kv = None, + query_rope = None, key_rope = None, + layout_query = 'BSND', layout_kv = 'BSND', + sparse_mode = 0, block_table = None): + return _attention.apply(query, key, value, + sparse_indices, scale_value, sparse_block_size, + actual_seq_lengths_query, actual_seq_lengths_kv, + query_rope, key_rope, + layout_query, layout_kv, + sparse_mode, block_table) + +def test_op(T, B, KV_S, Q_N, KV_N, D, D_rope, + sparse_size, scale_value, + sparse_block_size, sparse_mode, block_size, act_kv_s): + assert sparse_size <= KV_S + assert KV_N == 1 + assert sparse_mode == 0 or 3 + assert sparse_block_size == 1 + assert (B * KV_S) % block_size == 0 + assert D == 512 + assert D_rope == 0 or 64 + print("*batch_size=",B) + qkv_dtype = torch.float16 + #sparse_size = KV_S + query = torch.empty((T, Q_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + key = torch.empty((B * KV_S // block_size, block_size, KV_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + value = key.clone() + + act_q_s = T // B # step + # rand_vals = torch.rand(T, KV_N, act_kv_s, device=DEVICE) + # _, indices = torch.topk(rand_vals, sparse_size, dim=-1) #sparse_indices不重复 + # sparse_indices = indices.to(torch.int32) + sparse_indices = torch.arange(sparse_size, device=DEVICE, dtype=torch.int32).view(1, 1, -1).expand(T, KV_N, -1) + sparse_indices = sparse_indices.to(torch.int32) + # print("sparse_indices=", sparse_indices) + actual_seq_lengths_query = torch.arange(1, B + 1, dtype=torch.int32, device=DEVICE) + # actual_seq_lengths_query = torch.tensor([1]).reshape(B).to(torch.int32).to(DEVICE) + actual_seq_lengths_kv = torch.tensor([act_kv_s] * B, dtype=torch.int32, device=DEVICE) + print(actual_seq_lengths_kv) + block_table = torch.tensor([range(B * KV_S // block_size)], dtype=torch.int32, device=DEVICE).reshape(B, -1) + + if D_rope == 0: + query_rope = None + key_rope = None + else: + query_rope = torch.empty((T, Q_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + key_rope = torch.empty((B * KV_S // block_size, block_size, KV_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + + print("q.shape=",query.shape) + print("k.shape=",key.shape) + print("v.shape=",value.shape) + print("sparse_indices.shape=",sparse_indices.shape) + print("act_seq_query=",actual_seq_lengths_query) + print("act_seq_kv=", actual_seq_lengths_kv) + + + triton_out = sparse_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv= 'PA_BSND', + sparse_mode = sparse_mode, + block_table= block_table, + ) + npu_out = torch_npu.npu_sparse_flash_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv = 'PA_BSND', + sparse_mode = sparse_mode, + block_table = block_table, + # attention_mode = 2, + ) + triton_out = triton_out.to(npu_out.dtype) + torch.testing.assert_close(triton_out, npu_out, rtol=1e-2, atol=1e-2, equal_nan=True) + print("[PASSED]") + + # benchmarking + triton_time = do_bench_npu(lambda:sparse_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv= 'PA_BSND', + sparse_mode = sparse_mode, + block_table = block_table, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Triton SFA] Time: {triton_time:.4f} us") + + npu_time = do_bench_npu(lambda:torch_npu.npu_sparse_flash_attention( + query = query, + key = key, + value = value, + sparse_indices = sparse_indices, + scale_value = scale_value, + sparse_block_size = sparse_block_size, + actual_seq_lengths_query = actual_seq_lengths_query, + actual_seq_lengths_kv = actual_seq_lengths_kv, + query_rope = query_rope, + key_rope = key_rope, + layout_query = 'TND', + layout_kv = 'PA_BSND', + sparse_mode = sparse_mode, + block_table = block_table, + # attention_mode = 2, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Torch-NPU SFA] Time: {npu_time:.4f} us") + +if __name__ == "__main__": + print(torch_npu.__version__) + print("Test Real Case in DS-v3.2-Exp") + print(f"time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + i = 1 + print(f"====================第{i}次测试=================") + test_op(T = 1, B = 1, KV_S = 2560, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T = 4, B = 4, KV_S = 6400, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T = 8, B = 8, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T = 16, B = 16, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, + sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, + block_size = 128, act_kv_s = 2560) diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 18e17b6fa..2747e689c 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -968,8 +968,8 @@ def visit_For(self, node): warp_specialize = False disable_licm = False bind_sub_block = None - dsa = importlib.import_module("..experimental.tle.dsa", package=__package__) - if IteratorClass in [language.range, extension.parallel, dsa.pipeline, dsa.parallel]: + tle = importlib.import_module("triton.experimental.tle", package=__package__) + if IteratorClass in [language.range, extension.parallel, tle.dsa.pipeline, tle.dsa.parallel]: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -1069,8 +1069,8 @@ def visit_For(self, node): if disable_licm: for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) - dsa = importlib.import_module("..experimental.tle.dsa", package=__package__) - if (IteratorClass is extension.parallel or IteratorClass is dsa.parallel): + tle = importlib.import_module("triton.experimental.tle", package=__package__) + if (IteratorClass is extension.parallel or IteratorClass is tle.dsa.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) self.scf_stack.append(node) @@ -1198,7 +1198,7 @@ def visit_Call(self, node): try: if fn.__name__ == 'copy': # extract tle hints from the generator to identify if node in the tle hints scope - tle = importlib.import_module("..experimental.tle", package=__package__) + tle = importlib.import_module("triton.experimental.tle", package=__package__) top_hints = tle.extract_tle_hints_scope(self) # Only apply to some builtins; currently, 'copy' is relevant. From 97bce26e781e0148b1b2d6301e9b1dcc3c1e2079 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Mon, 9 Mar 2026 09:32:57 +0000 Subject: [PATCH 12/13] [FIX]: fix copyright declaration in tle --- .../experimental/tle/language/dsa/__init__.py | 2 +- .../tle/language/dsa/ascend/__init__.py | 2 +- .../tle/language/dsa/ascend/core.py | 2 +- .../tle/language/dsa/ascend/semantic.py | 2 +- .../experimental/tle/language/dsa/core.py | 2 +- .../experimental/tle/language/dsa/semantic.py | 2 +- .../experimental/tle/language/dsa/types.py | 2 +- python/tutorials/tle/sfa_tle_v1.py | 910 ------------------ 8 files changed, 7 insertions(+), 917 deletions(-) delete mode 100644 python/tutorials/tle/sfa_tle_v1.py diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py index 0e8c7b0ee..bc2e3aead 100644 --- a/python/triton/experimental/tle/language/dsa/__init__.py +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from .core import ( alloc, diff --git a/python/triton/experimental/tle/language/dsa/ascend/__init__.py b/python/triton/experimental/tle/language/dsa/ascend/__init__.py index e3348e629..3c7c09031 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/__init__.py +++ b/python/triton/experimental/tle/language/dsa/ascend/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from .core import ( UB, diff --git a/python/triton/experimental/tle/language/dsa/ascend/core.py b/python/triton/experimental/tle/language/dsa/ascend/core.py index e1e70dd54..14f215f16 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/core.py +++ b/python/triton/experimental/tle/language/dsa/ascend/core.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from triton.language.extra.cann.extension.core import ascend_address_space diff --git a/python/triton/experimental/tle/language/dsa/ascend/semantic.py b/python/triton/experimental/tle/language/dsa/ascend/semantic.py index 36c08a4d7..a6b9487a8 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/semantic.py +++ b/python/triton/experimental/tle/language/dsa/ascend/semantic.py @@ -1 +1 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. \ No newline at end of file +# Copyright 2026- Xcoresigma Technology Co., Ltd \ No newline at end of file diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py index ebc2a5ff2..fa963263e 100644 --- a/python/triton/experimental/tle/language/dsa/core.py +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd import triton.language.core as tl from triton.language import semantic as tl_semantic diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py index 47f85f42f..b295c7b97 100644 --- a/python/triton/experimental/tle/language/dsa/semantic.py +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from typing import List, Optional, Union, Tuple from triton.language import core as tl diff --git a/python/triton/experimental/tle/language/dsa/types.py b/python/triton/experimental/tle/language/dsa/types.py index 7c2c11a53..7ece8d035 100644 --- a/python/triton/experimental/tle/language/dsa/types.py +++ b/python/triton/experimental/tle/language/dsa/types.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +# Copyright 2026- Xcoresigma Technology Co., Ltd from triton._C.libtriton import ir diff --git a/python/tutorials/tle/sfa_tle_v1.py b/python/tutorials/tle/sfa_tle_v1.py deleted file mode 100644 index f59f6c54d..000000000 --- a/python/tutorials/tle/sfa_tle_v1.py +++ /dev/null @@ -1,910 +0,0 @@ -import pytest -import torch -import torch_npu -import triton -import triton.language as tl -import numpy as np -from datetime import datetime -from triton.backends.ascend.testing import do_bench_npu -import triton.experimental.tle as tle -# import random - -np.random.seed(21) -DEVICE = "npu" -DEVICE_ID = 0 -torch.manual_seed(20) -torch_npu.npu.set_device(int(DEVICE_ID)) -torch.set_printoptions(sci_mode=False, precision=4, linewidth=300) - -ascend_aiv_core_nums = triton.language.constexpr(24) - -# ===== Fused PA + Rope Concat + BNSD + Gather Kernel ===== -@triton.jit -def fused_pa_rope_to_sparse_kernel( - k_pa_ptr, k_rope_pa_ptr, v_pa_ptr, # PA_BSND input [block_num, block_size, n, d] - block_table_ptr, # block_table [B, max_blocks] - sparse_indices_ptr, # sparse_indices [B, N, TOPK] - k_sparse_out_ptr, v_sparse_out_ptr, # BNSD output [B, N, TOPK, d] - stride_k_pa_bn, stride_k_pa_bs, stride_k_pa_n, stride_k_pa_d, # K PA strides - stride_k_rope_pa_bn, stride_k_rope_pa_bs, stride_k_rope_pa_n, stride_k_rope_pa_d, # K_rope PA strides - stride_v_pa_bn, stride_v_pa_bs, stride_v_pa_n, stride_v_pa_d, # V PA strides - stride_bt_b, stride_bt_blk, # block_table strides - stride_si_b, stride_si_n, stride_si_topk, # sparse_indices strides - stride_out_b, stride_out_n, stride_out_topk, stride_out_d, # output strides - stride_v_b, stride_v_n, stride_v_topk, stride_v_d, - BLOCK_DK: tl.constexpr, - BLOCK_DV: tl.constexpr, - BLOCK_DK_ROPE: tl.constexpr, # 0 if no rope - TOPK: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - B: tl.constexpr, -): - """ - Fused kernel: PA_BSND + Rope Concat -> BNSD Sparse - Input: K/V in PA_BSND format, K_rope in PA_BSND format - Output: K/V_sparse in BNSD format - """ - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - - # Process (b, n, topk) combinations - for b_idx in range(B): - b = b_idx # sparse_indices is [B, N, TOPK], assume B=1 for now - for idx in range(pid, TOPK, num_programs): - # Get batch and sparse index from sparse_indices - n = 0 # KV_N = 1 - - # Load sparse index - sparse_idx = tl.load(sparse_indices_ptr + b * stride_si_b + n * stride_si_n + idx * stride_si_topk) - - # Map sparse_idx to PA_BSND position - block_id = sparse_idx // BLOCK_SIZE # Which block - bs_offset = sparse_idx % BLOCK_SIZE # Offset within block - - # Get actual block ID from block_table - actual_block_id = tl.load(block_table_ptr + b * stride_bt_b + block_id * stride_bt_blk) - - # Compute PA_BSND offset for K - k_pa_offset = (actual_block_id * stride_k_pa_bn + - bs_offset * stride_k_pa_bs + - n * stride_k_pa_n) - - # Compute PA_BSND offset for K_rope - k_rope_pa_offset = (actual_block_id * stride_k_rope_pa_bn + - bs_offset * stride_k_rope_pa_bs + - n * stride_k_rope_pa_n) - - # Compute PA_BSND offset for V - v_pa_offset = (actual_block_id * stride_v_pa_bn + - bs_offset * stride_v_pa_bs + - n * stride_v_pa_n) - # Load K vector (no rope part) - k_vec = tl.load( - k_pa_ptr + k_pa_offset + - tl.arange(0, BLOCK_DK) * stride_k_pa_d - ) - - # Load V vector - v_vec = tl.load( - v_pa_ptr + v_pa_offset + - tl.arange(0, BLOCK_DV) * stride_v_pa_d - ) - # Output to BNSD format: [B, N, TOPK, D] - out_offset = b * stride_out_b + n * stride_out_n + idx * stride_out_topk - out_offset_v = b* stride_v_b + n *stride_v_n + idx*stride_v_topk - - if BLOCK_DK_ROPE > 0: - # Load K_rope vector - full_k = tl.full((BLOCK_DK + BLOCK_DK_ROPE,), 0.0, dtype=tl.float16) - k_rope_vec = tl.load( - k_rope_pa_ptr + k_rope_pa_offset + - tl.arange(0, BLOCK_DK_ROPE) * stride_k_rope_pa_d - ) - full_k = tle.dsa.insert_slice(full_k, k_vec, offsets=(0,), sizes=(BLOCK_DK,), strides=(1,)) - full_k = tle.dsa.insert_slice(full_k, k_rope_vec, offsets=(BLOCK_DK,), sizes=(BLOCK_DK_ROPE,), strides=(1,)) - tl.store( - k_sparse_out_ptr + out_offset + - tl.arange(0, BLOCK_DK + BLOCK_DK_ROPE) * stride_out_d, - full_k - ) - else: - # No rope, store K directly - tl.store( - k_sparse_out_ptr + out_offset + - tl.arange(0, BLOCK_DK) * stride_out_d, - k_vec - ) - - # Store V - tl.store( - v_sparse_out_ptr + out_offset_v + - tl.arange(0, BLOCK_DV) * stride_v_d, - v_vec - ) - - -def triton_fused_pa_rope_to_sparse(k_pa, k_rope_pa, v_pa, block_table, sparse_indices, block_size): - """ - Fused PA_BSND + Rope Concat -> BNSD Sparse conversion - - Args: - k_pa: Key in PA_BSND format [block_num, block_size, n, dk] - k_rope_pa: Key rope in PA_BSND format [block_num, block_size, n, d_rope], None if no rope - v_pa: Value in PA_BSND format [block_num, block_size, n, dv] - block_table: Block table [B, max_blocks] - sparse_indices: Sparse indices [B, N, TOPK] - block_size: Block size for PA format - - Returns: - k_sparse: Sparse key in BNSD format [B, N, TOPK, dk+d_rope] - v_sparse: Sparse value in BNSD format [B, N, TOPK, dv] - """ - block_num, _, n, dk = k_pa.shape - B = block_table.shape[0] - TOPK = sparse_indices.size(-1) - N = 1 # KV_N = 1 - _, _, _, dv = v_pa.shape - - has_rope = k_rope_pa is not None - dk_rope = k_rope_pa.shape[-1] if has_rope else 0 - dk_total = dk + dk_rope - - # Output BNSD format [B, N, TOPK, D] - k_sparse = torch.empty((B, N, TOPK, dk_total), dtype=k_pa.dtype, device=DEVICE) - v_sparse = torch.empty((B, N, TOPK, dv), dtype=v_pa.dtype, device=DEVICE) - - # Grid: use 48 programs for parallelism - grid = (min(48, TOPK),) - - # sparse_indices input format: [T, N, TOPK] or [B, N, TOPK] - # No squeeze needed - kernel expects [B, N, TOPK] format - sparse_indices_input = sparse_indices - if sparse_indices.dim() == 2: - # If already 2D [B, TOPK], reshape to [B, 1, TOPK] - sparse_indices_input = sparse_indices.unsqueeze(1) - - # Set k_rope_pa to k_pa if no rope (dummy pointer, won't be accessed) - k_rope_pa_input = k_rope_pa if has_rope else k_pa - fused_pa_rope_to_sparse_kernel[grid]( - k_pa, k_rope_pa_input, v_pa, - block_table, - sparse_indices_input, - k_sparse, v_sparse, - k_pa.stride(0), k_pa.stride(1), k_pa.stride(2), k_pa.stride(3), - k_rope_pa_input.stride(0), k_rope_pa_input.stride(1), k_rope_pa_input.stride(2), k_rope_pa_input.stride(3), - v_pa.stride(0), v_pa.stride(1), v_pa.stride(2), v_pa.stride(3), - block_table.stride(0), block_table.stride(1), - sparse_indices_input.stride(0), sparse_indices_input.stride(1), sparse_indices_input.stride(2), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - BLOCK_DK=dk, - BLOCK_DV=dv, - BLOCK_DK_ROPE=dk_rope, - TOPK=TOPK, - BLOCK_SIZE=block_size, - B = B - ) - - return k_sparse, v_sparse - -@triton.jit -def gather_kv_bnsd_vec_kernel( - k_ptr, v_ptr, ind_ptr, - k_out_ptr, v_out_ptr, - stride_kb, stride_kn, stride_ks, stride_kd, - stride_vb, stride_vn, stride_vs, stride_vd, - stride_ob, stride_on, stride_os, stride_od, - stride_ovb, stride_ovn, stride_ovs, stride_ovd, - BLOCK_DK: tl.constexpr, - BLOCK_DV: tl.constexpr, - TOPK: tl.constexpr, - B: tl.constexpr, -): - end = TOPK // 48 * 48 - for b_idx in range(B): - # 分批处理所有TOPK个索引,每次48个 - for batch_start in range(0, end, 48): - pid_k = tl.program_id(0) + batch_start - - # 读 index - idx = tl.load(ind_ptr + pid_k) - - # 加载 K 向量 [BLOCK_DK] - 直接线性加载 - k_src_off = idx * stride_ks + b_idx * stride_kb - k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) - - # 加载 V 向量 [BLOCK_DV] - 直接线性加载 - v_src_off = idx * stride_vs + b_idx * stride_vb - v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) - - # 写回 K: [B, N, TOPK, Dk] - k_dst_off = pid_k * stride_os + b_idx * stride_ob - tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) - - # 写回 V: [B, N, TOPK, Dv] - v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb - tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) - - # 处理余数部分(end到TOPK) - for batch_start in range(end, TOPK, 48): - pid_k = tl.program_id(0) + batch_start - - # 必须在计算pid_k之后检查边界 - if pid_k < TOPK: - idx = tl.load(ind_ptr + pid_k) - - # 加载 K 向量 [BLOCK_DK] - 直接线性加载 - k_src_off = idx * stride_ks + b_idx * stride_kb - k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) - - # 加载 V 向量 [BLOCK_DV] - 直接线性加载 - v_src_off = idx * stride_vs + b_idx * stride_vb - v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) - - # 写回 K: [B, N, TOPK, Dk] - k_dst_off = pid_k * stride_os + b_idx * stride_ob - tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) - - # 写回 V: [B, N, TOPK, Dv] - v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb - tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) - -def triton_gather_kv_bnsd_vec(k, v, indices): - B, N, SK, Dk = k.shape # N=1 - B, N, SK, Dv = v.shape # N=1 - TOPK = indices.size(-1) - - # 输出保持 bnsd [B, N, TOPK, D] - k_sparse = torch.empty((B, N, TOPK, Dk), dtype=k.dtype, device=DEVICE) - v_sparse = torch.empty((B, N, TOPK, Dv), dtype=v.dtype, device=DEVICE) - - grid = (48,) # TOPK 个 program,每个搬 Dk/Dv 元素 - gather_kv_bnsd_vec_kernel[grid]( - k, v, indices.squeeze(0), # [B, N, SK, D] -> [N, SK, D] - k_sparse, v_sparse, - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - BLOCK_DK=Dk, - BLOCK_DV=Dv, - TOPK=TOPK, - B=B, - ) - return k_sparse, v_sparse - -@triton.jit -def _attn_fwd( - Q, K, V, O, scale_value, - stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, - stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, - stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, - stride_ob: tl.constexpr, stride_os: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, - B: tl.constexpr, - Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, - KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, - sparse_mode: tl.constexpr, # 0 or 3 - O_N:tl.constexpr, O_D: tl.constexpr, - actual_seq_lengths_query, - actual_seq_lengths_kv, - blk_size: tl.constexpr, - Q_BLOCK_SIZE: tl.constexpr, - ): - # total b * n tasks - BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE - NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM - pid = tl.program_id(0) - num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) - - #最外层循环,沿b*n切 - for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 - off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 - off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 - off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 - # off_n = 0 - - q_offset = off_b * stride_qb + off_s * stride_qs - o_offset = off_b * stride_ob + off_s * stride_os - k_offset = off_b * stride_kb # KV_N = 1 - v_offset = off_b * stride_vb - - cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) - - for i in range(cur_act_s_q): - cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) - logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) - acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] - - # load q - q_block_ptr = tl.make_block_ptr(base = Q + q_offset, - shape = (Q_N, Q_D), - strides = (stride_qn, stride_qd), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, Q_D), - order = (1, 0)) - q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] - k_block_ptr = tl.make_block_ptr(base = K + k_offset, - shape = (KV_S, K_D), - strides = (stride_ks, stride_kd), - offsets = (0, 0), - block_shape = (blk_size, K_D), - order = (1, 0),) - v_block_ptr = tl.make_block_ptr(base = V + v_offset, - shape = (KV_S, V_D), - strides = (stride_vs, stride_vd), - offsets = (0, 0), - block_shape = (blk_size, V_D), - order = (1, 0)) - - for k_idx in range(KV_S // blk_size): - # load k - k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) - - # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] - qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] - # online softmax update - # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. - block_max = tl.max(qk, axis=1) # [q_block_size] - # align shapes to (q_block_size, 1) for broadcasting - # block_max = block_max[:, None] # [q_block_size, 1] - new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] - coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] - p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] - # logsum per row - logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] - - # update accumulator: compute per-row pv by summing over block dim - v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] - pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] - acc = acc*coeff[:,None] + pv # [q_block_size, V_D] - cur_max = new_max - - k_block_ptr = k_block_ptr.advance((blk_size, 0)) - v_block_ptr = v_block_ptr.advance((blk_size, 0)) - - o_block_ptr = tl.make_block_ptr(base = O + o_offset, - shape = (O_N, O_D), - strides = (stride_on, stride_od), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, O_D), - order = (1,0)) - # final normalize - acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] - tl.store(o_block_ptr, acc) - - - -@triton.jit -def _attn_fwd_fused_bsnd_to_tnd( - Q, K, V, O, scale_value, - stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, - stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, - stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, - stride_ot: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, - B: tl.constexpr, - Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, - KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, - sparse_mode: tl.constexpr, # 0 or 3 - O_N:tl.constexpr, O_D: tl.constexpr, - actual_seq_lengths_query, - actual_seq_lengths_kv, - blk_size: tl.constexpr, - Q_BLOCK_SIZE: tl.constexpr, - ): - # total b * n tasks - BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE - NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM - pid = tl.program_id(0) - num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) - - #最外层循环,沿b*n切 - for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 - off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 - off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 - off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 - - q_offset = off_b * stride_qb + off_s * stride_qs - o_offset = off_b * stride_ot - k_offset = off_b * stride_kb # KV_N = 1 - v_offset = off_b * stride_vb - - cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) - - for i in range(cur_act_s_q): - cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) - logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) - acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] - - # load q - q_block_ptr = tl.make_block_ptr(base = Q + q_offset, - shape = (Q_N, Q_D), - strides = (stride_qn, stride_qd), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, Q_D), - order = (1, 0)) - q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] - k_block_ptr = tl.make_block_ptr(base = K + k_offset, - shape = (KV_S, K_D), - strides = (stride_ks, stride_kd), - offsets = (0, 0), - block_shape = (blk_size, K_D), - order = (1, 0),) - v_block_ptr = tl.make_block_ptr(base = V + v_offset, - shape = (KV_S, V_D), - strides = (stride_vs, stride_vd), - offsets = (0, 0), - block_shape = (blk_size, V_D), - order = (1, 0)) - - for k_idx in range(KV_S // blk_size): - # load k - k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) - - # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] - qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] - # online softmax update - # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. - block_max = tl.max(qk, axis=1) # [q_block_size] - # align shapes to (q_block_size, 1) for broadcasting - # block_max = block_max[:, None] # [q_block_size, 1] - new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] - coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] - p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] - # logsum per row - logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] - - # update accumulator: compute per-row pv by summing over block dim - v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] - pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] - acc = acc*coeff[:,None] + pv # [q_block_size, V_D] - cur_max = new_max - - k_block_ptr = k_block_ptr.advance((blk_size, 0)) - v_block_ptr = v_block_ptr.advance((blk_size, 0)) - - o_block_ptr = tl.make_block_ptr(base = O + o_offset, - shape = (O_N, O_D), - strides = (stride_on, stride_od), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, O_D), - order = (1,0)) - # final normalize - acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] - tl.store(o_block_ptr, acc) - - - - -class _attention(torch.autograd.Function): - @staticmethod - def forward( - ctx, - query, - key, - value, - sparse_indices, - scale_value, - sparse_block_size = 1, - actual_seq_lengths_query = None, - actual_seq_lengths_kv = None, - query_rope = None, - key_rope = None, - layout_query = 'BSND', - layout_kv = 'BSND', - sparse_mode = 0, - block_table = None): - # Save original sparse_indices for PA_BSND case - sparse_indices_orig = sparse_indices.clone() - total_len = 0 - # Handle query layout transformation (TND -> BSND) - if layout_query == 'TND': - actual_seq_lengths_query, total_len = trans_tnd_actseq(actual_seq_lengths_query) - # ✅ 融合版本:一次 kernel 调用处理所有 tensor + concat - query, sparse_indices = trans_tnd_to_bsnd_fused( - query, query_rope, sparse_indices, query.shape, actual_seq_lengths_query - ) - else: - if query_rope != None: - query = torch.cat([query, query_rope], dim = -1) - - # Handle KV layout and gather sparse K/V - if layout_kv == 'PA_BSND': - # Fused PA -> BNSD + rope concat + sparse gather - block_size = key.shape[1] # Get block_size from PA shape - # Use original sparse_indices [T, N, TOPK] for fused kernel - k_sparse, v_sparse = triton_fused_pa_rope_to_sparse( - key, key_rope, value, block_table, sparse_indices_orig, block_size - ) - # sparse_indices is already in BSND, needs permute to BNSD for downstream use - sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() - else: - # Original path for non-PA layouts - if key_rope != None: - key = torch.cat([key, key_rope], dim = -1) - key_bnsd = key.permute(0, 2, 1, 3).contiguous() - value_bnsd = value.permute(0, 2, 1, 3).contiguous() - sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() - - k_sparse, v_sparse = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) - - k_sparse = k_sparse.contiguous() - v_sparse = v_sparse.contiguous() - enable_check_kv_sparse = 0 - if enable_check_kv_sparse: - key = pa_to_bsnd(key, block_table, actual_seq_lengths_kv) - key_rope = pa_to_bsnd(key_rope, block_table, actual_seq_lengths_kv) - value = pa_to_bsnd(value, block_table, actual_seq_lengths_kv) - if key_rope != None: - key = torch.cat([key, key_rope], dim = -1) - key_bnsd = key.permute(0, 2, 1, 3).contiguous() - value_bnsd = value.permute(0, 2, 1, 3).contiguous() - k_sparse_ref, v_sparse_ref = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) - print(f"k_sparse={k_sparse}") - print(f"k_sparse_ref={k_sparse_ref}") - print(f"v_sparse={v_sparse}") - print(f"v_sparse_ref={v_sparse_ref}") - assert torch.allclose(k_sparse, k_sparse_ref, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" - assert torch.allclose(v_sparse, v_sparse_ref, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" - - # expected_k = key_bnsd[:, :, :sparse_size, :].contiguous() - # assert torch.allclose(k_sparse, expected_k, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" - # expected_v = value_bnsd[:, :, :sparse_size, :].contiguous() - # assert torch.allclose(v_sparse, expected_v, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" - num_cores = ascend_aiv_core_nums - sparse_size = sparse_indices_bnsd.shape[-1] # 4 - out_shape_bsnd = list(query.shape) - if query_rope != None: - out_shape_bsnd[-1] = out_shape_bsnd[-1] - query_rope.shape[-1] - B, Q_S, Q_N, Q_D = query.shape - _, _, KV_S, K_D = k_sparse.shape - - if layout_query == 'TND': - # t = B*act_q_s - output = torch.empty((total_len, out_shape_bsnd[2], out_shape_bsnd[3]), device=query.device, dtype=torch.float32) - _attn_fwd_fused_bsnd_to_tnd[(num_cores,)]( - query, k_sparse, v_sparse, output, scale_value, - query.stride(0), query.stride(1), query.stride(2), query.stride(3), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - output.stride(0), output.stride(1), output.stride(2), - B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, - KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], - sparse_mode = sparse_mode, O_N = output.shape[1], O_D = output.shape[2], - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - blk_size=128, Q_BLOCK_SIZE=16,multibuffer=False - ) - - else: - output = torch.empty(out_shape_bsnd, device=query.device, dtype=torch.float32) - _attn_fwd[(num_cores,)]( - query, k_sparse, v_sparse, output, scale_value, - query.stride(0), query.stride(1), query.stride(2), query.stride(3), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - output.stride(0), output.stride(1), output.stride(2), output.stride(3), - B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, - KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], - sparse_mode = sparse_mode, O_N = output.shape[2], O_D = output.shape[3], - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - blk_size=128, Q_BLOCK_SIZE=16 - ) - output = output.permute(0, 2, 1, 3).contiguous() - - ctx.save_for_backward(query, k_sparse, v_sparse, output) - ctx.scale_value = scale_value - return output - -def pa_to_bsnd(pa_in, block_table, actual_seq_lengths): - block_num, block_size, n, d = pa_in.shape - b = len(actual_seq_lengths) - output = torch.empty((b, block_num * block_size // b, 1, d), dtype = pa_in.dtype).to(DEVICE) - for i in range(b): - for j in range(20): - output[i, j * block_size: (j + 1) * block_size, 0, :] = \ - pa_in[block_table[i][j], :, 0, :].reshape(block_size, d) - return output - - -@triton.jit -def trans_tnd_to_bsnd_fused_kernel( - query_ptr, query_rope_ptr, sparse_ptr, - query_out_ptr, sparse_out_ptr, # query_out 已经拼接了 rope - act_s, - stride_q_t, stride_q_tn, stride_q_td, - stride_qr_t, stride_qr_tn, stride_qr_td, - stride_s_t, stride_s_tn, stride_s_td, - stride_qob, stride_qobs, stride_qon, stride_qod, # query_out strides - stride_sb, stride_sbs, stride_sbn, stride_sbd, - B: tl.constexpr, - N: tl.constexpr, - D_QUERY: tl.constexpr, - D_ROPE: tl.constexpr, - D_SPARSE: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_QUERY: tl.constexpr, - BLOCK_D_ROPE: tl.constexpr, - BLOCK_D_SPARSE: tl.constexpr, -): - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - - # 计算 head 的总块数 - num_head_blocks = (N + BLOCK_N - 1) // BLOCK_N - t_idx = tl.full((), 0, dtype=tl.int64) # TODO: 需要正确的 token 映射 - # 每个 pid 负责处理特定的 (batch, head_block) 组合 - for tn_id in range(B): - # sparse_indices 是单头的,只在第一个 head_block 处理一次 - if pid == 0: - sparse_block_ptr = tl.make_block_ptr(base = sparse_ptr + t_idx * stride_s_t, - shape = (1, D_SPARSE), - strides = (stride_s_tn, stride_s_td), - offsets = (0, 0), - block_shape = (1, D_SPARSE), - order = (1, 0)) - sparse = tl.load(sparse_block_ptr) - - sparse_out_block_ptr = tl.make_block_ptr(base = sparse_out_ptr + t_idx * stride_sb, - shape = (1, D_SPARSE), - strides = ( stride_sbn, stride_sbd), - offsets = (0, 0), - block_shape = (1, D_SPARSE), - order = (1, 0)) - tl.store(sparse_out_block_ptr, sparse) - - # query 和 query_rope 是多头的,需要在 head 维度上分块处理 - for head_block_id in range(pid, num_head_blocks, num_programs): - n_offset = head_block_id * BLOCK_N - - # Load q and q_ro - q_block_ptr = tl.make_block_ptr(base = query_ptr + t_idx * stride_q_t, - shape = (N, D_QUERY), - strides = (stride_q_tn, stride_q_td), - offsets = (n_offset, 0), - block_shape = (BLOCK_N, D_QUERY), - order = (1, 0)) - q_ro_block_ptr = tl.make_block_ptr(base = query_rope_ptr + t_idx * stride_qr_t, - shape = (N, D_ROPE), - strides = (stride_qr_tn, stride_qr_td), - offsets = (n_offset, 0), - block_shape = (BLOCK_N, D_ROPE), - order = (1, 0)) - q = tl.load(q_block_ptr) - q_ro = tl.load(q_ro_block_ptr) - - # Combine query and query_rope using insert_slice, then store in one operation - full_q = tl.zeros((BLOCK_N, D_QUERY + D_ROPE), dtype=query_out_ptr.dtype.element_ty) - full_q = tle.dsa.insert_slice(full_q, q, offsets=(0, 0), sizes=(BLOCK_N, D_QUERY), strides=(1, 1)) - full_q = tle.dsa.insert_slice(full_q, q_ro, offsets=(0, D_QUERY), sizes=(BLOCK_N, D_ROPE), strides=(1, 1)) - - q_out_block_ptr = tl.make_block_ptr(base = query_out_ptr + t_idx * stride_qob, - shape = (N, D_QUERY + D_ROPE), - strides = (stride_qon, stride_qod), - offsets = (n_offset, 0), - block_shape = (BLOCK_N, D_QUERY + D_ROPE), - order = (1, 0)) - tl.store(q_out_block_ptr, full_q) - t_idx = t_idx + tl.load(act_s + tn_id) - - -def trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, shape, act_seq, grid=(16,)): - """ - 融合版本的 TND -> BSND 转换(包含 concat) - 一次性处理 query, query_rope, sparse_indices,并拼接 query + query_rope - """ - t, n, d_query = shape - b = len(act_seq) - s = max(act_seq) - - # 获取各个 tensor 的维度 - d_rope = query_rope.shape[2] if query_rope is not None else 0 - d_sparse = sparse_indices.shape[2] - d_query_out = d_query + d_rope # 拼接后的维度 - - # 分配输出(query_out 已经包含 rope) - query_out = torch.empty((b, s, n, d_query_out), dtype=query.dtype, device=query.device) - sparse_out = torch.empty((b, s, 1, d_sparse), dtype=sparse_indices.dtype, device=sparse_indices.device) - assert sparse_indices.shape[1] == 1, "sparse_indices second dim must be 1 when MLA" - # 启动 fused kernel - # 使用较小的 BLOCK_N 避免内存溢出 - block_n = min(16, n) - # 计算需要的核心数:使用多核心并行处理不同的头 - num_head_blocks = (n + block_n - 1) // block_n - num_programs = min(ascend_aiv_core_nums, num_head_blocks) # 最多使用24个核心 - - trans_tnd_to_bsnd_fused_kernel[num_programs,]( - query, query_rope, sparse_indices, - query_out, sparse_out, - act_seq, - query.stride(0), query.stride(1), query.stride(2), - query_rope.stride(0), query_rope.stride(1), query_rope.stride(2), - sparse_indices.stride(0), sparse_indices.stride(1), sparse_indices.stride(2), - query_out.stride(0), query_out.stride(1), query_out.stride(2), query_out.stride(3), - sparse_out.stride(0), sparse_out.stride(1), sparse_out.stride(2), sparse_out.stride(3), - B=b, - N=n, - D_QUERY=d_query, - D_ROPE=d_rope, - D_SPARSE=d_sparse, - BLOCK_N=block_n, - BLOCK_D_QUERY=d_query, - BLOCK_D_ROPE=d_rope, - BLOCK_D_SPARSE=d_sparse, - ) - return query_out, sparse_out - - -def trans_tnd_actseq(seq): - if isinstance(seq, torch.Tensor): - seq = seq.cpu().tolist() - list_len = len(seq) - output = [] - output = [seq[0]] - total_len = seq[0] - for i in range(list_len - 1): - new_item = seq[i + 1] - seq[i] - if new_item >= 0: - output.append(new_item) - total_len += new_item - else: - print(f"[ERROR]trans_tnd_actseq: Wrong input actseq:{seq}, in loop {i}, item {new_item} < 0") - return torch.tensor(output).to(DEVICE), total_len - -def sparse_attention(query, key, value, - sparse_indices, scale_value, sparse_block_size = 1, - actual_seq_lengths_query = None, actual_seq_lengths_kv = None, - query_rope = None, key_rope = None, - layout_query = 'BSND', layout_kv = 'BSND', - sparse_mode = 0, block_table = None): - return _attention.apply(query, key, value, - sparse_indices, scale_value, sparse_block_size, - actual_seq_lengths_query, actual_seq_lengths_kv, - query_rope, key_rope, - layout_query, layout_kv, - sparse_mode, block_table) - -def test_op(T, B, KV_S, Q_N, KV_N, D, D_rope, - sparse_size, scale_value, - sparse_block_size, sparse_mode, block_size, act_kv_s): - assert sparse_size <= KV_S - assert KV_N == 1 - assert sparse_mode == 0 or 3 - assert sparse_block_size == 1 - assert (B * KV_S) % block_size == 0 - assert D == 512 - assert D_rope == 0 or 64 - print("*batch_size=",B) - qkv_dtype = torch.float16 - #sparse_size = KV_S - query = torch.empty((T, Q_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - key = torch.empty((B * KV_S // block_size, block_size, KV_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - value = key.clone() - - act_q_s = T // B # step - # rand_vals = torch.rand(T, KV_N, act_kv_s, device=DEVICE) - # _, indices = torch.topk(rand_vals, sparse_size, dim=-1) #sparse_indices不重复 - # sparse_indices = indices.to(torch.int32) - sparse_indices = torch.arange(sparse_size, device=DEVICE, dtype=torch.int32).view(1, 1, -1).expand(T, KV_N, -1) - sparse_indices = sparse_indices.to(torch.int32) - # print("sparse_indices=", sparse_indices) - actual_seq_lengths_query = torch.arange(1, B + 1, dtype=torch.int32, device=DEVICE) - # actual_seq_lengths_query = torch.tensor([1]).reshape(B).to(torch.int32).to(DEVICE) - actual_seq_lengths_kv = torch.tensor([act_kv_s] * B, dtype=torch.int32, device=DEVICE) - print(actual_seq_lengths_kv) - block_table = torch.tensor([range(B * KV_S // block_size)], dtype=torch.int32, device=DEVICE).reshape(B, -1) - - if D_rope == 0: - query_rope = None - key_rope = None - else: - query_rope = torch.empty((T, Q_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - key_rope = torch.empty((B * KV_S // block_size, block_size, KV_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - - print("q.shape=",query.shape) - print("k.shape=",key.shape) - print("v.shape=",value.shape) - print("sparse_indices.shape=",sparse_indices.shape) - print("act_seq_query=",actual_seq_lengths_query) - print("act_seq_kv=", actual_seq_lengths_kv) - - - triton_out = sparse_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv= 'PA_BSND', - sparse_mode = sparse_mode, - block_table= block_table, - ) - npu_out = torch_npu.npu_sparse_flash_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv = 'PA_BSND', - sparse_mode = sparse_mode, - block_table = block_table, - # attention_mode = 2, - ) - triton_out = triton_out.to(npu_out.dtype) - torch.testing.assert_close(triton_out, npu_out, rtol=1e-2, atol=1e-2, equal_nan=True) - print("[PASSED]") - - # benchmarking - triton_time = do_bench_npu(lambda:sparse_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv= 'PA_BSND', - sparse_mode = sparse_mode, - block_table = block_table, - ), clear_l2_cache=True, collect_prof=False) - print(f"[Triton SFA] Time: {triton_time:.4f} us") - - npu_time = do_bench_npu(lambda:torch_npu.npu_sparse_flash_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv = 'PA_BSND', - sparse_mode = sparse_mode, - block_table = block_table, - # attention_mode = 2, - ), clear_l2_cache=True, collect_prof=False) - print(f"[Torch-NPU SFA] Time: {npu_time:.4f} us") - -if __name__ == "__main__": - print(torch_npu.__version__) - print("Test Real Case in DS-v3.2-Exp") - print(f"time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - i = 1 - print(f"====================第{i}次测试=================") - test_op(T = 1, B = 1, KV_S = 2560, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) - i += 1 - print(f"====================第{i}次测试=================") - test_op(T = 4, B = 4, KV_S = 6400, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) - i += 1 - print(f"====================第{i}次测试=================") - test_op(T = 8, B = 8, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) - i += 1 - print(f"====================第{i}次测试=================") - test_op(T = 16, B = 16, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) From d0b0e2dbbf4662f03d227839716f9428198e6ca1 Mon Sep 17 00:00:00 2001 From: Eugene Wu Date: Mon, 9 Mar 2026 09:50:41 +0000 Subject: [PATCH 13/13] [FIX](tle): remove redundant code and fix code format --- .github/workflows/ascend-build-and-test.yml | 8 + CMakeLists.txt | 7 +- include/triton/Dialect/Triton/IR/TritonOps.td | 1 + python/setup.py | 2 +- python/setup_tools/utils/ascend.py | 5 + python/test/tle/test_bind_buffer.py | 5 +- python/test/tle/test_tle_with_hints.py | 13 +- python/test/tle/test_vec_add.py | 13 +- python/test/tle/test_vec_add_2d.py | 26 +- python/test/tle/test_vec_add_mix.py | 10 +- python/test/tle/test_vec_mathOps.py | 40 +- python/triton/experimental/__init__.py | 2 +- python/triton/experimental/tle/__init__.py | 23 +- .../experimental/tle/language/__init__.py | 4 + .../experimental/tle/language/builder.py | 3 +- .../experimental/tle/language/dsa/README.md | 4 +- .../experimental/tle/language/dsa/__init__.py | 1 + .../tle/language/dsa/ascend/__init__.py | 4 +- .../tle/language/dsa/ascend/core.py | 2 - .../tle/language/dsa/ascend/semantic.py | 2 +- .../experimental/tle/language/dsa/core.py | 53 +- .../experimental/tle/language/dsa/semantic.py | 68 +- .../experimental/tle/language/dsa/types.py | 4 +- .../tutorials/tle/01-sparse-flash-attn-tle.py | 949 ++++++++++++++++++ python/tutorials/tle/sparse_flash_attn_tle.py | 912 ----------------- third_party/ascend/AscendNPU-IR | 1 - .../triton/Dialect/Triton/IR/TritonOps.td | 1 + third_party/ascend/backend/testing.py | 2 + third_party/tle/dsa/CMakeLists.txt | 4 +- .../tle/dsa/dialect/include/CMakeLists.txt | 2 +- .../Conversion/TleToLinalg/DSACopyConverter.h | 14 +- .../Conversion/TleToLinalg/MathConverter.h | 148 ++- .../tle/dsa/dialect/include/IR/CMakeLists.txt | 2 +- .../tle/dsa/dialect/include/IR/Dialect.h | 4 +- .../tle/dsa/dialect/include/IR/TleDialect.td | 2 +- .../tle/dsa/dialect/include/IR/TleOps.td | 6 +- .../tle/dsa/dialect/lib/CMakeLists.txt | 2 +- .../dsa/dialect/lib/Conversion/CMakeLists.txt | 2 +- .../lib/Conversion/TleToLinalg/CMakeLists.txt | 2 +- .../TleToLinalg/DSACopyConverter.cpp | 34 +- .../Conversion/TleToLinalg/MathConverter.cpp | 34 +- .../tle/dsa/dialect/lib/IR/CMakeLists.txt | 2 +- .../tle/dsa/dialect/lib/IR/Dialect.cpp | 8 +- third_party/tle/dsa/dialect/lib/IR/TleOps.cpp | 3 +- third_party/tle/dsa/tle_ir.cc | 275 ++--- 45 files changed, 1401 insertions(+), 1308 deletions(-) mode change 100755 => 100644 python/test/tle/test_tle_with_hints.py mode change 100755 => 100644 python/test/tle/test_vec_add.py mode change 100755 => 100644 python/test/tle/test_vec_add_2d.py mode change 100755 => 100644 python/test/tle/test_vec_add_mix.py mode change 100755 => 100644 python/test/tle/test_vec_mathOps.py create mode 100644 python/tutorials/tle/01-sparse-flash-attn-tle.py delete mode 100644 python/tutorials/tle/sparse_flash_attn_tle.py delete mode 160000 third_party/ascend/AscendNPU-IR diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index fbee99af4..77831002a 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -104,6 +104,14 @@ jobs: --ignore=test_assume.py \ --ignore=test_index_select.py popd + # flagtree tle test + pushd python/test/tle + python3 test_vec_add.py + python3 test_vec_add_2d.py + python3 test_vec_add_mix.py + python3 test_vec_mathOps.py + python3 test_tle_with_hints.py + popd - name: FlagTree Editable Build And Teston Ascend if: steps.check_files.outputs.only_docs_changed != 'true' diff --git a/CMakeLists.txt b/CMakeLists.txt index 8518361ab..f62e2e14c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -292,8 +292,11 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() # add TLE plugin - list(APPEND TRITON_PLUGIN_NAMES "tle") - add_subdirectory(third_party/tle/dsa) + # just support ascend backend so far + if(FLAGTREE_BACKEND STREQUAL "ascend") + list(APPEND TRITON_PLUGIN_NAMES "tle") + add_subdirectory(third_party/tle/dsa) + endif() get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 0d843b358..283dd9165 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1256,4 +1256,5 @@ def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op< }]; } + #endif // Triton_OPS diff --git a/python/setup.py b/python/setup.py index dcfa04ced..ac90f6a87 100644 --- a/python/setup.py +++ b/python/setup.py @@ -702,7 +702,7 @@ def get_packages(): "triton/backends", "triton/tools", - # for tle + # for flagtree tle "triton/experimental", "triton/experimental/tle", "triton/experimental/tle/language", diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py index fc3d8da2f..957b90588 100644 --- a/python/setup_tools/utils/ascend.py +++ b/python/setup_tools/utils/ascend.py @@ -15,6 +15,7 @@ def get_extra_install_packages(): "triton/extension", "triton/extension/buffer", "triton/extension/buffer/language", + "triton/experimental/tle/language/dsa/ascend", ] @@ -24,6 +25,10 @@ def get_package_dir(): package_dict["triton/extension"] = ascend_ext_base package_dict["triton/extension/buffer"] = f"{ascend_ext_base}/buffer" package_dict["triton/extension/buffer/language"] = f"{ascend_ext_base}/buffer/language" + + # flagtree tle ascend + flagtree_tle_ascend_base = "../python/triton/experimental/tle/language/dsa" + package_dict["triton/experimental/tle/language/dsa/ascend"] = f"{flagtree_tle_ascend_base}/ascend" return package_dict diff --git a/python/test/tle/test_bind_buffer.py b/python/test/tle/test_bind_buffer.py index c1cafdbf1..b8cd1e775 100644 --- a/python/test/tle/test_bind_buffer.py +++ b/python/test/tle/test_bind_buffer.py @@ -7,6 +7,7 @@ from triton._C.libtriton import ir, tle as tle_ir from triton._C.libtriton.ascend import ir as ascend_ir + class Options: num_warps = 4 num_stages = 3 @@ -26,13 +27,15 @@ def compile_kernel(kernel, signature, constants): module = ast_to_ttir(kernel, src, context, Options(), {}, {}) return str(module) + @triton.jit def bind_buffer(): # tle.dsa.ascend.UB is triton.language.extra.extension.cann.core.ascend_address_space.UB buffer1 = tle.dsa.alloc(shape=[32, 32], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) tle.dsa.to_tensor(buffer1, writable=True) + if __name__ == "__main__": print("=" * 60) - mlir = compile_kernel(bind_buffer, {}, {}) + mlir = compile_kernel(bind_buffer, {}, {}) print(mlir) diff --git a/python/test/tle/test_tle_with_hints.py b/python/test/tle/test_tle_with_hints.py old mode 100755 new mode 100644 index 78e980052..2f4ab5974 --- a/python/test/tle/test_tle_with_hints.py +++ b/python/test/tle/test_tle_with_hints.py @@ -1,10 +1,11 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd import torch import triton +import torch_npu # noqa import triton.language as tl -# import triton.language.extra.tle.ascend as tle import triton.experimental.tle as tle + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. @@ -34,6 +35,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tle.dsa.add(a_ub, b_ub, c_ub) tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + def custom_func(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) n_elements = output.numel() @@ -41,22 +43,23 @@ def custom_func(x: torch.Tensor, y: torch.Tensor): add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) return output + def test_add(): torch.manual_seed(0) size = 1024 - x = torch.rand(size, device='npu', dtype=torch.float) - y = torch.rand(size, device='npu', dtype=torch.float) + x = torch.rand(size, dtype=torch.float).npu() + y = torch.rand(size, dtype=torch.float).npu() output_torch = x + y output_triton = custom_func(x, y) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch - output_triton))}') from triton.backends.ascend.testing import do_bench_npu bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False) bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False) - # 保留两位小数输出 print(f"torch time : {bench_torch:.2f}") print(f"triton time: {bench_triton:.2f}") + if __name__ == "__main__": test_add() diff --git a/python/test/tle/test_vec_add.py b/python/test/tle/test_vec_add.py old mode 100755 new mode 100644 index 3a3fbeccf..3d0098b74 --- a/python/test/tle/test_vec_add.py +++ b/python/test/tle/test_vec_add.py @@ -1,10 +1,11 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd import torch +import torch_npu # noqa import triton import triton.language as tl -# import triton.language.extra.tle.ascend as tle import triton.experimental.tle as tle + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. @@ -31,6 +32,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tle.dsa.add(a_ub, b_ub, c_ub) tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + def custom_func(x: torch.Tensor, y: torch.Tensor): output = torch.empty_like(x) n_elements = output.numel() @@ -38,22 +40,23 @@ def custom_func(x: torch.Tensor, y: torch.Tensor): add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) return output + def test_add(): torch.manual_seed(0) size = 1024 - x = torch.rand(size, device='npu', dtype=torch.float) - y = torch.rand(size, device='npu', dtype=torch.float) + x = torch.rand(size, dtype=torch.float).npu() + y = torch.rand(size, dtype=torch.float).npu() output_torch = x + y output_triton = custom_func(x, y) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch - output_triton))}') from triton.backends.ascend.testing import do_bench_npu bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False) bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False) - # 保留两位小数输出 print(f"torch time : {bench_torch:.2f}") print(f"triton time: {bench_triton:.2f}") + if __name__ == "__main__": test_add() diff --git a/python/test/tle/test_vec_add_2d.py b/python/test/tle/test_vec_add_2d.py old mode 100755 new mode 100644 index 6c0b273a4..10ddd4ca4 --- a/python/test/tle/test_vec_add_2d.py +++ b/python/test/tle/test_vec_add_2d.py @@ -1,16 +1,17 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd import torch +import torch_npu # noqa import triton import triton.language as tl import triton.experimental.tle as tle + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. output_ptr, # *Pointer* to output vector. n_elements, # Size of the vector. - n_cols, n_rows, - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + n_cols, n_rows, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. # NOTE: `constexpr` so it can be used as a shape value. ): pid_m = tl.program_id(0) @@ -22,7 +23,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. offs_m = block_start_m + tl.arange(0, BLOCK_SIZE) offs_n = block_start_n + tl.arange(0, BLOCK_SIZE) - # 计算线性地址(row-major) + # get address(row-major) x_ptrs = x_ptr + offs_m[:, None] * n_cols + offs_n[None, :] y_ptrs = y_ptr + offs_m[:, None] * n_cols + offs_n[None, :] out_ptrs = output_ptr + offs_m[:, None] * n_cols + offs_n[None, :] @@ -43,32 +44,35 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tle.dsa.copy(c_ub, out_ptrs, [tail_size_m, tail_size_n]) + def custom_func(x: torch.Tensor, y: torch.Tensor, size: int): output = torch.empty_like(x) n_elements = output.numel() BLOCK_SIZE = 16 # grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE)) - add_kernel[grid](x, y, output, n_elements, size, size-1, BLOCK_SIZE) + add_kernel[grid](x, y, output, n_elements, size, size - 1, BLOCK_SIZE) return output + def test_add(): torch.manual_seed(0) size = 128 - x = torch.rand((size,size-1), device='npu', dtype=torch.float) - y = torch.rand((size,size-1), device='npu', dtype=torch.float) + x = torch.rand((size, size - 1), dtype=torch.float).npu() + y = torch.rand((size, size - 1), dtype=torch.float).npu() output_torch = x + y output_triton = custom_func(x, y, size) - print(f"============X===========") + print("============X===========") print(x) - print(f"============Y===========") + print("============Y===========") print(y) - print(f"============outTorch===========") + print("============outTorch===========") print(output_torch) - print(f"============outTriton===========") + print("============outTriton===========") print(output_triton) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch - output_triton))}') + if __name__ == "__main__": test_add() diff --git a/python/test/tle/test_vec_add_mix.py b/python/test/tle/test_vec_add_mix.py old mode 100755 new mode 100644 index 4c18d94f4..115ab4d84 --- a/python/test/tle/test_vec_add_mix.py +++ b/python/test/tle/test_vec_add_mix.py @@ -1,9 +1,11 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd import torch import triton +import torch_npu # noqa import triton.language as tl import triton.experimental.tle as tle + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. @@ -46,15 +48,17 @@ def custom_func(x: torch.Tensor, y: torch.Tensor): add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) return output + def test_add(): torch.manual_seed(0) size = 1024 - x = torch.rand(size, device='npu', dtype=torch.float) - y = torch.rand(size, device='npu', dtype=torch.float) + x = torch.rand(size, dtype=torch.float).npu() + y = torch.rand(size, dtype=torch.float).npu() output_torch = x output_triton = custom_func(x, y) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch - output_triton))}') + if __name__ == "__main__": test_add() diff --git a/python/test/tle/test_vec_mathOps.py b/python/test/tle/test_vec_mathOps.py old mode 100755 new mode 100644 index 4252370a6..be84bc16b --- a/python/test/tle/test_vec_mathOps.py +++ b/python/test/tle/test_vec_mathOps.py @@ -1,19 +1,22 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd -from typing import Callable, Tuple import torch +import torch_npu # noqa import triton import triton.language as tl import triton.experimental.tle as tle + @triton.jit def run_test( - x_ptr, y_ptr, output_ptr, n_elements, + x_ptr, + y_ptr, + output_ptr, + n_elements, OP_ID: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) @@ -22,17 +25,18 @@ def run_test( tle.dsa.copy(x_ptr + offsets, a_ub, [BLOCK_SIZE]) tle.dsa.copy(y_ptr + offsets, b_ub, [BLOCK_SIZE]) - if OP_ID == 0: # add + if OP_ID == 0: # add tle.dsa.add(a_ub, b_ub, c_ub) - elif OP_ID == 1: # sub + elif OP_ID == 1: # sub tle.dsa.sub(a_ub, b_ub, c_ub) - elif OP_ID == 2: # mul + elif OP_ID == 2: # mul tle.dsa.mul(a_ub, b_ub, c_ub) - elif OP_ID == 3: # div + elif OP_ID == 3: # div tle.dsa.div(a_ub, b_ub, c_ub) tle.dsa.copy(c_ub, output_ptr + offsets, [BLOCK_SIZE]) + OP_REGISTRY = { 'add': (0, torch.add), 'sub': (1, torch.sub), @@ -40,26 +44,31 @@ def run_test( 'div': (3, torch.div), } + def common_test(op_name: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: if op_name not in OP_REGISTRY: raise ValueError(f"Unsupported op: {op_name}") - + op_id, _ = OP_REGISTRY[op_name] output = torch.empty_like(x) n_elements = output.numel() - - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + run_test[grid]( - x, y, output, n_elements, + x, + y, + output, + n_elements, OP_ID=op_id, BLOCK_SIZE=128, ) return output -def test_binary_op(size: int = 1024, dtype=torch.float32, device='npu'): - x = torch.rand(size, device=device, dtype=dtype) - y = torch.rand(size, device=device, dtype=dtype) + +def test_binary_op(size: int = 1024, dtype=torch.float32): + x = torch.rand(size, dtype=dtype).npu() + y = torch.rand(size, dtype=dtype).npu() y = y + 0.1 print(f"Testing {len(OP_REGISTRY)} operators with size={size}, dtype={dtype}") @@ -73,5 +82,6 @@ def test_binary_op(size: int = 1024, dtype=torch.float32, device='npu'): status = "SUCCESS" if max_diff < 1e-5 else "FAIL" print(f"{status} {op_name:8}: max diff = {max_diff:.2e}") + if __name__ == "__main__": test_binary_op(size=1024, dtype=torch.float32) diff --git a/python/triton/experimental/__init__.py b/python/triton/experimental/__init__.py index a6b9487a8..a111d0159 100644 --- a/python/triton/experimental/__init__.py +++ b/python/triton/experimental/__init__.py @@ -1 +1 @@ -# Copyright 2026- Xcoresigma Technology Co., Ltd \ No newline at end of file +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py index 1eeda5be0..1af2373e4 100644 --- a/python/triton/experimental/tle/__init__.py +++ b/python/triton/experimental/tle/__init__.py @@ -14,16 +14,21 @@ raise RuntimeError("tle is not available") triton_compiler = importlib.import_module("triton.compiler", package=__package__) + + def tle_patch_for_triton_compile(): original_compile_fn = triton_compiler.compile + def tle_compile(src, target=None, options=None): # ir.context() will return a new MLIRContext each time, here should keep the same context cur_context = ir.context() tle_ir.load_dialects(cur_context) original_context_fn = ir.context + def patched_context(): return cur_context + ir.context = patched_context try: @@ -32,16 +37,20 @@ def patched_context(): ir.context = original_context_fn return compiled_kernel + return tle_compile + code_generator = importlib.import_module("triton.compiler.code_generator", package=__package__) + class TleCodeGenerator(code_generator.CodeGenerator): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0): - super().__init__(context, prototype, gscope, attributes, constants, function_name, jit_fn, options, - codegen_fns, module_map, module, is_kernel, function_types, noinline, file_name, begin_line) + super().__init__(context, prototype, gscope, attributes, constants, function_name, jit_fn, options, codegen_fns, + module_map, module, is_kernel, function_types, noinline, file_name, begin_line) self.tle_builder = tle_ir.tle_builder(context) self.tle_builder.set_loc(file_name, begin_line, 0) @@ -62,7 +71,8 @@ def visit_With(self, node): if isinstance(context.func, ast.Attribute) and context.func.attr == "hint": for kw in context.keywords: if not isinstance(kw.value, ast.Constant): - raise self._unsupported(node, "keyword arguments to hint() are only supported for constant values") + raise self._unsupported(node, + "keyword arguments to hint() are only supported for constant values") hints[kw.arg] = kw.value.value # append hints to with_hints anyway, to indicate that we're in the with scope @@ -73,6 +83,7 @@ def visit_With(self, node): # pop hints to indicate that we're out of the with scope self.with_hints.pop() + def extract_tle_hints_scope(generator: TleCodeGenerator): """ with tle.hints(inter_no_alias=True): @@ -86,7 +97,7 @@ def extract_tle_hints_scope(generator: TleCodeGenerator): when visit_Call for call_fn1, we can get the hints scope as follows: [{'inter_no_alias': True}, {xxx}, {'inter_no_alias': False}, {xxx}] should get the parent scope hints 'inter_no_alias': False for call_fn1, after visit call_fn1, pop the scope - + when visit_Call for call_fn, we can get the hints scope as follows: [{'inter_no_alias': True}, {xxx}, {'inter_no_alias': False}] and now the hint scope is 'inter_no_alias': False' for call_fn, after visit call_fn, pop the scope @@ -99,7 +110,7 @@ def extract_tle_hints_scope(generator: TleCodeGenerator): hints = generator.with_hints[i] if "inter_no_alias" in hints: return hints - + return {} @@ -110,4 +121,4 @@ def extract_tle_hints_scope(generator: TleCodeGenerator): __all__ = [ "dsa", -] \ No newline at end of file +] diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py index 32ed8c87b..4acb1bfa9 100644 --- a/python/triton/experimental/tle/language/__init__.py +++ b/python/triton/experimental/tle/language/__init__.py @@ -1,3 +1,7 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd from . import dsa + +__all__ = [ + 'dsa', +] diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py index 9c71697fd..d70394c59 100644 --- a/python/triton/experimental/tle/language/builder.py +++ b/python/triton/experimental/tle/language/builder.py @@ -1,5 +1,6 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd + def create_dsa_method_wrapper_with_tle_builder(main_builder, delegate_builder, method_name): delegate_method = getattr(delegate_builder, method_name) @@ -53,4 +54,4 @@ def setup_unified_builder_with_tle_builder(main_builder, buffer_builder): "create_dsa_insert_slice", "create_dsa_subview", ] - attach_builder_methods_with_tle_builder(main_builder, buffer_builder, buffer_methods) \ No newline at end of file + attach_builder_methods_with_tle_builder(main_builder, buffer_builder, buffer_methods) diff --git a/python/triton/experimental/tle/language/dsa/README.md b/python/triton/experimental/tle/language/dsa/README.md index 6e50c79d4..4bb7b8f42 100644 --- a/python/triton/experimental/tle/language/dsa/README.md +++ b/python/triton/experimental/tle/language/dsa/README.md @@ -51,7 +51,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. ## Testing ```bash -cd python/test/tle +cd python/test/tle python3 test_vec_add.py ``` @@ -59,4 +59,4 @@ python3 test_vec_add.py See other examples in `python/test/tle`: - `test_matmul.py` - GEMM implementation and pipeline usage -- `test_vec_mathOps.py` - Vector math operations, such as add, sub, mul, div \ No newline at end of file +- `test_vec_mathOps.py` - Vector math operations, such as add, sub, mul, div diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py index bc2e3aead..19a490b3f 100644 --- a/python/triton/experimental/tle/language/dsa/__init__.py +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -40,4 +40,5 @@ "insert_slice", "extract_element", "subview", + "ascend", ] diff --git a/python/triton/experimental/tle/language/dsa/ascend/__init__.py b/python/triton/experimental/tle/language/dsa/ascend/__init__.py index 3c7c09031..ee754717d 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/__init__.py +++ b/python/triton/experimental/tle/language/dsa/ascend/__init__.py @@ -5,7 +5,7 @@ L1, L0A, L0B, - L0C, + L0C, ) __all__ = [ @@ -13,5 +13,5 @@ "L1", "L0A", "L0B", - "L0C", + "L0C", ] diff --git a/python/triton/experimental/tle/language/dsa/ascend/core.py b/python/triton/experimental/tle/language/dsa/ascend/core.py index 14f215f16..79ebd93db 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/core.py +++ b/python/triton/experimental/tle/language/dsa/ascend/core.py @@ -1,5 +1,3 @@ -# Copyright 2026- Xcoresigma Technology Co., Ltd - from triton.language.extra.cann.extension.core import ascend_address_space UB = ascend_address_space.UB diff --git a/python/triton/experimental/tle/language/dsa/ascend/semantic.py b/python/triton/experimental/tle/language/dsa/ascend/semantic.py index a6b9487a8..a111d0159 100644 --- a/python/triton/experimental/tle/language/dsa/ascend/semantic.py +++ b/python/triton/experimental/tle/language/dsa/ascend/semantic.py @@ -1 +1 @@ -# Copyright 2026- Xcoresigma Technology Co., Ltd \ No newline at end of file +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py index fa963263e..f703fc447 100644 --- a/python/triton/experimental/tle/language/dsa/core.py +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -2,12 +2,7 @@ import triton.language.core as tl from triton.language import semantic as tl_semantic -from triton.language.core import ( - _constexpr_to_value, - tensor, - constexpr -) -from triton._C.libtriton import ir +from triton.language.core import (_constexpr_to_value, tensor, constexpr) from typing import List, TypeVar from functools import wraps @@ -20,12 +15,13 @@ TRITON_BUILTIN = "__triton_builtin__" TLE_BUILTIN = "__tle_builtin__" + def builtin(fn: T) -> T: """ Decorator for builtin functions to mark a function as a tle language builtin function. """ assert callable - + @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: @@ -38,6 +34,7 @@ def wrapper(*args, **kwargs): return wrapper + def is_builtin(fn) -> bool: """ Returns whether a function is a builtin function. @@ -114,6 +111,7 @@ def __iter__(self): def __next__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + class pipeline(range): """ Iterator that counts upward forever, with software pipeline semantics. @@ -121,6 +119,7 @@ class pipeline(range): This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) @@ -133,6 +132,7 @@ class parallel(range): :code:`triton.jit` functions. In addition, it indicates that there are no dependencies between loop iterations, allowing them to be executed in parallel. """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) @@ -145,13 +145,14 @@ def from_buffer_to_tensor_pointer(src: buffer, _builder=None) -> tl.tensor: block_type = tl.block_type(ele_type, shape) return tl.tensor(src.handle, block_type) + @builtin def copy(src, dst, shape, inter_no_alias=False, _builder=None): """Copy data from `src` to `dst` shaped by `shape`. :param inter_no_alias: If True, the copy is annotated as no aliasing between different iterations. """ - assert len(shape) != 0, f"Can't deduce copy extents from args" + assert len(shape) != 0, "Can't deduce copy extents from args" shape = _constexpr_to_value(shape) inter_no_alias = _constexpr_to_value(inter_no_alias) @@ -165,6 +166,7 @@ def add(input, other, result, _builder=None): result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.add(input, other, result, _builder) + @builtin def sub(input, other, result, _builder=None): input = from_buffer_to_tensor_pointer(input, _builder=_builder) @@ -172,6 +174,7 @@ def sub(input, other, result, _builder=None): result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.sub(input, other, result, _builder) + @builtin def mul(input, other, result, _builder=None): input = from_buffer_to_tensor_pointer(input, _builder=_builder) @@ -179,6 +182,7 @@ def mul(input, other, result, _builder=None): result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.mul(input, other, result, _builder) + @builtin def div(input, other, result, _builder=None): input = from_buffer_to_tensor_pointer(input, _builder=_builder) @@ -186,6 +190,7 @@ def div(input, other, result, _builder=None): result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.div(input, other, result, _builder) + @builtin def max(input, other, result, _builder=None): # elementwise binary vector maximum op @@ -194,6 +199,7 @@ def max(input, other, result, _builder=None): result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.max(input, other, result, _builder) + @builtin def min(input, other, result, _builder=None): # elementwise binary vector minimum op @@ -202,15 +208,6 @@ def min(input, other, result, _builder=None): result = from_buffer_to_tensor_pointer(result, _builder=_builder) tle_semantic.min(input, other, result, _builder) -### @builtin -### def dot(inputA, inputB, result, size, initC, a_transpose=False, b_transpose=False, enable_hf32=False, _builder=None): -### initC = _constexpr_to_value(initC) -### a_transpose = _constexpr_to_value(a_transpose) -### b_transpose = _constexpr_to_value(b_transpose) -### enable_hf32 = _constexpr_to_value(enable_hf32) -### tle_semantic.dot(inputA, inputB, result, size, initC, a_transpose, b_transpose, enable_hf32, _builder) - - @builtin def alloc(shape: List[tl.constexpr], dtype: tl.dtype, mem_addr_space: address_space, _builder=None) -> buffer: @@ -253,6 +250,7 @@ def to_tensor(memref: buffer, writable: bool = True, target_shape=None, _builder """ return tle_semantic.to_tensor(memref, writable, _builder, target_shape=target_shape) + @builtin def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], _builder=None) -> buffer: @@ -309,13 +307,15 @@ def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], return tle_semantic.subview(src, new_offsets, new_sizes, new_strides, _builder) + def hint(**kwargs): """Dummy function for AST parsing. Not executed during JIT compilation.""" raise RuntimeError("tle.hint() cannot be called directly.") @builtin -def insert_slice(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], _builder=None) -> tensor: +def insert_slice(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + _builder=None) -> tensor: """ Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. @@ -334,10 +334,7 @@ def insert_slice(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[in assert len(ful.shape) == len(sub.shape) assert (len(ful.shape) == len(sizes)) assert (len(ful.shape) == len(strides)) - new_offsets = [ - tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o - for o in offsets - ] + new_offsets = [tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] out = tle_semantic.insert_slice(ful, sub, new_offsets, sizes, strides, _builder) return out @@ -357,10 +354,7 @@ def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) :type strides: tuple of ints """ assert len(ful.shape) > 0 - new_offsets = [ - tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o - for o in offsets - ] + new_offsets = [tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] sub = tle_semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) return sub @@ -378,8 +372,5 @@ def extract_element(src, indice, _builder=None, _generator=None): :type indice: tuple of ints """ assert len(src.shape) > 0 - new_indice = [ - tl_semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i - for i in indice - ] - return tle_semantic.extract_element(src, new_indice, _builder) \ No newline at end of file + new_indice = [tl_semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] + return tle_semantic.extract_element(src, new_indice, _builder) diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py index b295c7b97..c3d1686cc 100644 --- a/python/triton/experimental/tle/language/dsa/semantic.py +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -1,13 +1,13 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd -from typing import List, Optional, Union, Tuple +from typing import List, Union from triton.language import core as tl from triton.language.semantic import ( - binary_op_type_checking_impl, -) + binary_op_type_checking_impl, ) from triton._C.libtriton import ir from .types import buffer, buffer_type, address_space + def wrap_tensor(x, scalar_ty, ret_shape): if ret_shape: res_ty = tl.block_type(scalar_ty, ret_shape) @@ -16,6 +16,7 @@ def wrap_tensor(x, scalar_ty, ret_shape): res_ty = scalar_ty return tl.tensor(x, res_ty) + def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: # assert value.numel.value == 1, "only accepts size-1 tensor" if isinstance(value, tl.constexpr): @@ -35,54 +36,37 @@ def copy(src, dst, shape: List[Union[tl.constexpr, int]], inter_no_alias: bool, builder.create_dsa_copy(src.handle, dst.handle, [s.handle for s in shape], inter_no_alias) -### def to_tensor(buffer: tl.tensor, builder: ir.builder) -> tl.tensor: -### if not isinstance(buffer, tl.tensor): -### raise TypeError("buffer must be tensor of pointers") -### -### tensor_ty = buffer.type -### element_ty = tensor_ty.element_ty -### if not element_ty.is_ptr: -### raise TypeError("The basic elements of a buffer must be pointers") -### -### return tl.tensor(builder.dsa_to_tensor(buffer.handle), tensor_ty) -### -### def to_buffer(src: tl.tensor, builder: ir.builder) -> tl.tensor: -### if not isinstance(src, tl.tensor): -### raise TypeError("src of to_buffer must be tensor") -### -### return tl.tensor(builder.dsa_to_buffer(src.handle), src.type) - def add(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_add(input.handle, other.handle, result.handle) + def sub(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_sub(input.handle, other.handle, result.handle) + def mul(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_mul(input.handle, other.handle, result.handle) + def div(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_div(input.handle, other.handle, result.handle) + def max(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_max(input.handle, result.handle) + def min(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): input, other = binary_op_type_checking_impl(input, other, builder, True, True) builder.create_dsa_min(input.handle, other.handle, result.handle) -### def dot(inputA: tl.tensor, inputB: tl.tensor, result: tl.tensor, size: List[int], initC: bool, a_transpose: bool, b_transpose: bool, enable_hf32: bool, builder: ir.builder): -### assert len(size) == 3, f"Please set the M、N、K value." -### -### builder.create_dsa_dot(inputA.handle, inputB.handle, result.handle, size, initC, a_transpose, b_transpose, enable_hf32) -def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: address_space, - builder: ir.builder) -> buffer: +def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: address_space, builder: ir.builder) -> buffer: shape = tl._unwrap_shape(shape) if not isinstance(shape, (tuple, list)): raise TypeError("shape must be list/tuple") @@ -107,7 +91,7 @@ def to_buffer( # if isinstance(bind_buffer, buffer): # builder.create_bind_buffer(tensor.handle, bind_buffer.handle) # return bind_buffer - if not (bind_buffer is None): + if bind_buffer is not None: raise ValueError("bind_buffer must be a buffer or None") address_space = tl._constexpr_to_value(address_space) addr_space_attr = (address_space.to_ir(builder) if address_space else builder.dsa_get_null_attr()) @@ -142,28 +126,32 @@ def to_tensor(memref: buffer, writable: bool, builder: ir.builder, target_shape= return tl.tensor(builder.dsa_to_tensor(memref_value, writable), tensor_type) -def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: - assert(len(ful.shape) == len(offsets)) - assert(len(ful.shape) == len(sizes)) - assert(len(ful.shape) == len(strides)) - assert(all([s>=1 for s in sizes])) - assert(all([s>=0 for s in strides])) +def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tl.tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) new_offsets = [o.handle for o in offsets] ret_type = tl.block_type(ful.type.scalar, ful.shape) out = builder.create_dsa_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) return tl.tensor(out, ret_type) -def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], builder: ir.builder) -> tl.tensor: - assert(len(ful.shape) == len(offsets)) - assert(len(ful.shape) == len(sizes)) - assert(len(ful.shape) == len(strides)) - assert(all([s>=1 for s in sizes])) - assert(all([s>=0 for s in strides])) + +def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tl.tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) new_offsets = [o.handle for o in offsets] ret_type = tl.block_type(ful.type.scalar, sizes) out = builder.create_dsa_extract_slice(ful.handle, new_offsets, sizes, strides) return tl.tensor(out, ret_type) + def extract_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): if len(src.shape) != len(indice): raise ValueError("Indice's rank must be equal to src tensor's rank") @@ -202,4 +190,4 @@ def subview(src: buffer, offsets: List[tl.tensor], sizes: List[tl.constexpr], st # create buffer_type with strides buffer_ty = buffer_type(element_ty=src.dtype, shape=sizes_int, space=src.space, strides=result_memory_strides) - return buffer(result_handle, buffer_ty) \ No newline at end of file + return buffer(result_handle, buffer_ty) diff --git a/python/triton/experimental/tle/language/dsa/types.py b/python/triton/experimental/tle/language/dsa/types.py index 7ece8d035..8c65f69d3 100644 --- a/python/triton/experimental/tle/language/dsa/types.py +++ b/python/triton/experimental/tle/language/dsa/types.py @@ -1,10 +1,8 @@ -# Copyright 2026- Xcoresigma Technology Co., Ltd - from triton._C.libtriton import ir from typing import List import triton.language.core as tl -from triton.language.core import builtin + class address_space: """Represents a buffer's address space. diff --git a/python/tutorials/tle/01-sparse-flash-attn-tle.py b/python/tutorials/tle/01-sparse-flash-attn-tle.py new file mode 100644 index 000000000..d7d2bb62e --- /dev/null +++ b/python/tutorials/tle/01-sparse-flash-attn-tle.py @@ -0,0 +1,949 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +import torch +import torch_npu +import triton +import triton.language as tl +import numpy as np +from datetime import datetime +from triton.backends.ascend.testing import do_bench_npu +import triton.experimental.tle as tle +# import random + +np.random.seed(21) +DEVICE = "npu" +DEVICE_ID = 0 +torch.manual_seed(20) +torch_npu.npu.set_device(int(DEVICE_ID)) +torch.set_printoptions(sci_mode=False, precision=4, linewidth=300) + +ascend_aiv_core_nums = triton.language.constexpr(24) + + +# ===== Fused PA + Rope Concat + BNSD + Gather Kernel ===== +@triton.jit +def fused_pa_rope_to_sparse_kernel( + k_pa_ptr, + k_rope_pa_ptr, + v_pa_ptr, # PA_BSND input [block_num, block_size, n, d] + block_table_ptr, # block_table [B, max_blocks] + sparse_indices_ptr, # sparse_indices [B, N, TOPK] + k_sparse_out_ptr, + v_sparse_out_ptr, # BNSD output [B, N, TOPK, d] + stride_k_pa_bn, + stride_k_pa_bs, + stride_k_pa_n, + stride_k_pa_d, # K PA strides + stride_k_rope_pa_bn, + stride_k_rope_pa_bs, + stride_k_rope_pa_n, + stride_k_rope_pa_d, # K_rope PA strides + stride_v_pa_bn, + stride_v_pa_bs, + stride_v_pa_n, + stride_v_pa_d, # V PA strides + stride_bt_b, + stride_bt_blk, # block_table strides + stride_si_b, + stride_si_n, + stride_si_topk, # sparse_indices strides + stride_out_b, + stride_out_n, + stride_out_topk, + stride_out_d, # output strides + stride_v_b, + stride_v_n, + stride_v_topk, + stride_v_d, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_DK_ROPE: tl.constexpr, # 0 if no rope + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + B: tl.constexpr, +): + """ + Fused kernel: PA_BSND + Rope Concat -> BNSD Sparse + Input: K/V in PA_BSND format, K_rope in PA_BSND format + Output: K/V_sparse in BNSD format + """ + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # Process (b, n, topk) combinations + for b_idx in range(B): + b = b_idx # sparse_indices is [B, N, TOPK], assume B=1 for now + for idx in range(pid, TOPK, num_programs): + # Get batch and sparse index from sparse_indices + n = 0 # KV_N = 1 + + # Load sparse index + sparse_idx = tl.load(sparse_indices_ptr + b * stride_si_b + n * stride_si_n + idx * stride_si_topk) + + # Map sparse_idx to PA_BSND position + block_id = sparse_idx // BLOCK_SIZE # Which block + bs_offset = sparse_idx % BLOCK_SIZE # Offset within block + + # Get actual block ID from block_table + actual_block_id = tl.load(block_table_ptr + b * stride_bt_b + block_id * stride_bt_blk) + + # Compute PA_BSND offset for K + k_pa_offset = (actual_block_id * stride_k_pa_bn + bs_offset * stride_k_pa_bs + n * stride_k_pa_n) + + # Compute PA_BSND offset for K_rope + k_rope_pa_offset = (actual_block_id * stride_k_rope_pa_bn + bs_offset * stride_k_rope_pa_bs + + n * stride_k_rope_pa_n) + + # Compute PA_BSND offset for V + v_pa_offset = (actual_block_id * stride_v_pa_bn + bs_offset * stride_v_pa_bs + n * stride_v_pa_n) + # Load K vector (no rope part) + k_vec = tl.load(k_pa_ptr + k_pa_offset + tl.arange(0, BLOCK_DK) * stride_k_pa_d) + + # Load V vector + v_vec = tl.load(v_pa_ptr + v_pa_offset + tl.arange(0, BLOCK_DV) * stride_v_pa_d) + # Output to BNSD format: [B, N, TOPK, D] + out_offset = b * stride_out_b + n * stride_out_n + idx * stride_out_topk + out_offset_v = b * stride_v_b + n * stride_v_n + idx * stride_v_topk + + if BLOCK_DK_ROPE > 0: + # Load K_rope vector + full_k = tl.full((BLOCK_DK + BLOCK_DK_ROPE, ), 0.0, dtype=tl.float16) + k_rope_vec = tl.load(k_rope_pa_ptr + k_rope_pa_offset + + tl.arange(0, BLOCK_DK_ROPE) * stride_k_rope_pa_d) + full_k = tle.dsa.insert_slice(full_k, k_vec, offsets=(0, ), sizes=(BLOCK_DK, ), strides=(1, )) + full_k = tle.dsa.insert_slice(full_k, k_rope_vec, offsets=(BLOCK_DK, ), sizes=(BLOCK_DK_ROPE, ), + strides=(1, )) + tl.store(k_sparse_out_ptr + out_offset + tl.arange(0, BLOCK_DK + BLOCK_DK_ROPE) * stride_out_d, full_k) + else: + # No rope, store K directly + tl.store(k_sparse_out_ptr + out_offset + tl.arange(0, BLOCK_DK) * stride_out_d, k_vec) + + # Store V + tl.store(v_sparse_out_ptr + out_offset_v + tl.arange(0, BLOCK_DV) * stride_v_d, v_vec) + + +def triton_fused_pa_rope_to_sparse(k_pa, k_rope_pa, v_pa, block_table, sparse_indices, block_size): + """ + Fused PA_BSND + Rope Concat -> BNSD Sparse conversion + + Args: + k_pa: Key in PA_BSND format [block_num, block_size, n, dk] + k_rope_pa: Key rope in PA_BSND format [block_num, block_size, n, d_rope], None if no rope + v_pa: Value in PA_BSND format [block_num, block_size, n, dv] + block_table: Block table [B, max_blocks] + sparse_indices: Sparse indices [B, N, TOPK] + block_size: Block size for PA format + + Returns: + k_sparse: Sparse key in BNSD format [B, N, TOPK, dk+d_rope] + v_sparse: Sparse value in BNSD format [B, N, TOPK, dv] + """ + block_num, _, n, dk = k_pa.shape + B = block_table.shape[0] + TOPK = sparse_indices.size(-1) + N = 1 # KV_N = 1 + _, _, _, dv = v_pa.shape + + has_rope = k_rope_pa is not None + dk_rope = k_rope_pa.shape[-1] if has_rope else 0 + dk_total = dk + dk_rope + + # Output BNSD format [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, dk_total), dtype=k_pa.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, dv), dtype=v_pa.dtype, device=DEVICE) + + # Grid: use 48 programs for parallelism + grid = (min(48, TOPK), ) + + # sparse_indices input format: [T, N, TOPK] or [B, N, TOPK] + # No squeeze needed - kernel expects [B, N, TOPK] format + sparse_indices_input = sparse_indices + if sparse_indices.dim() == 2: + # If already 2D [B, TOPK], reshape to [B, 1, TOPK] + sparse_indices_input = sparse_indices.unsqueeze(1) + + # Set k_rope_pa to k_pa if no rope (dummy pointer, won't be accessed) + k_rope_pa_input = k_rope_pa if has_rope else k_pa + fused_pa_rope_to_sparse_kernel[grid](k_pa, k_rope_pa_input, v_pa, block_table, sparse_indices_input, + k_sparse, v_sparse, k_pa.stride(0), k_pa.stride(1), k_pa.stride(2), + k_pa.stride(3), k_rope_pa_input.stride(0), k_rope_pa_input.stride(1), + k_rope_pa_input.stride(2), k_rope_pa_input.stride(3), v_pa.stride(0), + v_pa.stride(1), v_pa.stride(2), v_pa.stride(3), block_table.stride(0), + block_table.stride(1), sparse_indices_input.stride(0), + sparse_indices_input.stride(1), sparse_indices_input.stride(2), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + BLOCK_DK=dk, BLOCK_DV=dv, BLOCK_DK_ROPE=dk_rope, TOPK=TOPK, + BLOCK_SIZE=block_size, B=B) + + return k_sparse, v_sparse + + +@triton.jit +def gather_kv_bnsd_vec_kernel( + k_ptr, + v_ptr, + ind_ptr, + k_out_ptr, + v_out_ptr, + stride_kb, + stride_kn, + stride_ks, + stride_kd, + stride_vb, + stride_vn, + stride_vs, + stride_vd, + stride_ob, + stride_on, + stride_os, + stride_od, + stride_ovb, + stride_ovn, + stride_ovs, + stride_ovd, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + TOPK: tl.constexpr, + B: tl.constexpr, +): + end = TOPK // 48 * 48 + for b_idx in range(B): + # 分批处理所有TOPK个索引,每次48个 + for batch_start in range(0, end, 48): + pid_k = tl.program_id(0) + batch_start + + # 读 index + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + + # 处理余数部分(end到TOPK) + for batch_start in range(end, TOPK, 48): + pid_k = tl.program_id(0) + batch_start + + # 必须在计算pid_k之后检查边界 + if pid_k < TOPK: + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + + +def triton_gather_kv_bnsd_vec(k, v, indices): + B, N, SK, Dk = k.shape # N=1 + B, N, SK, Dv = v.shape # N=1 + TOPK = indices.size(-1) + + # 输出保持 bnsd [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, Dk), dtype=k.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, Dv), dtype=v.dtype, device=DEVICE) + + grid = (48, ) # TOPK 个 program,每个搬 Dk/Dv 元素 + gather_kv_bnsd_vec_kernel[grid]( + k, + v, + indices.squeeze(0), # [B, N, SK, D] -> [N, SK, D] + k_sparse, + v_sparse, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + k_sparse.stride(0), + k_sparse.stride(1), + k_sparse.stride(2), + k_sparse.stride(3), + v_sparse.stride(0), + v_sparse.stride(1), + v_sparse.stride(2), + v_sparse.stride(3), + BLOCK_DK=Dk, + BLOCK_DV=Dv, + TOPK=TOPK, + B=B, + ) + return k_sparse, v_sparse + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + O, + scale_value, + stride_qb: tl.constexpr, + stride_qs: tl.constexpr, + stride_qn: tl.constexpr, + stride_qd: tl.constexpr, + stride_kb: tl.constexpr, + stride_kn: tl.constexpr, + stride_ks: tl.constexpr, + stride_kd: tl.constexpr, + stride_vb: tl.constexpr, + stride_vn: tl.constexpr, + stride_vs: tl.constexpr, + stride_vd: tl.constexpr, + stride_ob: tl.constexpr, + stride_os: tl.constexpr, + stride_on: tl.constexpr, + stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, + Q_D: tl.constexpr, + Q_S: tl.constexpr, + KV_S: tl.constexpr, + K_D: tl.constexpr, + V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N: tl.constexpr, + O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, +): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B * Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + # off_n = 0 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ob + off_s * stride_os + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE, ), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE, ), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(Q_N, Q_D), strides=(stride_qn, stride_qd), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, Q_D), + order=(1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0, 1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(KV_S, K_D), + strides=(stride_ks, stride_kd), + offsets=(0, 0), + block_shape=(blk_size, K_D), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(KV_S, V_D), strides=(stride_vs, stride_vd), + offsets=(0, 0), block_shape=(blk_size, V_D), order=(1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0, 1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), + tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:, None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0, 1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc * coeff[:, None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base=O + o_offset, shape=(O_N, O_D), strides=(stride_on, stride_od), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, O_D), + order=(1, 0)) + # final normalize + acc = acc / logSum[:, None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + +@triton.jit +def _attn_fwd_fused_bsnd_to_tnd( + Q, + K, + V, + O, + scale_value, + stride_qb: tl.constexpr, + stride_qs: tl.constexpr, + stride_qn: tl.constexpr, + stride_qd: tl.constexpr, + stride_kb: tl.constexpr, + stride_kn: tl.constexpr, + stride_ks: tl.constexpr, + stride_kd: tl.constexpr, + stride_vb: tl.constexpr, + stride_vn: tl.constexpr, + stride_vs: tl.constexpr, + stride_vd: tl.constexpr, + stride_ot: tl.constexpr, + stride_on: tl.constexpr, + stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, + Q_D: tl.constexpr, + Q_S: tl.constexpr, + KV_S: tl.constexpr, + K_D: tl.constexpr, + V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N: tl.constexpr, + O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, +): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B * Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ot + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE, ), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE, ), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(Q_N, Q_D), strides=(stride_qn, stride_qd), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, Q_D), + order=(1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0, 1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(KV_S, K_D), + strides=(stride_ks, stride_kd), + offsets=(0, 0), + block_shape=(blk_size, K_D), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(KV_S, V_D), strides=(stride_vs, stride_vd), + offsets=(0, 0), block_shape=(blk_size, V_D), order=(1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0, 1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), + tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:, None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0, 1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc * coeff[:, None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base=O + o_offset, shape=(O_N, O_D), strides=(stride_on, stride_od), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, O_D), + order=(1, 0)) + # final normalize + acc = acc / logSum[:, None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, query, key, value, sparse_indices, scale_value, sparse_block_size=1, actual_seq_lengths_query=None, + actual_seq_lengths_kv=None, query_rope=None, key_rope=None, layout_query='BSND', layout_kv='BSND', + sparse_mode=0, block_table=None): + # Save original sparse_indices for PA_BSND case + sparse_indices_orig = sparse_indices.clone() + total_len = 0 + # Handle query layout transformation (TND -> BSND) + if layout_query == 'TND': + actual_seq_lengths_query, total_len = trans_tnd_actseq(actual_seq_lengths_query) + # ✅ 融合版本:一次 kernel 调用处理所有 tensor + concat + query, sparse_indices = trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, query.shape, + actual_seq_lengths_query) + else: + if query_rope is not None: + query = torch.cat([query, query_rope], dim=-1) + + # Handle KV layout and gather sparse K/V + if layout_kv == 'PA_BSND': + # Fused PA -> BNSD + rope concat + sparse gather + block_size = key.shape[1] # Get block_size from PA shape + # Use original sparse_indices [T, N, TOPK] for fused kernel + k_sparse, v_sparse = triton_fused_pa_rope_to_sparse(key, key_rope, value, block_table, sparse_indices_orig, + block_size) + # sparse_indices is already in BSND, needs permute to BNSD for downstream use + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + else: + # Original path for non-PA layouts + if key_rope is not None: + key = torch.cat([key, key_rope], dim=-1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + + k_sparse, v_sparse = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + + k_sparse = k_sparse.contiguous() + v_sparse = v_sparse.contiguous() + enable_check_kv_sparse = 0 + if enable_check_kv_sparse: + key = pa_to_bsnd(key, block_table, actual_seq_lengths_kv) + key_rope = pa_to_bsnd(key_rope, block_table, actual_seq_lengths_kv) + value = pa_to_bsnd(value, block_table, actual_seq_lengths_kv) + if key_rope is not None: + key = torch.cat([key, key_rope], dim=-1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + k_sparse_ref, v_sparse_ref = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + print(f"k_sparse={k_sparse}") + print(f"k_sparse_ref={k_sparse_ref}") + print(f"v_sparse={v_sparse}") + print(f"v_sparse_ref={v_sparse_ref}") + assert torch.allclose(k_sparse, k_sparse_ref, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + assert torch.allclose(v_sparse, v_sparse_ref, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + + # expected_k = key_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(k_sparse, expected_k, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + # expected_v = value_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(v_sparse, expected_v, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + num_cores = ascend_aiv_core_nums + # sparse_size = sparse_indices_bnsd.shape[-1] # 4 + out_shape_bsnd = list(query.shape) + if query_rope is not None: + out_shape_bsnd[-1] = out_shape_bsnd[-1] - query_rope.shape[-1] + B, Q_S, Q_N, Q_D = query.shape + _, _, KV_S, K_D = k_sparse.shape + + if layout_query == 'TND': + # t = B*act_q_s + output = torch.empty((total_len, out_shape_bsnd[2], out_shape_bsnd[3]), device=query.device, + dtype=torch.float32) + _attn_fwd_fused_bsnd_to_tnd[(num_cores, )]( + query, k_sparse, v_sparse, output, scale_value, query.stride(0), query.stride(1), query.stride(2), + query.stride(3), k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), output.stride(0), + output.stride(1), output.stride(2), B=B, Q_N=Q_N, Q_D=Q_D, Q_S=Q_S, KV_S=KV_S, K_D=K_D, + V_D=v_sparse.shape[3], sparse_mode=sparse_mode, O_N=output.shape[1], O_D=output.shape[2], + actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_kv=actual_seq_lengths_kv, + blk_size=128, Q_BLOCK_SIZE=16, multibuffer=False) + + else: + output = torch.empty(out_shape_bsnd, device=query.device, dtype=torch.float32) + _attn_fwd[(num_cores, )](query, k_sparse, v_sparse, output, scale_value, query.stride(0), query.stride(1), + query.stride(2), query.stride(3), k_sparse.stride(0), k_sparse.stride(1), + k_sparse.stride(2), k_sparse.stride(3), v_sparse.stride(0), v_sparse.stride(1), + v_sparse.stride(2), v_sparse.stride(3), output.stride(0), output.stride(1), + output.stride(2), output.stride(3), B=B, Q_N=Q_N, Q_D=Q_D, Q_S=Q_S, KV_S=KV_S, + K_D=K_D, V_D=v_sparse.shape[3], sparse_mode=sparse_mode, O_N=output.shape[2], + O_D=output.shape[3], actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, blk_size=128, Q_BLOCK_SIZE=16) + output = output.permute(0, 2, 1, 3).contiguous() + + ctx.save_for_backward(query, k_sparse, v_sparse, output) + ctx.scale_value = scale_value + return output + + +def pa_to_bsnd(pa_in, block_table, actual_seq_lengths): + block_num, block_size, n, d = pa_in.shape + b = len(actual_seq_lengths) + output = torch.empty((b, block_num * block_size // b, 1, d), dtype=pa_in.dtype).to(DEVICE) + for i in range(b): + for j in range(20): + output[i, j * block_size: (j + 1) * block_size, 0, :] = \ + pa_in[block_table[i][j], :, 0, :].reshape(block_size, d) + return output + + +@triton.jit +def trans_tnd_to_bsnd_fused_kernel( + query_ptr, + query_rope_ptr, + sparse_ptr, + query_out_ptr, + sparse_out_ptr, # query_out 已经拼接了 rope + act_s, + stride_q_t, + stride_q_tn, + stride_q_td, + stride_qr_t, + stride_qr_tn, + stride_qr_td, + stride_s_t, + stride_s_tn, + stride_s_td, + stride_qob, + stride_qobs, + stride_qon, + stride_qod, # query_out strides + stride_sb, + stride_sbs, + stride_sbn, + stride_sbd, + B: tl.constexpr, + N: tl.constexpr, + D_QUERY: tl.constexpr, + D_ROPE: tl.constexpr, + D_SPARSE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_QUERY: tl.constexpr, + BLOCK_D_ROPE: tl.constexpr, + BLOCK_D_SPARSE: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # 计算 head 的总块数 + num_head_blocks = (N + BLOCK_N - 1) // BLOCK_N + t_idx = tl.full((), 0, dtype=tl.int64) # TODO: 需要正确的 token 映射 + # 每个 pid 负责处理特定的 (batch, head_block) 组合 + for tn_id in range(B): + # sparse_indices 是单头的,只在第一个 head_block 处理一次 + if pid == 0: + sparse_block_ptr = tl.make_block_ptr(base=sparse_ptr + t_idx * stride_s_t, shape=(1, D_SPARSE), + strides=(stride_s_tn, stride_s_td), offsets=(0, 0), + block_shape=(1, D_SPARSE), order=(1, 0)) + sparse = tl.load(sparse_block_ptr) + + sparse_out_block_ptr = tl.make_block_ptr(base=sparse_out_ptr + t_idx * stride_sb, shape=(1, D_SPARSE), + strides=(stride_sbn, stride_sbd), offsets=(0, 0), + block_shape=(1, D_SPARSE), order=(1, 0)) + tl.store(sparse_out_block_ptr, sparse) + + # query 和 query_rope 是多头的,需要在 head 维度上分块处理 + for head_block_id in range(pid, num_head_blocks, num_programs): + n_offset = head_block_id * BLOCK_N + + # Load q and q_ro + q_block_ptr = tl.make_block_ptr(base=query_ptr + t_idx * stride_q_t, shape=(N, D_QUERY), + strides=(stride_q_tn, stride_q_td), offsets=(n_offset, 0), + block_shape=(BLOCK_N, D_QUERY), order=(1, 0)) + q_ro_block_ptr = tl.make_block_ptr(base=query_rope_ptr + t_idx * stride_qr_t, shape=(N, D_ROPE), + strides=(stride_qr_tn, stride_qr_td), offsets=(n_offset, 0), + block_shape=(BLOCK_N, D_ROPE), order=(1, 0)) + q = tl.load(q_block_ptr) + q_ro = tl.load(q_ro_block_ptr) + + # Combine query and query_rope using insert_slice, then store in one operation + full_q = tl.zeros((BLOCK_N, D_QUERY + D_ROPE), dtype=query_out_ptr.dtype.element_ty) + full_q = tle.dsa.insert_slice(full_q, q, offsets=(0, 0), sizes=(BLOCK_N, D_QUERY), strides=(1, 1)) + full_q = tle.dsa.insert_slice(full_q, q_ro, offsets=(0, D_QUERY), sizes=(BLOCK_N, D_ROPE), strides=(1, 1)) + + q_out_block_ptr = tl.make_block_ptr(base=query_out_ptr + t_idx * stride_qob, shape=(N, D_QUERY + D_ROPE), + strides=(stride_qon, stride_qod), offsets=(n_offset, 0), + block_shape=(BLOCK_N, D_QUERY + D_ROPE), order=(1, 0)) + tl.store(q_out_block_ptr, full_q) + t_idx = t_idx + tl.load(act_s + tn_id) + + +def trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, shape, act_seq, grid=(16, )): + """ + 融合版本的 TND -> BSND 转换(包含 concat) + 一次性处理 query, query_rope, sparse_indices,并拼接 query + query_rope + """ + t, n, d_query = shape + b = len(act_seq) + s = max(act_seq) + + # 获取各个 tensor 的维度 + d_rope = query_rope.shape[2] if query_rope is not None else 0 + d_sparse = sparse_indices.shape[2] + d_query_out = d_query + d_rope # 拼接后的维度 + + # 分配输出(query_out 已经包含 rope) + query_out = torch.empty((b, s, n, d_query_out), dtype=query.dtype, device=query.device) + sparse_out = torch.empty((b, s, 1, d_sparse), dtype=sparse_indices.dtype, device=sparse_indices.device) + assert sparse_indices.shape[1] == 1, "sparse_indices second dim must be 1 when MLA" + # 启动 fused kernel + # 使用较小的 BLOCK_N 避免内存溢出 + block_n = min(16, n) + # 计算需要的核心数:使用多核心并行处理不同的头 + num_head_blocks = (n + block_n - 1) // block_n + num_programs = min(ascend_aiv_core_nums, num_head_blocks) # 最多使用24个核心 + + trans_tnd_to_bsnd_fused_kernel[ + num_programs, + ]( + query, + query_rope, + sparse_indices, + query_out, + sparse_out, + act_seq, + query.stride(0), + query.stride(1), + query.stride(2), + query_rope.stride(0), + query_rope.stride(1), + query_rope.stride(2), + sparse_indices.stride(0), + sparse_indices.stride(1), + sparse_indices.stride(2), + query_out.stride(0), + query_out.stride(1), + query_out.stride(2), + query_out.stride(3), + sparse_out.stride(0), + sparse_out.stride(1), + sparse_out.stride(2), + sparse_out.stride(3), + B=b, + N=n, + D_QUERY=d_query, + D_ROPE=d_rope, + D_SPARSE=d_sparse, + BLOCK_N=block_n, + BLOCK_D_QUERY=d_query, + BLOCK_D_ROPE=d_rope, + BLOCK_D_SPARSE=d_sparse, + ) + return query_out, sparse_out + + +def trans_tnd_actseq(seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().tolist() + list_len = len(seq) + output = [] + output = [seq[0]] + total_len = seq[0] + for i in range(list_len - 1): + new_item = seq[i + 1] - seq[i] + if new_item >= 0: + output.append(new_item) + total_len += new_item + else: + print(f"[ERROR]trans_tnd_actseq: Wrong input actseq:{seq}, in loop {i}, item {new_item} < 0") + return torch.tensor(output).to(DEVICE), total_len + + +def sparse_attention(query, key, value, sparse_indices, scale_value, sparse_block_size=1, actual_seq_lengths_query=None, + actual_seq_lengths_kv=None, query_rope=None, key_rope=None, layout_query='BSND', layout_kv='BSND', + sparse_mode=0, block_table=None): + return _attention.apply(query, key, value, sparse_indices, scale_value, sparse_block_size, actual_seq_lengths_query, + actual_seq_lengths_kv, query_rope, key_rope, layout_query, layout_kv, sparse_mode, + block_table) + + +def test_op(T, B, KV_S, Q_N, KV_N, D, D_rope, sparse_size, scale_value, sparse_block_size, sparse_mode, block_size, + act_kv_s): + assert sparse_size <= KV_S + assert KV_N == 1 + assert sparse_mode == 0 or 3 + assert sparse_block_size == 1 + assert (B * KV_S) % block_size == 0 + assert D == 512 + assert D_rope == 0 or 64 + print("*batch_size=", B) + qkv_dtype = torch.float16 + #sparse_size = KV_S + query = torch.empty((T, Q_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + key = torch.empty((B * KV_S // block_size, block_size, KV_N, D), dtype=qkv_dtype, + device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + value = key.clone() + + # act_q_s = T // B # step + # rand_vals = torch.rand(T, KV_N, act_kv_s, device=DEVICE) + # _, indices = torch.topk(rand_vals, sparse_size, dim=-1) #sparse_indices不重复 + # sparse_indices = indices.to(torch.int32) + sparse_indices = torch.arange(sparse_size, device=DEVICE, dtype=torch.int32).view(1, 1, -1).expand(T, KV_N, -1) + sparse_indices = sparse_indices.to(torch.int32) + # print("sparse_indices=", sparse_indices) + actual_seq_lengths_query = torch.arange(1, B + 1, dtype=torch.int32, device=DEVICE) + # actual_seq_lengths_query = torch.tensor([1]).reshape(B).to(torch.int32).to(DEVICE) + actual_seq_lengths_kv = torch.tensor([act_kv_s] * B, dtype=torch.int32, device=DEVICE) + print(actual_seq_lengths_kv) + block_table = torch.tensor([range(B * KV_S // block_size)], dtype=torch.int32, device=DEVICE).reshape(B, -1) + + if D_rope == 0: + query_rope = None + key_rope = None + else: + query_rope = torch.empty((T, Q_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, + std=0.5).requires_grad_() + key_rope = torch.empty((B * KV_S // block_size, block_size, KV_N, D_rope), dtype=qkv_dtype, + device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + + print("q.shape=", query.shape) + print("k.shape=", key.shape) + print("v.shape=", value.shape) + print("sparse_indices.shape=", sparse_indices.shape) + print("act_seq_query=", actual_seq_lengths_query) + print("act_seq_kv=", actual_seq_lengths_kv) + + triton_out = sparse_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + ) + npu_out = torch_npu.npu_sparse_flash_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + # attention_mode = 2, + ) + triton_out = triton_out.to(npu_out.dtype) + torch.testing.assert_close(triton_out, npu_out, rtol=1e-2, atol=1e-2, equal_nan=True) + print("[PASSED]") + + # benchmarking + triton_time = do_bench_npu( + lambda: sparse_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Triton SFA] Time: {triton_time:.4f} us") + + npu_time = do_bench_npu( + lambda: torch_npu.npu_sparse_flash_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + # attention_mode = 2, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Torch-NPU SFA] Time: {npu_time:.4f} us") + + +if __name__ == "__main__": + print(torch_npu.__version__) + print("Test Real Case in DS-v3.2-Exp") + print(f"time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + i = 1 + print(f"====================第{i}次测试=================") + test_op(T=1, B=1, KV_S=2560, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T=4, B=4, KV_S=6400, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T=8, B=8, KV_S=48000, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T=16, B=16, KV_S=48000, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) diff --git a/python/tutorials/tle/sparse_flash_attn_tle.py b/python/tutorials/tle/sparse_flash_attn_tle.py deleted file mode 100644 index fea4b1682..000000000 --- a/python/tutorials/tle/sparse_flash_attn_tle.py +++ /dev/null @@ -1,912 +0,0 @@ -# Copyright 2026- Xcoresigma Technology Co., Ltd - -import pytest -import torch -import torch_npu -import triton -import triton.language as tl -import numpy as np -from datetime import datetime -from triton.backends.ascend.testing import do_bench_npu -import triton.experimental.tle as tle -# import random - -np.random.seed(21) -DEVICE = "npu" -DEVICE_ID = 0 -torch.manual_seed(20) -torch_npu.npu.set_device(int(DEVICE_ID)) -torch.set_printoptions(sci_mode=False, precision=4, linewidth=300) - -ascend_aiv_core_nums = triton.language.constexpr(24) - -# ===== Fused PA + Rope Concat + BNSD + Gather Kernel ===== -@triton.jit -def fused_pa_rope_to_sparse_kernel( - k_pa_ptr, k_rope_pa_ptr, v_pa_ptr, # PA_BSND input [block_num, block_size, n, d] - block_table_ptr, # block_table [B, max_blocks] - sparse_indices_ptr, # sparse_indices [B, N, TOPK] - k_sparse_out_ptr, v_sparse_out_ptr, # BNSD output [B, N, TOPK, d] - stride_k_pa_bn, stride_k_pa_bs, stride_k_pa_n, stride_k_pa_d, # K PA strides - stride_k_rope_pa_bn, stride_k_rope_pa_bs, stride_k_rope_pa_n, stride_k_rope_pa_d, # K_rope PA strides - stride_v_pa_bn, stride_v_pa_bs, stride_v_pa_n, stride_v_pa_d, # V PA strides - stride_bt_b, stride_bt_blk, # block_table strides - stride_si_b, stride_si_n, stride_si_topk, # sparse_indices strides - stride_out_b, stride_out_n, stride_out_topk, stride_out_d, # output strides - stride_v_b, stride_v_n, stride_v_topk, stride_v_d, - BLOCK_DK: tl.constexpr, - BLOCK_DV: tl.constexpr, - BLOCK_DK_ROPE: tl.constexpr, # 0 if no rope - TOPK: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - B: tl.constexpr, -): - """ - Fused kernel: PA_BSND + Rope Concat -> BNSD Sparse - Input: K/V in PA_BSND format, K_rope in PA_BSND format - Output: K/V_sparse in BNSD format - """ - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - - # Process (b, n, topk) combinations - for b_idx in range(B): - b = b_idx # sparse_indices is [B, N, TOPK], assume B=1 for now - for idx in range(pid, TOPK, num_programs): - # Get batch and sparse index from sparse_indices - n = 0 # KV_N = 1 - - # Load sparse index - sparse_idx = tl.load(sparse_indices_ptr + b * stride_si_b + n * stride_si_n + idx * stride_si_topk) - - # Map sparse_idx to PA_BSND position - block_id = sparse_idx // BLOCK_SIZE # Which block - bs_offset = sparse_idx % BLOCK_SIZE # Offset within block - - # Get actual block ID from block_table - actual_block_id = tl.load(block_table_ptr + b * stride_bt_b + block_id * stride_bt_blk) - - # Compute PA_BSND offset for K - k_pa_offset = (actual_block_id * stride_k_pa_bn + - bs_offset * stride_k_pa_bs + - n * stride_k_pa_n) - - # Compute PA_BSND offset for K_rope - k_rope_pa_offset = (actual_block_id * stride_k_rope_pa_bn + - bs_offset * stride_k_rope_pa_bs + - n * stride_k_rope_pa_n) - - # Compute PA_BSND offset for V - v_pa_offset = (actual_block_id * stride_v_pa_bn + - bs_offset * stride_v_pa_bs + - n * stride_v_pa_n) - # Load K vector (no rope part) - k_vec = tl.load( - k_pa_ptr + k_pa_offset + - tl.arange(0, BLOCK_DK) * stride_k_pa_d - ) - - # Load V vector - v_vec = tl.load( - v_pa_ptr + v_pa_offset + - tl.arange(0, BLOCK_DV) * stride_v_pa_d - ) - # Output to BNSD format: [B, N, TOPK, D] - out_offset = b * stride_out_b + n * stride_out_n + idx * stride_out_topk - out_offset_v = b* stride_v_b + n *stride_v_n + idx*stride_v_topk - - if BLOCK_DK_ROPE > 0: - # Load K_rope vector - full_k = tl.full((BLOCK_DK + BLOCK_DK_ROPE,), 0.0, dtype=tl.float16) - k_rope_vec = tl.load( - k_rope_pa_ptr + k_rope_pa_offset + - tl.arange(0, BLOCK_DK_ROPE) * stride_k_rope_pa_d - ) - full_k = tle.dsa.insert_slice(full_k, k_vec, offsets=(0,), sizes=(BLOCK_DK,), strides=(1,)) - full_k = tle.dsa.insert_slice(full_k, k_rope_vec, offsets=(BLOCK_DK,), sizes=(BLOCK_DK_ROPE,), strides=(1,)) - tl.store( - k_sparse_out_ptr + out_offset + - tl.arange(0, BLOCK_DK + BLOCK_DK_ROPE) * stride_out_d, - full_k - ) - else: - # No rope, store K directly - tl.store( - k_sparse_out_ptr + out_offset + - tl.arange(0, BLOCK_DK) * stride_out_d, - k_vec - ) - - # Store V - tl.store( - v_sparse_out_ptr + out_offset_v + - tl.arange(0, BLOCK_DV) * stride_v_d, - v_vec - ) - - -def triton_fused_pa_rope_to_sparse(k_pa, k_rope_pa, v_pa, block_table, sparse_indices, block_size): - """ - Fused PA_BSND + Rope Concat -> BNSD Sparse conversion - - Args: - k_pa: Key in PA_BSND format [block_num, block_size, n, dk] - k_rope_pa: Key rope in PA_BSND format [block_num, block_size, n, d_rope], None if no rope - v_pa: Value in PA_BSND format [block_num, block_size, n, dv] - block_table: Block table [B, max_blocks] - sparse_indices: Sparse indices [B, N, TOPK] - block_size: Block size for PA format - - Returns: - k_sparse: Sparse key in BNSD format [B, N, TOPK, dk+d_rope] - v_sparse: Sparse value in BNSD format [B, N, TOPK, dv] - """ - block_num, _, n, dk = k_pa.shape - B = block_table.shape[0] - TOPK = sparse_indices.size(-1) - N = 1 # KV_N = 1 - _, _, _, dv = v_pa.shape - - has_rope = k_rope_pa is not None - dk_rope = k_rope_pa.shape[-1] if has_rope else 0 - dk_total = dk + dk_rope - - # Output BNSD format [B, N, TOPK, D] - k_sparse = torch.empty((B, N, TOPK, dk_total), dtype=k_pa.dtype, device=DEVICE) - v_sparse = torch.empty((B, N, TOPK, dv), dtype=v_pa.dtype, device=DEVICE) - - # Grid: use 48 programs for parallelism - grid = (min(48, TOPK),) - - # sparse_indices input format: [T, N, TOPK] or [B, N, TOPK] - # No squeeze needed - kernel expects [B, N, TOPK] format - sparse_indices_input = sparse_indices - if sparse_indices.dim() == 2: - # If already 2D [B, TOPK], reshape to [B, 1, TOPK] - sparse_indices_input = sparse_indices.unsqueeze(1) - - # Set k_rope_pa to k_pa if no rope (dummy pointer, won't be accessed) - k_rope_pa_input = k_rope_pa if has_rope else k_pa - fused_pa_rope_to_sparse_kernel[grid]( - k_pa, k_rope_pa_input, v_pa, - block_table, - sparse_indices_input, - k_sparse, v_sparse, - k_pa.stride(0), k_pa.stride(1), k_pa.stride(2), k_pa.stride(3), - k_rope_pa_input.stride(0), k_rope_pa_input.stride(1), k_rope_pa_input.stride(2), k_rope_pa_input.stride(3), - v_pa.stride(0), v_pa.stride(1), v_pa.stride(2), v_pa.stride(3), - block_table.stride(0), block_table.stride(1), - sparse_indices_input.stride(0), sparse_indices_input.stride(1), sparse_indices_input.stride(2), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - BLOCK_DK=dk, - BLOCK_DV=dv, - BLOCK_DK_ROPE=dk_rope, - TOPK=TOPK, - BLOCK_SIZE=block_size, - B = B - ) - - return k_sparse, v_sparse - -@triton.jit -def gather_kv_bnsd_vec_kernel( - k_ptr, v_ptr, ind_ptr, - k_out_ptr, v_out_ptr, - stride_kb, stride_kn, stride_ks, stride_kd, - stride_vb, stride_vn, stride_vs, stride_vd, - stride_ob, stride_on, stride_os, stride_od, - stride_ovb, stride_ovn, stride_ovs, stride_ovd, - BLOCK_DK: tl.constexpr, - BLOCK_DV: tl.constexpr, - TOPK: tl.constexpr, - B: tl.constexpr, -): - end = TOPK // 48 * 48 - for b_idx in range(B): - # 分批处理所有TOPK个索引,每次48个 - for batch_start in range(0, end, 48): - pid_k = tl.program_id(0) + batch_start - - # 读 index - idx = tl.load(ind_ptr + pid_k) - - # 加载 K 向量 [BLOCK_DK] - 直接线性加载 - k_src_off = idx * stride_ks + b_idx * stride_kb - k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) - - # 加载 V 向量 [BLOCK_DV] - 直接线性加载 - v_src_off = idx * stride_vs + b_idx * stride_vb - v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) - - # 写回 K: [B, N, TOPK, Dk] - k_dst_off = pid_k * stride_os + b_idx * stride_ob - tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) - - # 写回 V: [B, N, TOPK, Dv] - v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb - tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) - - # 处理余数部分(end到TOPK) - for batch_start in range(end, TOPK, 48): - pid_k = tl.program_id(0) + batch_start - - # 必须在计算pid_k之后检查边界 - if pid_k < TOPK: - idx = tl.load(ind_ptr + pid_k) - - # 加载 K 向量 [BLOCK_DK] - 直接线性加载 - k_src_off = idx * stride_ks + b_idx * stride_kb - k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) - - # 加载 V 向量 [BLOCK_DV] - 直接线性加载 - v_src_off = idx * stride_vs + b_idx * stride_vb - v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) - - # 写回 K: [B, N, TOPK, Dk] - k_dst_off = pid_k * stride_os + b_idx * stride_ob - tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) - - # 写回 V: [B, N, TOPK, Dv] - v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb - tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) - -def triton_gather_kv_bnsd_vec(k, v, indices): - B, N, SK, Dk = k.shape # N=1 - B, N, SK, Dv = v.shape # N=1 - TOPK = indices.size(-1) - - # 输出保持 bnsd [B, N, TOPK, D] - k_sparse = torch.empty((B, N, TOPK, Dk), dtype=k.dtype, device=DEVICE) - v_sparse = torch.empty((B, N, TOPK, Dv), dtype=v.dtype, device=DEVICE) - - grid = (48,) # TOPK 个 program,每个搬 Dk/Dv 元素 - gather_kv_bnsd_vec_kernel[grid]( - k, v, indices.squeeze(0), # [B, N, SK, D] -> [N, SK, D] - k_sparse, v_sparse, - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - BLOCK_DK=Dk, - BLOCK_DV=Dv, - TOPK=TOPK, - B=B, - ) - return k_sparse, v_sparse - -@triton.jit -def _attn_fwd( - Q, K, V, O, scale_value, - stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, - stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, - stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, - stride_ob: tl.constexpr, stride_os: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, - B: tl.constexpr, - Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, - KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, - sparse_mode: tl.constexpr, # 0 or 3 - O_N:tl.constexpr, O_D: tl.constexpr, - actual_seq_lengths_query, - actual_seq_lengths_kv, - blk_size: tl.constexpr, - Q_BLOCK_SIZE: tl.constexpr, - ): - # total b * n tasks - BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE - NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM - pid = tl.program_id(0) - num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) - - #最外层循环,沿b*n切 - for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 - off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 - off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 - off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 - # off_n = 0 - - q_offset = off_b * stride_qb + off_s * stride_qs - o_offset = off_b * stride_ob + off_s * stride_os - k_offset = off_b * stride_kb # KV_N = 1 - v_offset = off_b * stride_vb - - cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) - - for i in range(cur_act_s_q): - cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) - logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) - acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] - - # load q - q_block_ptr = tl.make_block_ptr(base = Q + q_offset, - shape = (Q_N, Q_D), - strides = (stride_qn, stride_qd), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, Q_D), - order = (1, 0)) - q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] - k_block_ptr = tl.make_block_ptr(base = K + k_offset, - shape = (KV_S, K_D), - strides = (stride_ks, stride_kd), - offsets = (0, 0), - block_shape = (blk_size, K_D), - order = (1, 0),) - v_block_ptr = tl.make_block_ptr(base = V + v_offset, - shape = (KV_S, V_D), - strides = (stride_vs, stride_vd), - offsets = (0, 0), - block_shape = (blk_size, V_D), - order = (1, 0)) - - for k_idx in range(KV_S // blk_size): - # load k - k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) - - # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] - qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] - # online softmax update - # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. - block_max = tl.max(qk, axis=1) # [q_block_size] - # align shapes to (q_block_size, 1) for broadcasting - # block_max = block_max[:, None] # [q_block_size, 1] - new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] - coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] - p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] - # logsum per row - logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] - - # update accumulator: compute per-row pv by summing over block dim - v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] - pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] - acc = acc*coeff[:,None] + pv # [q_block_size, V_D] - cur_max = new_max - - k_block_ptr = k_block_ptr.advance((blk_size, 0)) - v_block_ptr = v_block_ptr.advance((blk_size, 0)) - - o_block_ptr = tl.make_block_ptr(base = O + o_offset, - shape = (O_N, O_D), - strides = (stride_on, stride_od), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, O_D), - order = (1,0)) - # final normalize - acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] - tl.store(o_block_ptr, acc) - - - -@triton.jit -def _attn_fwd_fused_bsnd_to_tnd( - Q, K, V, O, scale_value, - stride_qb: tl.constexpr, stride_qs: tl.constexpr, stride_qn: tl.constexpr, stride_qd: tl.constexpr, - stride_kb: tl.constexpr, stride_kn: tl.constexpr, stride_ks: tl.constexpr, stride_kd: tl.constexpr, - stride_vb: tl.constexpr, stride_vn: tl.constexpr, stride_vs: tl.constexpr, stride_vd: tl.constexpr, - stride_ot: tl.constexpr, stride_on: tl.constexpr, stride_od: tl.constexpr, - B: tl.constexpr, - Q_N: tl.constexpr, Q_D: tl.constexpr, Q_S: tl.constexpr, - KV_S: tl.constexpr, K_D: tl.constexpr, V_D: tl.constexpr, - sparse_mode: tl.constexpr, # 0 or 3 - O_N:tl.constexpr, O_D: tl.constexpr, - actual_seq_lengths_query, - actual_seq_lengths_kv, - blk_size: tl.constexpr, - Q_BLOCK_SIZE: tl.constexpr, - ): - # total b * n tasks - BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE - NUM_BLOCKS = B *Q_S * BLOCK_QN_NUM - pid = tl.program_id(0) - num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) - - #最外层循环,沿b*n切 - for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 - off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 - off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 - off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 - - q_offset = off_b * stride_qb + off_s * stride_qs - o_offset = off_b * stride_ot - k_offset = off_b * stride_kb # KV_N = 1 - v_offset = off_b * stride_vb - - cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) - - for i in range(cur_act_s_q): - cur_max = tl.full((Q_BLOCK_SIZE,), float('-inf'), dtype=tl.float32) - logSum = tl.zeros((Q_BLOCK_SIZE,), dtype=tl.float32) - acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] - - # load q - q_block_ptr = tl.make_block_ptr(base = Q + q_offset, - shape = (Q_N, Q_D), - strides = (stride_qn, stride_qd), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, Q_D), - order = (1, 0)) - q_vec = tl.load(q_block_ptr, boundary_check=(0,1)) # [q_block_size, K_D] - k_block_ptr = tl.make_block_ptr(base = K + k_offset, - shape = (KV_S, K_D), - strides = (stride_ks, stride_kd), - offsets = (0, 0), - block_shape = (blk_size, K_D), - order = (1, 0),) - v_block_ptr = tl.make_block_ptr(base = V + v_offset, - shape = (KV_S, V_D), - strides = (stride_vs, stride_vd), - offsets = (0, 0), - block_shape = (blk_size, V_D), - order = (1, 0)) - - for k_idx in range(KV_S // blk_size): - # load k - k_vec = tl.load(k_block_ptr, boundary_check=(0,1)) - - # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] - qk = tl.dot(q_vec.to(tl.float16), tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] - # online softmax update - # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. - block_max = tl.max(qk, axis=1) # [q_block_size] - # align shapes to (q_block_size, 1) for broadcasting - # block_max = block_max[:, None] # [q_block_size, 1] - new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] - coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] - p = tl.math.exp(qk - new_max[:,None]) # [q_block_size, blk_size] - # logsum per row - logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] - - # update accumulator: compute per-row pv by summing over block dim - v_vec = tl.load(v_block_ptr, boundary_check=(0,1)) # [blk_size, V_D] - pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] - acc = acc*coeff[:,None] + pv # [q_block_size, V_D] - cur_max = new_max - - k_block_ptr = k_block_ptr.advance((blk_size, 0)) - v_block_ptr = v_block_ptr.advance((blk_size, 0)) - - o_block_ptr = tl.make_block_ptr(base = O + o_offset, - shape = (O_N, O_D), - strides = (stride_on, stride_od), - offsets = (off_n * Q_BLOCK_SIZE, 0), - block_shape = (Q_BLOCK_SIZE, O_D), - order = (1,0)) - # final normalize - acc = acc / logSum[:,None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] - tl.store(o_block_ptr, acc) - - - - -class _attention(torch.autograd.Function): - @staticmethod - def forward( - ctx, - query, - key, - value, - sparse_indices, - scale_value, - sparse_block_size = 1, - actual_seq_lengths_query = None, - actual_seq_lengths_kv = None, - query_rope = None, - key_rope = None, - layout_query = 'BSND', - layout_kv = 'BSND', - sparse_mode = 0, - block_table = None): - # Save original sparse_indices for PA_BSND case - sparse_indices_orig = sparse_indices.clone() - total_len = 0 - # Handle query layout transformation (TND -> BSND) - if layout_query == 'TND': - actual_seq_lengths_query, total_len = trans_tnd_actseq(actual_seq_lengths_query) - # ✅ 融合版本:一次 kernel 调用处理所有 tensor + concat - query, sparse_indices = trans_tnd_to_bsnd_fused( - query, query_rope, sparse_indices, query.shape, actual_seq_lengths_query - ) - else: - if query_rope != None: - query = torch.cat([query, query_rope], dim = -1) - - # Handle KV layout and gather sparse K/V - if layout_kv == 'PA_BSND': - # Fused PA -> BNSD + rope concat + sparse gather - block_size = key.shape[1] # Get block_size from PA shape - # Use original sparse_indices [T, N, TOPK] for fused kernel - k_sparse, v_sparse = triton_fused_pa_rope_to_sparse( - key, key_rope, value, block_table, sparse_indices_orig, block_size - ) - # sparse_indices is already in BSND, needs permute to BNSD for downstream use - sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() - else: - # Original path for non-PA layouts - if key_rope != None: - key = torch.cat([key, key_rope], dim = -1) - key_bnsd = key.permute(0, 2, 1, 3).contiguous() - value_bnsd = value.permute(0, 2, 1, 3).contiguous() - sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() - - k_sparse, v_sparse = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) - - k_sparse = k_sparse.contiguous() - v_sparse = v_sparse.contiguous() - enable_check_kv_sparse = 0 - if enable_check_kv_sparse: - key = pa_to_bsnd(key, block_table, actual_seq_lengths_kv) - key_rope = pa_to_bsnd(key_rope, block_table, actual_seq_lengths_kv) - value = pa_to_bsnd(value, block_table, actual_seq_lengths_kv) - if key_rope != None: - key = torch.cat([key, key_rope], dim = -1) - key_bnsd = key.permute(0, 2, 1, 3).contiguous() - value_bnsd = value.permute(0, 2, 1, 3).contiguous() - k_sparse_ref, v_sparse_ref = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) - print(f"k_sparse={k_sparse}") - print(f"k_sparse_ref={k_sparse_ref}") - print(f"v_sparse={v_sparse}") - print(f"v_sparse_ref={v_sparse_ref}") - assert torch.allclose(k_sparse, k_sparse_ref, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" - assert torch.allclose(v_sparse, v_sparse_ref, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" - - # expected_k = key_bnsd[:, :, :sparse_size, :].contiguous() - # assert torch.allclose(k_sparse, expected_k, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" - # expected_v = value_bnsd[:, :, :sparse_size, :].contiguous() - # assert torch.allclose(v_sparse, expected_v, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" - num_cores = ascend_aiv_core_nums - sparse_size = sparse_indices_bnsd.shape[-1] # 4 - out_shape_bsnd = list(query.shape) - if query_rope != None: - out_shape_bsnd[-1] = out_shape_bsnd[-1] - query_rope.shape[-1] - B, Q_S, Q_N, Q_D = query.shape - _, _, KV_S, K_D = k_sparse.shape - - if layout_query == 'TND': - # t = B*act_q_s - output = torch.empty((total_len, out_shape_bsnd[2], out_shape_bsnd[3]), device=query.device, dtype=torch.float32) - _attn_fwd_fused_bsnd_to_tnd[(num_cores,)]( - query, k_sparse, v_sparse, output, scale_value, - query.stride(0), query.stride(1), query.stride(2), query.stride(3), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - output.stride(0), output.stride(1), output.stride(2), - B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, - KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], - sparse_mode = sparse_mode, O_N = output.shape[1], O_D = output.shape[2], - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - blk_size=128, Q_BLOCK_SIZE=16,multibuffer=False - ) - - else: - output = torch.empty(out_shape_bsnd, device=query.device, dtype=torch.float32) - _attn_fwd[(num_cores,)]( - query, k_sparse, v_sparse, output, scale_value, - query.stride(0), query.stride(1), query.stride(2), query.stride(3), - k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), - v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), - output.stride(0), output.stride(1), output.stride(2), output.stride(3), - B = B, Q_N = Q_N, Q_D = Q_D, Q_S = Q_S, - KV_S = KV_S, K_D = K_D, V_D = v_sparse.shape[3], - sparse_mode = sparse_mode, O_N = output.shape[2], O_D = output.shape[3], - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - blk_size=128, Q_BLOCK_SIZE=16 - ) - output = output.permute(0, 2, 1, 3).contiguous() - - ctx.save_for_backward(query, k_sparse, v_sparse, output) - ctx.scale_value = scale_value - return output - -def pa_to_bsnd(pa_in, block_table, actual_seq_lengths): - block_num, block_size, n, d = pa_in.shape - b = len(actual_seq_lengths) - output = torch.empty((b, block_num * block_size // b, 1, d), dtype = pa_in.dtype).to(DEVICE) - for i in range(b): - for j in range(20): - output[i, j * block_size: (j + 1) * block_size, 0, :] = \ - pa_in[block_table[i][j], :, 0, :].reshape(block_size, d) - return output - - -@triton.jit -def trans_tnd_to_bsnd_fused_kernel( - query_ptr, query_rope_ptr, sparse_ptr, - query_out_ptr, sparse_out_ptr, # query_out 已经拼接了 rope - act_s, - stride_q_t, stride_q_tn, stride_q_td, - stride_qr_t, stride_qr_tn, stride_qr_td, - stride_s_t, stride_s_tn, stride_s_td, - stride_qob, stride_qobs, stride_qon, stride_qod, # query_out strides - stride_sb, stride_sbs, stride_sbn, stride_sbd, - B: tl.constexpr, - N: tl.constexpr, - D_QUERY: tl.constexpr, - D_ROPE: tl.constexpr, - D_SPARSE: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D_QUERY: tl.constexpr, - BLOCK_D_ROPE: tl.constexpr, - BLOCK_D_SPARSE: tl.constexpr, -): - pid = tl.program_id(0) - num_programs = tl.num_programs(0) - - # 计算 head 的总块数 - num_head_blocks = (N + BLOCK_N - 1) // BLOCK_N - t_idx = tl.full((), 0, dtype=tl.int64) # TODO: 需要正确的 token 映射 - # 每个 pid 负责处理特定的 (batch, head_block) 组合 - for tn_id in range(B): - # sparse_indices 是单头的,只在第一个 head_block 处理一次 - if pid == 0: - sparse_block_ptr = tl.make_block_ptr(base = sparse_ptr + t_idx * stride_s_t, - shape = (1, D_SPARSE), - strides = (stride_s_tn, stride_s_td), - offsets = (0, 0), - block_shape = (1, D_SPARSE), - order = (1, 0)) - sparse = tl.load(sparse_block_ptr) - - sparse_out_block_ptr = tl.make_block_ptr(base = sparse_out_ptr + t_idx * stride_sb, - shape = (1, D_SPARSE), - strides = ( stride_sbn, stride_sbd), - offsets = (0, 0), - block_shape = (1, D_SPARSE), - order = (1, 0)) - tl.store(sparse_out_block_ptr, sparse) - - # query 和 query_rope 是多头的,需要在 head 维度上分块处理 - for head_block_id in range(pid, num_head_blocks, num_programs): - n_offset = head_block_id * BLOCK_N - - # Load q and q_ro - q_block_ptr = tl.make_block_ptr(base = query_ptr + t_idx * stride_q_t, - shape = (N, D_QUERY), - strides = (stride_q_tn, stride_q_td), - offsets = (n_offset, 0), - block_shape = (BLOCK_N, D_QUERY), - order = (1, 0)) - q_ro_block_ptr = tl.make_block_ptr(base = query_rope_ptr + t_idx * stride_qr_t, - shape = (N, D_ROPE), - strides = (stride_qr_tn, stride_qr_td), - offsets = (n_offset, 0), - block_shape = (BLOCK_N, D_ROPE), - order = (1, 0)) - q = tl.load(q_block_ptr) - q_ro = tl.load(q_ro_block_ptr) - - # Combine query and query_rope using insert_slice, then store in one operation - full_q = tl.zeros((BLOCK_N, D_QUERY + D_ROPE), dtype=query_out_ptr.dtype.element_ty) - full_q = tle.dsa.insert_slice(full_q, q, offsets=(0, 0), sizes=(BLOCK_N, D_QUERY), strides=(1, 1)) - full_q = tle.dsa.insert_slice(full_q, q_ro, offsets=(0, D_QUERY), sizes=(BLOCK_N, D_ROPE), strides=(1, 1)) - - q_out_block_ptr = tl.make_block_ptr(base = query_out_ptr + t_idx * stride_qob, - shape = (N, D_QUERY + D_ROPE), - strides = (stride_qon, stride_qod), - offsets = (n_offset, 0), - block_shape = (BLOCK_N, D_QUERY + D_ROPE), - order = (1, 0)) - tl.store(q_out_block_ptr, full_q) - t_idx = t_idx + tl.load(act_s + tn_id) - - -def trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, shape, act_seq, grid=(16,)): - """ - 融合版本的 TND -> BSND 转换(包含 concat) - 一次性处理 query, query_rope, sparse_indices,并拼接 query + query_rope - """ - t, n, d_query = shape - b = len(act_seq) - s = max(act_seq) - - # 获取各个 tensor 的维度 - d_rope = query_rope.shape[2] if query_rope is not None else 0 - d_sparse = sparse_indices.shape[2] - d_query_out = d_query + d_rope # 拼接后的维度 - - # 分配输出(query_out 已经包含 rope) - query_out = torch.empty((b, s, n, d_query_out), dtype=query.dtype, device=query.device) - sparse_out = torch.empty((b, s, 1, d_sparse), dtype=sparse_indices.dtype, device=sparse_indices.device) - assert sparse_indices.shape[1] == 1, "sparse_indices second dim must be 1 when MLA" - # 启动 fused kernel - # 使用较小的 BLOCK_N 避免内存溢出 - block_n = min(16, n) - # 计算需要的核心数:使用多核心并行处理不同的头 - num_head_blocks = (n + block_n - 1) // block_n - num_programs = min(ascend_aiv_core_nums, num_head_blocks) # 最多使用24个核心 - - trans_tnd_to_bsnd_fused_kernel[num_programs,]( - query, query_rope, sparse_indices, - query_out, sparse_out, - act_seq, - query.stride(0), query.stride(1), query.stride(2), - query_rope.stride(0), query_rope.stride(1), query_rope.stride(2), - sparse_indices.stride(0), sparse_indices.stride(1), sparse_indices.stride(2), - query_out.stride(0), query_out.stride(1), query_out.stride(2), query_out.stride(3), - sparse_out.stride(0), sparse_out.stride(1), sparse_out.stride(2), sparse_out.stride(3), - B=b, - N=n, - D_QUERY=d_query, - D_ROPE=d_rope, - D_SPARSE=d_sparse, - BLOCK_N=block_n, - BLOCK_D_QUERY=d_query, - BLOCK_D_ROPE=d_rope, - BLOCK_D_SPARSE=d_sparse, - ) - return query_out, sparse_out - - -def trans_tnd_actseq(seq): - if isinstance(seq, torch.Tensor): - seq = seq.cpu().tolist() - list_len = len(seq) - output = [] - output = [seq[0]] - total_len = seq[0] - for i in range(list_len - 1): - new_item = seq[i + 1] - seq[i] - if new_item >= 0: - output.append(new_item) - total_len += new_item - else: - print(f"[ERROR]trans_tnd_actseq: Wrong input actseq:{seq}, in loop {i}, item {new_item} < 0") - return torch.tensor(output).to(DEVICE), total_len - -def sparse_attention(query, key, value, - sparse_indices, scale_value, sparse_block_size = 1, - actual_seq_lengths_query = None, actual_seq_lengths_kv = None, - query_rope = None, key_rope = None, - layout_query = 'BSND', layout_kv = 'BSND', - sparse_mode = 0, block_table = None): - return _attention.apply(query, key, value, - sparse_indices, scale_value, sparse_block_size, - actual_seq_lengths_query, actual_seq_lengths_kv, - query_rope, key_rope, - layout_query, layout_kv, - sparse_mode, block_table) - -def test_op(T, B, KV_S, Q_N, KV_N, D, D_rope, - sparse_size, scale_value, - sparse_block_size, sparse_mode, block_size, act_kv_s): - assert sparse_size <= KV_S - assert KV_N == 1 - assert sparse_mode == 0 or 3 - assert sparse_block_size == 1 - assert (B * KV_S) % block_size == 0 - assert D == 512 - assert D_rope == 0 or 64 - print("*batch_size=",B) - qkv_dtype = torch.float16 - #sparse_size = KV_S - query = torch.empty((T, Q_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - key = torch.empty((B * KV_S // block_size, block_size, KV_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - value = key.clone() - - act_q_s = T // B # step - # rand_vals = torch.rand(T, KV_N, act_kv_s, device=DEVICE) - # _, indices = torch.topk(rand_vals, sparse_size, dim=-1) #sparse_indices不重复 - # sparse_indices = indices.to(torch.int32) - sparse_indices = torch.arange(sparse_size, device=DEVICE, dtype=torch.int32).view(1, 1, -1).expand(T, KV_N, -1) - sparse_indices = sparse_indices.to(torch.int32) - # print("sparse_indices=", sparse_indices) - actual_seq_lengths_query = torch.arange(1, B + 1, dtype=torch.int32, device=DEVICE) - # actual_seq_lengths_query = torch.tensor([1]).reshape(B).to(torch.int32).to(DEVICE) - actual_seq_lengths_kv = torch.tensor([act_kv_s] * B, dtype=torch.int32, device=DEVICE) - print(actual_seq_lengths_kv) - block_table = torch.tensor([range(B * KV_S // block_size)], dtype=torch.int32, device=DEVICE).reshape(B, -1) - - if D_rope == 0: - query_rope = None - key_rope = None - else: - query_rope = torch.empty((T, Q_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - key_rope = torch.empty((B * KV_S // block_size, block_size, KV_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() - - print("q.shape=",query.shape) - print("k.shape=",key.shape) - print("v.shape=",value.shape) - print("sparse_indices.shape=",sparse_indices.shape) - print("act_seq_query=",actual_seq_lengths_query) - print("act_seq_kv=", actual_seq_lengths_kv) - - - triton_out = sparse_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv= 'PA_BSND', - sparse_mode = sparse_mode, - block_table= block_table, - ) - npu_out = torch_npu.npu_sparse_flash_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv = 'PA_BSND', - sparse_mode = sparse_mode, - block_table = block_table, - # attention_mode = 2, - ) - triton_out = triton_out.to(npu_out.dtype) - torch.testing.assert_close(triton_out, npu_out, rtol=1e-2, atol=1e-2, equal_nan=True) - print("[PASSED]") - - # benchmarking - triton_time = do_bench_npu(lambda:sparse_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv= 'PA_BSND', - sparse_mode = sparse_mode, - block_table = block_table, - ), clear_l2_cache=True, collect_prof=False) - print(f"[Triton SFA] Time: {triton_time:.4f} us") - - npu_time = do_bench_npu(lambda:torch_npu.npu_sparse_flash_attention( - query = query, - key = key, - value = value, - sparse_indices = sparse_indices, - scale_value = scale_value, - sparse_block_size = sparse_block_size, - actual_seq_lengths_query = actual_seq_lengths_query, - actual_seq_lengths_kv = actual_seq_lengths_kv, - query_rope = query_rope, - key_rope = key_rope, - layout_query = 'TND', - layout_kv = 'PA_BSND', - sparse_mode = sparse_mode, - block_table = block_table, - # attention_mode = 2, - ), clear_l2_cache=True, collect_prof=False) - print(f"[Torch-NPU SFA] Time: {npu_time:.4f} us") - -if __name__ == "__main__": - print(torch_npu.__version__) - print("Test Real Case in DS-v3.2-Exp") - print(f"time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - i = 1 - print(f"====================第{i}次测试=================") - test_op(T = 1, B = 1, KV_S = 2560, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) - i += 1 - print(f"====================第{i}次测试=================") - test_op(T = 4, B = 4, KV_S = 6400, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) - i += 1 - print(f"====================第{i}次测试=================") - test_op(T = 8, B = 8, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) - i += 1 - print(f"====================第{i}次测试=================") - test_op(T = 16, B = 16, KV_S = 48000, Q_N = 128, KV_N = 1, D = 512, D_rope = 64, - sparse_size = 2048, scale_value = 0.5, sparse_block_size = 1, sparse_mode = 0, - block_size = 128, act_kv_s = 2560) diff --git a/third_party/ascend/AscendNPU-IR b/third_party/ascend/AscendNPU-IR deleted file mode 160000 index 5a3921f87..000000000 --- a/third_party/ascend/AscendNPU-IR +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5a3921f87197bad7f4c8037648c9935f205fae35 diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td index 8cad0f1e5..f024cfa81 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1380,4 +1380,5 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpIn let hasVerifier = 1; } + #endif // Triton_OPS diff --git a/third_party/ascend/backend/testing.py b/third_party/ascend/backend/testing.py index f763be03c..65d5968dd 100644 --- a/third_party/ascend/backend/testing.py +++ b/third_party/ascend/backend/testing.py @@ -87,6 +87,7 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None _rm_dic(keep_res, torch_path) return time_cost + # keep the original behavior to get the statistics for the specified kernel func def _collect_single(base_dir: str, key: str = None) -> float: if not os.path.exists(base_dir): @@ -126,6 +127,7 @@ def _collect_single(base_dir: str, key: str = None) -> float: return float("inf") + def _rm_dic(keep_res, torch_path): if keep_res: return diff --git a/third_party/tle/dsa/CMakeLists.txt b/third_party/tle/dsa/CMakeLists.txt index f9768c71f..669ebf813 100644 --- a/third_party/tle/dsa/CMakeLists.txt +++ b/third_party/tle/dsa/CMakeLists.txt @@ -12,7 +12,7 @@ if (TRITON_BUILD_PYTHON_MODULE) TleIR TritonIR ) - + find_package(Python3 REQUIRED COMPONENTS Development Interpreter) find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") include_directories(${Python3_INCLUDE_DIRS}) @@ -20,4 +20,4 @@ if (TRITON_BUILD_PYTHON_MODULE) link_directories(${Python3_LIBRARY_DIRS}) link_libraries(${Python3_LIBRARIES}) add_link_options(${Python3_LINK_OPTIONS}) -endif() \ No newline at end of file +endif() diff --git a/third_party/tle/dsa/dialect/include/CMakeLists.txt b/third_party/tle/dsa/dialect/include/CMakeLists.txt index 85a8512e3..83a37f590 100644 --- a/third_party/tle/dsa/dialect/include/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/include/CMakeLists.txt @@ -2,4 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(Conversion) -add_subdirectory(IR) \ No newline at end of file +add_subdirectory(IR) diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h index 0c6388b4a..1c8c45c2e 100644 --- a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h @@ -1,8 +1,9 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ -#define TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ +#define TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -10,7 +11,6 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -27,14 +27,14 @@ class CopyConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; + ConversionPatternRewriter &rewriter) const override; }; -} +} // namespace TleCopyConverter namespace mlir::triton::tle { void populateTleCopyOpConversionPatterns(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns); + mlir::RewritePatternSet &patterns); } -#endif \ No newline at end of file +#endif diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h index 789231e6d..c08e02da3 100644 --- a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h @@ -6,11 +6,11 @@ #include "bishengir/Dialect/HIVM/IR/HIVM.h" #endif +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" @@ -22,96 +22,94 @@ using namespace mlir; template class BinaryMathConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto result = adaptor.getRes(); - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - - if (result.getType() != lhs.getType() || - result.getType() != rhs.getType()) { - op->emitError("Unexpected binary calculation type!"); - return failure(); - } - - auto binOp = rewriter.create( - loc, - TypeRange{}, - ValueRange{lhs, rhs}, - ValueRange{result} - ); - - rewriter.replaceOp(op, binOp); - return success(); + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = adaptor.getRes(); + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (result.getType() != lhs.getType() || + result.getType() != rhs.getType()) { + op->emitError("Unexpected binary calculation type!"); + return failure(); } + + auto binOp = rewriter.create(loc, TypeRange{}, ValueRange{lhs, rhs}, + ValueRange{result}); + + rewriter.replaceOp(op, binOp); + return success(); + } }; template class UnaryMathConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(MathOp op, PatternRewriter &rewriter) const override { - } + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(MathOp op, + PatternRewriter &rewriter) const override {} }; template class MatMulConverter : public OpConversionPattern { public: - static constexpr llvm::StringLiteral fixpipeAlreadyInserted = - "fixpipe_already_inserted"; - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto inA = adaptor.getInA(); - auto inB = adaptor.getInB(); - auto res = adaptor.getRes(); - - auto sizeAttr = adaptor.getSize(); - if (sizeAttr.size() > 3) { - op->emitError("Unexpected matmul calculation size!"); - return failure(); - } - - auto mAttr = dyn_cast(sizeAttr[0]); - auto nAttr = dyn_cast(sizeAttr[1]); - auto kAttr = dyn_cast(sizeAttr[2]); - Value M = rewriter.create(loc, mAttr.getInt()); - Value N = rewriter.create(loc, nAttr.getInt()); - Value K = rewriter.create(loc, kAttr.getInt()); - - bool initC = adaptor.getInitC(); - auto initCValue = rewriter.create(loc, - /*value*/ initC, /*width*/ 1); - auto newOp = rewriter.create( - loc, - TypeRange{}, // result types - inA, // Matrix A - inB, // Matrix B - initCValue, // init condition - M, // M - K, // K - N, // N - res, // Matrix C - Value{}, // per channel bias - adaptor.getTraA() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose A - adaptor.getTraB() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose B - adaptor.getEnableHf32() ? rewriter.getUnitAttr() : UnitAttr{}// enable hf32 mode - ); - - newOp->setAttr(fixpipeAlreadyInserted, rewriter.getBoolAttr(true)); - rewriter.replaceOp(op, newOp); - return success(); + static constexpr llvm::StringLiteral fixpipeAlreadyInserted = + "fixpipe_already_inserted"; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto inA = adaptor.getInA(); + auto inB = adaptor.getInB(); + auto res = adaptor.getRes(); + + auto sizeAttr = adaptor.getSize(); + if (sizeAttr.size() > 3) { + op->emitError("Unexpected matmul calculation size!"); + return failure(); } + + auto mAttr = dyn_cast(sizeAttr[0]); + auto nAttr = dyn_cast(sizeAttr[1]); + auto kAttr = dyn_cast(sizeAttr[2]); + Value M = rewriter.create(loc, mAttr.getInt()); + Value N = rewriter.create(loc, nAttr.getInt()); + Value K = rewriter.create(loc, kAttr.getInt()); + + bool initC = adaptor.getInitC(); + auto initCValue = + rewriter.create(loc, + /*value*/ initC, /*width*/ 1); + auto newOp = rewriter.create( + loc, TypeRange{}, // result types + inA, // Matrix A + inB, // Matrix B + initCValue, // init condition + M, // M + K, // K + N, // N + res, // Matrix C + Value{}, // per channel bias + adaptor.getTraA() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose A + adaptor.getTraB() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose B + adaptor.getEnableHf32() ? rewriter.getUnitAttr() : UnitAttr{} + // enable hf32 mode + ); + + newOp->setAttr(fixpipeAlreadyInserted, rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, newOp); + return success(); + } }; } // namespace TleMathConverter namespace mlir::triton::tle { void populateTleMathOpConversionPatterns(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns); + mlir::RewritePatternSet &patterns); } -#endif \ No newline at end of file +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt index 2bbc0d99e..919ff5a4e 100644 --- a/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt @@ -16,4 +16,4 @@ set(LLVM_TARGET_DEFINITIONS TleAttrDefs.td) mlir_tablegen(TleAttrDefs.h.inc -gen-attrdef-decls) mlir_tablegen(TleAttrDefs.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(TleTableGen) \ No newline at end of file +add_public_tablegen_target(TleTableGen) diff --git a/third_party/tle/dsa/dialect/include/IR/Dialect.h b/third_party/tle/dsa/dialect/include/IR/Dialect.h index d7a07c85f..b8b241ff2 100644 --- a/third_party/tle/dsa/dialect/include/IR/Dialect.h +++ b/third_party/tle/dsa/dialect/include/IR/Dialect.h @@ -1,7 +1,7 @@ // Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_IR_DIALECT_H_ -#define TRITON_TLE_IR_DIALECT_H_ +#define TRITON_TLE_IR_DIALECT_H_ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -15,4 +15,4 @@ #define GET_OP_CLASSES #include "tle/dsa/dialect/include/IR/TleOps.h.inc" -#endif \ No newline at end of file +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/TleDialect.td b/third_party/tle/dsa/dialect/include/IR/TleDialect.td index ba85b281a..25898cd08 100644 --- a/third_party/tle/dsa/dialect/include/IR/TleDialect.td +++ b/third_party/tle/dsa/dialect/include/IR/TleDialect.td @@ -1,7 +1,7 @@ // Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_DIALECT -#define TRITON_TLE_DIALECT +#define TRITON_TLE_DIALECT include "mlir/IR/OpBase.td" diff --git a/third_party/tle/dsa/dialect/include/IR/TleOps.td b/third_party/tle/dsa/dialect/include/IR/TleOps.td index 984e2f59c..d0cefca38 100644 --- a/third_party/tle/dsa/dialect/include/IR/TleOps.td +++ b/third_party/tle/dsa/dialect/include/IR/TleOps.td @@ -1,7 +1,7 @@ // Copyright 2026- Xcoresigma Technology Co., Ltd #ifndef TRITON_TLE_OPS -#define TRITON_TLE_OPS +#define TRITON_TLE_OPS include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/CommonTypeConstraints.td" @@ -135,7 +135,7 @@ def TLE_DSAMinOp : TLE_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, /// let description = [{ /// $d = matrix_multiply($a, $b) + $c. /// }]; -/// +/// /// let arguments = ( /// ins /// AnyMemRef:$inA, @@ -147,7 +147,7 @@ def TLE_DSAMinOp : TLE_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, /// DefaultValuedAttr:$traB, /// DefaultValuedAttr:$enableHf32 /// ); -/// +/// /// let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; /// } diff --git a/third_party/tle/dsa/dialect/lib/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/CMakeLists.txt index 85a8512e3..83a37f590 100644 --- a/third_party/tle/dsa/dialect/lib/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/CMakeLists.txt @@ -2,4 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(Conversion) -add_subdirectory(IR) \ No newline at end of file +add_subdirectory(IR) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt index 21ea47f5d..af538eac7 100644 --- a/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt @@ -1,3 +1,3 @@ # Copyright 2026- Xcoresigma Technology Co., Ltd -add_subdirectory(TleToLinalg) \ No newline at end of file +add_subdirectory(TleToLinalg) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt index 167620ff5..482fcedec 100644 --- a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt @@ -7,4 +7,4 @@ add_triton_library(TleToLinalg DEPENDS TritonIR TleTableGen -) \ No newline at end of file +) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp index feb3a7c23..089cd8722 100644 --- a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #include "tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h" #if __has_include("bishengir/Dialect/HIVM/IR/HIVM.h") @@ -32,7 +32,7 @@ CopyConverter::CopyConverter(MLIRContext *context) LogicalResult CopyConverter::matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { auto src = adaptor.getSrc(); auto dst = adaptor.getDst(); auto loc = op.getLoc(); @@ -44,8 +44,9 @@ CopyConverter::matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, } llvm::SmallVector shapeValues; - for (auto shape: adaptor.getShape()) { - Value indexShape = rewriter.create(loc, rewriter.getIndexType(), shape); + for (auto shape : adaptor.getShape()) { + Value indexShape = rewriter.create( + loc, rewriter.getIndexType(), shape); shapeValues.push_back(indexShape); } @@ -75,26 +76,29 @@ CopyConverter::matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, Operation *copyOp = nullptr; if (srcAddrSpace == hivm::AddressSpace::GM && - dstAddrSpace == hivm::AddressSpace::UB || + dstAddrSpace == hivm::AddressSpace::UB || srcAddrSpace == hivm::AddressSpace::UB && - dstAddrSpace == hivm::AddressSpace::GM) { + dstAddrSpace == hivm::AddressSpace::GM) { copyOp = rewriter.create(loc, srcSubView, dstSubView); } else if (srcAddrSpace == hivm::AddressSpace::GM && dstAddrSpace == hivm::AddressSpace::L1) { - copyOp = rewriter.create(loc, /*result_tensor=*/TypeRange{}, - /*src=*/srcSubView, /*dst=*/dstSubView, - /*dst_continuous=*/UnitAttr::get(rewriter.getContext())); + copyOp = rewriter.create( + loc, /*result_tensor=*/TypeRange{}, + /*src=*/srcSubView, /*dst=*/dstSubView, + /*dst_continuous=*/UnitAttr::get(rewriter.getContext())); } /// else if (srcAddrSpace == hivm::AddressSpace::L0C && /// dstAddrSpace == hivm::AddressSpace::GM) { /// copyOp = rewriter.create(loc, - /// /*result_tensor=*/TypeRange{}, /*src=*/srcSubView, /*dst=*/dstSubView, + /// /*result_tensor=*/TypeRange{}, /*src=*/srcSubView, + /// /*dst=*/dstSubView, /// /*enable_nz2nd=*/UnitAttr::get(rewriter.getContext()) /// // #ifdef BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES /// /*nullptr, - /// hivm::FixpipeDMAModeAttr::get(rewriter.getContext(), hivm::FixpipeDMAMode::NZ2ND), - /// nullptr, nullptr, nullptr, nullptr, nullptr*/ + /// hivm::FixpipeDMAModeAttr::get(rewriter.getContext(), + /// hivm::FixpipeDMAMode::NZ2ND), nullptr, nullptr, nullptr, nullptr, + /// nullptr*/ /// ); /// } else { @@ -111,7 +115,7 @@ CopyConverter::matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, namespace mlir::triton::tle { void populateTleCopyOpConversionPatterns(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); } -} \ No newline at end of file +} // namespace mlir::triton::tle diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp index 6937e1571..bcdece160 100644 --- a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2025 XCoreSigma Inc. All rights reserved. +// Copyright 2026- Xcoresigma Technology Co., Ltd #include "tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h" #include "tle/dsa/dialect/include/IR/Dialect.h" @@ -7,18 +7,30 @@ namespace TleMathConverter { using namespace mlir; using namespace triton::tle; -} +} // namespace TleMathConverter namespace mlir::triton::tle { void populateTleMathOpConversionPatterns(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns) { - patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); + mlir::RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); - /// patterns.add>(patterns.getContext()); + /// patterns.add>(patterns.getContext()); } -} \ No newline at end of file +} // namespace mlir::triton::tle diff --git a/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt index d370c8313..d7c74bbbd 100644 --- a/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt +++ b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt @@ -10,4 +10,4 @@ add_triton_library(TleIR LINK_LIBS PUBLIC TritonIR MLIRIR -) \ No newline at end of file +) diff --git a/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp index 5bf9562ee..32385c0e1 100644 --- a/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp +++ b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp @@ -1,7 +1,7 @@ // Copyright 2026- Xcoresigma Technology Co., Ltd -#include "mlir/Support/LLVM.h" #include "tle/dsa/dialect/include/IR/Dialect.h" +#include "mlir/Support/LLVM.h" #include "tle/dsa/dialect/include/IR/Dialect.cpp.inc" #define GET_ATTRDEF_CLASSES @@ -15,11 +15,11 @@ void TleDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST #include "tle/dsa/dialect/include/IR/TleAttrDefs.cpp.inc" - >(); + >(); addOperations< #define GET_OP_LIST #include "tle/dsa/dialect/include/IR/TleOps.cpp.inc" - >(); + >(); } -} \ No newline at end of file +} // namespace mlir::triton::tle diff --git a/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp index 5a3eeb16e..2a81191e1 100644 --- a/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp +++ b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp @@ -4,5 +4,4 @@ #include "mlir/IR/Builders.h" #include "tle/dsa/dialect/include/IR/Dialect.h" -namespace mlir::triton::tle { -} \ No newline at end of file +namespace mlir::triton::tle {} diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc index 58dba49e5..b54b3d5a6 100644 --- a/third_party/tle/dsa/tle_ir.cc +++ b/third_party/tle/dsa/tle_ir.cc @@ -24,8 +24,7 @@ constexpr unsigned kIntegerAttrBitWidth = 64; struct DSAOpBuilder : public TritonOpBuilder {}; -void init_triton_tle(py::module &&m) -{ +void init_triton_tle(py::module &&m) { m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; registry.insert(); @@ -35,16 +34,17 @@ void init_triton_tle(py::module &&m) context.loadAllAvailableDialects(); }); - py::class_(m, "tle_builder", py::module_local(), py::dynamic_attr()) - .def(py::init()) - .def("dsa_get_null_attr", [](DSAOpBuilder &self) { return Attribute(); }) - .def("dsa_get_buffer_type", + py::class_( + m, "tle_builder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + .def("dsa_get_null_attr", [](DSAOpBuilder &self) { return Attribute(); }) + .def("dsa_get_buffer_type", [](DSAOpBuilder &self, std::vector &shape, Type &elementType, const Attribute &memorySpace) -> Type { return MemRefType::get(shape, elementType, MemRefLayoutAttrInterface{}, memorySpace); }) - .def("dsa_get_buffer_type_with_strides", + .def("dsa_get_buffer_type_with_strides", [](DSAOpBuilder &self, std::vector &shape, Type &elementType, const std::vector &strides, const Attribute &memorySpace) -> Type { @@ -53,134 +53,141 @@ void init_triton_tle(py::module &&m) self.getBuilder().getContext(), ShapedType::kDynamic, strides); return MemRefType::get(shape, elementType, layout, memorySpace); }) - .def("create_dsa_alloc", - [](DSAOpBuilder &self, Type memrefType) -> Value { - return self.create(mlir::cast(memrefType)); - }) - // Add copy op - .def("create_dsa_copy", - [](DSAOpBuilder &self, Value &src, Value &dst, std::vector &shape, bool inter_no_alias)-> void { - auto copyOp = self.create(src, dst, shape); - if (inter_no_alias) { - copyOp->setAttr("inter_no_alias", self.getBuilder().getBoolAttr(true)); - } - }) - // Add op - .def("create_dsa_add", - [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Sub op - .def("create_dsa_sub", - [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Mul op - .def("create_dsa_mul", - [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Div op - .def("create_dsa_div", - [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Max op - .def("create_dsa_max", - [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Min op - .def("create_dsa_min", - [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { - self.create(lhs, rhs, res); - }) - // Dot op - /// .def("create_dsa_dot", - /// [](DSAOpBuilder &self, Value &inA, Value &inB, Value &res, - /// std::vector &size, bool &initC, bool &traA, bool &traB, - /// bool &enable_hf32) -> void { - /// auto &builder = self.getBuilder(); - /// auto sizeAttr = builder.getI64ArrayAttr(size); + .def("create_dsa_alloc", + [](DSAOpBuilder &self, Type memrefType) -> Value { + return self.create( + mlir::cast(memrefType)); + }) + // Add copy op + .def("create_dsa_copy", + [](DSAOpBuilder &self, Value &src, Value &dst, + std::vector &shape, bool inter_no_alias) -> void { + auto copyOp = self.create(src, dst, shape); + if (inter_no_alias) { + copyOp->setAttr("inter_no_alias", + self.getBuilder().getBoolAttr(true)); + } + }) + // Add op + .def("create_dsa_add", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Sub op + .def("create_dsa_sub", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Mul op + .def("create_dsa_mul", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Div op + .def("create_dsa_div", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Max op + .def("create_dsa_max", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Min op + .def("create_dsa_min", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Dot op + /// .def("create_dsa_dot", + /// [](DSAOpBuilder &self, Value &inA, Value &inB, Value &res, + /// std::vector &size, bool &initC, bool &traA, bool + /// &traB, bool &enable_hf32) -> void { + /// auto &builder = self.getBuilder(); + /// auto sizeAttr = builder.getI64ArrayAttr(size); - /// // convert bool to boolattr. - /// auto initC_attr = builder.getBoolAttr(initC); - /// auto traA_attr = builder.getBoolAttr(traA); - /// auto traB_attr = builder.getBoolAttr(traB); - /// auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); + /// // convert bool to boolattr. + /// auto initC_attr = builder.getBoolAttr(initC); + /// auto traA_attr = builder.getBoolAttr(traA); + /// auto traB_attr = builder.getBoolAttr(traB); + /// auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); - /// self.create(inA, inB, res, sizeAttr, initC_attr, - /// traA_attr, traB_attr, enable_hf32_attr); - /// }) - .def("dsa_to_buffer", - [](DSAOpBuilder &self, Value &src, - const Attribute &addressSpace) -> Value { - auto tensorType = dyn_cast(src.getType()); - if (!tensorType) { - llvm::report_fatal_error("to_buffer: src must be tensor type"); - } - auto memrefType = MemRefType::get( - tensorType.getShape(), tensorType.getElementType(), - MemRefLayoutAttrInterface{}, addressSpace); - return self.create(memrefType, src); - }) - .def("dsa_to_tensor", - [](DSAOpBuilder &self, Value &src, bool writable) -> Value { - const auto &memrefType = mlir::cast(src.getType()); - auto hasAddressSpace = memrefType.getMemorySpace(); - if (hasAddressSpace) { - return self.create( - src, true, writable); - } - return self.create(src, true, writable); - }) - .def("create_dsa_extract_scalar", - [](DSAOpBuilder &self, Value &src, std::vector &indices) -> Value { - llvm::SmallVector arg_indices; - for (const auto &i : indices) { - auto iTy = i.getType(); - if (!iTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), i); - arg_indices.push_back(v); - } else { - arg_indices.push_back(i); - } - } - auto ret = self.create(src, arg_indices); - return ret; - }) + /// self.create(inA, inB, res, sizeAttr, + /// initC_attr, + /// traA_attr, traB_attr, enable_hf32_attr); + /// }) + .def("dsa_to_buffer", + [](DSAOpBuilder &self, Value &src, + const Attribute &addressSpace) -> Value { + auto tensorType = dyn_cast(src.getType()); + if (!tensorType) { + llvm::report_fatal_error("to_buffer: src must be tensor type"); + } + auto memrefType = MemRefType::get( + tensorType.getShape(), tensorType.getElementType(), + MemRefLayoutAttrInterface{}, addressSpace); + return self.create(memrefType, src); + }) + .def("dsa_to_tensor", + [](DSAOpBuilder &self, Value &src, bool writable) -> Value { + const auto &memrefType = mlir::cast(src.getType()); + auto hasAddressSpace = memrefType.getMemorySpace(); + if (hasAddressSpace) { + return self.create(src, true, + writable); + } + return self.create(src, true, writable); + }) + .def("create_dsa_extract_scalar", + [](DSAOpBuilder &self, Value &src, + std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) .def("create_dsa_extract_slice", - [](DSAOpBuilder &self, Value &ful, std::vector &offs_vec, - std::vector &sizs_vec, std::vector &strd_vec) -> Value { - llvm::SmallVector offsets; - for (const auto &o : offs_vec) { - auto oTy = o.getType(); - if (!oTy.isIndex()) { - auto v = self.create( - self.getBuilder().getIndexType(), o); - offsets.push_back(v); - } else { - offsets.push_back(o); - } - } - llvm::SmallVector sizes; - llvm::SmallVector retSizes; - for (const auto &s : sizs_vec) { - auto v = self.create(s); - sizes.push_back(v); - retSizes.push_back(s); - } - llvm::SmallVector strides; - for (const auto &s : strd_vec) { - auto v = self.create(s); - strides.push_back(v); - } - auto retTy = RankedTensorType::get(retSizes, - cast(ful.getType()).getElementType()); + [](DSAOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); - return self.create(retTy, ful, offsets, sizes, strides); - }) + return self.create(retTy, ful, offsets, + sizes, strides); + }) .def("create_dsa_insert_slice", [](DSAOpBuilder &self, Value &ful, Value &sub, std::vector &offs_vec, std::vector &sizs_vec, @@ -285,7 +292,5 @@ void init_triton_tle(py::module &&m) return self.create(source, mixedOffsets, mixedSizes, mixedStrides); - }) - ; - -} \ No newline at end of file + }); +}