diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index fbee99af4c..77831002a2 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -104,6 +104,14 @@ jobs: --ignore=test_assume.py \ --ignore=test_index_select.py popd + # flagtree tle test + pushd python/test/tle + python3 test_vec_add.py + python3 test_vec_add_2d.py + python3 test_vec_add_mix.py + python3 test_vec_mathOps.py + python3 test_tle_with_hints.py + popd - name: FlagTree Editable Build And Teston Ascend if: steps.check_files.outputs.only_docs_changed != 'true' diff --git a/CMakeLists.txt b/CMakeLists.txt index 534647a6d2..f62e2e14c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -291,6 +291,13 @@ if(TRITON_BUILD_PYTHON_MODULE) add_subdirectory(third_party/proton) endif() + # add TLE plugin + # just support ascend backend so far + if(FLAGTREE_BACKEND STREQUAL "ascend") + list(APPEND TRITON_PLUGIN_NAMES "tle") + add_subdirectory(third_party/tle/dsa) + endif() + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) set(TRITON_LIBRARIES diff --git a/python/setup.py b/python/setup.py index 9880184d3d..ac90f6a87b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -701,9 +701,16 @@ def get_packages(): "triton/runtime", "triton/backends", "triton/tools", + + # for flagtree tle + "triton/experimental", + "triton/experimental/tle", + "triton/experimental/tle/language", + "triton/experimental/tle/language/dsa", ] if helper.flagtree_backend and helper.flagtree_backend in helper.configs.language_extra_backends: packages.append(f"triton/language/extra/{helper.get_device_name()}") + packages.append("triton/experimental/tle/language/dsa/{helper.get_device_name()}") packages += helper.get_extra_packages() packages += get_language_extra_packages() packages += [f'triton/backends/{backend.name}' for backend in backends] diff --git a/python/setup_tools/utils/__init__.py b/python/setup_tools/utils/__init__.py index 4f93cd8cc9..a0fda077ba 100644 --- a/python/setup_tools/utils/__init__.py +++ b/python/setup_tools/utils/__init__.py @@ -10,7 +10,7 @@ commit_id="380b87122c88af131530903a702d5318ec59bb33", dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")), "flir": - tools.Module(name="flir", url="https://github.com/FlagTree/flir.git", + tools.Module(name="flir", url="https://github.com/flagos-ai/flir.git", dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")), } diff --git a/python/setup_tools/utils/ascend.py b/python/setup_tools/utils/ascend.py index fc3d8da2fa..957b90588c 100644 --- a/python/setup_tools/utils/ascend.py +++ b/python/setup_tools/utils/ascend.py @@ -15,6 +15,7 @@ def get_extra_install_packages(): "triton/extension", "triton/extension/buffer", "triton/extension/buffer/language", + "triton/experimental/tle/language/dsa/ascend", ] @@ -24,6 +25,10 @@ def get_package_dir(): package_dict["triton/extension"] = ascend_ext_base package_dict["triton/extension/buffer"] = f"{ascend_ext_base}/buffer" package_dict["triton/extension/buffer/language"] = f"{ascend_ext_base}/buffer/language" + + # flagtree tle ascend + flagtree_tle_ascend_base = "../python/triton/experimental/tle/language/dsa" + package_dict["triton/experimental/tle/language/dsa/ascend"] = f"{flagtree_tle_ascend_base}/ascend" return package_dict diff --git a/python/test/tle/test_bind_buffer.py b/python/test/tle/test_bind_buffer.py new file mode 100644 index 0000000000..b8cd1e7754 --- /dev/null +++ b/python/test/tle/test_bind_buffer.py @@ -0,0 +1,41 @@ +import triton +import triton.experimental.tle as tle +import triton.language as tl + +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir, tle as tle_ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + tle_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def bind_buffer(): + # tle.dsa.ascend.UB is triton.language.extra.extension.cann.core.ascend_address_space.UB + buffer1 = tle.dsa.alloc(shape=[32, 32], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + tle.dsa.to_tensor(buffer1, writable=True) + + +if __name__ == "__main__": + print("=" * 60) + mlir = compile_kernel(bind_buffer, {}, {}) + print(mlir) diff --git a/python/test/tle/test_tle_with_hints.py b/python/test/tle/test_tle_with_hints.py new file mode 100644 index 0000000000..2f4ab5974d --- /dev/null +++ b/python/test/tle/test_tle_with_hints.py @@ -0,0 +1,65 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd +import torch +import triton +import torch_npu # noqa +import triton.language as tl +import triton.experimental.tle as tle + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # mem_addr_space is language.extra.cann.core.ascend_address_space + with tle.dsa.hint(inter_no_alias=True): + with tle.dsa.hint(test_k1=10): + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + with tle.dsa.hint(test_k2=20): + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + tle.dsa.add(a_ub, b_ub, c_ub) + tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + + +def custom_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) + return output + + +def test_add(): + torch.manual_seed(0) + size = 1024 + x = torch.rand(size, dtype=torch.float).npu() + y = torch.rand(size, dtype=torch.float).npu() + output_torch = x + y + output_triton = custom_func(x, y) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + from triton.backends.ascend.testing import do_bench_npu + bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False) + bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False) + print(f"torch time : {bench_torch:.2f}") + print(f"triton time: {bench_triton:.2f}") + + +if __name__ == "__main__": + test_add() diff --git a/python/test/tle/test_vec_add.py b/python/test/tle/test_vec_add.py new file mode 100644 index 0000000000..3d0098b740 --- /dev/null +++ b/python/test/tle/test_vec_add.py @@ -0,0 +1,62 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd +import torch +import torch_npu # noqa +import triton +import triton.language as tl +import triton.experimental.tle as tle + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # mem_addr_space is language.extra.cann.core.ascend_address_space + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + tle.dsa.add(a_ub, b_ub, c_ub) + tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + + +def custom_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) + return output + + +def test_add(): + torch.manual_seed(0) + size = 1024 + x = torch.rand(size, dtype=torch.float).npu() + y = torch.rand(size, dtype=torch.float).npu() + output_torch = x + y + output_triton = custom_func(x, y) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + from triton.backends.ascend.testing import do_bench_npu + bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False) + bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False) + print(f"torch time : {bench_torch:.2f}") + print(f"triton time: {bench_triton:.2f}") + + +if __name__ == "__main__": + test_add() diff --git a/python/test/tle/test_vec_add_2d.py b/python/test/tle/test_vec_add_2d.py new file mode 100644 index 0000000000..10ddd4ca4c --- /dev/null +++ b/python/test/tle/test_vec_add_2d.py @@ -0,0 +1,78 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd +import torch +import torch_npu # noqa +import triton +import triton.language as tl +import triton.experimental.tle as tle + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + n_cols, n_rows, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # 2D offsets + block_start_m = pid_m * BLOCK_SIZE + block_start_n = pid_n * BLOCK_SIZE + offs_m = block_start_m + tl.arange(0, BLOCK_SIZE) + offs_n = block_start_n + tl.arange(0, BLOCK_SIZE) + + # get address(row-major) + x_ptrs = x_ptr + offs_m[:, None] * n_cols + offs_n[None, :] + y_ptrs = y_ptr + offs_m[:, None] * n_cols + offs_n[None, :] + out_ptrs = output_ptr + offs_m[:, None] * n_cols + offs_n[None, :] + + a_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start_m + t1 = n_elements - block_start_n + tail_size_m = tl.minimum(t0, BLOCK_SIZE) + tail_size_n = tl.minimum(t1, BLOCK_SIZE) + + tle.dsa.copy(x_ptrs, a_ub, [tail_size_m, tail_size_n]) + tle.dsa.copy(y_ptrs, b_ub, [tail_size_m, tail_size_n]) + + tle.dsa.add(a_ub, b_ub, c_ub) + + tle.dsa.copy(c_ub, out_ptrs, [tail_size_m, tail_size_n]) + + +def custom_func(x: torch.Tensor, y: torch.Tensor, size: int): + output = torch.empty_like(x) + n_elements = output.numel() + BLOCK_SIZE = 16 + # grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE)) + add_kernel[grid](x, y, output, n_elements, size, size - 1, BLOCK_SIZE) + return output + + +def test_add(): + torch.manual_seed(0) + size = 128 + x = torch.rand((size, size - 1), dtype=torch.float).npu() + y = torch.rand((size, size - 1), dtype=torch.float).npu() + output_torch = x + y + output_triton = custom_func(x, y, size) + print("============X===========") + print(x) + print("============Y===========") + print(y) + print("============outTorch===========") + print(output_torch) + print("============outTriton===========") + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + +if __name__ == "__main__": + test_add() diff --git a/python/test/tle/test_vec_add_mix.py b/python/test/tle/test_vec_add_mix.py new file mode 100644 index 0000000000..115ab4d848 --- /dev/null +++ b/python/test/tle/test_vec_add_mix.py @@ -0,0 +1,64 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd +import torch +import triton +import torch_npu # noqa +import triton.language as tl +import triton.experimental.tle as tle + + +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + tle.dsa.add(a_ub, b_ub, c_ub) + + c_val = tle.dsa.to_tensor(c_ub) + b_val = tle.dsa.to_tensor(b_ub) + + result = c_val - b_val + + #tl.store(output_ptr + offsets, result) + + d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB) + tle.dsa.copy(d_ub, output_ptr + offsets, [tail_size]) + + +def custom_func(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128) + return output + + +def test_add(): + torch.manual_seed(0) + size = 1024 + x = torch.rand(size, dtype=torch.float).npu() + y = torch.rand(size, dtype=torch.float).npu() + output_torch = x + output_triton = custom_func(x, y) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + +if __name__ == "__main__": + test_add() diff --git a/python/test/tle/test_vec_mathOps.py b/python/test/tle/test_vec_mathOps.py new file mode 100644 index 0000000000..be84bc16bb --- /dev/null +++ b/python/test/tle/test_vec_mathOps.py @@ -0,0 +1,87 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd +import torch +import torch_npu # noqa +import triton +import triton.language as tl +import triton.experimental.tle as tle + + +@triton.jit +def run_test( + x_ptr, + y_ptr, + output_ptr, + n_elements, + OP_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + tle.dsa.copy(x_ptr + offsets, a_ub, [BLOCK_SIZE]) + tle.dsa.copy(y_ptr + offsets, b_ub, [BLOCK_SIZE]) + + if OP_ID == 0: # add + tle.dsa.add(a_ub, b_ub, c_ub) + elif OP_ID == 1: # sub + tle.dsa.sub(a_ub, b_ub, c_ub) + elif OP_ID == 2: # mul + tle.dsa.mul(a_ub, b_ub, c_ub) + elif OP_ID == 3: # div + tle.dsa.div(a_ub, b_ub, c_ub) + + tle.dsa.copy(c_ub, output_ptr + offsets, [BLOCK_SIZE]) + + +OP_REGISTRY = { + 'add': (0, torch.add), + 'sub': (1, torch.sub), + 'mul': (2, torch.mul), + 'div': (3, torch.div), +} + + +def common_test(op_name: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if op_name not in OP_REGISTRY: + raise ValueError(f"Unsupported op: {op_name}") + + op_id, _ = OP_REGISTRY[op_name] + output = torch.empty_like(x) + n_elements = output.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + run_test[grid]( + x, + y, + output, + n_elements, + OP_ID=op_id, + BLOCK_SIZE=128, + ) + return output + + +def test_binary_op(size: int = 1024, dtype=torch.float32): + x = torch.rand(size, dtype=dtype).npu() + y = torch.rand(size, dtype=dtype).npu() + y = y + 0.1 + + print(f"Testing {len(OP_REGISTRY)} operators with size={size}, dtype={dtype}") + + for op_name in OP_REGISTRY: + torch_fn = OP_REGISTRY[op_name][1] + triton_out = common_test(op_name, x, y) + torch_out = torch_fn(x, y) + + max_diff = torch.max(torch.abs(torch_out - triton_out)).item() + status = "SUCCESS" if max_diff < 1e-5 else "FAIL" + print(f"{status} {op_name:8}: max diff = {max_diff:.2e}") + + +if __name__ == "__main__": + test_binary_op(size=1024, dtype=torch.float32) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index d8ca58d8d1..e0838aad99 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,6 +15,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +import importlib def mangle_ty(ty): @@ -1111,6 +1112,15 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: + if fn.__name__ == 'copy': + # extract tle hints from the generator to identify if node in the tle hints scope + tle = importlib.import_module("..experimental.tle", package=__package__) + top_hints = tle.extract_tle_hints_scope(self) + + # Only apply to some builtins; currently, 'copy' is relevant. + if 'inter_no_alias' in top_hints and 'inter_no_alias' not in kws: + kws['inter_no_alias'] = top_hints['inter_no_alias'] + return fn(*args, **extra_kwargs, **kws) except Exception as e: # Normally when we raise a CompilationError, we raise it as diff --git a/python/triton/experimental/__init__.py b/python/triton/experimental/__init__.py new file mode 100644 index 0000000000..a111d0159b --- /dev/null +++ b/python/triton/experimental/__init__.py @@ -0,0 +1 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py new file mode 100644 index 0000000000..1af2373e42 --- /dev/null +++ b/python/triton/experimental/tle/__init__.py @@ -0,0 +1,124 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +from triton._C.libtriton import ir +from typing import Optional, Dict +from triton.runtime import JITFunction +from .language.builder import setup_unified_builder_with_tle_builder +import importlib +import ast +from typing_extensions import override + +try: + from triton._C.libtriton import tle as tle_ir +except ImportError: + raise RuntimeError("tle is not available") + +triton_compiler = importlib.import_module("triton.compiler", package=__package__) + + +def tle_patch_for_triton_compile(): + original_compile_fn = triton_compiler.compile + + def tle_compile(src, target=None, options=None): + # ir.context() will return a new MLIRContext each time, here should keep the same context + cur_context = ir.context() + tle_ir.load_dialects(cur_context) + + original_context_fn = ir.context + + def patched_context(): + return cur_context + + ir.context = patched_context + + try: + compiled_kernel = original_compile_fn(src, target, options) + finally: + ir.context = original_context_fn + + return compiled_kernel + + return tle_compile + + +code_generator = importlib.import_module("triton.compiler.code_generator", package=__package__) + + +class TleCodeGenerator(code_generator.CodeGenerator): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + super().__init__(context, prototype, gscope, attributes, constants, function_name, jit_fn, options, codegen_fns, + module_map, module, is_kernel, function_types, noinline, file_name, begin_line) + self.tle_builder = tle_ir.tle_builder(context) + self.tle_builder.set_loc(file_name, begin_line, 0) + + # Stack to keep track of active `with`-hints (e.g., tle.hint(...)) + # Each entry is a dict mapping hint names to literal values. + self.with_hints = [] + + setup_unified_builder_with_tle_builder(self.builder, self.tle_builder) + + @override + def visit_With(self, node): + assert len(node.items) == 1 + context = node.items[0].context_expr + + # extract tle hints + hints = {} + if isinstance(context, ast.Call): + if isinstance(context.func, ast.Attribute) and context.func.attr == "hint": + for kw in context.keywords: + if not isinstance(kw.value, ast.Constant): + raise self._unsupported(node, + "keyword arguments to hint() are only supported for constant values") + hints[kw.arg] = kw.value.value + + # append hints to with_hints anyway, to indicate that we're in the with scope + self.with_hints.append(hints) + + super().visit_With(node) + + # pop hints to indicate that we're out of the with scope + self.with_hints.pop() + + +def extract_tle_hints_scope(generator: TleCodeGenerator): + """ + with tle.hints(inter_no_alias=True): + with xxxx: + with tle.hints(inter_no_alias=False): + ... + with xxx: + call_fn1(...) + call_fn(...) + + when visit_Call for call_fn1, we can get the hints scope as follows: + [{'inter_no_alias': True}, {xxx}, {'inter_no_alias': False}, {xxx}] + should get the parent scope hints 'inter_no_alias': False for call_fn1, after visit call_fn1, pop the scope + + when visit_Call for call_fn, we can get the hints scope as follows: + [{'inter_no_alias': True}, {xxx}, {'inter_no_alias': False}] + and now the hint scope is 'inter_no_alias': False' for call_fn, after visit call_fn, pop the scope + """ + if not generator.with_hints: + return {} + + # visit with_hints backward to find inter_no_alias hint + for i in range(len(generator.with_hints) - 1, -1, -1): + hints = generator.with_hints[i] + if "inter_no_alias" in hints: + return hints + + return {} + + +triton_compiler.compile = tle_patch_for_triton_compile() +code_generator.CodeGenerator = TleCodeGenerator + +from .language import dsa + +__all__ = [ + "dsa", +] diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py new file mode 100644 index 0000000000..4acb1bfa91 --- /dev/null +++ b/python/triton/experimental/tle/language/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +from . import dsa + +__all__ = [ + 'dsa', +] diff --git a/python/triton/experimental/tle/language/builder.py b/python/triton/experimental/tle/language/builder.py new file mode 100644 index 0000000000..d70394c595 --- /dev/null +++ b/python/triton/experimental/tle/language/builder.py @@ -0,0 +1,57 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + + +def create_dsa_method_wrapper_with_tle_builder(main_builder, delegate_builder, method_name): + delegate_method = getattr(delegate_builder, method_name) + + def wrapper(*args, **kwargs): + saved_ip = main_builder.get_insertion_point() + saved_loc = main_builder.get_loc() + delegate_builder.restore_insertion_point(saved_ip) + if saved_loc: + delegate_builder.set_loc(saved_loc) + result = delegate_method(*args, **kwargs) + main_builder.restore_insertion_point(saved_ip) + if saved_loc: + main_builder.set_loc(saved_loc) + return result + + wrapper.__name__ = method_name + wrapper.__doc__ = getattr(delegate_method, '__doc__', None) + return wrapper + + +def attach_builder_methods_with_tle_builder(main_builder, delegate_builder, method_names): + """Attach multiple methods from a delegate builder to the main builder.""" + for method_name in method_names: + wrapper = create_dsa_method_wrapper_with_tle_builder(main_builder, delegate_builder, method_name) + + if hasattr(main_builder, method_name): + raise AttributeError(f"Method '{method_name}' already exists in the main builder.") + setattr(main_builder, method_name, wrapper) + + +def setup_unified_builder_with_tle_builder(main_builder, buffer_builder): + """Set up a unified builder interface by attaching methods from specialized builders.""" + main_builder._buffer_builder = buffer_builder + buffer_methods = [ + 'create_dsa_alloc', + 'create_dsa_copy', + 'create_dsa_add', + 'create_dsa_sub', + 'create_dsa_mul', + 'create_dsa_div', + 'create_dsa_max', + 'create_dsa_min', + # 'create_dsa_dot', + 'dsa_to_buffer', + 'dsa_to_tensor', + 'dsa_get_null_attr', + 'dsa_get_buffer_type', + 'dsa_get_buffer_type_with_strides', + "create_dsa_extract_scalar", + "create_dsa_extract_slice", + "create_dsa_insert_slice", + "create_dsa_subview", + ] + attach_builder_methods_with_tle_builder(main_builder, buffer_builder, buffer_methods) diff --git a/python/triton/experimental/tle/language/dsa/README.md b/python/triton/experimental/tle/language/dsa/README.md new file mode 100644 index 0000000000..4bb7b8f420 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/README.md @@ -0,0 +1,62 @@ +# TLE (Triton Language Extension) + +TLE is a language extension for Triton that exposes on-chip memory, pipeline compile hints and the accompanying calculation operations for high-performance computing. This extension is specifically optimized for Ascend 910B devices. + +## Features + +- **On-chip Memory Management**: `tle.dsa.alloc()` - Allocate memory on UB/L1/L0C +- **Data Movement**: `tle.dsa.copy()` - Efficient bidirectional copying between memory spaces +- **compute Operations**: `tle.dsa.add()` - Addition on UB +- **Pipeline Optimization**: `tle.dsa.pipeline()` - Hardware-aware pipeline iteration + +## Memory Scopes & Layouts for ascend + +- **Scopes**: `tle.dsa.ascend.UB` (UB memory), `tle.dsa.ascend.L1` (L1 memory), `tle.dsa.ascend.L0C` (L0C memory) + +## Quick Example + +```python +@triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Allocate UB memory + a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB) + + # Tail block processing + t0 = n_elements - block_start + tail_size = tl.minimum(t0, BLOCK_SIZE) + + # Copy data from GM to UB + tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size]) + tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size]) + + # Addition + tle.dsa.add(a_ub, b_ub, c_ub) + + # Copy result back to GM + tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size]) + +``` + +## Testing + +```bash +cd python/test/tle +python3 test_vec_add.py +``` + +## Learn More + +See other examples in `python/test/tle`: +- `test_matmul.py` - GEMM implementation and pipeline usage +- `test_vec_mathOps.py` - Vector math operations, such as add, sub, mul, div diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py new file mode 100644 index 0000000000..19a490b3f6 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +from .core import ( + alloc, + copy, + pipeline, + parallel, + to_tensor, + to_buffer, + add, + sub, + mul, + div, + max, + min, + hint, + extract_slice, + insert_slice, + extract_element, + subview, +) + +from . import ascend + +__all__ = [ + "alloc", + "copy", + "pipeline", + "parallel", + "to_tensor", + "to_buffer", + "add", + "sub", + "mul", + "div", + "max", + "min", + "hint", + "extract_slice", + "insert_slice", + "extract_element", + "subview", + "ascend", +] diff --git a/python/triton/experimental/tle/language/dsa/ascend/__init__.py b/python/triton/experimental/tle/language/dsa/ascend/__init__.py new file mode 100644 index 0000000000..ee754717dd --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/ascend/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +from .core import ( + UB, + L1, + L0A, + L0B, + L0C, +) + +__all__ = [ + "UB", + "L1", + "L0A", + "L0B", + "L0C", +] diff --git a/python/triton/experimental/tle/language/dsa/ascend/core.py b/python/triton/experimental/tle/language/dsa/ascend/core.py new file mode 100644 index 0000000000..79ebd93dbc --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/ascend/core.py @@ -0,0 +1,7 @@ +from triton.language.extra.cann.extension.core import ascend_address_space + +UB = ascend_address_space.UB +L1 = ascend_address_space.L1 +L0A = ascend_address_space.L0A +L0B = ascend_address_space.L0B +L0C = ascend_address_space.L0C diff --git a/python/triton/experimental/tle/language/dsa/ascend/semantic.py b/python/triton/experimental/tle/language/dsa/ascend/semantic.py new file mode 100644 index 0000000000..a111d0159b --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/ascend/semantic.py @@ -0,0 +1 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py new file mode 100644 index 0000000000..f703fc4476 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -0,0 +1,376 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +import triton.language.core as tl +from triton.language import semantic as tl_semantic +from triton.language.core import (_constexpr_to_value, tensor, constexpr) + +from typing import List, TypeVar +from functools import wraps + +from . import semantic as tle_semantic +from .types import address_space, buffer + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +TLE_BUILTIN = "__tle_builtin__" + + +def builtin(fn: T) -> T: + """ + Decorator for builtin functions to mark a function as a tle language builtin function. + """ + assert callable + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, TLE_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """ + Returns whether a function is a builtin function. + """ + return getattr(fn, TLE_BUILTIN, False) + + +class range(): + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. + :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot + operation in the loop to be multi-buffered, if applicable. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. + :param warp_specialize: Enable automatic warp specialization on the loop. + The compiler will attempt to partition memory, MMA, and vector + operations in the loop into separate async partitions. This will + increase the total number of warps required by the kernel. + :param disable_licm: Tells the compiler it shouldn't hoist loop invariant + code outside the loop. This is often useful to avoid creating long liveranges + within a loop. + + Note that warp specialization is only supported on Blackwell GPUs and + only works on simple matmul loops. Support for arbitrary loops will be + expanded over time. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor + self.disallow_acc_multi_buffer = disallow_acc_multi_buffer + self.flatten = flatten + self.warp_specialize = warp_specialize + self.disable_licm = disable_licm + + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + + +class pipeline(range): + """ + Iterator that counts upward forever, with software pipeline semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + + +class parallel(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it indicates that there are no dependencies between loop iterations, + allowing them to be executed in parallel. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + + +@builtin +def from_buffer_to_tensor_pointer(src: buffer, _builder=None) -> tl.tensor: + buffer_ty = src.type + ele_type = buffer_ty.element_ty + shape = buffer_ty.shape + block_type = tl.block_type(ele_type, shape) + return tl.tensor(src.handle, block_type) + + +@builtin +def copy(src, dst, shape, inter_no_alias=False, _builder=None): + """Copy data from `src` to `dst` shaped by `shape`. + + :param inter_no_alias: If True, the copy is annotated as no aliasing between different iterations. + """ + assert len(shape) != 0, "Can't deduce copy extents from args" + + shape = _constexpr_to_value(shape) + inter_no_alias = _constexpr_to_value(inter_no_alias) + tle_semantic.copy(src, dst, shape, inter_no_alias, _builder) + + +@builtin +def add(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) + tle_semantic.add(input, other, result, _builder) + + +@builtin +def sub(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) + tle_semantic.sub(input, other, result, _builder) + + +@builtin +def mul(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) + tle_semantic.mul(input, other, result, _builder) + + +@builtin +def div(input, other, result, _builder=None): + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) + tle_semantic.div(input, other, result, _builder) + + +@builtin +def max(input, other, result, _builder=None): + # elementwise binary vector maximum op + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) + tle_semantic.max(input, other, result, _builder) + + +@builtin +def min(input, other, result, _builder=None): + # elementwise binary vector minimum op + input = from_buffer_to_tensor_pointer(input, _builder=_builder) + other = from_buffer_to_tensor_pointer(other, _builder=_builder) + result = from_buffer_to_tensor_pointer(result, _builder=_builder) + tle_semantic.min(input, other, result, _builder) + + +@builtin +def alloc(shape: List[tl.constexpr], dtype: tl.dtype, mem_addr_space: address_space, _builder=None) -> buffer: + """ + Allocates a region of local memory with the specified shape and type. + + :param etype: the element type of the buffer. + :type etype: tl.dtype + :param shape: A list of non-negative integers representing the shape of the buffer. + :type shape: List[tl.constexpr] + :param _address_space: (Optional) backend-specific local memory address space + :type _address_space: bl.address_space + """ + assert (mem_addr_space is not None) + return tle_semantic.alloc(dtype, shape, mem_addr_space, _builder) + + +@builtin +def to_buffer(tensor: tl.tensor, space: address_space = None, bind_buffer: buffer = None, _builder=None) -> buffer: + """ + Convert a tensor to a buffer. + + :param tensor: the tensor to convert. + :type tensor: tl.tensor + :param space: the address space for the buffer (optional). + :type space: address_space + """ + return tle_semantic.to_buffer(tensor, space, bind_buffer, _builder) + + +@builtin +def to_tensor(memref: buffer, writable: bool = True, target_shape=None, _builder=None) -> tl.tensor: + """ + Create a tl.tensor from a bl.buffer. + + :param memref: the input bl.buffer object. + :memref type: bl.buffer + :param writable: If set true, the resultant tensor is considered "writable" during bufferization. + :type writable: bool + """ + return tle_semantic.to_tensor(memref, writable, _builder, target_shape=target_shape) + + +@builtin +def subview(src: buffer, offsets: List[tl.constexpr], sizes: List[tl.constexpr], strides: List[tl.constexpr], + _builder=None) -> buffer: + ''' + Creates a subview of the source buffer with the specified offsets, sizes, and strides. + + :param src: The source buffer to create a subview from. + :type src: buffer + :param offsets: A list of non-negative integers representing the offsets in each dimension. + :type offsets: List[tl.constexpr] + :param sizes: A list of non-negative integers representing the sizes in each dimension. + :type sizes: List[tl.constexpr] + :param strides: A list of non-negative integers representing the strides in each dimension. + :type strides: List[tl.constexpr] + :return: A new buffer representing the subview of the source buffer. + :rtype: buffer + ''' + # Validate that sizes and strides contain only constexpr values + new_sizes = [] + for i, size in enumerate(sizes): + if isinstance(size, int): + # Convert regular integers to constexpr + new_sizes.append(tl.constexpr(size)) + elif isinstance(size, tl.constexpr): + new_sizes.append(size) + else: + raise TypeError(f"sizes[{i}] must be constexpr, got {type(size).__name__}") + + new_strides = [] + for i, stride in enumerate(strides): + if isinstance(stride, int): + # Convert regular integers to constexpr + new_strides.append(tl.constexpr(stride)) + elif isinstance(stride, tl.constexpr): + new_strides.append(stride) + else: + raise TypeError(f"strides[{i}] must be constexpr, got {type(stride).__name__}") + + new_offsets = [] + for offset in offsets: + if isinstance(offset, tl.constexpr): + # Check that constexpr offset values cannot be negative + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(tl_semantic.to_tensor(offset, _builder)) + elif isinstance(offset, int): + # Convert regular integers to constexpr and then to tensor + if offset < 0: + raise ValueError(f"Offset value must be non-negative, got {offset}") + new_offsets.append(tl_semantic.to_tensor(tl.constexpr(offset), _builder)) + else: + # Assume it's already a tensor + new_offsets.append(offset) + + return tle_semantic.subview(src, new_offsets, new_sizes, new_strides, _builder) + + +def hint(**kwargs): + """Dummy function for AST parsing. Not executed during JIT compilation.""" + raise RuntimeError("tle.hint() cannot be called directly.") + + +@builtin +def insert_slice(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + _builder=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + new_offsets = [tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + out = tle_semantic.insert_slice(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + assert len(ful.shape) > 0 + new_offsets = [tl_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + sub = tle_semantic.extract_slice(ful, new_offsets, sizes, strides, _builder) + return sub + + +@builtin +def extract_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + assert len(src.shape) > 0 + new_indice = [tl_semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] + return tle_semantic.extract_element(src, new_indice, _builder) diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py new file mode 100644 index 0000000000..c3d1686cc6 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -0,0 +1,193 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +from typing import List, Union +from triton.language import core as tl +from triton.language.semantic import ( + binary_op_type_checking_impl, ) +from triton._C.libtriton import ir +from .types import buffer, buffer_type, address_space + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + # assert value.numel.value == 1, "only accepts size-1 tensor" + if isinstance(value, tl.constexpr): + value = builder.get_int32(value) + return tl.tensor(value, dtype) + + if value.dtype.is_int(): + return tl.tensor(value.handle, dtype) + + +def copy(src, dst, shape: List[Union[tl.constexpr, int]], inter_no_alias: bool, builder: ir.builder): + """ + Generate tt.copy(src, dst, shape) and return dst-like tensor. + Lowering to hivm.load/hivm.store is done in MLIR pass. + """ + shape = [scalar_constant(x, tl.int32, builder) for x in shape] + builder.create_dsa_copy(src.handle, dst.handle, [s.handle for s in shape], inter_no_alias) + + +def add(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_add(input.handle, other.handle, result.handle) + + +def sub(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_sub(input.handle, other.handle, result.handle) + + +def mul(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_mul(input.handle, other.handle, result.handle) + + +def div(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_div(input.handle, other.handle, result.handle) + + +def max(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_max(input.handle, result.handle) + + +def min(input: tl.tensor, other: tl.tensor, result: tl.tensor, builder: ir.builder): + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + builder.create_dsa_min(input.handle, other.handle, result.handle) + + +def alloc(etype: tl.dtype, shape: List[tl.constexpr], address_space: address_space, builder: ir.builder) -> buffer: + shape = tl._unwrap_shape(shape) + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + etype = tl._constexpr_to_value(etype) + address_space = tl._constexpr_to_value(address_space) + element_ty_ir = etype.to_ir(builder) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.dsa_get_null_attr()) + memref_ty = builder.dsa_get_buffer_type(shape, element_ty_ir, addr_space_attr) + handle = builder.create_dsa_alloc(memref_ty) + buffer_ty = buffer_type(element_ty=etype, shape=shape, space=address_space) + return buffer(handle, buffer_ty) + + +def to_buffer( + tensor: tl.tensor, + address_space: address_space, + bind_buffer: buffer, + builder: ir.builder, +) -> buffer: + if not isinstance(tensor.shape, (tuple, list)) or not tensor.shape: + raise TypeError("scalar type cannot be converted to buffer") + # if isinstance(bind_buffer, buffer): + # builder.create_bind_buffer(tensor.handle, bind_buffer.handle) + # return bind_buffer + if bind_buffer is not None: + raise ValueError("bind_buffer must be a buffer or None") + address_space = tl._constexpr_to_value(address_space) + addr_space_attr = (address_space.to_ir(builder) if address_space else builder.dsa_get_null_attr()) + handle = builder.dsa_to_buffer(tensor.handle, addr_space_attr) + buffer_ty = buffer_type(element_ty=tensor.dtype, shape=tensor.shape, space=address_space) + return buffer(handle, buffer_ty) + + +def to_tensor(memref: buffer, writable: bool, builder: ir.builder, target_shape=None) -> tl.tensor: + if not isinstance(memref, buffer): + raise TypeError("memref must be buffer") + + need_convert_layout = False + shape = memref.shape + if target_shape: + need_convert_layout = True + shape = tl._unwrap_shape(target_shape) + assert shape != memref.shape, "target shape is the same as source shape" + if not isinstance(shape, (tuple, list)): + raise TypeError("shape must be list/tuple") + tensor_type = tl.block_type(memref.dtype, shape) + + memref_value = memref.handle + if need_convert_layout: + buffer_ty = buffer_type( + element_ty=memref.dtype, + shape=shape, + space=memref.space, + ) + memref_value = builder.create_convert_layout(memref_value, buffer_ty.to_ir(builder)) + + return tl.tensor(builder.dsa_to_tensor(memref_value, writable), tensor_type) + + +def insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tl.tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_dsa_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + + +def extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tl.tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + new_offsets = [o.handle for o in offsets] + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_dsa_extract_slice(ful.handle, new_offsets, sizes, strides) + return tl.tensor(out, ret_type) + + +def extract_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + new_indice = [i.handle for i in indice] + result = builder.create_dsa_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + + +def subview(src: buffer, offsets: List[tl.tensor], sizes: List[tl.constexpr], strides: List[tl.constexpr], + builder: ir.builder) -> buffer: + + new_offsets = [offset.handle for offset in offsets] + sizes_int = tl._unwrap_shape(sizes) + strides_int = tl._unwrap_shape(strides) + + result_handle = builder.create_dsa_subview(src.handle, new_offsets, sizes_int, strides_int) + + # calculate the memory layout strides of the source buffer + if src.strides: + # use the strides of the source buffer + src_memory_strides = src.strides + else: + # calculate the default row-major strides + src_memory_strides = [] + stride = 1 + for dim_size in reversed(src.shape): + if dim_size < 0: + raise ValueError("Cannot compute strides for buffer with dynamic dimensions") + src_memory_strides.insert(0, stride) + stride *= dim_size + + result_memory_strides = [] + for src_stride, subview_stride in zip(src_memory_strides, strides_int): + result_memory_strides.append(src_stride * subview_stride) + + # create buffer_type with strides + buffer_ty = buffer_type(element_ty=src.dtype, shape=sizes_int, space=src.space, strides=result_memory_strides) + return buffer(result_handle, buffer_ty) diff --git a/python/triton/experimental/tle/language/dsa/types.py b/python/triton/experimental/tle/language/dsa/types.py new file mode 100644 index 0000000000..8c65f69d32 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/types.py @@ -0,0 +1,101 @@ +from triton._C.libtriton import ir + +from typing import List +import triton.language.core as tl + + +class address_space: + """Represents a buffer's address space. + + The :code:`address_space` of a buffer is a target-specific attribute. + """ + + def to_ir(self, builder: ir.builder) -> ir.type: + raise NotImplementedError("Abstract address_space cannot be converted to ir") + + +class buffer_type(tl.dtype): + + def __init__(self, element_ty: tl.dtype, shape: List, space: address_space = None, strides: List = None): + self.element_ty = element_ty + self.shape = shape if isinstance(shape, list) else list(shape) + self.space = space + self.strides = strides if strides is not None else [] + self.name = self._make_name() + + def _make_name(self): + res = '' + + def to_ir(self, builder: ir.builder) -> ir.type: + element_ty_ir = self.element_ty.to_ir(builder) + addr_space_attr = self.space.to_ir(builder) if self.space else builder.get_null_attr() + + # use the method with strides if strides is not empty + if self.strides: + return builder.dsa_get_buffer_type_with_strides(self.shape, element_ty_ir, self.strides, addr_space_attr) + else: + return builder.dsa_get_buffer_ty(self.shape, element_ty_ir, addr_space_attr) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def __eq__(self, other) -> bool: + if not isinstance(other, buffer_type): + return False + return (self.element_ty == other.element_ty and self.shape == other.shape and self.space == other.space + and self.strides == other.strides) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +# ----------------------- +# buffer +# ----------------------- + + +class buffer(tl._value): + """Represents a region of memory. + + :code:`buffer` is the fundamental data structure for Triton programs using + the buffer language extension. Most functions in + :py:mod:`triton.extension.buffer.language` operate on and return buffers. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, buffer_ty: buffer_type): + """Not called by user code.""" + super().__init__(handle) + self.type = buffer_ty + self.dtype = buffer_ty.element_ty.scalar + self.shape = buffer_ty.shape + self.space = buffer_ty.space + self.strides = buffer_ty.strides + + def __str__(self) -> str: + # ex. "<16x32xfloat32, address_space>" + res = '<' + 'x'.join(str(s) for s in self.shape) + 'x' + str(self.dtype) + if self.space: + res += ', ' + str(self.space) + return res + '>' diff --git a/python/tutorials/tle/01-sparse-flash-attn-tle.py b/python/tutorials/tle/01-sparse-flash-attn-tle.py new file mode 100644 index 0000000000..d7d2bb62e8 --- /dev/null +++ b/python/tutorials/tle/01-sparse-flash-attn-tle.py @@ -0,0 +1,949 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +import torch +import torch_npu +import triton +import triton.language as tl +import numpy as np +from datetime import datetime +from triton.backends.ascend.testing import do_bench_npu +import triton.experimental.tle as tle +# import random + +np.random.seed(21) +DEVICE = "npu" +DEVICE_ID = 0 +torch.manual_seed(20) +torch_npu.npu.set_device(int(DEVICE_ID)) +torch.set_printoptions(sci_mode=False, precision=4, linewidth=300) + +ascend_aiv_core_nums = triton.language.constexpr(24) + + +# ===== Fused PA + Rope Concat + BNSD + Gather Kernel ===== +@triton.jit +def fused_pa_rope_to_sparse_kernel( + k_pa_ptr, + k_rope_pa_ptr, + v_pa_ptr, # PA_BSND input [block_num, block_size, n, d] + block_table_ptr, # block_table [B, max_blocks] + sparse_indices_ptr, # sparse_indices [B, N, TOPK] + k_sparse_out_ptr, + v_sparse_out_ptr, # BNSD output [B, N, TOPK, d] + stride_k_pa_bn, + stride_k_pa_bs, + stride_k_pa_n, + stride_k_pa_d, # K PA strides + stride_k_rope_pa_bn, + stride_k_rope_pa_bs, + stride_k_rope_pa_n, + stride_k_rope_pa_d, # K_rope PA strides + stride_v_pa_bn, + stride_v_pa_bs, + stride_v_pa_n, + stride_v_pa_d, # V PA strides + stride_bt_b, + stride_bt_blk, # block_table strides + stride_si_b, + stride_si_n, + stride_si_topk, # sparse_indices strides + stride_out_b, + stride_out_n, + stride_out_topk, + stride_out_d, # output strides + stride_v_b, + stride_v_n, + stride_v_topk, + stride_v_d, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_DK_ROPE: tl.constexpr, # 0 if no rope + TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + B: tl.constexpr, +): + """ + Fused kernel: PA_BSND + Rope Concat -> BNSD Sparse + Input: K/V in PA_BSND format, K_rope in PA_BSND format + Output: K/V_sparse in BNSD format + """ + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # Process (b, n, topk) combinations + for b_idx in range(B): + b = b_idx # sparse_indices is [B, N, TOPK], assume B=1 for now + for idx in range(pid, TOPK, num_programs): + # Get batch and sparse index from sparse_indices + n = 0 # KV_N = 1 + + # Load sparse index + sparse_idx = tl.load(sparse_indices_ptr + b * stride_si_b + n * stride_si_n + idx * stride_si_topk) + + # Map sparse_idx to PA_BSND position + block_id = sparse_idx // BLOCK_SIZE # Which block + bs_offset = sparse_idx % BLOCK_SIZE # Offset within block + + # Get actual block ID from block_table + actual_block_id = tl.load(block_table_ptr + b * stride_bt_b + block_id * stride_bt_blk) + + # Compute PA_BSND offset for K + k_pa_offset = (actual_block_id * stride_k_pa_bn + bs_offset * stride_k_pa_bs + n * stride_k_pa_n) + + # Compute PA_BSND offset for K_rope + k_rope_pa_offset = (actual_block_id * stride_k_rope_pa_bn + bs_offset * stride_k_rope_pa_bs + + n * stride_k_rope_pa_n) + + # Compute PA_BSND offset for V + v_pa_offset = (actual_block_id * stride_v_pa_bn + bs_offset * stride_v_pa_bs + n * stride_v_pa_n) + # Load K vector (no rope part) + k_vec = tl.load(k_pa_ptr + k_pa_offset + tl.arange(0, BLOCK_DK) * stride_k_pa_d) + + # Load V vector + v_vec = tl.load(v_pa_ptr + v_pa_offset + tl.arange(0, BLOCK_DV) * stride_v_pa_d) + # Output to BNSD format: [B, N, TOPK, D] + out_offset = b * stride_out_b + n * stride_out_n + idx * stride_out_topk + out_offset_v = b * stride_v_b + n * stride_v_n + idx * stride_v_topk + + if BLOCK_DK_ROPE > 0: + # Load K_rope vector + full_k = tl.full((BLOCK_DK + BLOCK_DK_ROPE, ), 0.0, dtype=tl.float16) + k_rope_vec = tl.load(k_rope_pa_ptr + k_rope_pa_offset + + tl.arange(0, BLOCK_DK_ROPE) * stride_k_rope_pa_d) + full_k = tle.dsa.insert_slice(full_k, k_vec, offsets=(0, ), sizes=(BLOCK_DK, ), strides=(1, )) + full_k = tle.dsa.insert_slice(full_k, k_rope_vec, offsets=(BLOCK_DK, ), sizes=(BLOCK_DK_ROPE, ), + strides=(1, )) + tl.store(k_sparse_out_ptr + out_offset + tl.arange(0, BLOCK_DK + BLOCK_DK_ROPE) * stride_out_d, full_k) + else: + # No rope, store K directly + tl.store(k_sparse_out_ptr + out_offset + tl.arange(0, BLOCK_DK) * stride_out_d, k_vec) + + # Store V + tl.store(v_sparse_out_ptr + out_offset_v + tl.arange(0, BLOCK_DV) * stride_v_d, v_vec) + + +def triton_fused_pa_rope_to_sparse(k_pa, k_rope_pa, v_pa, block_table, sparse_indices, block_size): + """ + Fused PA_BSND + Rope Concat -> BNSD Sparse conversion + + Args: + k_pa: Key in PA_BSND format [block_num, block_size, n, dk] + k_rope_pa: Key rope in PA_BSND format [block_num, block_size, n, d_rope], None if no rope + v_pa: Value in PA_BSND format [block_num, block_size, n, dv] + block_table: Block table [B, max_blocks] + sparse_indices: Sparse indices [B, N, TOPK] + block_size: Block size for PA format + + Returns: + k_sparse: Sparse key in BNSD format [B, N, TOPK, dk+d_rope] + v_sparse: Sparse value in BNSD format [B, N, TOPK, dv] + """ + block_num, _, n, dk = k_pa.shape + B = block_table.shape[0] + TOPK = sparse_indices.size(-1) + N = 1 # KV_N = 1 + _, _, _, dv = v_pa.shape + + has_rope = k_rope_pa is not None + dk_rope = k_rope_pa.shape[-1] if has_rope else 0 + dk_total = dk + dk_rope + + # Output BNSD format [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, dk_total), dtype=k_pa.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, dv), dtype=v_pa.dtype, device=DEVICE) + + # Grid: use 48 programs for parallelism + grid = (min(48, TOPK), ) + + # sparse_indices input format: [T, N, TOPK] or [B, N, TOPK] + # No squeeze needed - kernel expects [B, N, TOPK] format + sparse_indices_input = sparse_indices + if sparse_indices.dim() == 2: + # If already 2D [B, TOPK], reshape to [B, 1, TOPK] + sparse_indices_input = sparse_indices.unsqueeze(1) + + # Set k_rope_pa to k_pa if no rope (dummy pointer, won't be accessed) + k_rope_pa_input = k_rope_pa if has_rope else k_pa + fused_pa_rope_to_sparse_kernel[grid](k_pa, k_rope_pa_input, v_pa, block_table, sparse_indices_input, + k_sparse, v_sparse, k_pa.stride(0), k_pa.stride(1), k_pa.stride(2), + k_pa.stride(3), k_rope_pa_input.stride(0), k_rope_pa_input.stride(1), + k_rope_pa_input.stride(2), k_rope_pa_input.stride(3), v_pa.stride(0), + v_pa.stride(1), v_pa.stride(2), v_pa.stride(3), block_table.stride(0), + block_table.stride(1), sparse_indices_input.stride(0), + sparse_indices_input.stride(1), sparse_indices_input.stride(2), + k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), + BLOCK_DK=dk, BLOCK_DV=dv, BLOCK_DK_ROPE=dk_rope, TOPK=TOPK, + BLOCK_SIZE=block_size, B=B) + + return k_sparse, v_sparse + + +@triton.jit +def gather_kv_bnsd_vec_kernel( + k_ptr, + v_ptr, + ind_ptr, + k_out_ptr, + v_out_ptr, + stride_kb, + stride_kn, + stride_ks, + stride_kd, + stride_vb, + stride_vn, + stride_vs, + stride_vd, + stride_ob, + stride_on, + stride_os, + stride_od, + stride_ovb, + stride_ovn, + stride_ovs, + stride_ovd, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, + TOPK: tl.constexpr, + B: tl.constexpr, +): + end = TOPK // 48 * 48 + for b_idx in range(B): + # 分批处理所有TOPK个索引,每次48个 + for batch_start in range(0, end, 48): + pid_k = tl.program_id(0) + batch_start + + # 读 index + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + + # 处理余数部分(end到TOPK) + for batch_start in range(end, TOPK, 48): + pid_k = tl.program_id(0) + batch_start + + # 必须在计算pid_k之后检查边界 + if pid_k < TOPK: + idx = tl.load(ind_ptr + pid_k) + + # 加载 K 向量 [BLOCK_DK] - 直接线性加载 + k_src_off = idx * stride_ks + b_idx * stride_kb + k_val = tl.load(k_ptr + k_src_off + tl.arange(0, BLOCK_DK) * stride_kd) + + # 加载 V 向量 [BLOCK_DV] - 直接线性加载 + v_src_off = idx * stride_vs + b_idx * stride_vb + v_val = tl.load(v_ptr + v_src_off + tl.arange(0, BLOCK_DV) * stride_vd) + + # 写回 K: [B, N, TOPK, Dk] + k_dst_off = pid_k * stride_os + b_idx * stride_ob + tl.store(k_out_ptr + k_dst_off + tl.arange(0, BLOCK_DK) * stride_od, k_val) + + # 写回 V: [B, N, TOPK, Dv] + v_dst_off = pid_k * stride_ovs + b_idx * stride_ovb + tl.store(v_out_ptr + v_dst_off + tl.arange(0, BLOCK_DV) * stride_ovd, v_val) + + +def triton_gather_kv_bnsd_vec(k, v, indices): + B, N, SK, Dk = k.shape # N=1 + B, N, SK, Dv = v.shape # N=1 + TOPK = indices.size(-1) + + # 输出保持 bnsd [B, N, TOPK, D] + k_sparse = torch.empty((B, N, TOPK, Dk), dtype=k.dtype, device=DEVICE) + v_sparse = torch.empty((B, N, TOPK, Dv), dtype=v.dtype, device=DEVICE) + + grid = (48, ) # TOPK 个 program,每个搬 Dk/Dv 元素 + gather_kv_bnsd_vec_kernel[grid]( + k, + v, + indices.squeeze(0), # [B, N, SK, D] -> [N, SK, D] + k_sparse, + v_sparse, + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + k_sparse.stride(0), + k_sparse.stride(1), + k_sparse.stride(2), + k_sparse.stride(3), + v_sparse.stride(0), + v_sparse.stride(1), + v_sparse.stride(2), + v_sparse.stride(3), + BLOCK_DK=Dk, + BLOCK_DV=Dv, + TOPK=TOPK, + B=B, + ) + return k_sparse, v_sparse + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + O, + scale_value, + stride_qb: tl.constexpr, + stride_qs: tl.constexpr, + stride_qn: tl.constexpr, + stride_qd: tl.constexpr, + stride_kb: tl.constexpr, + stride_kn: tl.constexpr, + stride_ks: tl.constexpr, + stride_kd: tl.constexpr, + stride_vb: tl.constexpr, + stride_vn: tl.constexpr, + stride_vs: tl.constexpr, + stride_vd: tl.constexpr, + stride_ob: tl.constexpr, + stride_os: tl.constexpr, + stride_on: tl.constexpr, + stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, + Q_D: tl.constexpr, + Q_S: tl.constexpr, + KV_S: tl.constexpr, + K_D: tl.constexpr, + V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N: tl.constexpr, + O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, +): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B * Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + # off_n = 0 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ob + off_s * stride_os + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE, ), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE, ), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(Q_N, Q_D), strides=(stride_qn, stride_qd), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, Q_D), + order=(1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0, 1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(KV_S, K_D), + strides=(stride_ks, stride_kd), + offsets=(0, 0), + block_shape=(blk_size, K_D), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(KV_S, V_D), strides=(stride_vs, stride_vd), + offsets=(0, 0), block_shape=(blk_size, V_D), order=(1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0, 1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), + tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:, None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0, 1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc * coeff[:, None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base=O + o_offset, shape=(O_N, O_D), strides=(stride_on, stride_od), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, O_D), + order=(1, 0)) + # final normalize + acc = acc / logSum[:, None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + +@triton.jit +def _attn_fwd_fused_bsnd_to_tnd( + Q, + K, + V, + O, + scale_value, + stride_qb: tl.constexpr, + stride_qs: tl.constexpr, + stride_qn: tl.constexpr, + stride_qd: tl.constexpr, + stride_kb: tl.constexpr, + stride_kn: tl.constexpr, + stride_ks: tl.constexpr, + stride_kd: tl.constexpr, + stride_vb: tl.constexpr, + stride_vn: tl.constexpr, + stride_vs: tl.constexpr, + stride_vd: tl.constexpr, + stride_ot: tl.constexpr, + stride_on: tl.constexpr, + stride_od: tl.constexpr, + B: tl.constexpr, + Q_N: tl.constexpr, + Q_D: tl.constexpr, + Q_S: tl.constexpr, + KV_S: tl.constexpr, + K_D: tl.constexpr, + V_D: tl.constexpr, + sparse_mode: tl.constexpr, # 0 or 3 + O_N: tl.constexpr, + O_D: tl.constexpr, + actual_seq_lengths_query, + actual_seq_lengths_kv, + blk_size: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, +): + # total b * n tasks + BLOCK_QN_NUM = Q_N // Q_BLOCK_SIZE + NUM_BLOCKS = B * Q_S * BLOCK_QN_NUM + pid = tl.program_id(0) + num_cores = min(ascend_aiv_core_nums, NUM_BLOCKS) + + #最外层循环,沿b*n切 + for block_idx in range(pid, NUM_BLOCKS, num_cores): # 并行 + off_b = (block_idx // (Q_S * BLOCK_QN_NUM)).to(tl.int32) #当前任务在第几个b块中 + off_s = ((block_idx // BLOCK_QN_NUM) % Q_S).to(tl.int32) #当前任务在第几个s块中 + off_n = (block_idx % BLOCK_QN_NUM).to(tl.int32) #当前任务在第几个n块中 + + q_offset = off_b * stride_qb + off_s * stride_qs + o_offset = off_b * stride_ot + k_offset = off_b * stride_kb # KV_N = 1 + v_offset = off_b * stride_vb + + cur_act_s_q = tl.load(actual_seq_lengths_query + off_b) + + for i in range(cur_act_s_q): + cur_max = tl.full((Q_BLOCK_SIZE, ), float('-inf'), dtype=tl.float32) + logSum = tl.zeros((Q_BLOCK_SIZE, ), dtype=tl.float32) + acc = tl.zeros((Q_BLOCK_SIZE, V_D), dtype=tl.float32) # 升维到[q_block_size, V_D] + + # load q + q_block_ptr = tl.make_block_ptr(base=Q + q_offset, shape=(Q_N, Q_D), strides=(stride_qn, stride_qd), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, Q_D), + order=(1, 0)) + q_vec = tl.load(q_block_ptr, boundary_check=(0, 1)) # [q_block_size, K_D] + k_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(KV_S, K_D), + strides=(stride_ks, stride_kd), + offsets=(0, 0), + block_shape=(blk_size, K_D), + order=(1, 0), + ) + v_block_ptr = tl.make_block_ptr(base=V + v_offset, shape=(KV_S, V_D), strides=(stride_vs, stride_vd), + offsets=(0, 0), block_shape=(blk_size, V_D), order=(1, 0)) + + for k_idx in range(KV_S // blk_size): + # load k + k_vec = tl.load(k_block_ptr, boundary_check=(0, 1)) + + # 使用dot加速:[blk_size, K_D] @ [K_D] -> [q_block_size, blk_size] + qk = tl.dot(q_vec.to(tl.float16), + tl.trans(k_vec).to(tl.float16)) * scale_value # [q_block_size, blk_size] + # online softmax update + # Triton's tl.max doesn't accept keyword 'dim'; use positional axis. + block_max = tl.max(qk, axis=1) # [q_block_size] + # align shapes to (q_block_size, 1) for broadcasting + # block_max = block_max[:, None] # [q_block_size, 1] + new_max = tl.maximum(cur_max, block_max) # [q_block_size, 1] + coeff = tl.math.exp(cur_max - new_max) # [q_block_size, 1] + p = tl.math.exp(qk - new_max[:, None]) # [q_block_size, blk_size] + # logsum per row + logSum = logSum * coeff + tl.sum(p, axis=1) # [q_block_size, 1] + + # update accumulator: compute per-row pv by summing over block dim + v_vec = tl.load(v_block_ptr, boundary_check=(0, 1)) # [blk_size, V_D] + pv = tl.dot(p.to(tl.float16), v_vec) # [q_block_size, V_D] + acc = acc * coeff[:, None] + pv # [q_block_size, V_D] + cur_max = new_max + + k_block_ptr = k_block_ptr.advance((blk_size, 0)) + v_block_ptr = v_block_ptr.advance((blk_size, 0)) + + o_block_ptr = tl.make_block_ptr(base=O + o_offset, shape=(O_N, O_D), strides=(stride_on, stride_od), + offsets=(off_n * Q_BLOCK_SIZE, 0), block_shape=(Q_BLOCK_SIZE, O_D), + order=(1, 0)) + # final normalize + acc = acc / logSum[:, None] # [q_block_size, V_D] / [q_block_size,1] -> [q_block_size, V_D] + tl.store(o_block_ptr, acc) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, query, key, value, sparse_indices, scale_value, sparse_block_size=1, actual_seq_lengths_query=None, + actual_seq_lengths_kv=None, query_rope=None, key_rope=None, layout_query='BSND', layout_kv='BSND', + sparse_mode=0, block_table=None): + # Save original sparse_indices for PA_BSND case + sparse_indices_orig = sparse_indices.clone() + total_len = 0 + # Handle query layout transformation (TND -> BSND) + if layout_query == 'TND': + actual_seq_lengths_query, total_len = trans_tnd_actseq(actual_seq_lengths_query) + # ✅ 融合版本:一次 kernel 调用处理所有 tensor + concat + query, sparse_indices = trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, query.shape, + actual_seq_lengths_query) + else: + if query_rope is not None: + query = torch.cat([query, query_rope], dim=-1) + + # Handle KV layout and gather sparse K/V + if layout_kv == 'PA_BSND': + # Fused PA -> BNSD + rope concat + sparse gather + block_size = key.shape[1] # Get block_size from PA shape + # Use original sparse_indices [T, N, TOPK] for fused kernel + k_sparse, v_sparse = triton_fused_pa_rope_to_sparse(key, key_rope, value, block_table, sparse_indices_orig, + block_size) + # sparse_indices is already in BSND, needs permute to BNSD for downstream use + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + else: + # Original path for non-PA layouts + if key_rope is not None: + key = torch.cat([key, key_rope], dim=-1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + sparse_indices_bnsd = sparse_indices.permute(0, 2, 1, 3).contiguous() + + k_sparse, v_sparse = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + + k_sparse = k_sparse.contiguous() + v_sparse = v_sparse.contiguous() + enable_check_kv_sparse = 0 + if enable_check_kv_sparse: + key = pa_to_bsnd(key, block_table, actual_seq_lengths_kv) + key_rope = pa_to_bsnd(key_rope, block_table, actual_seq_lengths_kv) + value = pa_to_bsnd(value, block_table, actual_seq_lengths_kv) + if key_rope is not None: + key = torch.cat([key, key_rope], dim=-1) + key_bnsd = key.permute(0, 2, 1, 3).contiguous() + value_bnsd = value.permute(0, 2, 1, 3).contiguous() + k_sparse_ref, v_sparse_ref = triton_gather_kv_bnsd_vec(key_bnsd, value_bnsd, sparse_indices_bnsd) + print(f"k_sparse={k_sparse}") + print(f"k_sparse_ref={k_sparse_ref}") + print(f"v_sparse={v_sparse}") + print(f"v_sparse_ref={v_sparse_ref}") + assert torch.allclose(k_sparse, k_sparse_ref, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + assert torch.allclose(v_sparse, v_sparse_ref, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + + # expected_k = key_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(k_sparse, expected_k, rtol=1e-5, atol=1e-5), "K_sparse mismatch!" + # expected_v = value_bnsd[:, :, :sparse_size, :].contiguous() + # assert torch.allclose(v_sparse, expected_v, rtol=1e-5, atol=1e-5), "V_sparse mismatch!" + num_cores = ascend_aiv_core_nums + # sparse_size = sparse_indices_bnsd.shape[-1] # 4 + out_shape_bsnd = list(query.shape) + if query_rope is not None: + out_shape_bsnd[-1] = out_shape_bsnd[-1] - query_rope.shape[-1] + B, Q_S, Q_N, Q_D = query.shape + _, _, KV_S, K_D = k_sparse.shape + + if layout_query == 'TND': + # t = B*act_q_s + output = torch.empty((total_len, out_shape_bsnd[2], out_shape_bsnd[3]), device=query.device, + dtype=torch.float32) + _attn_fwd_fused_bsnd_to_tnd[(num_cores, )]( + query, k_sparse, v_sparse, output, scale_value, query.stride(0), query.stride(1), query.stride(2), + query.stride(3), k_sparse.stride(0), k_sparse.stride(1), k_sparse.stride(2), k_sparse.stride(3), + v_sparse.stride(0), v_sparse.stride(1), v_sparse.stride(2), v_sparse.stride(3), output.stride(0), + output.stride(1), output.stride(2), B=B, Q_N=Q_N, Q_D=Q_D, Q_S=Q_S, KV_S=KV_S, K_D=K_D, + V_D=v_sparse.shape[3], sparse_mode=sparse_mode, O_N=output.shape[1], O_D=output.shape[2], + actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_kv=actual_seq_lengths_kv, + blk_size=128, Q_BLOCK_SIZE=16, multibuffer=False) + + else: + output = torch.empty(out_shape_bsnd, device=query.device, dtype=torch.float32) + _attn_fwd[(num_cores, )](query, k_sparse, v_sparse, output, scale_value, query.stride(0), query.stride(1), + query.stride(2), query.stride(3), k_sparse.stride(0), k_sparse.stride(1), + k_sparse.stride(2), k_sparse.stride(3), v_sparse.stride(0), v_sparse.stride(1), + v_sparse.stride(2), v_sparse.stride(3), output.stride(0), output.stride(1), + output.stride(2), output.stride(3), B=B, Q_N=Q_N, Q_D=Q_D, Q_S=Q_S, KV_S=KV_S, + K_D=K_D, V_D=v_sparse.shape[3], sparse_mode=sparse_mode, O_N=output.shape[2], + O_D=output.shape[3], actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, blk_size=128, Q_BLOCK_SIZE=16) + output = output.permute(0, 2, 1, 3).contiguous() + + ctx.save_for_backward(query, k_sparse, v_sparse, output) + ctx.scale_value = scale_value + return output + + +def pa_to_bsnd(pa_in, block_table, actual_seq_lengths): + block_num, block_size, n, d = pa_in.shape + b = len(actual_seq_lengths) + output = torch.empty((b, block_num * block_size // b, 1, d), dtype=pa_in.dtype).to(DEVICE) + for i in range(b): + for j in range(20): + output[i, j * block_size: (j + 1) * block_size, 0, :] = \ + pa_in[block_table[i][j], :, 0, :].reshape(block_size, d) + return output + + +@triton.jit +def trans_tnd_to_bsnd_fused_kernel( + query_ptr, + query_rope_ptr, + sparse_ptr, + query_out_ptr, + sparse_out_ptr, # query_out 已经拼接了 rope + act_s, + stride_q_t, + stride_q_tn, + stride_q_td, + stride_qr_t, + stride_qr_tn, + stride_qr_td, + stride_s_t, + stride_s_tn, + stride_s_td, + stride_qob, + stride_qobs, + stride_qon, + stride_qod, # query_out strides + stride_sb, + stride_sbs, + stride_sbn, + stride_sbd, + B: tl.constexpr, + N: tl.constexpr, + D_QUERY: tl.constexpr, + D_ROPE: tl.constexpr, + D_SPARSE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_QUERY: tl.constexpr, + BLOCK_D_ROPE: tl.constexpr, + BLOCK_D_SPARSE: tl.constexpr, +): + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + + # 计算 head 的总块数 + num_head_blocks = (N + BLOCK_N - 1) // BLOCK_N + t_idx = tl.full((), 0, dtype=tl.int64) # TODO: 需要正确的 token 映射 + # 每个 pid 负责处理特定的 (batch, head_block) 组合 + for tn_id in range(B): + # sparse_indices 是单头的,只在第一个 head_block 处理一次 + if pid == 0: + sparse_block_ptr = tl.make_block_ptr(base=sparse_ptr + t_idx * stride_s_t, shape=(1, D_SPARSE), + strides=(stride_s_tn, stride_s_td), offsets=(0, 0), + block_shape=(1, D_SPARSE), order=(1, 0)) + sparse = tl.load(sparse_block_ptr) + + sparse_out_block_ptr = tl.make_block_ptr(base=sparse_out_ptr + t_idx * stride_sb, shape=(1, D_SPARSE), + strides=(stride_sbn, stride_sbd), offsets=(0, 0), + block_shape=(1, D_SPARSE), order=(1, 0)) + tl.store(sparse_out_block_ptr, sparse) + + # query 和 query_rope 是多头的,需要在 head 维度上分块处理 + for head_block_id in range(pid, num_head_blocks, num_programs): + n_offset = head_block_id * BLOCK_N + + # Load q and q_ro + q_block_ptr = tl.make_block_ptr(base=query_ptr + t_idx * stride_q_t, shape=(N, D_QUERY), + strides=(stride_q_tn, stride_q_td), offsets=(n_offset, 0), + block_shape=(BLOCK_N, D_QUERY), order=(1, 0)) + q_ro_block_ptr = tl.make_block_ptr(base=query_rope_ptr + t_idx * stride_qr_t, shape=(N, D_ROPE), + strides=(stride_qr_tn, stride_qr_td), offsets=(n_offset, 0), + block_shape=(BLOCK_N, D_ROPE), order=(1, 0)) + q = tl.load(q_block_ptr) + q_ro = tl.load(q_ro_block_ptr) + + # Combine query and query_rope using insert_slice, then store in one operation + full_q = tl.zeros((BLOCK_N, D_QUERY + D_ROPE), dtype=query_out_ptr.dtype.element_ty) + full_q = tle.dsa.insert_slice(full_q, q, offsets=(0, 0), sizes=(BLOCK_N, D_QUERY), strides=(1, 1)) + full_q = tle.dsa.insert_slice(full_q, q_ro, offsets=(0, D_QUERY), sizes=(BLOCK_N, D_ROPE), strides=(1, 1)) + + q_out_block_ptr = tl.make_block_ptr(base=query_out_ptr + t_idx * stride_qob, shape=(N, D_QUERY + D_ROPE), + strides=(stride_qon, stride_qod), offsets=(n_offset, 0), + block_shape=(BLOCK_N, D_QUERY + D_ROPE), order=(1, 0)) + tl.store(q_out_block_ptr, full_q) + t_idx = t_idx + tl.load(act_s + tn_id) + + +def trans_tnd_to_bsnd_fused(query, query_rope, sparse_indices, shape, act_seq, grid=(16, )): + """ + 融合版本的 TND -> BSND 转换(包含 concat) + 一次性处理 query, query_rope, sparse_indices,并拼接 query + query_rope + """ + t, n, d_query = shape + b = len(act_seq) + s = max(act_seq) + + # 获取各个 tensor 的维度 + d_rope = query_rope.shape[2] if query_rope is not None else 0 + d_sparse = sparse_indices.shape[2] + d_query_out = d_query + d_rope # 拼接后的维度 + + # 分配输出(query_out 已经包含 rope) + query_out = torch.empty((b, s, n, d_query_out), dtype=query.dtype, device=query.device) + sparse_out = torch.empty((b, s, 1, d_sparse), dtype=sparse_indices.dtype, device=sparse_indices.device) + assert sparse_indices.shape[1] == 1, "sparse_indices second dim must be 1 when MLA" + # 启动 fused kernel + # 使用较小的 BLOCK_N 避免内存溢出 + block_n = min(16, n) + # 计算需要的核心数:使用多核心并行处理不同的头 + num_head_blocks = (n + block_n - 1) // block_n + num_programs = min(ascend_aiv_core_nums, num_head_blocks) # 最多使用24个核心 + + trans_tnd_to_bsnd_fused_kernel[ + num_programs, + ]( + query, + query_rope, + sparse_indices, + query_out, + sparse_out, + act_seq, + query.stride(0), + query.stride(1), + query.stride(2), + query_rope.stride(0), + query_rope.stride(1), + query_rope.stride(2), + sparse_indices.stride(0), + sparse_indices.stride(1), + sparse_indices.stride(2), + query_out.stride(0), + query_out.stride(1), + query_out.stride(2), + query_out.stride(3), + sparse_out.stride(0), + sparse_out.stride(1), + sparse_out.stride(2), + sparse_out.stride(3), + B=b, + N=n, + D_QUERY=d_query, + D_ROPE=d_rope, + D_SPARSE=d_sparse, + BLOCK_N=block_n, + BLOCK_D_QUERY=d_query, + BLOCK_D_ROPE=d_rope, + BLOCK_D_SPARSE=d_sparse, + ) + return query_out, sparse_out + + +def trans_tnd_actseq(seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().tolist() + list_len = len(seq) + output = [] + output = [seq[0]] + total_len = seq[0] + for i in range(list_len - 1): + new_item = seq[i + 1] - seq[i] + if new_item >= 0: + output.append(new_item) + total_len += new_item + else: + print(f"[ERROR]trans_tnd_actseq: Wrong input actseq:{seq}, in loop {i}, item {new_item} < 0") + return torch.tensor(output).to(DEVICE), total_len + + +def sparse_attention(query, key, value, sparse_indices, scale_value, sparse_block_size=1, actual_seq_lengths_query=None, + actual_seq_lengths_kv=None, query_rope=None, key_rope=None, layout_query='BSND', layout_kv='BSND', + sparse_mode=0, block_table=None): + return _attention.apply(query, key, value, sparse_indices, scale_value, sparse_block_size, actual_seq_lengths_query, + actual_seq_lengths_kv, query_rope, key_rope, layout_query, layout_kv, sparse_mode, + block_table) + + +def test_op(T, B, KV_S, Q_N, KV_N, D, D_rope, sparse_size, scale_value, sparse_block_size, sparse_mode, block_size, + act_kv_s): + assert sparse_size <= KV_S + assert KV_N == 1 + assert sparse_mode == 0 or 3 + assert sparse_block_size == 1 + assert (B * KV_S) % block_size == 0 + assert D == 512 + assert D_rope == 0 or 64 + print("*batch_size=", B) + qkv_dtype = torch.float16 + #sparse_size = KV_S + query = torch.empty((T, Q_N, D), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + key = torch.empty((B * KV_S // block_size, block_size, KV_N, D), dtype=qkv_dtype, + device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + value = key.clone() + + # act_q_s = T // B # step + # rand_vals = torch.rand(T, KV_N, act_kv_s, device=DEVICE) + # _, indices = torch.topk(rand_vals, sparse_size, dim=-1) #sparse_indices不重复 + # sparse_indices = indices.to(torch.int32) + sparse_indices = torch.arange(sparse_size, device=DEVICE, dtype=torch.int32).view(1, 1, -1).expand(T, KV_N, -1) + sparse_indices = sparse_indices.to(torch.int32) + # print("sparse_indices=", sparse_indices) + actual_seq_lengths_query = torch.arange(1, B + 1, dtype=torch.int32, device=DEVICE) + # actual_seq_lengths_query = torch.tensor([1]).reshape(B).to(torch.int32).to(DEVICE) + actual_seq_lengths_kv = torch.tensor([act_kv_s] * B, dtype=torch.int32, device=DEVICE) + print(actual_seq_lengths_kv) + block_table = torch.tensor([range(B * KV_S // block_size)], dtype=torch.int32, device=DEVICE).reshape(B, -1) + + if D_rope == 0: + query_rope = None + key_rope = None + else: + query_rope = torch.empty((T, Q_N, D_rope), dtype=qkv_dtype, device=DEVICE).normal_(mean=0.0, + std=0.5).requires_grad_() + key_rope = torch.empty((B * KV_S // block_size, block_size, KV_N, D_rope), dtype=qkv_dtype, + device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_() + + print("q.shape=", query.shape) + print("k.shape=", key.shape) + print("v.shape=", value.shape) + print("sparse_indices.shape=", sparse_indices.shape) + print("act_seq_query=", actual_seq_lengths_query) + print("act_seq_kv=", actual_seq_lengths_kv) + + triton_out = sparse_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + ) + npu_out = torch_npu.npu_sparse_flash_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + # attention_mode = 2, + ) + triton_out = triton_out.to(npu_out.dtype) + torch.testing.assert_close(triton_out, npu_out, rtol=1e-2, atol=1e-2, equal_nan=True) + print("[PASSED]") + + # benchmarking + triton_time = do_bench_npu( + lambda: sparse_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Triton SFA] Time: {triton_time:.4f} us") + + npu_time = do_bench_npu( + lambda: torch_npu.npu_sparse_flash_attention( + query=query, + key=key, + value=value, + sparse_indices=sparse_indices, + scale_value=scale_value, + sparse_block_size=sparse_block_size, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_kv, + query_rope=query_rope, + key_rope=key_rope, + layout_query='TND', + layout_kv='PA_BSND', + sparse_mode=sparse_mode, + block_table=block_table, + # attention_mode = 2, + ), clear_l2_cache=True, collect_prof=False) + print(f"[Torch-NPU SFA] Time: {npu_time:.4f} us") + + +if __name__ == "__main__": + print(torch_npu.__version__) + print("Test Real Case in DS-v3.2-Exp") + print(f"time is {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + i = 1 + print(f"====================第{i}次测试=================") + test_op(T=1, B=1, KV_S=2560, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T=4, B=4, KV_S=6400, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T=8, B=8, KV_S=48000, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) + i += 1 + print(f"====================第{i}次测试=================") + test_op(T=16, B=16, KV_S=48000, Q_N=128, KV_N=1, D=512, D_rope=64, sparse_size=2048, scale_value=0.5, + sparse_block_size=1, sparse_mode=0, block_size=128, act_kv_s=2560) diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 172ba90b44..2747e689cf 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -7,6 +7,7 @@ import textwrap from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +import importlib import triton.language.extra.cann.extension as extension from triton.extension.buffer.language import core as bl @@ -967,7 +968,8 @@ def visit_For(self, node): warp_specialize = False disable_licm = False bind_sub_block = None - if IteratorClass in [language.range, extension.parallel]: + tle = importlib.import_module("triton.experimental.tle", package=__package__) + if IteratorClass in [language.range, extension.parallel, tle.dsa.pipeline, tle.dsa.parallel]: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now @@ -976,6 +978,9 @@ def visit_For(self, node): ub = iterator.end step = iterator.step num_stages = iterator.num_stages + if num_stages is not None and num_stages > 2: + raise AssertionError('Only `range` iterator supports num_stages <= 2') + loop_unroll_factor = iterator.loop_unroll_factor disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer flatten = iterator.flatten @@ -1063,7 +1068,9 @@ def visit_For(self, node): for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr()) if disable_licm: for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) - if (IteratorClass is extension.parallel): + + tle = importlib.import_module("triton.experimental.tle", package=__package__) + if (IteratorClass is extension.parallel or IteratorClass is tle.dsa.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) self.scf_stack.append(node) @@ -1189,6 +1196,14 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: + if fn.__name__ == 'copy': + # extract tle hints from the generator to identify if node in the tle hints scope + tle = importlib.import_module("triton.experimental.tle", package=__package__) + top_hints = tle.extract_tle_hints_scope(self) + + # Only apply to some builtins; currently, 'copy' is relevant. + if 'inter_no_alias' in top_hints and 'inter_no_alias' not in kws: + kws['inter_no_alias'] = top_hints['inter_no_alias'] ret = fn(*args, **extra_kwargs, **kws) # Sync the builder's location before return. ip, last_loc = self._get_insertion_point_and_loc(_builder) diff --git a/third_party/ascend/backend/testing.py b/third_party/ascend/backend/testing.py index 97e36ca2f4..65d5968dd7 100644 --- a/third_party/ascend/backend/testing.py +++ b/third_party/ascend/backend/testing.py @@ -26,7 +26,7 @@ import triton.runtime as runtime -def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None, keep_res=False): +def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None, keep_res=False, collect_prof=True): import torch import torch_npu @@ -80,11 +80,54 @@ def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None if clear_l2_cache: del buffer - time_cost = _collect_prof_result(torch_path, funcs, warmup, active) + if collect_prof: + time_cost = _collect_prof_result(torch_path, funcs, warmup, active) + else: + time_cost = _collect_single(torch_path) _rm_dic(keep_res, torch_path) return time_cost +# keep the original behavior to get the statistics for the specified kernel func +def _collect_single(base_dir: str, key: str = None) -> float: + if not os.path.exists(base_dir): + return float("inf") + + import pandas as pd + + for root, _, files in os.walk(base_dir): + for file in files: + if file != "op_statistic.csv": + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + print(df) + if key is not None: + key_rows = df[df["OP Type"].str.startswith(key, na=False)] + if not key_rows.empty: + return key_rows["Avg Time(us)"].values[0] + return float("inf") + else: + # default: read the first row except header + # return df.loc[0, "Avg Time(us)"] + # default: extract ZerosLike time (L2 cache clear operation) + filter_cond = df["OP Type"].str.contains(r"^ZerosLike$", case=False, na=False) + filter_df = df[filter_cond] + if not filter_df.empty: + zeroslike_time = filter_df.iloc[0]['Avg Time(us)'] + print("Clear L2 cache time:", zeroslike_time) + + # Calculate total time of all operators excluding ZerosLike + non_zeroslike_df = df[~df["OP Type"].str.contains(r"^ZerosLike$", case=False, na=False)] + all_ops_total_time = non_zeroslike_df['Avg Time(us)'].sum() + all_ops_total_time = round(all_ops_total_time, 3) + print("All ops total time:", all_ops_total_time) + + return all_ops_total_time + + return float("inf") + + def _rm_dic(keep_res, torch_path): if keep_res: return diff --git a/third_party/tle/dsa/CMakeLists.txt b/third_party/tle/dsa/CMakeLists.txt new file mode 100644 index 0000000000..669ebf813f --- /dev/null +++ b/third_party/tle/dsa/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/dialect/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/dialect/include) +add_subdirectory(dialect) + +if (TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonTLE + ${CMAKE_CURRENT_SOURCE_DIR}/tle_ir.cc + + LINK_LIBS + TleIR + TritonIR + ) + + find_package(Python3 REQUIRED COMPONENTS Development Interpreter) + find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}") + include_directories(${Python3_INCLUDE_DIRS}) + include_directories(${pybind11_INCLUDE_DIR}) + link_directories(${Python3_LIBRARY_DIRS}) + link_libraries(${Python3_LIBRARIES}) + add_link_options(${Python3_LINK_OPTIONS}) +endif() diff --git a/third_party/tle/dsa/dialect/CMakeLists.txt b/third_party/tle/dsa/dialect/CMakeLists.txt new file mode 100644 index 0000000000..e918a4d150 --- /dev/null +++ b/third_party/tle/dsa/dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +include_directories(${PROJECT_SOURCE_DIR}/python/src) +add_subdirectory(include) +add_subdirectory(lib) diff --git a/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt b/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..a111d0159b --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Analysis/CMakeLists.txt @@ -0,0 +1 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/third_party/tle/dsa/dialect/include/CMakeLists.txt b/third_party/tle/dsa/dialect/include/CMakeLists.txt new file mode 100644 index 0000000000..83a37f590f --- /dev/null +++ b/third_party/tle/dsa/dialect/include/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(IR) diff --git a/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt new file mode 100644 index 0000000000..a111d0159b --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h new file mode 100644 index 0000000000..1c8c45c2e7 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h @@ -0,0 +1,40 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#ifndef TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ +#define TRITON_TLE_CONVERSION_DSA_COPY_CONVERTER_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" + +#include "tle/dsa/dialect/include/IR/Dialect.h" + +namespace TleCopyConverter { + +using namespace mlir; + +class CopyConverter : public OpConversionPattern { + +public: + explicit CopyConverter(MLIRContext *context); + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace TleCopyConverter + +namespace mlir::triton::tle { +void populateTleCopyOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); +} + +#endif diff --git a/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h new file mode 100644 index 0000000000..c08e02da3e --- /dev/null +++ b/third_party/tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h @@ -0,0 +1,115 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd +#ifndef TRITON_TLE_CONVERSION_MATH_CONVERTER_H +#define TRITON_TLE_CONVERSION_MATH_CONVERTER_H + +#if __has_include("bishengir/Dialect/HIVM/IR/HIVM.h") +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#endif + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace TleMathConverter { + +using namespace mlir; + +template +class BinaryMathConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = adaptor.getRes(); + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + + if (result.getType() != lhs.getType() || + result.getType() != rhs.getType()) { + op->emitError("Unexpected binary calculation type!"); + return failure(); + } + + auto binOp = rewriter.create(loc, TypeRange{}, ValueRange{lhs, rhs}, + ValueRange{result}); + + rewriter.replaceOp(op, binOp); + return success(); + } +}; + +template +class UnaryMathConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(MathOp op, + PatternRewriter &rewriter) const override {} +}; + +template +class MatMulConverter : public OpConversionPattern { +public: + static constexpr llvm::StringLiteral fixpipeAlreadyInserted = + "fixpipe_already_inserted"; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto inA = adaptor.getInA(); + auto inB = adaptor.getInB(); + auto res = adaptor.getRes(); + + auto sizeAttr = adaptor.getSize(); + if (sizeAttr.size() > 3) { + op->emitError("Unexpected matmul calculation size!"); + return failure(); + } + + auto mAttr = dyn_cast(sizeAttr[0]); + auto nAttr = dyn_cast(sizeAttr[1]); + auto kAttr = dyn_cast(sizeAttr[2]); + Value M = rewriter.create(loc, mAttr.getInt()); + Value N = rewriter.create(loc, nAttr.getInt()); + Value K = rewriter.create(loc, kAttr.getInt()); + + bool initC = adaptor.getInitC(); + auto initCValue = + rewriter.create(loc, + /*value*/ initC, /*width*/ 1); + auto newOp = rewriter.create( + loc, TypeRange{}, // result types + inA, // Matrix A + inB, // Matrix B + initCValue, // init condition + M, // M + K, // K + N, // N + res, // Matrix C + Value{}, // per channel bias + adaptor.getTraA() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose A + adaptor.getTraB() ? rewriter.getUnitAttr() : UnitAttr{}, // transpose B + adaptor.getEnableHf32() ? rewriter.getUnitAttr() : UnitAttr{} + // enable hf32 mode + ); + + newOp->setAttr(fixpipeAlreadyInserted, rewriter.getBoolAttr(true)); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +} // namespace TleMathConverter + +namespace mlir::triton::tle { +void populateTleMathOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); +} +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt new file mode 100644 index 0000000000..919ff5a4ea --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TleOps.td) +mlir_tablegen(TleOps.h.inc -gen-op-decls) +mlir_tablegen(TleOps.cpp.inc -gen-op-defs) +add_mlir_doc(TleOps TleOps dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS TleDialect.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=tle) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=tle) +add_mlir_doc(TleDialect TleDialect dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS TleAttrDefs.td) +mlir_tablegen(TleAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TleAttrDefs.cpp.inc -gen-attrdef-defs) + +add_public_tablegen_target(TleTableGen) diff --git a/third_party/tle/dsa/dialect/include/IR/Dialect.h b/third_party/tle/dsa/dialect/include/IR/Dialect.h new file mode 100644 index 0000000000..b8b241ff21 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/Dialect.h @@ -0,0 +1,18 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#ifndef TRITON_TLE_IR_DIALECT_H_ +#define TRITON_TLE_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "tle/dsa/dialect/include/IR/Dialect.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "tle/dsa/dialect/include/IR/TleAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "tle/dsa/dialect/include/IR/TleOps.h.inc" + +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td b/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td new file mode 100644 index 0000000000..01e577718a --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/TleAttrDefs.td @@ -0,0 +1,11 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#ifndef TRITON_TLE_ATTR_DEFS +#define TRITON_TLE_ATTR_DEFS + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "tle/dsa/dialect/include/IR/TleDialect.td" + + +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/TleDialect.td b/third_party/tle/dsa/dialect/include/IR/TleDialect.td new file mode 100644 index 0000000000..25898cd081 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/TleDialect.td @@ -0,0 +1,23 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#ifndef TRITON_TLE_DIALECT +#define TRITON_TLE_DIALECT + +include "mlir/IR/OpBase.td" + +def Tle_Dialect : Dialect { + let name = "tle"; + let cppNamespace = "::mlir::triton::tle"; + let description = [{ + Triton Language Extension Dialect. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "triton::TritonDialect", + ]; + + let usePropertiesForAttributes = 1; +} + +#endif diff --git a/third_party/tle/dsa/dialect/include/IR/TleOps.td b/third_party/tle/dsa/dialect/include/IR/TleOps.td new file mode 100644 index 0000000000..d0cefca385 --- /dev/null +++ b/third_party/tle/dsa/dialect/include/IR/TleOps.td @@ -0,0 +1,154 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#ifndef TRITON_TLE_OPS +#define TRITON_TLE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/CommonTypeConstraints.td" +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "tle/dsa/dialect/include/IR/TleDialect.td" + +// +// Op Base +// +class TLE_Op traits = []>: + Op { +} + +// +// Copy OP +// +def TLE_DSACopyOp : TLE_Op<"dsa_copy", [Pure, MemoryEffects<[MemWrite]>]> { + let summary = "self-defined copy operation"; + let description = [{ + 'tle.dsa_copy' triton copy op is designed to copy data between memory regions. + Example: + ```mlir + tle.dsa_copy %src, %dst, %shape : tensor<128xf32> + ``` + }]; + let arguments = (ins + AnyType:$src, + AnyType:$dst, + Variadic:$shape + ); + + //assemble + let assemblyFormat = "$src `,` $dst `,` $shape attr-dict `:` type($src) `,` type($dst) `,` `[`type($shape)`]`"; +} + +// +// Add Op +// +def TLE_DSAAddOp : TLE_Op<"dsa_add", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined add operation"; + let description = [{ + `tle.dsa_add` triton dsa_add op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Sub op +// +def TLE_DSASubOp : TLE_Op<"dsa_sub", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined sub operation"; + let description = [{ + `tle.dsa_sub` triton dsa_sub op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Mul op +// +def TLE_DSAMulOp : TLE_Op<"dsa_mul", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined mul operation"; + let description = [{ + `tle.dsa_mul` triton dsa_mul op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Div op +// +def TLE_DSADivOp : TLE_Op<"dsa_div", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined div operation"; + let description = [{ + `tle.dsa_div` triton dsa_div op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Max op +// +def TLE_DSAMaxOp : TLE_Op<"dsa_max", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Max operation"; + let description = [{ + `tle.dsa_max` triton dsa_max op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +// +// Min op +// +def TLE_DSAMinOp : TLE_Op<"dsa_min", [Pure, MemoryEffects<[MemWrite]>, + SameOperandsShape, SameOperandsElementType]> { + let summary = "self-defined Min operation"; + let description = [{ + `tle.dsa_min` triton dsa_min op is designed to performs element-wise addition. + }]; + let arguments = (ins AnyMemRef:$lhs, AnyMemRef:$rhs, AnyMemRef:$res); + + let assemblyFormat = "$lhs `,` $rhs `,` $res attr-dict `:` type($lhs) `,` type($rhs) `,` type($res)"; +} + +/// // +/// // Dot op +/// // +/// def TLE_DSADotOp : TLE_Op<"dsa_dot", [Pure, +/// MemoryEffects<[MemWrite]>, +/// SameOperandsElementType, +/// DotLike]> { +/// let summary = "self-defined Dot operation"; +/// let description = [{ +/// $d = matrix_multiply($a, $b) + $c. +/// }]; +/// +/// let arguments = ( +/// ins +/// AnyMemRef:$inA, +/// AnyMemRef:$inB, +/// AnyMemRef:$res, +/// I64ArrayAttr:$size, +/// DefaultValuedAttr:$initC, +/// DefaultValuedAttr:$traA, +/// DefaultValuedAttr:$traB, +/// DefaultValuedAttr:$enableHf32 +/// ); +/// +/// let assemblyFormat = "$inA `,` $inB `,` $res attr-dict `:` type($inA) `,` type($inB) `,` type($res)"; +/// } + +#endif diff --git a/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000000..a111d0159b --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Analysis/CMakeLists.txt @@ -0,0 +1 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd diff --git a/third_party/tle/dsa/dialect/lib/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/CMakeLists.txt new file mode 100644 index 0000000000..83a37f590f --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/CMakeLists.txt @@ -0,0 +1,5 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +add_subdirectory(Analysis) +add_subdirectory(Conversion) +add_subdirectory(IR) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt new file mode 100644 index 0000000000..af538eac7a --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/CMakeLists.txt @@ -0,0 +1,3 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +add_subdirectory(TleToLinalg) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt new file mode 100644 index 0000000000..482fcedecd --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +add_triton_library(TleToLinalg + DSACopyConverter.cpp + MathConverter.cpp + + DEPENDS + TritonIR + TleTableGen +) diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp new file mode 100644 index 0000000000..089cd8722d --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/DSACopyConverter.cpp @@ -0,0 +1,121 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#include "tle/dsa/dialect/include/Conversion/TleToLinalg/DSACopyConverter.h" +#if __has_include("bishengir/Dialect/HIVM/IR/HIVM.h") +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#endif + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace TleCopyConverter { + +using namespace mlir; + +memref::SubViewOp makeSubViewOp(Value src, + const llvm::SmallVector &sizes, + const Location &loc, + ConversionPatternRewriter &rewriter) { + auto srcType = cast(src.getType()); + SmallVector offsets(srcType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), + rewriter.getIndexAttr(1)); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, dyn_cast(dstType), + src, offsets, sizes, strides); +} + +CopyConverter::CopyConverter(MLIRContext *context) + : OpConversionPattern(context) {} + +LogicalResult +CopyConverter::matchAndRewrite(triton::tle::DSACopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto src = adaptor.getSrc(); + auto dst = adaptor.getDst(); + auto loc = op.getLoc(); + + if (!dyn_cast(src.getType()) || + !dyn_cast(dst.getType())) { + op.emitError("Unexpected copy type!"); + return failure(); + } + + llvm::SmallVector shapeValues; + for (auto shape : adaptor.getShape()) { + Value indexShape = rewriter.create( + loc, rewriter.getIndexType(), shape); + shapeValues.push_back(indexShape); + } + + // create copyOp + auto srcSubView = makeSubViewOp(src, shapeValues, loc, rewriter); + auto dstSubView = makeSubViewOp(dst, shapeValues, loc, rewriter); + + MemRefType srcMemRefTy = cast(srcSubView.getType()); + MemRefType dstMemRefTy = cast(dstSubView.getType()); + + // Extract AddressSpace from MemRefType. + auto getAddressSpace = [](MemRefType ty) -> hivm::AddressSpace { + auto attr = ty.getMemorySpace(); + if (!attr) { + // The default memory attribute is GM. + return hivm::AddressSpace::GM; + } + auto addrSpaceAttr = dyn_cast(attr); + if (!addrSpaceAttr) { + return hivm::AddressSpace::GM; + } + return addrSpaceAttr.getAddressSpace(); + }; + + hivm::AddressSpace srcAddrSpace = getAddressSpace(srcMemRefTy); + hivm::AddressSpace dstAddrSpace = getAddressSpace(dstMemRefTy); + + Operation *copyOp = nullptr; + if (srcAddrSpace == hivm::AddressSpace::GM && + dstAddrSpace == hivm::AddressSpace::UB || + srcAddrSpace == hivm::AddressSpace::UB && + dstAddrSpace == hivm::AddressSpace::GM) { + copyOp = rewriter.create(loc, srcSubView, dstSubView); + } else if (srcAddrSpace == hivm::AddressSpace::GM && + dstAddrSpace == hivm::AddressSpace::L1) { + copyOp = rewriter.create( + loc, /*result_tensor=*/TypeRange{}, + /*src=*/srcSubView, /*dst=*/dstSubView, + /*dst_continuous=*/UnitAttr::get(rewriter.getContext())); + } + /// else if (srcAddrSpace == hivm::AddressSpace::L0C && + /// dstAddrSpace == hivm::AddressSpace::GM) { + /// copyOp = rewriter.create(loc, + /// /*result_tensor=*/TypeRange{}, /*src=*/srcSubView, + /// /*dst=*/dstSubView, + /// /*enable_nz2nd=*/UnitAttr::get(rewriter.getContext()) + + /// // #ifdef BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES + /// /*nullptr, + /// hivm::FixpipeDMAModeAttr::get(rewriter.getContext(), + /// hivm::FixpipeDMAMode::NZ2ND), nullptr, nullptr, nullptr, nullptr, + /// nullptr*/ + /// ); + /// } + else { + op.emitError("Not implemented!"); + return failure(); + } + + copyOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, copyOp); + return success(); +} + +} // namespace TleCopyConverter + +namespace mlir::triton::tle { +void populateTleCopyOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} +} // namespace mlir::triton::tle diff --git a/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp new file mode 100644 index 0000000000..bcdece1602 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/Conversion/TleToLinalg/MathConverter.cpp @@ -0,0 +1,36 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd +#include "tle/dsa/dialect/include/Conversion/TleToLinalg/MathConverter.h" +#include "tle/dsa/dialect/include/IR/Dialect.h" + +namespace TleMathConverter { + +using namespace mlir; +using namespace triton::tle; + +} // namespace TleMathConverter + +namespace mlir::triton::tle { +void populateTleMathOpConversionPatterns(mlir::TypeConverter &typeConverter, + mlir::RewritePatternSet &patterns) { + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + patterns.add>( + patterns.getContext()); + + /// patterns.add>(patterns.getContext()); +} +} // namespace mlir::triton::tle diff --git a/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt new file mode 100644 index 0000000000..d7c74bbbd0 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +# Copyright 2026- Xcoresigma Technology Co., Ltd + +add_triton_library(TleIR + Dialect.cpp + TleOps.cpp + + DEPENDS + TleTableGen + + LINK_LIBS PUBLIC + TritonIR + MLIRIR +) diff --git a/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp new file mode 100644 index 0000000000..32385c0e11 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/IR/Dialect.cpp @@ -0,0 +1,25 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#include "tle/dsa/dialect/include/IR/Dialect.h" +#include "mlir/Support/LLVM.h" +#include "tle/dsa/dialect/include/IR/Dialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "tle/dsa/dialect/include/IR/TleAttrDefs.cpp.inc" + +#define GET_OP_CLASSES +#include "tle/dsa/dialect/include/IR/TleOps.cpp.inc" + +namespace mlir::triton::tle { +void TleDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "tle/dsa/dialect/include/IR/TleAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "tle/dsa/dialect/include/IR/TleOps.cpp.inc" + >(); +} +} // namespace mlir::triton::tle diff --git a/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp new file mode 100644 index 0000000000..2a81191e18 --- /dev/null +++ b/third_party/tle/dsa/dialect/lib/IR/TleOps.cpp @@ -0,0 +1,7 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "tle/dsa/dialect/include/IR/Dialect.h" + +namespace mlir::triton::tle {} diff --git a/third_party/tle/dsa/tle_ir.cc b/third_party/tle/dsa/tle_ir.cc new file mode 100644 index 0000000000..b54b3d5a61 --- /dev/null +++ b/third_party/tle/dsa/tle_ir.cc @@ -0,0 +1,296 @@ +// Copyright 2026- Xcoresigma Technology Co., Ltd + +#include +#include + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#include "tle/dsa/dialect/include/IR/Dialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "ir.h" + +using namespace mlir; +namespace py = pybind11; + +constexpr unsigned kIntegerAttrBitWidth = 64; + +struct DSAOpBuilder : public TritonOpBuilder {}; + +void init_triton_tle(py::module &&m) { + m.def("load_dialects", [](MLIRContext &context) { + DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_( + m, "tle_builder", py::module_local(), py::dynamic_attr()) + .def(py::init()) + .def("dsa_get_null_attr", [](DSAOpBuilder &self) { return Attribute(); }) + .def("dsa_get_buffer_type", + [](DSAOpBuilder &self, std::vector &shape, + Type &elementType, const Attribute &memorySpace) -> Type { + return MemRefType::get(shape, elementType, + MemRefLayoutAttrInterface{}, memorySpace); + }) + .def("dsa_get_buffer_type_with_strides", + [](DSAOpBuilder &self, std::vector &shape, + Type &elementType, const std::vector &strides, + const Attribute &memorySpace) -> Type { + // create a layout with strides, using dynamic offset + auto layout = StridedLayoutAttr::get( + self.getBuilder().getContext(), ShapedType::kDynamic, strides); + return MemRefType::get(shape, elementType, layout, memorySpace); + }) + .def("create_dsa_alloc", + [](DSAOpBuilder &self, Type memrefType) -> Value { + return self.create( + mlir::cast(memrefType)); + }) + // Add copy op + .def("create_dsa_copy", + [](DSAOpBuilder &self, Value &src, Value &dst, + std::vector &shape, bool inter_no_alias) -> void { + auto copyOp = self.create(src, dst, shape); + if (inter_no_alias) { + copyOp->setAttr("inter_no_alias", + self.getBuilder().getBoolAttr(true)); + } + }) + // Add op + .def("create_dsa_add", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Sub op + .def("create_dsa_sub", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Mul op + .def("create_dsa_mul", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Div op + .def("create_dsa_div", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Max op + .def("create_dsa_max", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Min op + .def("create_dsa_min", + [](DSAOpBuilder &self, Value &lhs, Value &rhs, Value &res) -> void { + self.create(lhs, rhs, res); + }) + // Dot op + /// .def("create_dsa_dot", + /// [](DSAOpBuilder &self, Value &inA, Value &inB, Value &res, + /// std::vector &size, bool &initC, bool &traA, bool + /// &traB, bool &enable_hf32) -> void { + /// auto &builder = self.getBuilder(); + /// auto sizeAttr = builder.getI64ArrayAttr(size); + + /// // convert bool to boolattr. + /// auto initC_attr = builder.getBoolAttr(initC); + /// auto traA_attr = builder.getBoolAttr(traA); + /// auto traB_attr = builder.getBoolAttr(traB); + /// auto enable_hf32_attr = builder.getBoolAttr(enable_hf32); + + /// self.create(inA, inB, res, sizeAttr, + /// initC_attr, + /// traA_attr, traB_attr, enable_hf32_attr); + /// }) + .def("dsa_to_buffer", + [](DSAOpBuilder &self, Value &src, + const Attribute &addressSpace) -> Value { + auto tensorType = dyn_cast(src.getType()); + if (!tensorType) { + llvm::report_fatal_error("to_buffer: src must be tensor type"); + } + auto memrefType = MemRefType::get( + tensorType.getShape(), tensorType.getElementType(), + MemRefLayoutAttrInterface{}, addressSpace); + return self.create(memrefType, src); + }) + .def("dsa_to_tensor", + [](DSAOpBuilder &self, Value &src, bool writable) -> Value { + const auto &memrefType = mlir::cast(src.getType()); + auto hasAddressSpace = memrefType.getMemorySpace(); + if (hasAddressSpace) { + return self.create(src, true, + writable); + } + return self.create(src, true, writable); + }) + .def("create_dsa_extract_scalar", + [](DSAOpBuilder &self, Value &src, + std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) + .def("create_dsa_extract_slice", + [](DSAOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + + return self.create(retTy, ful, offsets, + sizes, strides); + }) + .def("create_dsa_insert_slice", + [](DSAOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + .def("create_dsa_subview", + [](DSAOpBuilder &self, Value source, std::vector &offsets, + const std::vector &sizes, + const std::vector &strides) -> Value { + SmallVector mixedOffsets; + auto *context = self.getBuilder().getContext(); + auto &builder = self.getBuilder(); + + // Get source memref type for validation + auto sourceType = mlir::cast(source.getType()); + int64_t rank = sourceType.getRank(); + // Verify the number of parameters + if (offsets.size() != rank || sizes.size() != rank || + strides.size() != rank) { + throw std::runtime_error("Number of offsets, sizes, and strides " + "must match memref rank"); + } + + for (const auto &offset : offsets) { + auto indexType = builder.getIndexType(); + if (offset.getType() != indexType) { + Value offset_val = + self.create(indexType, offset); + mixedOffsets.push_back(offset_val); + } else { + mixedOffsets.push_back(offset); + } + } + + SmallVector mixedSizes; + SmallVector mixedStrides; + for (int64_t i = 0; i < rank; ++i) { + int64_t size = sizes[i]; + int64_t stride = strides[i]; + int64_t srcDim = sourceType.getDimSize(i); + + // verify sizes cannot be negative or zero + if (size <= 0) { + throw std::runtime_error("Expected sizes to be positive"); + } + + // verify strides cannot be negative or zero + if (stride <= 0) { + throw std::runtime_error("Expected strides to be positive"); + } + + // getDimSize() returns -1 (ShapedType::kDynamic) for dynamic + // dimensions + if (!ShapedType::isDynamic(srcDim)) { + // verify the subview size does not exceed the source dimension + if (size > srcDim) { + throw std::runtime_error( + "Subview size cannot exceed source dimension size"); + } + + // verify strides cannot exceed the source dimension size + if (stride > srcDim) { + throw std::runtime_error( + "Stride cannot exceed source dimension size"); + } + } + + mixedSizes.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), size)); + mixedStrides.push_back(IntegerAttr::get( + IntegerType::get(context, kIntegerAttrBitWidth), stride)); + } + + return self.create(source, mixedOffsets, + mixedSizes, mixedStrides); + }); +}