Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ascend-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ jobs:
--ignore=test_assume.py \
--ignore=test_index_select.py
popd
# flagtree tle test
pushd python/test/tle
python3 test_vec_add.py
python3 test_vec_add_2d.py
python3 test_vec_add_mix.py
python3 test_vec_mathOps.py
python3 test_tle_with_hints.py
popd

- name: FlagTree Editable Build And Teston Ascend
if: steps.check_files.outputs.only_docs_changed != 'true'
Expand Down
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,13 @@ if(TRITON_BUILD_PYTHON_MODULE)
add_subdirectory(third_party/proton)
endif()

# add TLE plugin
# just support ascend backend so far
if(FLAGTREE_BACKEND STREQUAL "ascend")
list(APPEND TRITON_PLUGIN_NAMES "tle")
add_subdirectory(third_party/tle/dsa)
endif()

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
set(TRITON_LIBRARIES
Expand Down
7 changes: 7 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,16 @@ def get_packages():
"triton/runtime",
"triton/backends",
"triton/tools",

# for flagtree tle
"triton/experimental",
"triton/experimental/tle",
"triton/experimental/tle/language",
"triton/experimental/tle/language/dsa",
]
if helper.flagtree_backend and helper.flagtree_backend in helper.configs.language_extra_backends:
packages.append(f"triton/language/extra/{helper.get_device_name()}")
packages.append("triton/experimental/tle/language/dsa/{helper.get_device_name()}")
packages += helper.get_extra_packages()
packages += get_language_extra_packages()
packages += [f'triton/backends/{backend.name}' for backend in backends]
Expand Down
2 changes: 1 addition & 1 deletion python/setup_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
commit_id="380b87122c88af131530903a702d5318ec59bb33",
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")),
"flir":
tools.Module(name="flir", url="https://github.com/FlagTree/flir.git",
tools.Module(name="flir", url="https://github.com/flagos-ai/flir.git",
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")),
}

Expand Down
5 changes: 5 additions & 0 deletions python/setup_tools/utils/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_extra_install_packages():
"triton/extension",
"triton/extension/buffer",
"triton/extension/buffer/language",
"triton/experimental/tle/language/dsa/ascend",
]


Expand All @@ -24,6 +25,10 @@ def get_package_dir():
package_dict["triton/extension"] = ascend_ext_base
package_dict["triton/extension/buffer"] = f"{ascend_ext_base}/buffer"
package_dict["triton/extension/buffer/language"] = f"{ascend_ext_base}/buffer/language"

# flagtree tle ascend
flagtree_tle_ascend_base = "../python/triton/experimental/tle/language/dsa"
package_dict["triton/experimental/tle/language/dsa/ascend"] = f"{flagtree_tle_ascend_base}/ascend"
return package_dict


Expand Down
41 changes: 41 additions & 0 deletions python/test/tle/test_bind_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import triton
import triton.experimental.tle as tle
import triton.language as tl

from triton.compiler.compiler import ASTSource
from triton.compiler.code_generator import ast_to_ttir
from triton._C.libtriton import ir, tle as tle_ir
from triton._C.libtriton.ascend import ir as ascend_ir


class Options:
num_warps = 4
num_stages = 3
num_ctas = 1
cluster_dims = (1, 1, 1)
enable_fp_fusion = True
debug = False


def compile_kernel(kernel, signature, constants):
"""Helper to compile a kernel to MLIR."""
src = ASTSource(kernel, signature, constants)
context = ir.context()
ir.load_dialects(context)
tle_ir.load_dialects(context)
ascend_ir.load_dialects(context)
module = ast_to_ttir(kernel, src, context, Options(), {}, {})
return str(module)


@triton.jit
def bind_buffer():
# tle.dsa.ascend.UB is triton.language.extra.extension.cann.core.ascend_address_space.UB
buffer1 = tle.dsa.alloc(shape=[32, 32], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
tle.dsa.to_tensor(buffer1, writable=True)


if __name__ == "__main__":
print("=" * 60)
mlir = compile_kernel(bind_buffer, {}, {})
print(mlir)
66 changes: 66 additions & 0 deletions python/test/tle/test_tle_with_hints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2026- Xcoresigma Technology Co., Ltd
import torch
import triton
import torch_npu # noqa
import triton.language as tl
import triton.experimental.tle as tle


@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

# mem_addr_space is language.extra.cann.core.ascend_address_space
with tle.dsa.hint(inter_no_alias=True):
with tle.dsa.hint(test_k1=10):
a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)

t0 = n_elements - block_start
tail_size = tl.minimum(t0, BLOCK_SIZE)

with tle.dsa.hint(test_k2=20):
tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size])
tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size])

tle.dsa.add(a_ub, b_ub, c_ub)
tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size])


def custom_func(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128)
return output


def test_add():
torch.manual_seed(0)
size = 1024
x = torch.rand(size, dtype=torch.float).npu()
y = torch.rand(size, dtype=torch.float).npu()
output_torch = x + y
output_triton = custom_func(x, y)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')

from triton.backends.ascend.testing import do_bench_npu
bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False)
bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False)
# 保留两位小数输出
print(f"torch time : {bench_torch:.2f}")
print(f"triton time: {bench_triton:.2f}")


if __name__ == "__main__":
test_add()
63 changes: 63 additions & 0 deletions python/test/tle/test_vec_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026- Xcoresigma Technology Co., Ltd
import torch
import torch_npu # noqa
import triton
import triton.language as tl
import triton.experimental.tle as tle


@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

# mem_addr_space is language.extra.cann.core.ascend_address_space
a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)

t0 = n_elements - block_start
tail_size = tl.minimum(t0, BLOCK_SIZE)

tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size])
tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size])

tle.dsa.add(a_ub, b_ub, c_ub)
tle.dsa.copy(c_ub, output_ptr + offsets, [tail_size])


def custom_func(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128)
return output


def test_add():
torch.manual_seed(0)
size = 1024
x = torch.rand(size, dtype=torch.float).npu()
y = torch.rand(size, dtype=torch.float).npu()
output_torch = x + y
output_triton = custom_func(x, y)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')

from triton.backends.ascend.testing import do_bench_npu
bench_torch = do_bench_npu(lambda: x + y, clear_l2_cache=True, keep_res=True, collect_prof=False)
bench_triton = do_bench_npu(lambda: custom_func(x, y), clear_l2_cache=True, keep_res=True, collect_prof=False)
# 保留两位小数输出
print(f"torch time : {bench_torch:.2f}")
print(f"triton time: {bench_triton:.2f}")


if __name__ == "__main__":
test_add()
78 changes: 78 additions & 0 deletions python/test/tle/test_vec_add_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2026- Xcoresigma Technology Co., Ltd
import torch
import torch_npu # noqa
import triton
import triton.language as tl
import triton.experimental.tle as tle


@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
n_cols, n_rows, BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

# 2D offsets
block_start_m = pid_m * BLOCK_SIZE
block_start_n = pid_n * BLOCK_SIZE
offs_m = block_start_m + tl.arange(0, BLOCK_SIZE)
offs_n = block_start_n + tl.arange(0, BLOCK_SIZE)

# get address(row-major)
x_ptrs = x_ptr + offs_m[:, None] * n_cols + offs_n[None, :]
y_ptrs = y_ptr + offs_m[:, None] * n_cols + offs_n[None, :]
out_ptrs = output_ptr + offs_m[:, None] * n_cols + offs_n[None, :]

a_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BLOCK_SIZE, BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)

t0 = n_elements - block_start_m
t1 = n_elements - block_start_n
tail_size_m = tl.minimum(t0, BLOCK_SIZE)
tail_size_n = tl.minimum(t1, BLOCK_SIZE)

tle.dsa.copy(x_ptrs, a_ub, [tail_size_m, tail_size_n])
tle.dsa.copy(y_ptrs, b_ub, [tail_size_m, tail_size_n])

tle.dsa.add(a_ub, b_ub, c_ub)

tle.dsa.copy(c_ub, out_ptrs, [tail_size_m, tail_size_n])


def custom_func(x: torch.Tensor, y: torch.Tensor, size: int):
output = torch.empty_like(x)
n_elements = output.numel()
BLOCK_SIZE = 16
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE))
add_kernel[grid](x, y, output, n_elements, size, size - 1, BLOCK_SIZE)
return output


def test_add():
torch.manual_seed(0)
size = 128
x = torch.rand((size, size - 1), dtype=torch.float).npu()
y = torch.rand((size, size - 1), dtype=torch.float).npu()
output_torch = x + y
output_triton = custom_func(x, y, size)
print("============X===========")
print(x)
print("============Y===========")
print(y)
print("============outTorch===========")
print(output_torch)
print("============outTriton===========")
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')


if __name__ == "__main__":
test_add()
64 changes: 64 additions & 0 deletions python/test/tle/test_vec_add_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2026- Xcoresigma Technology Co., Ltd
import torch
import triton
import torch_npu # noqa
import triton.language as tl
import triton.experimental.tle as tle


@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

a_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BLOCK_SIZE], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)

t0 = n_elements - block_start
tail_size = tl.minimum(t0, BLOCK_SIZE)

tle.dsa.copy(x_ptr + offsets, a_ub, [tail_size])
tle.dsa.copy(y_ptr + offsets, b_ub, [tail_size])

tle.dsa.add(a_ub, b_ub, c_ub)

c_val = tle.dsa.to_tensor(c_ub)
b_val = tle.dsa.to_tensor(b_ub)

result = c_val - b_val

#tl.store(output_ptr + offsets, result)

d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB)
tle.dsa.copy(d_ub, output_ptr + offsets, [tail_size])


def custom_func(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=128)
return output


def test_add():
torch.manual_seed(0)
size = 1024
x = torch.rand(size, dtype=torch.float).npu()
y = torch.rand(size, dtype=torch.float).npu()
output_torch = x
output_triton = custom_func(x, y)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')


if __name__ == "__main__":
test_add()
Loading
Loading