From d3643c2ce707c4709a9c9265c564e587ec2a0fc1 Mon Sep 17 00:00:00 2001 From: tsingmicro-public-e Date: Fri, 13 Mar 2026 14:56:52 +0800 Subject: [PATCH] [BACKEND] Add TLE (DSA) support and TLEToMK pipeline for tsingmicro * [TLE] third_party/tle DSA dialect and DsaToCore conversion * [BACKEND] TLEToMK, Tx81 recv/send, MK/Tx81/compiler integration * [PYTHON] triton.experimental.tle language and ir bindings * [BUILD] CMake, setup.py, wheel/build scripts * [CI] tsingmicro workflow and flaggems CI script * [EXAMPLE] tle DSA NOC GEMM example --------- Co-authored-by: tsingmicro-public-e --- .../workflows/tsingmicro-build-and-test.yml | 1 + CMakeLists.txt | 7 + python/setup.py | 4 + python/src/ir.cc | 85 +- python/src/ir.h | 91 +++ python/triton/experimental/__init__.py | 1 + python/triton/experimental/tle/__init__.py | 43 + python/triton/experimental/tle/distributed.py | 768 ++++++++++++++++++ .../experimental/tle/language/__init__.py | 10 + .../triton/experimental/tle/language/core.py | 58 ++ .../experimental/tle/language/dsa/__init__.py | 31 + .../experimental/tle/language/dsa/core.py | 310 +++++++ .../experimental/tle/language/dsa/semantic.py | 179 ++++ .../experimental/tle/language/dsa/types.py | 114 +++ third_party/tle/CMakeLists.txt | 9 + third_party/tle/REANME.md | 0 third_party/tle/include/CMakeLists.txt | 1 + .../tle-dsa/Conversion/DsaToCore/DsaToCore.h | 17 + .../include/tle-dsa/Dialect/IR/CMakeLists.txt | 14 + .../include/tle-dsa/Dialect/IR/DsaDialect.h | 33 + .../include/tle-dsa/Dialect/IR/DsaDialect.td | 68 ++ .../tle/include/tle-dsa/Dialect/IR/DsaOps.td | 78 ++ third_party/tle/lib/CMakeLists.txt | 2 + .../lib/Conversion/DsaToCore/CMakeLists.txt | 17 + .../lib/Conversion/DsaToCore/DsaToCore.cpp | 72 ++ third_party/tle/lib/Dialect/IR/CMakeLists.txt | 19 + third_party/tle/lib/Dialect/IR/DsaDialect.cpp | 39 + third_party/tle/python/CMakeLists.txt | 7 + third_party/tle/python/triton_tle_dsa.cc | 123 +++ third_party/tsingmicro/backend/compiler.py | 6 +- .../tsingmicro/bin/RegisterTritonDialects.h | 7 +- third_party/tsingmicro/crt/CMakeLists.txt | 9 + third_party/tsingmicro/crt/lib/Tx81/recv.c | 26 + third_party/tsingmicro/crt/lib/Tx81/send.c | 166 ++++ .../tle/test_tle_dsa_noc_gemm_4096.py | 176 ++++ .../magic-kernel/Conversion/CMakeLists.txt | 1 + .../Conversion/TLEToMK/CMakeLists.txt | 3 + .../magic-kernel/Conversion/TLEToMK/Passes.h | 22 + .../magic-kernel/Conversion/TLEToMK/Passes.td | 19 + .../magic-kernel/Conversion/TLEToMK/TLEToMK.h | 32 + .../magic-kernel/Dialect/IR/MagicKernelOps.td | 56 ++ .../tsingmicro-tx81/Dialect/IR/Tx81Ops.td | 83 ++ .../tsingmicro/lib/Conversion/CMakeLists.txt | 1 + .../lib/Conversion/LinalgToMK/LinalgToMK.cpp | 36 + .../lib/Conversion/MKToTx81/MKToTx81.cpp | 165 ++++ .../lib/Conversion/TLEToMK/CMakeLists.txt | 21 + .../TLEToMK/MKCommonBufferPlanningPass.cpp | 183 +++++ .../lib/Conversion/TLEToMK/TLEToMK.cpp | 425 ++++++++++ .../lib/Conversion/TLEToMK/TLEToMKPass.cpp | 56 ++ .../TritonToCoreDialects/CMakeLists.txt | 1 + .../TritonToCoreDialectsPass.cpp | 2 + .../lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp | 142 ++++ .../BufferizableOpInterfaceImpl.cpp | 62 ++ .../tsingmicro/scripts/build_tsingmicro.sh | 1 + .../tsingmicro/scripts/build_tx8_deps.sh | 9 +- .../scripts/ci/run_triton_flaggems_ci_test.sh | 2 +- .../tsingmicro/scripts/copy_config.conf | 1 + .../tsingmicro/scripts/publish/build_wheel.sh | 1 + 58 files changed, 3824 insertions(+), 91 deletions(-) create mode 100644 python/src/ir.h create mode 100644 python/triton/experimental/__init__.py create mode 100644 python/triton/experimental/tle/__init__.py create mode 100644 python/triton/experimental/tle/distributed.py create mode 100644 python/triton/experimental/tle/language/__init__.py create mode 100644 python/triton/experimental/tle/language/core.py create mode 100644 python/triton/experimental/tle/language/dsa/__init__.py create mode 100644 python/triton/experimental/tle/language/dsa/core.py create mode 100644 python/triton/experimental/tle/language/dsa/semantic.py create mode 100644 python/triton/experimental/tle/language/dsa/types.py create mode 100644 third_party/tle/CMakeLists.txt create mode 100644 third_party/tle/REANME.md create mode 100644 third_party/tle/include/CMakeLists.txt create mode 100644 third_party/tle/include/tle-dsa/Conversion/DsaToCore/DsaToCore.h create mode 100644 third_party/tle/include/tle-dsa/Dialect/IR/CMakeLists.txt create mode 100644 third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.h create mode 100644 third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.td create mode 100644 third_party/tle/include/tle-dsa/Dialect/IR/DsaOps.td create mode 100644 third_party/tle/lib/CMakeLists.txt create mode 100644 third_party/tle/lib/Conversion/DsaToCore/CMakeLists.txt create mode 100644 third_party/tle/lib/Conversion/DsaToCore/DsaToCore.cpp create mode 100644 third_party/tle/lib/Dialect/IR/CMakeLists.txt create mode 100644 third_party/tle/lib/Dialect/IR/DsaDialect.cpp create mode 100644 third_party/tle/python/CMakeLists.txt create mode 100644 third_party/tle/python/triton_tle_dsa.cc create mode 100644 third_party/tsingmicro/crt/lib/Tx81/recv.c create mode 100644 third_party/tsingmicro/crt/lib/Tx81/send.c create mode 100644 third_party/tsingmicro/examples/tle/test_tle_dsa_noc_gemm_4096.py create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/CMakeLists.txt create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.h create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.td create mode 100644 third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/TLEToMK.h create mode 100644 third_party/tsingmicro/lib/Conversion/TLEToMK/CMakeLists.txt create mode 100644 third_party/tsingmicro/lib/Conversion/TLEToMK/MKCommonBufferPlanningPass.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMK.cpp create mode 100644 third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMKPass.cpp diff --git a/.github/workflows/tsingmicro-build-and-test.yml b/.github/workflows/tsingmicro-build-and-test.yml index 37665d3941..fd8f82458d 100644 --- a/.github/workflows/tsingmicro-build-and-test.yml +++ b/.github/workflows/tsingmicro-build-and-test.yml @@ -99,3 +99,4 @@ jobs: python3 test_softmax.py >result-test_softmax.txt python3 test_vec_add.py >result-test_vec_add.txt python3 time1.py >result-time1.txt + python3 test_tle_dsa_noc_gemm_4096.py >result-noc_gemm.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 05f6faaf1d..8cee92a3e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ elseif(FLAGTREE_BACKEND STREQUAL "aipu") add_definitions(-D__NVIDIA__) add_definitions(-D__AMD__) elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro") + set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}") set(CMAKE_C_COMPILER clang-21) set(CMAKE_CXX_COMPILER clang++-21) set(CMAKE_LINKER lld-21) @@ -285,6 +286,10 @@ if(TRITON_BUILD_PYTHON_MODULE) list(APPEND TRITON_PLUGIN_NAMES "proton") add_subdirectory(third_party/proton/dialect) + # Add TLE plugin + list(APPEND TRITON_PLUGIN_NAMES "tle") + add_subdirectory(third_party/tle) + get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) set(TRITON_LIBRARIES @@ -460,6 +465,8 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() add_subdirectory(third_party/proton/dialect) + # flagtree tle + add_subdirectory(third_party/tle) endif() find_package(Threads REQUIRED) diff --git a/python/setup.py b/python/setup.py index 63a4a268b0..0b197c5d6d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -741,6 +741,10 @@ def get_packages(): "triton/backends", "triton/tools", "triton/tools/extra", + "triton/experimental", + "triton/experimental/tle", + "triton/experimental/tle/language", + "triton/experimental/tle/language/dsa", ] if helper.flagtree_backend == "xpu": packages.append("triton/language/extra/xpu") diff --git a/python/src/ir.cc b/python/src/ir.cc index ee35ce834f..d1a86f043b 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -23,6 +23,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/LocationSnapshot.h" +#include "ir.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -56,90 +57,6 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() { } } -// A custom op builder that keeps track of the last location -class TritonOpBuilder { -public: - TritonOpBuilder(MLIRContext *context) { - builder = std::make_unique(context); - lastLoc = std::make_unique(builder->getUnknownLoc()); - } - - OpBuilder &getBuilder() { return *builder; } - MLIRContext *getContext() { return builder->getContext(); } - - bool isLineInfoEnabled() { return lineInfoEnabled; } - - void setLastLoc(Location loc) { - if (lineInfoEnabled) - lastLoc = std::make_unique(loc); - } - - void setLastLoc(const std::string &fileName, int line, int column) { - auto context = builder->getContext(); - setLastLoc(FileLineColLoc::get(context, fileName, line, column)); - } - - Location getLastLoc() { - assert(lastLoc); - return *lastLoc; - } - - void setInsertionPointToStart(Block &block) { - if (!block.empty()) - setLastLoc(block.begin()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->setInsertionPointToStart(&block); - } - - void setInsertionPointToEnd(Block &block) { - if (!block.empty()) - setLastLoc(block.back().getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->setInsertionPointToEnd(&block); - } - - void setInsertionPointAfter(Operation &op) { - setLastLoc(op.getLoc()); - builder->setInsertionPointAfter(&op); - } - - void restoreInsertionPoint(OpBuilder::InsertPoint pt) { - if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) - setLastLoc(pt.getPoint()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); - builder->restoreInsertionPoint(pt); - } - - template OpTy create(Args &&...args) { - auto loc = getLastLoc(); - return builder->create(loc, std::forward(args)...); - } - - // Overload to create or fold a single result operation. - template - std::enable_if_t(), Value> - createOrFold(Args &&...args) { - auto loc = getLastLoc(); - return builder->createOrFold(loc, std::forward(args)...); - } - - // Overload to create or fold a zero result operation. - template - std::enable_if_t(), OpTy> - createOrFold(Args &&...args) { - auto loc = getLastLoc(); - return builder->createOrFold(loc, std::forward(args)...); - } - -private: - std::unique_ptr builder; - std::unique_ptr lastLoc; - bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); -}; - // Run the pass manager under a source manager diagnostic handler, which // enables emitted MLIR diagnostics to directly reference Python source // code. This diagnostic handler supports filtering diagnostic info by diff --git a/python/src/ir.h b/python/src/ir.h new file mode 100644 index 0000000000..cdb5257c40 --- /dev/null +++ b/python/src/ir.h @@ -0,0 +1,91 @@ +#pragma once + +#include "mlir/IR/Builders.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#include +#include + +// A custom op builder that keeps track of the last location. +class TritonOpBuilder { +public: + TritonOpBuilder(mlir::MLIRContext *context) { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + } + + mlir::OpBuilder &getBuilder() { return *builder; } + mlir::MLIRContext *getContext() { return builder->getContext(); } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + void setLastLoc(mlir::Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(mlir::FileLineColLoc::get(context, fileName, line, column)); + } + + mlir::Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(mlir::Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(mlir::Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(mlir::Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + template + std::enable_if_t(), + mlir::Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = + !mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); +}; diff --git a/python/triton/experimental/__init__.py b/python/triton/experimental/__init__.py new file mode 100644 index 0000000000..e57e1f201e --- /dev/null +++ b/python/triton/experimental/__init__.py @@ -0,0 +1 @@ +# flagtree tle diff --git a/python/triton/experimental/tle/__init__.py b/python/triton/experimental/tle/__init__.py new file mode 100644 index 0000000000..873e1f288d --- /dev/null +++ b/python/triton/experimental/tle/__init__.py @@ -0,0 +1,43 @@ +# flagtree tle +from .distributed import ( + B, + P, + S, + ShardedTensor, + ShardingSpec, + device_mesh, + distributed_barrier, + distributed_dot, + make_sharded_tensor, + remote, + reshard, + shard_id, + sharding, +) + +from . import language + +# try: +# from . import raw +# except ModuleNotFoundError: +# raw = None + +__all__ = [ + "device_mesh", + "S", + "P", + "B", + "sharding", + "ShardingSpec", + "ShardedTensor", + "make_sharded_tensor", + "reshard", + "remote", + "shard_id", + "distributed_barrier", + "distributed_dot", + "language", +] + +# if raw is not None: +# __all__.append("raw") diff --git a/python/triton/experimental/tle/distributed.py b/python/triton/experimental/tle/distributed.py new file mode 100644 index 0000000000..4d51ef9341 --- /dev/null +++ b/python/triton/experimental/tle/distributed.py @@ -0,0 +1,768 @@ +# flagtree tle +from __future__ import annotations + +import copy +from dataclasses import dataclass +from itertools import product +from typing import Any, Iterable, Mapping, Sequence + +import triton.language.core as tl +import triton.language.semantic as _semantic_mod + + +def _prod(values: Iterable[int]) -> int: + result = 1 + for value in values: + result *= value + return result + + +def _as_positive_int(value: Any, label: str) -> int: + if not isinstance(value, int): + raise TypeError(f"{label} must be int, got {type(value).__name__}") + if value <= 0: + raise ValueError(f"{label} must be > 0, got {value}") + return value + + +class device_mesh: + """ + Logical view of a physical device topology. + """ + + def __init__( + self, + topology: Mapping[str, Any] | None = None, + *, + _shape: Sequence[int] | None = None, + _dim_names: Sequence[str] | None = None, + _physical_ids: Sequence[int] | None = None, + _launch_shape: Sequence[int] | None = None, + _launch_dim_names: Sequence[str] | None = None, + ): + if topology is None: + if _shape is None or _dim_names is None or _physical_ids is None: + raise ValueError("internal mesh constructor requires shape/names/physical ids") + self._shape = tuple(_shape) + self._dim_names = tuple(_dim_names) + self._physical_ids = tuple(_physical_ids) + self._launch_shape = tuple(_launch_shape if _launch_shape is not None else _shape) + self._launch_dim_names = tuple(_launch_dim_names if _launch_dim_names is not None else _dim_names) + return + + if not isinstance(topology, Mapping): + raise TypeError(f"topology must be a mapping, got {type(topology).__name__}") + if not topology: + raise ValueError("topology cannot be empty") + + shape = [] + dim_names = [] + for level_name, level_desc in topology.items(): + if not isinstance(level_name, str) or not level_name: + raise ValueError(f"invalid topology level name: {level_name!r}") + level_shape, level_names = self._parse_level(level_name, level_desc) + shape.extend(level_shape) + dim_names.extend(level_names) + + if len(set(dim_names)) != len(dim_names): + raise ValueError(f"dimension names must be unique, got {dim_names}") + + self._shape = tuple(shape) + self._dim_names = tuple(dim_names) + self._physical_ids = tuple(range(_prod(shape))) + self._launch_shape = self._shape + self._launch_dim_names = self._dim_names + + @staticmethod + def _parse_level(level_name: str, level_desc: Any) -> tuple[list[int], list[str]]: + if isinstance(level_desc, int): + return [_as_positive_int(level_desc, level_name)], [level_name] + if not isinstance(level_desc, (tuple, list)): + raise TypeError(f"topology[{level_name!r}] must be int or list/tuple of (name, size), " + f"got {type(level_desc).__name__}") + if not level_desc: + raise ValueError(f"topology[{level_name!r}] cannot be empty") + + shape = [] + names = [] + for item in level_desc: + if not isinstance(item, (tuple, list)) or len(item) != 2: + raise ValueError(f"topology[{level_name!r}] entries must be (name, size), got {item!r}") + dim_name, dim_size = item + if not isinstance(dim_name, str) or not dim_name: + raise ValueError(f"invalid dimension name in {level_name!r}: {dim_name!r}") + shape.append(_as_positive_int(dim_size, f"{level_name}.{dim_name}")) + names.append(dim_name) + return shape, names + + @property + def shape(self) -> tuple[int, ...]: + return self._shape + + @property + def ndim(self) -> int: + return len(self._shape) + + @property + def dim_names(self) -> tuple[str, ...]: + return self._dim_names + + @property + def physical_ids(self) -> tuple[int, ...]: + return self._physical_ids + + @property + def launch_shape(self) -> tuple[int, ...]: + return self._launch_shape + + @property + def launch_dim_names(self) -> tuple[str, ...]: + return self._launch_dim_names + + @property + def size(self) -> int: + return len(self._physical_ids) + + def flatten(self) -> "device_mesh": + return self.reshape(self.size) + + def reshape(self, *shape: int | Sequence[int]) -> "device_mesh": + if len(shape) == 1 and isinstance(shape[0], (tuple, list)): + new_shape = tuple(shape[0]) + else: + new_shape = tuple(shape) + if not new_shape: + raise ValueError("new shape cannot be empty") + new_shape = tuple(_as_positive_int(v, "shape dimension") for v in new_shape) + if _prod(new_shape) != self.size: + raise ValueError(f"cannot reshape mesh of size {self.size} into shape {new_shape}") + if len(new_shape) == self.ndim: + new_dim_names = self._dim_names + elif len(new_shape) == 1: + new_dim_names = ("flat", ) + else: + new_dim_names = tuple(f"dim{i}" for i in range(len(new_shape))) + return device_mesh( + None, + _shape=new_shape, + _dim_names=new_dim_names, + _physical_ids=self._physical_ids, + _launch_shape=self._launch_shape, + _launch_dim_names=self._launch_dim_names, + ) + + def _normalize_key(self, key: Any) -> tuple[Any, ...]: + if not isinstance(key, tuple): + key = (key, ) + + if any(item is Ellipsis for item in key): + if key.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis") + ellipsis_pos = key.index(Ellipsis) + missing = self.ndim - (len(key) - 1) + if missing < 0: + raise IndexError("too many indices for device_mesh") + key = key[:ellipsis_pos] + (slice(None), ) * missing + key[ellipsis_pos + 1:] + + if len(key) > self.ndim: + raise IndexError("too many indices for device_mesh") + + return key + (slice(None), ) * (self.ndim - len(key)) + + def _linear_index(self, coords: Sequence[int]) -> int: + index = 0 + for coord, dim_size in zip(coords, self._shape): + index = index * dim_size + coord + return index + + def __getitem__(self, key: Any) -> "device_mesh": + key = self._normalize_key(key) + selected_per_dim: list[list[int]] = [] + keep_dim: list[bool] = [] + + for dim_size, dim_key in zip(self._shape, key): + if isinstance(dim_key, int): + idx = dim_key + dim_size if dim_key < 0 else dim_key + if idx < 0 or idx >= dim_size: + raise IndexError(f"index {dim_key} out of range for dim size {dim_size}") + selected_per_dim.append([idx]) + keep_dim.append(False) + elif isinstance(dim_key, slice): + indices = list(range(*dim_key.indices(dim_size))) + if not indices: + raise ValueError("empty sub-mesh is not supported") + selected_per_dim.append(indices) + keep_dim.append(True) + else: + raise TypeError(f"device_mesh indices must be int/slice/ellipsis, got {type(dim_key).__name__}") + + new_shape = tuple(len(indices) for indices, keep in zip(selected_per_dim, keep_dim) if keep) + new_dim_names = tuple(dim_name for dim_name, keep in zip(self._dim_names, keep_dim) if keep) + + new_physical_ids = [] + for coords in product(*selected_per_dim): + new_physical_ids.append(self._physical_ids[self._linear_index(coords)]) + + return device_mesh( + None, + _shape=new_shape, + _dim_names=new_dim_names, + _physical_ids=tuple(new_physical_ids), + _launch_shape=self._launch_shape, + _launch_dim_names=self._launch_dim_names, + ) + + def __repr__(self): + return f"DeviceMesh(shape={self._shape}, names={self._dim_names})" + + +class _BroadcastSpec: + + def __repr__(self) -> str: + return "B" + + +B = _BroadcastSpec() + + +@dataclass(frozen=True) +class S: + axis: str | Sequence[str] + + +@dataclass(frozen=True) +class P: + axis: str | Sequence[str] + + +def _normalize_axis_group(spec: Any, label: str) -> tuple[str, ...]: + if spec is None or spec is B: + return tuple() + + if isinstance(spec, S): + spec = spec.axis + if isinstance(spec, P): + spec = spec.axis + + if isinstance(spec, str): + if not spec: + raise ValueError(f"{label} axis name cannot be empty") + return (spec, ) + + if isinstance(spec, (tuple, list)): + if not spec: + return tuple() + axes = [] + for axis in spec: + if not isinstance(axis, str) or not axis: + raise ValueError(f"{label} axis name must be non-empty str, got {axis!r}") + axes.append(axis) + if len(set(axes)) != len(axes): + raise ValueError(f"{label} axis names must be unique, got {axes}") + return tuple(axes) + + raise TypeError(f"{label} axis spec must be str/list/tuple/S/P/B, got {type(spec).__name__}") + + +def _normalize_partial_specs(partial: Any) -> tuple[str, ...]: + if partial is None: + return tuple() + if isinstance(partial, (str, S, P)): + partial = [partial] + if not isinstance(partial, (tuple, list)): + raise TypeError(f"partial must be a list/tuple, got {type(partial).__name__}") + + axes = [] + for item in partial: + axes.extend(_normalize_axis_group(item, "partial")) + if len(set(axes)) != len(axes): + raise ValueError(f"partial axes must be unique, got {axes}") + return tuple(axes) + + +@dataclass(frozen=True) +class ShardingSpec: + mesh: device_mesh + split: tuple[tuple[str, ...], ...] + partial: tuple[str, ...] + broadcast: tuple[str, ...] + + def axis_state(self, axis: str) -> str: + if axis in self.partial: + return "P" + for split_axes in self.split: + if axis in split_axes: + return "S" + return "B" + + +@dataclass(frozen=True) +class ShardedTensor: + handle: Any + sharding: ShardingSpec + shape: tuple[int, ...] | None = None + + +def sharding( + mesh: device_mesh, + split: Sequence[Any] | None = None, + partial: Sequence[Any] | None = None, +) -> ShardingSpec: + """ + Construct a sharding spec bound to a device mesh. + + This is annotation metadata today. Communication lowering is added in later + phases. + """ + if not isinstance(mesh, device_mesh): + raise TypeError(f"mesh must be device_mesh, got {type(mesh).__name__}") + + split_specs: list[tuple[str, ...]] = [] + if split is None: + split = tuple() + if not isinstance(split, (tuple, list)): + raise TypeError(f"split must be a list/tuple, got {type(split).__name__}") + for split_item in split: + split_specs.append(_normalize_axis_group(split_item, "split")) + + partial_axes = _normalize_partial_specs(partial) + + split_axes = [axis for split_item in split_specs for axis in split_item] + if len(set(split_axes)) != len(split_axes): + raise ValueError(f"split axes must be unique across tensor dims, got {split_axes}") + + split_set = set(split_axes) + partial_set = set(partial_axes) + + unknown = [axis for axis in split_axes + list(partial_axes) if axis not in mesh.dim_names] + if unknown: + raise ValueError(f"unknown mesh axis names: {unknown}; mesh axes are {mesh.dim_names}") + + overlap = split_set.intersection(partial_set) + if overlap: + raise ValueError(f"mesh axis cannot be both split and partial: {sorted(overlap)}") + + broadcast = tuple(axis for axis in mesh.dim_names if axis not in split_set and axis not in partial_set) + return ShardingSpec( + mesh=mesh, + split=tuple(split_specs), + partial=tuple(axis for axis in mesh.dim_names if axis in partial_set), + broadcast=broadcast, + ) + + +def make_sharded_tensor( + handle: Any, + sharding: ShardingSpec, + shape: Sequence[int] | None = None, +) -> ShardedTensor: + if not isinstance(sharding, ShardingSpec): + raise TypeError(f"sharding must be ShardingSpec, got {type(sharding).__name__}") + normalized_shape = None + if shape is not None: + if not isinstance(shape, (tuple, list)): + raise TypeError(f"shape must be list/tuple, got {type(shape).__name__}") + normalized_shape = tuple(_as_positive_int(v, "tensor shape") for v in shape) + if sharding.split and len(sharding.split) != len(normalized_shape): + raise ValueError(f"split rank ({len(sharding.split)}) must match tensor rank ({len(normalized_shape)})") + return ShardedTensor(handle=handle, sharding=sharding, shape=normalized_shape) + + +def reshard(tensor: ShardedTensor, spec: ShardingSpec) -> ShardedTensor: + """ + M4 entrypoint. Deferred by roadmap priority. + """ + raise NotImplementedError("reshard is deferred to M4") + + +def _shape_to_cluster_dims(shape: Sequence[int]) -> tuple[int, int, int]: + if not shape: + return (1, 1, 1) + dims = tuple(int(v) for v in shape) + if len(dims) == 1: + return (dims[0], 1, 1) + if len(dims) == 2: + return (dims[0], dims[1], 1) + if len(dims) == 3: + return dims + return (_prod(dims), 1, 1) + + +def _mesh_to_cluster_dims(mesh: device_mesh) -> tuple[int, int, int]: + # Prefer explicit cluster axes, then block axes, then fallback to full mesh. + cluster_axes = [size for name, size in zip(mesh.launch_dim_names, mesh.launch_shape) if "cluster" in name] + if not cluster_axes: + cluster_axes = [size for name, size in zip(mesh.launch_dim_names, mesh.launch_shape) if "block" in name] + if not cluster_axes: + cluster_axes = list(mesh.launch_shape) + return _shape_to_cluster_dims(cluster_axes) + + +@dataclass(frozen=True) +class _BarrierGroupDescriptor: + kind: str + rank: int + shape: tuple[int, ...] + axes: tuple[int, ...] + mask: tuple[int, ...] + + +def _infer_submesh_barrier_group( + mesh: device_mesh, + cluster_dims: Sequence[int], +) -> _BarrierGroupDescriptor | None: + cluster_size = _prod(cluster_dims) + if mesh.size == cluster_size: + return None + if mesh.size > cluster_size: + raise ValueError(f"mesh size ({mesh.size}) exceeds inferred cluster size ({cluster_size})") + + launch_size = _prod(mesh.launch_shape) + if launch_size != cluster_size: + raise NotImplementedError( + "sub-mesh distributed_barrier currently requires launch mesh domain " + f"to match inferred cluster size; launch_size={launch_size}, cluster_size={cluster_size}") + + if not mesh.dim_names: + raise NotImplementedError("scalar sub-mesh barrier is not implemented yet; provide at least one sliced axis") + + launch_name_to_axis = {name: i for i, name in enumerate(mesh.launch_dim_names)} + if any(name not in launch_name_to_axis for name in mesh.dim_names): + raise NotImplementedError("sub-mesh barrier currently supports slicing-derived meshes with " + "axis names inherited from launch mesh") + + axes = tuple(int(launch_name_to_axis[name]) for name in mesh.dim_names) + if len(set(axes)) != len(axes): + raise ValueError(f"invalid subgroup axes (duplicate launch axes): {axes}") + + shape = tuple(int(v) for v in mesh.shape) + if not shape or any(v <= 0 for v in shape): + raise ValueError(f"invalid subgroup shape inferred from mesh: {shape}") + + mask = tuple(int(v) for v in mesh.physical_ids) + if not mask: + raise ValueError("sub-mesh barrier group mask cannot be empty") + if any(v < 0 or v >= cluster_size for v in mask): + raise ValueError("sub-mesh barrier group mask contains out-of-range cluster member ids: " + f"mask={mask}, cluster_size={cluster_size}") + + return _BarrierGroupDescriptor( + kind="submesh", + rank=len(shape), + shape=shape, + axes=axes, + mask=mask, + ) + + +def _apply_mesh_cluster_launch(mesh: device_mesh, _builder) -> tuple[int, int, int]: + cluster_dims = _mesh_to_cluster_dims(mesh) + options = getattr(_builder, "options", None) + if options is None: + return cluster_dims + + # The num_ctas == 1 constraint is NVIDIA CTA-cluster specific. + # On backends like TsingMicro the mesh describes tile communication + # topology, not CUDA CTA clusters, so only enforce when num_ctas > 1 + # (i.e. the backend actively uses multi-CTA grouping). + num_ctas = int(getattr(options, "num_ctas", 1)) + if num_ctas > 1: + raise ValueError("mesh-driven cluster launch requires num_ctas=1; cluster size is inferred from mesh") + + if hasattr(options, "cluster_dims"): + existing = tuple(getattr(options, "cluster_dims", (1, 1, 1))) + if existing != (1, 1, 1) and existing != cluster_dims: + raise ValueError(f"conflicting cluster_dims: existing={existing}, inferred_from_mesh={cluster_dims}") + object.__setattr__(options, "cluster_dims", cluster_dims) + return cluster_dims + + +def _resolve_launch_axis(mesh: device_mesh, axis: str | int) -> int: + if isinstance(axis, int): + ndim = len(mesh.launch_shape) + axis_idx = axis + ndim if axis < 0 else axis + if axis_idx < 0 or axis_idx >= ndim: + raise IndexError(f"axis index {axis} out of range for launch ndim {ndim}") + return axis_idx + + if isinstance(axis, str): + if axis not in mesh.launch_dim_names: + raise ValueError(f"unknown mesh axis {axis!r}; available launch axes: {mesh.launch_dim_names}") + return mesh.launch_dim_names.index(axis) + + raise TypeError(f"axis must be int or str, got {type(axis).__name__}") + + +@tl.builtin +def shard_id( + mesh: device_mesh, + axis: str | int, + _builder=None, +): + """ + Return current shard coordinate on the given launch mesh axis. + + `axis` can be axis name (`str`) or axis index (`int`, supports negative). + The returned value is a scalar int32 tensor. + """ + mesh = tl._unwrap_if_constexpr(mesh) + axis = tl._unwrap_if_constexpr(axis) + + if not isinstance(mesh, device_mesh): + raise TypeError(f"mesh must be device_mesh, got {type(mesh).__name__}") + axis_idx = _resolve_launch_axis(mesh, axis) + launch_shape = tuple(int(v) for v in mesh.launch_shape) + launch_size = _prod(launch_shape) + if launch_size <= 0: + raise ValueError(f"invalid launch mesh shape: {launch_shape}") + + _apply_mesh_cluster_launch(mesh, _builder) + linear = tl.program_id(0, _builder=_builder) + if launch_size > 1: + launch_size_t = _semantic_mod.to_tensor(launch_size, _builder) + linear = _semantic_mod.mod(linear, launch_size_t, _builder) + + stride = _prod(launch_shape[axis_idx + 1:]) if axis_idx + 1 < len(launch_shape) else 1 + coord = linear + if stride > 1: + stride_t = _semantic_mod.to_tensor(stride, _builder) + coord = _semantic_mod.floordiv(coord, stride_t, _builder) + dim = launch_shape[axis_idx] + if dim > 1: + dim_t = _semantic_mod.to_tensor(dim, _builder) + coord = _semantic_mod.mod(coord, dim_t, _builder) + return coord + + +@tl.builtin +def distributed_barrier(mesh: device_mesh | None = None, _builder=None): + """ + M3 entrypoint: cluster synchronization primitive. + + `mesh` is currently accepted for API compatibility. Sub-mesh selective sync + is handled in a later iteration. + """ + mesh = tl._unwrap_if_constexpr(mesh) + if mesh is not None and not isinstance(mesh, device_mesh): + raise TypeError(f"mesh must be device_mesh or None, got {type(mesh).__name__}") + subgroup = None + if mesh is not None: + cluster_dims = _apply_mesh_cluster_launch(mesh, _builder) + subgroup = _infer_submesh_barrier_group(mesh, cluster_dims) + if subgroup is not None: + if not hasattr(_builder, "create_distributed_barrier"): + raise NotImplementedError("sub-mesh distributed_barrier requires TLE builder support; " + f"inferred subgroup descriptor: rank={subgroup.rank}, " + f"shape={subgroup.shape}, axes={subgroup.axes}, size={len(subgroup.mask)}") + try: + _builder.create_distributed_barrier( + subgroup.kind, + list(subgroup.shape), + list(subgroup.axes), + list(subgroup.mask), + ) + return None + except TypeError as exc: + raise NotImplementedError( + "sub-mesh distributed_barrier requires rebuilt TLE extension with " + "group-aware create_distributed_barrier(group_kind, group_shape, group_axes, group_mask); " + f"inferred subgroup descriptor: rank={subgroup.rank}, " + f"shape={subgroup.shape}, axes={subgroup.axes}, size={len(subgroup.mask)}") from exc + if hasattr(_builder, "create_distributed_barrier"): + _builder.create_distributed_barrier() + else: + # Compatibility fallback for environments where the C++ extension + # has not been rebuilt yet. + _builder.create_barrier() + return None + + +def _normalize_remote_shard_id( + shard_id: Any, + scope: device_mesh | None, +) -> int: + shard_id = tl._unwrap_if_constexpr(shard_id) + scope = tl._unwrap_if_constexpr(scope) + + if isinstance(shard_id, int): + if shard_id < 0: + raise ValueError(f"shard_id must be >= 0, got {shard_id}") + return shard_id + + if not isinstance(shard_id, (tuple, list)): + raise TypeError(f"shard_id must be int or tuple/list of ints, got {type(shard_id).__name__}") + if not shard_id: + raise ValueError("shard_id tuple cannot be empty") + if not all(isinstance(v, int) for v in shard_id): + raise TypeError(f"shard_id tuple must contain ints, got {shard_id!r}") + + if scope is None: + raise ValueError("tuple shard_id requires scope=device_mesh to linearize coordinates") + if not isinstance(scope, device_mesh): + raise TypeError(f"scope must be device_mesh when shard_id is tuple, got {type(scope).__name__}") + if len(shard_id) != scope.ndim: + raise ValueError(f"tuple shard_id rank mismatch: got {len(shard_id)}, expected {scope.ndim}") + + linear = 0 + for idx, dim in zip(shard_id, scope.shape): + if idx < 0 or idx >= dim: + raise ValueError(f"shard_id coordinate {idx} out of range for dim size {dim}") + linear = linear * dim + idx + return linear + + +def _is_buffered_tensor_like(value: Any) -> bool: + return (not isinstance(value, tl.tensor) and value.__class__.__name__ == "buffered_tensor" + and hasattr(value, "handle") and hasattr(value, "type")) + + +def _normalize_compile_time_remote_shard_id( + shard_id: int | tuple[int, ...] | list[int], + scope: device_mesh | None, +) -> int: + linear_shard_id = _normalize_remote_shard_id(shard_id, scope) + if linear_shard_id > 0x7FFFFFFF: + raise ValueError(f"linearized shard_id {linear_shard_id} exceeds int32 range") + return linear_shard_id + + +def _normalize_runtime_remote_shard_id_tensor(shard_id_tensor: tl.tensor) -> tl.tensor: + if not shard_id_tensor.dtype.is_int() or shard_id_tensor.dtype.primitive_bitwidth != 32: + raise TypeError("runtime shard_id must be a scalar int32 tensor/value") + if shard_id_tensor.shape: + raise ValueError("runtime shard_id must be scalar (shape=())") + return shard_id_tensor + + +def _create_remote_pointers_tensor( + tensor: tl.tensor, + shard_id_tensor: tl.tensor, + _builder, +) -> tl.tensor | None: + remote_type = tensor.type.to_ir(_builder) + try: + remote_op = _builder.create_remote_pointers( + remote_type, + tensor.handle, + shard_id_tensor.handle, + ) + except AttributeError: + return None + return tl.tensor(remote_op.get_result(0), tensor.type) + + +def _remote_pointer( + tensor: tl.tensor, + shard_id, + scope: device_mesh | None = None, + _builder=None, +) -> tl.tensor: + if not isinstance(tensor, tl.tensor): + raise TypeError(f"tensor must be tl.tensor, got {type(tensor).__name__}") + if not tensor.dtype.is_ptr(): + raise TypeError("remote(pointer, ...) internal path requires a pointer tensor") + if tensor.dtype.address_space != 3: + raise ValueError("remote(pointer, ...) internal path requires shared-memory pointers (addrspace=3)") + + # Compile-time constant shard id path. + if isinstance(shard_id, (int, tuple, list)): + linear_shard_id = _normalize_compile_time_remote_shard_id(shard_id, scope) + # Prefer explicit remote_pointers op so remote metadata survives + # downstream layout/materialization rewrites. + shard_id_tensor = _semantic_mod.to_tensor(int(linear_shard_id), _builder) + shard_id_tensor = _normalize_runtime_remote_shard_id_tensor(shard_id_tensor) + remote_ptr = _create_remote_pointers_tensor(tensor, shard_id_tensor, _builder) + if remote_ptr is not None: + return remote_ptr + + # Compatibility fallback for older TLE extensions. + tensor.handle.set_attr("tle.remote_cta_id", _builder.get_int32_attr(int(linear_shard_id))) + return tensor + + # Runtime shard id path. This materializes a TLE op that carries the + # runtime i32 shard id through lowering. + shard_id_tensor = shard_id if isinstance(shard_id, tl.tensor) else _semantic_mod.to_tensor(shard_id, _builder) + shard_id_tensor = _normalize_runtime_remote_shard_id_tensor(shard_id_tensor) + + # Preferred path: keep remote semantics through a dedicated TLE op so the + # shard-id survives local_pointers lowering. + remote_ptr = _create_remote_pointers_tensor(tensor, shard_id_tensor, _builder) + if remote_ptr is not None: + return remote_ptr + + # Compatibility fallback for older TLE extensions. + # Represent runtime shard_id with a marked addptr op. The lowering rewrites + # pointer arithmetic to use the original base pointer and consumes the + # runtime i32 from addptr's offset operand as cluster CTA id. + remote_ptr = _semantic_mod.add(tensor, shard_id_tensor, _builder) + remote_ptr.handle.set_attr("tle.remote_shard_id_carrier", _builder.get_unit_attr()) + return remote_ptr + + +@tl.builtin +def remote( + tensor, + shard_id, + scope: device_mesh | None = None, + _builder=None, +): + """ + M3 entrypoint: mark distributed access target. + + Supported input: + - tle buffered_tensor: returns a remote-marked buffered tensor; caller + should then use `tleg.local_ptr(...)` to materialize remote pointers. + + `shard_id` is the target block id inside the current thread block cluster. + When `scope` is provided, launch cluster dimensions are inferred from that + mesh and this mode requires `num_ctas=1` (one program maps to one block). + """ + shard_id = tl._unwrap_if_constexpr(shard_id) + scope = tl._unwrap_if_constexpr(scope) + if scope is not None and not isinstance(scope, device_mesh): + raise TypeError(f"scope must be device_mesh or None, got {type(scope).__name__}") + if scope is not None: + _apply_mesh_cluster_launch(scope, _builder) + + # Buffered tensor path: carry remote metadata and let `local_ptr` materialize + # remote pointers later. + if _is_buffered_tensor_like(tensor): + if (hasattr(tensor, "_tle_remote_shard_id") or hasattr(tensor, "_tle_remote_scope") + or hasattr(tensor.type, "_tle_remote_shard_id") or hasattr(tensor.type, "_tle_remote_scope")): + raise ValueError("remote(buffered_tensor, ...) cannot be applied twice; " + "materialize pointer views with tleg.local_ptr(remote_buffer, indices)") + if isinstance(shard_id, (int, tuple, list)): + shard_id = _normalize_compile_time_remote_shard_id(shard_id, scope) + else: + shard_id_tensor = shard_id if isinstance(shard_id, tl.tensor) else None + if shard_id_tensor is None: + if _builder is None: + raise TypeError("runtime shard_id for remote(buffered_tensor, ...) must be scalar int32 " + "and requires JIT context for materialization") + shard_id_tensor = _semantic_mod.to_tensor(shard_id, _builder) + shard_id = _normalize_runtime_remote_shard_id_tensor(shard_id_tensor) + # Keep remote metadata on buffered_tensor.type so it survives value + # reconstruction in JIT interpreter paths (value-level attrs can drop). + remote_buffer = copy.copy(tensor) + remote_type = copy.copy(tensor.type) + try: + setattr(remote_type, "_tle_remote_shard_id", shard_id) + setattr(remote_type, "_tle_remote_scope", scope) + remote_buffer.type = remote_type + except AttributeError: + # Type object may be immutable for unit-test stubs. + pass + # Keep value-level metadata as a secondary carrier to maximize + # compatibility with existing JIT object reconstruction paths. + setattr(remote_buffer, "_tle_remote_shard_id", shard_id) + setattr(remote_buffer, "_tle_remote_scope", scope) + return remote_buffer + + if isinstance(tensor, tl.tensor): + raise TypeError("remote(...) only accepts tle.buffered_tensor; " + "use remote(buffered_tensor, shard_id, scope) + local_ptr(...)") + raise TypeError(f"tensor must be tle.buffered_tensor, got {type(tensor).__name__}") + + +def distributed_dot(a: ShardedTensor, b: ShardedTensor, c: ShardedTensor | None = None): + raise NotImplementedError("distributed_dot is deferred to M5") diff --git a/python/triton/experimental/tle/language/__init__.py b/python/triton/experimental/tle/language/__init__.py new file mode 100644 index 0000000000..53a440258d --- /dev/null +++ b/python/triton/experimental/tle/language/__init__.py @@ -0,0 +1,10 @@ +# flagtree tle +from .core import ( + load, ) + +__all__ = [ + "load", + "dsa", +] + +from . import dsa diff --git a/python/triton/experimental/tle/language/core.py b/python/triton/experimental/tle/language/core.py new file mode 100644 index 0000000000..7da37bcc39 --- /dev/null +++ b/python/triton/experimental/tle/language/core.py @@ -0,0 +1,58 @@ +# flagtree tle +import triton.language.core as tl + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@tl.builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, is_async=False, _semantic=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be `None`, and + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value. + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for + cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1), + and ".cv" means don’t cache and fetch again. see + `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + x = tl.load(pointer, mask=mask, other=other, boundary_check=boundary_check, padding_option=padding_option, + cache_modifier=cache_modifier, eviction_policy=eviction_policy, volatile=volatile, _semantic=_semantic) + x.handle.set_attr("tt.load.async", _semantic.builder.get_bool_attr(is_async)) + return x diff --git a/python/triton/experimental/tle/language/dsa/__init__.py b/python/triton/experimental/tle/language/dsa/__init__.py new file mode 100644 index 0000000000..9b206b3ef7 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/__init__.py @@ -0,0 +1,31 @@ +# flagtree tle +from .core import ( + pipeline, + alloc, + copy, + memory_space, + local_ptr, +) +from .types import ( + scope, + local, + spm, + buffered_tensor, + buffered_tensor_type, +) +from .semantic import DSASemantic, DSASemanticError + +__all__ = [ + "pipeline", + "alloc", + "copy", + "memory_space", + "local_ptr", + "scope", + "local", + "spm", + "buffered_tensor", + "buffered_tensor_type", + "DSASemantic", + "DSASemanticError", +] diff --git a/python/triton/experimental/tle/language/dsa/core.py b/python/triton/experimental/tle/language/dsa/core.py new file mode 100644 index 0000000000..dfb118475f --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/core.py @@ -0,0 +1,310 @@ +# flagtree tle +import builtins +import triton.language.core as tl +import triton.language.semantic as semantic +from typing import Optional, Sequence +from enum import Enum +from . import types as tle + +from triton.language.core import ( + constexpr, + tensor, + range, +) + +# Address space 3 matches the shared-memory space used in TritonGPU lowering. +SHARED_MEMORY_ADDRESS_SPACE = 3 + + +class pipeline(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param bind_sub_block: Tells the compiler if multiple vector cores participate in the loop. + This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of + iteration in this loop. Currently on 910B, max 2 vector cores could be used. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + + +@tl.builtin +def memory_space(input, space, _builder=None): + """ + Annotate a tensor with a target memory-space tag. + + The attribute ``tt.memory_space`` is propagated through the IR and can be + consumed by downstream DSA passes (e.g. ``--dsa-memory-to-core``) to make + allocation / placement decisions. + + Args: + input: Tensor to annotate. + space: Memory-space name string, e.g. ``"spm"`` or ``"shared_memory"``. + """ + space = tl._unwrap_if_constexpr(space) + if _builder is not None and hasattr(input, 'handle') and hasattr(input.handle, 'set_attr'): + input.handle.set_attr("tt.memory_space", _builder.get_string_attr(str(space))) + return input + + +@tl.builtin +def alloc( + shape: tuple, + dtype: tl.dtype, + layout: Optional[object] = None, + scope: tle.scope = None, + _builder=None, +) -> tle.buffered_tensor: + """ + Allocate local memory buffer + + Args: + shape: Buffer shape + dtype: Data type + layout: Memory layout encoding (optional) + scope: Storage type (default to shared memory) + _semantic: Semantic analyzer (internal use) + + Returns: + Allocated buffer tensor + + Raises: + ValueError: When parameters are invalid + RuntimeError: When allocation fails + """ + from .semantic import DSASemantic + + if _builder is None: + raise ValueError("alloc must be used inside @triton.jit") + if layout is not None: + raise ValueError("alloc(): layout parameter is not yet support for DSA backend") + + # --- Validate inputs via semantic layer --- + unwrapped_shape = DSASemantic.validate_alloc_shape(shape) + elem_dtype = DSASemantic.validate_alloc_dtype(dtype) + resolved_scope = DSASemantic.validate_alloc_scope(scope) + + elem_ir_ty = elem_dtype.to_ir(_builder) + + if not hasattr(_builder, "create_dsa_alloc"): + raise RuntimeError("builder missing create_dsa_alloc for DSA alloc") + + alloc_value = _builder.create_dsa_alloc(list(unwrapped_shape), elem_ir_ty) + buf_ty = tle.buffered_tensor_type(unwrapped_shape, elem_dtype, resolved_scope) + return tle.buffered_tensor(alloc_value, buf_ty) + + +class CopyDirection(Enum): + """Copy direction enum for data transfer operations""" + GM_TO_LOCAL = "GMTOLOCAL" # Global memory to local memory + LOCAL_TO_GM = "LOCALTOGM" # Local memory to global memory + + +@tl.builtin +def copy( + src, + dst, + shape, + offsets: Sequence[constexpr | tensor] = None, + _builder=None, +) -> None: + """ + Copy data between global memory (GM) and local scratchpad memory (SPM). + + Supported combinations: + + 1. **tl.tensor -> buffered_tensor** (GM -> SPM): + Load data from a global tensor pointer into a local buffer. + 2. **buffered_tensor -> tl.tensor** (SPM -> GM): + Store data from a local buffer into global memory via a tensor pointer. + 3. **buffered_tensor -> buffered_tensor** (SPM -> SPM): + Direct local-to-local copy (original path, delegates to backend). + + Args: + src: Source operand - either a ``tl.tensor`` (pointer) or ``buffered_tensor``. + dst: Destination operand - either a ``tl.tensor`` (pointer) or ``buffered_tensor``. + shape: Logical shape of the data to copy (used for GM<->Local). + offsets: Reserved for API compatibility with TMA copy (unused on DSA). + """ + del offsets # DSA copy does not use offsets + + if _builder is None: + raise ValueError("copy must be used inside @triton.jit") + + src_is_buf = isinstance(src, tle.buffered_tensor) + dst_is_buf = isinstance(dst, tle.buffered_tensor) + + # ---- Case 1: buffered_tensor -> buffered_tensor (SPM <-> SPM) ---- + if src_is_buf and dst_is_buf: + if not hasattr(_builder, "create_dsa_copy"): + raise RuntimeError("builder missing create_dsa_copy for DSA copy") + _builder.create_dsa_copy(src.handle, dst.handle) + return None + + # ---- Case 2: tl.tensor (GM ptr) -> buffered_tensor (SPM) ---- + if not src_is_buf and dst_is_buf: + if not isinstance(src, tl.tensor): + raise ValueError(f"copy src must be tl.tensor (pointer) or buffered_tensor, got {type(src)}") + # Validate element type compatibility + src_ptr_dtype = src.dtype + if hasattr(src_ptr_dtype, 'element_ty'): + src_elem_dtype = src_ptr_dtype.element_ty + else: + src_elem_dtype = src_ptr_dtype + dst_elem_dtype = dst.type.element_ty + if src_elem_dtype != dst_elem_dtype: + raise ValueError(f"copy element type mismatch: src has {src_elem_dtype}, " + f"dst has {dst_elem_dtype}") + if not hasattr(_builder, "create_dsa_copy"): + raise RuntimeError("builder missing create_dsa_copy for DSA copy") + _builder.create_dsa_copy(src.handle, dst.handle) + return None + + # ---- Case 3: buffered_tensor (SPM) -> tl.tensor (GM ptr) ---- + if src_is_buf and not dst_is_buf: + if not isinstance(dst, tl.tensor): + raise ValueError(f"copy dst must be tl.tensor (pointer) or buffered_tensor, got {type(dst)}") + dst_ptr_dtype = dst.dtype + if hasattr(dst_ptr_dtype, 'element_ty'): + dst_elem_dtype = dst_ptr_dtype.element_ty + else: + dst_elem_dtype = dst_ptr_dtype + src_elem_dtype = src.type.element_ty + if src_elem_dtype != dst_elem_dtype: + raise ValueError(f"copy element type mismatch: src has {src_elem_dtype}, " + f"dst has {dst_elem_dtype}") + if not hasattr(_builder, "create_dsa_copy"): + raise RuntimeError("builder missing create_dsa_copy for DSA copy") + _builder.create_dsa_copy(src.handle, dst.handle) + return None + + # ---- Unsupported combination ---- + raise ValueError("copy requires at least one operand to be a buffered_tensor. " + f"Got src={type(src).__name__}, dst={type(dst).__name__}") + + +def _expand_index_to_shape(index: tl.tensor, shape: Sequence[int], axis: int, _builder) -> tl.tensor: + idx = index + for _ in builtins.range(axis): + idx = tl.expand_dims(idx, 0, _builder=_builder) + for _ in builtins.range(len(shape) - axis - 1): + idx = tl.expand_dims(idx, len(idx.shape), _builder=_builder) + return tl.broadcast_to(idx, *shape, _builder=_builder) + + +def _make_full_indices(buffer: tle.buffered_tensor, _builder) -> tuple[tl.tensor, ...]: + shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in buffer.type.shape) + indices = [] + for axis, dim in enumerate(shape): + idx = tl.arange(0, dim, _builder=_builder) + idx = _expand_index_to_shape(idx, shape, axis, _builder) + indices.append(idx) + return tuple(indices) + + +@tl.builtin +def local_ptr( + buffer: tle.buffered_tensor, + indices: Optional[Sequence] = None, + _builder=None, + _generator=None, +) -> tl.tensor: + """ + Materialize shared-memory pointers that cover the given buffered tensor. + + Args: + buffer: Local memory buffer tensor returned by ``tle.alloc``. + indices: Tuple of integer index tensors. The tuple length must equal + the rank of ``buffer`` and every tensor must have the same shape. + The output pointer tensor will have that same shape. + + Returns: + Tensor of pointers suitable for ``tl.load``/``tl.store``. + """ + if not isinstance(buffer, tle.buffered_tensor): + raise ValueError(f"Buffer parameter must be tle.buffered_tensor, but got {type(buffer)}") + + if _builder is None: + raise ValueError("local_ptr must be used inside @triton.jit") + + # Preferred metadata source: buffered_tensor.type (survives JIT value + # reconstruction). Keep value attrs as backward-compatibility fallback. + remote_shard_id = getattr(buffer.type, "_tle_remote_shard_id", None) + _ = getattr(buffer.type, "_tle_remote_scope", None) + if remote_shard_id is None: + remote_shard_id = getattr(buffer, "_tle_remote_shard_id", None) + _ = getattr(buffer, "_tle_remote_scope", None) + remote_buffer_marker = remote_shard_id is not None + + indices = tl._unwrap_if_constexpr(indices) + if indices is None: + raise ValueError("local_ptr indices must be provided as a tuple of tensors") + if isinstance(indices, tl.tuple): + indices_tuple = tuple(indices.values) + elif isinstance(indices, (tuple, list)): + indices_tuple = tuple(indices) + else: + raise ValueError("local_ptr indices must be a tuple or list of tensors") + + buffer_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in buffer.type.shape) + if len(indices_tuple) != len(buffer_shape): + raise ValueError(f"local_ptr indices must provide {len(buffer_shape)} tensors, got {len(indices_tuple)}") + + idx_tensors: list[tensor] = [] + view_shape: Optional[tuple[int, ...]] = None + scalar_index_flags: list[bool] = [] + for idx in indices_tuple: + idx_tensor = idx if isinstance(idx, tensor) else semantic.to_tensor(idx, _builder) + if not idx_tensor.dtype.is_int(): + raise ValueError("local_ptr indices must use integer dtypes") + is_scalar_index = not idx_tensor.type.is_block() + scalar_index_flags.append(is_scalar_index) + if is_scalar_index: + idx_tensors.append(idx_tensor) + continue + if view_shape is None: + view_shape = tuple(idx_tensor.shape) + elif tuple(idx_tensor.shape) != view_shape: + raise ValueError("local_ptr indices must have identical shapes") + idx_tensors.append(idx_tensor) + + if not idx_tensors: + raise ValueError("local_ptr indices cannot be empty") + all_scalar_indices = all(scalar_index_flags) + any_scalar_indices = any(scalar_index_flags) + if any_scalar_indices and not all_scalar_indices: + raise ValueError("local_ptr indices must be either all scalar or all tensors with identical shapes") + if not all_scalar_indices and view_shape is None: + view_shape = tuple() + + ptr_dtype = tl.pointer_type(buffer.type.element_ty) + insert_block = _builder.get_insertion_block() + if insert_block is None: + raise RuntimeError("TLE local_ptr called without an insertion block") + if all_scalar_indices: + result_ty = ptr_dtype + result_ir = ptr_dtype.to_ir(_builder) + else: + result_ty = tl.block_type(ptr_dtype, list(view_shape)) + result_ir = result_ty.to_ir(_builder) + handles = [idx.handle for idx in idx_tensors] + if not hasattr(_builder, "create_dsa_local_pointers"): + raise RuntimeError("builder missing create_dsa_local_pointers for DSA local_ptr") + local_ptr_op = _builder.create_dsa_local_pointers(result_ir, buffer.handle, *handles) + + result_tensor = tl.tensor(local_ptr_op.get_result(0), result_ty) + + if remote_buffer_marker: + if all_scalar_indices: + raise ValueError("local_ptr does not yet support scalar indices on remote buffers") + if not hasattr(_builder, "create_dsa_remote_pointers"): + raise RuntimeError("builder missing create_dsa_remote_pointers for remote buffers") + shard_val = (remote_shard_id.handle if isinstance(remote_shard_id, tl.tensor) else semantic.to_tensor( + remote_shard_id, _builder).handle) + remote_op = _builder.create_dsa_remote_pointers(result_ir, result_tensor.handle, shard_val) + result_tensor = tl.tensor(remote_op.get_result(0), result_ty) + + return result_tensor diff --git a/python/triton/experimental/tle/language/dsa/semantic.py b/python/triton/experimental/tle/language/dsa/semantic.py new file mode 100644 index 0000000000..e354db2d4e --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/semantic.py @@ -0,0 +1,179 @@ +""" +DSA Semantic Validation Layer +============================= + +Provides early, human-readable error messages for invalid TLE DSA operations +before they reach the MLIR lowering pipeline. Mirrors the role of +``flagtree_tle``'s ``TLESemantic`` class but adapted for the TsingMicro / +DSA backend. +""" + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple + +import triton.language.core as tl +from . import types as tle + + +class DSASemanticError(Exception): + """Raised when a DSA operation fails semantic validation.""" + pass + + +# Data types supported by the TsingMicro DSA backend for buffer allocation. +_SUPPORTED_ALLOC_DTYPES = frozenset([ + tl.float32, + tl.float16, + tl.bfloat16, + tl.int8, + tl.int16, + tl.int32, + tl.int64, + tl.uint8, + tl.uint16, + tl.uint32, + tl.uint64, +]) + + +class DSASemantic: + """Semantic analyzer for DSA TLE operations. + + Each ``validate_*`` method raises :class:`DSASemanticError` with a + descriptive message if validation fails, and returns silently on + success. + """ + + # ------------------------------------------------------------------ + # alloc() validation + # ------------------------------------------------------------------ + + @staticmethod + def validate_alloc_shape(shape: Sequence) -> Tuple[int, ...]: + """Validate and normalise *shape* for ``alloc()``. + + Returns the unwrapped shape tuple on success. + """ + if not isinstance(shape, (tuple, list)): + if hasattr(shape, "__iter__"): + shape = tuple(shape) + else: + raise DSASemanticError(f"alloc: shape must be a tuple or list, got {type(shape).__name__}") + + unwrapped = [] + for i, dim in enumerate(shape): + dim = tl._unwrap_if_constexpr(dim) + if not isinstance(dim, int) or dim <= 0: + raise DSASemanticError(f"alloc: shape[{i}] must be a positive integer, got {dim!r}") + unwrapped.append(dim) + return tuple(unwrapped) + + @staticmethod + def validate_alloc_dtype(dtype: tl.dtype) -> tl.dtype: + """Validate *dtype* for ``alloc()``.""" + dtype = tl._unwrap_if_constexpr(dtype) + if not isinstance(dtype, tl.dtype): + raise DSASemanticError(f"alloc: dtype must be a tl.dtype instance, got {type(dtype).__name__}") + if dtype not in _SUPPORTED_ALLOC_DTYPES: + supported = ", ".join(str(d) for d in sorted(_SUPPORTED_ALLOC_DTYPES, key=str)) + raise DSASemanticError(f"alloc: unsupported dtype {dtype}. Supported types: {supported}") + return dtype + + @staticmethod + def validate_alloc_scope(scope) -> tle.scope: + """Validate *scope* for ``alloc()``.""" + if scope is None: + return tle.spm # default + if not isinstance(scope, tle.scope): + raise DSASemanticError(f"alloc: scope must be a tle.scope instance, got {type(scope).__name__}") + return scope + + # ------------------------------------------------------------------ + # copy() validation + # ------------------------------------------------------------------ + + @staticmethod + def validate_copy_operands(src, dst) -> str: + """Validate *src*/*dst* types for ``copy()`` and return a direction tag. + + Returns one of ``"SPM_TO_SPM"``, ``"GM_TO_SPM"``, ``"SPM_TO_GM"``. + + Raises :class:`DSASemanticError` if the combination is unsupported. + """ + src_is_buf = isinstance(src, tle.buffered_tensor) + dst_is_buf = isinstance(dst, tle.buffered_tensor) + + if src_is_buf and dst_is_buf: + return "SPM_TO_SPM" + if (not src_is_buf) and dst_is_buf: + if not isinstance(src, tl.tensor): + raise DSASemanticError(f"copy: src must be tl.tensor or buffered_tensor, got {type(src).__name__}") + return "GM_TO_SPM" + if src_is_buf and (not dst_is_buf): + if not isinstance(dst, tl.tensor): + raise DSASemanticError(f"copy: dst must be tl.tensor or buffered_tensor, got {type(dst).__name__}") + return "SPM_TO_GM" + raise DSASemanticError("copy: at least one operand must be a buffered_tensor. " + f"Got src={type(src).__name__}, dst={type(dst).__name__}") + + @staticmethod + def validate_copy_dtype_compat(src_dtype, dst_dtype) -> None: + """Check that element types of *src* and *dst* are compatible.""" + if src_dtype != dst_dtype: + raise DSASemanticError(f"copy: element type mismatch – src has {src_dtype}, dst has {dst_dtype}") + + # ------------------------------------------------------------------ + # local_ptr() validation + # ------------------------------------------------------------------ + + @staticmethod + def validate_local_ptr_buffer(buffer) -> None: + """Validate that *buffer* is a proper ``buffered_tensor``.""" + if not isinstance(buffer, tle.buffered_tensor): + raise DSASemanticError(f"local_ptr: buffer must be a buffered_tensor, got {type(buffer).__name__}") + if buffer.type.shape is None: + raise DSASemanticError("local_ptr: buffer shape is None (deferred shapes not yet supported)") + + @staticmethod + def validate_local_ptr_indices( + indices: Sequence, + buffer_rank: int, + ) -> None: + """Validate *indices* for ``local_ptr()``. + + Checks: + - indices length matches buffer rank + - all indices are integer-typed + - indices are either all scalar or all tensor with matching shapes + """ + if indices is None: + raise DSASemanticError("local_ptr: indices must be provided as a tuple of tensors") + if len(indices) != buffer_rank: + raise DSASemanticError(f"local_ptr: expected {buffer_rank} index tensors, got {len(indices)}") + + view_shape: Optional[tuple] = None + has_scalar = False + has_tensor = False + + for i, idx in enumerate(indices): + if not isinstance(idx, tl.tensor): + raise DSASemanticError(f"local_ptr: indices[{i}] must be a tl.tensor, " + f"got {type(idx).__name__}") + if not idx.dtype.is_int(): + raise DSASemanticError(f"local_ptr: indices[{i}] must have integer dtype, " + f"got {idx.dtype}") + is_scalar = not idx.type.is_block() + if is_scalar: + has_scalar = True + else: + has_tensor = True + if view_shape is None: + view_shape = tuple(idx.shape) + elif tuple(idx.shape) != view_shape: + raise DSASemanticError(f"local_ptr: index tensor shape mismatch at dim {i}: " + f"expected {view_shape}, got {tuple(idx.shape)}") + + if has_scalar and has_tensor: + raise DSASemanticError("local_ptr: indices must be either all scalar or all " + "tensor with identical shapes (mixed not allowed)") diff --git a/python/triton/experimental/tle/language/dsa/types.py b/python/triton/experimental/tle/language/dsa/types.py new file mode 100644 index 0000000000..ba162f8737 --- /dev/null +++ b/python/triton/experimental/tle/language/dsa/types.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Tuple + +from triton.language.core import base_type, base_value, dtype, _unwrap_if_constexpr + + +@dataclass(frozen=True) +class scope: + """ + Simple storage descriptor for DSA buffers. + + This is intentionally backend-agnostic. `name` / `value` / `memory_space` + are carried through as metadata only; concrete lowering is handled by the + DSA dialect and backend. + """ + + name: str + value: str + memory_space: str + + def __repr__(self) -> str: + return self.name + + +# DSA storage scopes. +local = scope("local", "local", "local") + +# Scratch Pad Memory – the on-chip SRAM exposed by TsingMicro TX8. +# This is the primary storage scope for DSA kernels and serves the same +# conceptual role as NVIDIA shared memory (smem). +spm = scope("spm", "spm", "spm") + + +class buffered_tensor(base_value): + """ + Symbolic handle to a buffer allocated via DSA. + + This is a thin wrapper over an IR value plus a `buffered_tensor_type` + describing shape / element dtype / memory space. + """ + + def __init__(self, handle: Any, ty: "buffered_tensor_type"): + self.handle = handle + self.type = ty + self.shape = ty.shape + self.dtype = ty.element_ty + + def _flatten_ir(self, handles: List[Any]) -> None: + handles.append(self.handle) + + +class buffered_tensor_type(base_type): + """ + Frontend description of a DSA buffer. + + - `shape`: logical block shape (may be None for deferred shapes) + - `element_ty`: scalar dtype + - `storage`: abstract storage scope (currently `local`) + - `memory_space`: backend-visible memory space string, defaults to + `storage.memory_space` + """ + + def __init__( + self, + shape, + element_ty: dtype, + storage: scope | None = None, + memory_space: str = "", + ): + if shape is None: + self.shape = None + else: + shape = _unwrap_if_constexpr(shape) + self.shape = tuple(int(_unwrap_if_constexpr(x)) for x in shape) + self.element_ty = _unwrap_if_constexpr(element_ty) + self.storage = storage if storage is not None else local + self.memory_space = (str(_unwrap_if_constexpr(memory_space)) if memory_space else self.storage.memory_space) + + @property + def scalar(self) -> dtype: + return self.element_ty + + def __eq__(self, other: object) -> bool: + if not isinstance(other, buffered_tensor_type): + return False + return ( + self.shape, + self.element_ty, + self.storage, + self.memory_space, + ) == ( + other.shape, + other.element_ty, + other.storage, + other.memory_space, + ) + + def __repr__(self) -> str: + shape = "?" if self.shape is None else "x".join(map(str, self.shape)) + return f"buffered_tensor_type<{shape}, {self.element_ty}, {self.memory_space}>" + + def _unflatten_ir(self, handles: List[Any], cursor: int) -> Tuple[base_value, int]: + value = buffered_tensor(handles[cursor], self) + # Preserve remote metadata if present on the type. + if hasattr(self, "_tle_remote_shard_id"): + shard_id = getattr(self, "_tle_remote_shard_id") + scope = getattr(self, "_tle_remote_scope", None) + setattr(value, "_tle_remote_shard_id", shard_id) + setattr(value, "_tle_remote_scope", scope) + setattr(value.type, "_tle_remote_shard_id", shard_id) + setattr(value.type, "_tle_remote_scope", scope) + return value, cursor + 1 diff --git a/third_party/tle/CMakeLists.txt b/third_party/tle/CMakeLists.txt new file mode 100644 index 0000000000..08d4c270ad --- /dev/null +++ b/third_party/tle/CMakeLists.txt @@ -0,0 +1,9 @@ +# Include and generated-file paths for all subdirs +set(TLE_INCLUDE_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/include") +set(TLE_INCLUDE_BINARY "${CMAKE_CURRENT_BINARY_DIR}/include") + +include_directories(${TLE_INCLUDE_SOURCE}) +include_directories(${TLE_INCLUDE_BINARY}) +add_subdirectory(include) +add_subdirectory(lib) +add_subdirectory(python) diff --git a/third_party/tle/REANME.md b/third_party/tle/REANME.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/third_party/tle/include/CMakeLists.txt b/third_party/tle/include/CMakeLists.txt new file mode 100644 index 0000000000..1688254946 --- /dev/null +++ b/third_party/tle/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(tle-dsa/Dialect/IR) diff --git a/third_party/tle/include/tle-dsa/Conversion/DsaToCore/DsaToCore.h b/third_party/tle/include/tle-dsa/Conversion/DsaToCore/DsaToCore.h new file mode 100644 index 0000000000..f121aa6db8 --- /dev/null +++ b/third_party/tle/include/tle-dsa/Conversion/DsaToCore/DsaToCore.h @@ -0,0 +1,17 @@ +#ifndef TLE_DSA_CONVERSION_DSATOMCORE_H +#define TLE_DSA_CONVERSION_DSATOMCORE_H + +#include + +namespace mlir { +class Pass; +} // namespace mlir + +namespace mlir::dsa { + +std::unique_ptr createDsaMemoryToCorePass(); +void registerDsaMemoryToCorePass(); + +} // namespace mlir::dsa + +#endif // TLE_DSA_CONVERSION_DSATOMCORE_H diff --git a/third_party/tle/include/tle-dsa/Dialect/IR/CMakeLists.txt b/third_party/tle/include/tle-dsa/Dialect/IR/CMakeLists.txt new file mode 100644 index 0000000000..92f59801cb --- /dev/null +++ b/third_party/tle/include/tle-dsa/Dialect/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS DsaDialect.td) +mlir_tablegen(DsaOpsDialect.h.inc -gen-dialect-decls -dialect=dsa) +mlir_tablegen(DsaOpsDialect.cpp.inc -gen-dialect-defs -dialect=dsa) +add_public_tablegen_target(TleDsaDialectIncGen) + +set(LLVM_TARGET_DEFINITIONS DsaDialect.td) +mlir_tablegen(DsaOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=dsa) +mlir_tablegen(DsaOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=dsa) +add_public_tablegen_target(TleDsaTypesIncGen) + +set(LLVM_TARGET_DEFINITIONS DsaOps.td) +mlir_tablegen(DsaOps.h.inc -gen-op-decls) +mlir_tablegen(DsaOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(TleDsaOpsIncGen) diff --git a/third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.h b/third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.h new file mode 100644 index 0000000000..4dd49464ec --- /dev/null +++ b/third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.h @@ -0,0 +1,33 @@ +//===- DsaDialect.h - TLE DSA dialect ---------------------------*- C++ -*-===// +// +// Template dialect for TLE-Struct style DSA extensions. +// +//===----------------------------------------------------------------------===// + +#ifndef TLE_DSA_DIALECT_IR_DSADIALECT_H +#define TLE_DSA_DIALECT_IR_DSADIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +// DsaOps.td uses TT_Tensor / TT_Ptr / TT_Int type constraints from the +// Triton dialect, so the generated verifiers need these types visible. +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "tle-dsa/Dialect/IR/DsaOpsDialect.h.inc" + +namespace mlir { +class PatternRewriter; +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "tle-dsa/Dialect/IR/DsaOpsTypes.h.inc" + +#define GET_OP_CLASSES +#include "tle-dsa/Dialect/IR/DsaOps.h.inc" + +#endif // TLE_DSA_DIALECT_IR_DSADIALECT_H diff --git a/third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.td b/third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.td new file mode 100644 index 0000000000..b314470491 --- /dev/null +++ b/third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.td @@ -0,0 +1,68 @@ +//===- DsaDialect.td - TLE DSA dialect --------------------*- tablegen -*-===// +// +// Template dialect for TLE-Struct style DSA extensions. +// +//===----------------------------------------------------------------------===// + +#ifndef TLE_DSA_DIALECT +#define TLE_DSA_DIALECT + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Dialect definition. +//===----------------------------------------------------------------------===// + +def Dsa_Dialect : Dialect { + let name = "dsa"; + let summary = "TLE DSA struct dialect (template)"; + let cppNamespace = "::mlir::dsa"; + let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + void registerTypes(); + }]; +} + +//===----------------------------------------------------------------------===// +// Type definitions. +//===----------------------------------------------------------------------===// + +class Dsa_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def DsaBufferType : Dsa_Type<"Buffer", "buffer"> { + let summary = "Opaque local buffer handle"; + let description = [{ + This is a minimal, parameterized buffer handle type for DSA extensions. + The backend defines the lowering semantics (e.g. mapping to SRAM/UB/NZ, etc). + }]; + + let parameters = (ins + "Type":$elementType, + OptionalParameter<"Attribute">:$memorySpace + ); + + let builders = [ + TypeBuilder<(ins + "Type":$elementType, + CArg<"Attribute", "nullptr">:$memorySpace), [{ + return Base::get($_ctxt, elementType, memorySpace); + }]> + ]; + let skipDefaultBuilders = 1; + + let assemblyFormat = "`<` $elementType (`,` $memorySpace^)? `>`"; +} + +//===----------------------------------------------------------------------===// +// Base op definition. +//===----------------------------------------------------------------------===// + +class Dsa_Op traits = []> : + Op; + +#endif // TLE_DSA_DIALECT diff --git a/third_party/tle/include/tle-dsa/Dialect/IR/DsaOps.td b/third_party/tle/include/tle-dsa/Dialect/IR/DsaOps.td new file mode 100644 index 0000000000..925fcb19e2 --- /dev/null +++ b/third_party/tle/include/tle-dsa/Dialect/IR/DsaOps.td @@ -0,0 +1,78 @@ +//===- DsaOps.td - TLE DSA dialect ops --------------------*- tablegen -*-===// +// +// Template ops for TLE-Struct style DSA extensions. +// +//===----------------------------------------------------------------------===// + +#ifndef TLE_DSA_OPS +#define TLE_DSA_OPS + +include "tle-dsa/Dialect/IR/DsaDialect.td" +include "mlir/Dialect/LLVMIR/LLVMTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/CommonTypeConstraints.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" + +//===----------------------------------------------------------------------===// +// dsa.alloc +//===----------------------------------------------------------------------===// + +def Dsa_AllocOp : Dsa_Op<"alloc", [MemoryEffects<[MemAlloc]>]> { + let summary = "Allocate a DSA local buffer"; + let arguments = (ins DenseI64ArrayAttr:$shape); + let results = (outs AnyRankedOrUnrankedMemRef:$result); + let assemblyFormat = [{ + $shape attr-dict `:` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// dsa.copy +//===----------------------------------------------------------------------===// + +def Dsa_CopyOp : Dsa_Op<"copy", + [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "Copy between DSA local buffers"; + let arguments = (ins AnyRankedOrUnrankedMemRef:$src, AnyRankedOrUnrankedMemRef:$dst); + let results = (outs); + let assemblyFormat = [{ + $src `,` $dst attr-dict `:` type($src) `,` type($dst) + }]; +} + +def DsaLocalPointerResultType : AnyTypeOf<[TT_Tensor, TT_Ptr]>; +def DsaLocalPointerIndexType : AnyTypeOf<[TT_Tensor, TT_Int]>; +def DsaRemotePointerType : AnyTypeOf<[TT_Tensor, TT_Ptr]>; +def DsaRemoteShardIdType : AnyTypeOf<[TT_Tensor, TT_Int]>; + +//===----------------------------------------------------------------------===// +// dsa.local_pointers / dsa.remote_pointers / dsa.distributed_barrier +// These are DSA-side structural counterparts of the official TLE distributed +// and pointer-building ops. They let the frontend move toward the same +// structured programming model without binding the design to GPU dialect names. +//===----------------------------------------------------------------------===// + +def Dsa_LocalPointersOp : Dsa_Op<"local_pointers", [Pure]> { + let arguments = (ins AnyRankedOrUnrankedMemRef:$src, + Variadic:$indices); + let results = (outs DsaLocalPointerResultType:$result); +} + +def Dsa_DistributedBarrierOp : Dsa_Op<"distributed_barrier", + [MemoryEffects<[MemRead, MemWrite]>]> { + let arguments = (ins + OptionalAttr:$group_kind, + OptionalAttr:$group_rank, + OptionalAttr:$group_shape, + OptionalAttr:$group_axes, + OptionalAttr:$group_mask + ); + let assemblyFormat = "attr-dict"; +} + +def Dsa_RemotePointersOp : Dsa_Op<"remote_pointers", [Pure]> { + let arguments = (ins DsaRemotePointerType:$src, DsaRemoteShardIdType:$shard_id); + let results = (outs DsaRemotePointerType:$result); +} + +#endif // TLE_DSA_OPS diff --git a/third_party/tle/lib/CMakeLists.txt b/third_party/tle/lib/CMakeLists.txt new file mode 100644 index 0000000000..1037d1a72d --- /dev/null +++ b/third_party/tle/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Dialect/IR) +add_subdirectory(Conversion/DsaToCore) diff --git a/third_party/tle/lib/Conversion/DsaToCore/CMakeLists.txt b/third_party/tle/lib/Conversion/DsaToCore/CMakeLists.txt new file mode 100644 index 0000000000..6676386232 --- /dev/null +++ b/third_party/tle/lib/Conversion/DsaToCore/CMakeLists.txt @@ -0,0 +1,17 @@ +add_triton_library(TleDsaToCore + DsaToCore.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransformUtils + MLIRRewrite + TleDsaIR + TritonIR +) + +target_include_directories(TleDsaToCore + PUBLIC + ${TLE_INCLUDE_SOURCE} + ${TLE_INCLUDE_BINARY} +) diff --git a/third_party/tle/lib/Conversion/DsaToCore/DsaToCore.cpp b/third_party/tle/lib/Conversion/DsaToCore/DsaToCore.cpp new file mode 100644 index 0000000000..a5556c37e2 --- /dev/null +++ b/third_party/tle/lib/Conversion/DsaToCore/DsaToCore.cpp @@ -0,0 +1,72 @@ +//===- DsaToCore.cpp - Lower dsa memory ops to core dialects ----*- C++ -*-===// +// +// Lower dsa.alloc/copy into standard memref ops +// ops before one-shot-bufferize. +// +//===----------------------------------------------------------------------===// + +#include "tle-dsa/Conversion/DsaToCore/DsaToCore.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tle-dsa/Dialect/IR/DsaDialect.h" + +using namespace mlir; + +namespace { + +struct DsaAllocToMemRefPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::dsa::AllocOp op, + PatternRewriter &rewriter) const override { + auto memrefTy = dyn_cast(op.getResult().getType()); + if (!memrefTy) + return failure(); + // tx81-memref-to-llvm expects integer/default memref address spaces. + // Canonicalize away non-integer memory-space attrs (e.g. "local"). + if (Attribute ms = memrefTy.getMemorySpace(); ms && !isa(ms)) { + memrefTy = MemRefType::get(memrefTy.getShape(), memrefTy.getElementType(), + memrefTy.getLayout()); + } + rewriter.replaceOpWithNewOp(op, memrefTy); + return success(); + } +}; + +struct DsaCopyToMemRefPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::dsa::CopyOp op, + PatternRewriter &rewriter) const override { + rewriter.create(op.getLoc(), op.getSrc(), op.getDst()); + rewriter.eraseOp(op); + return success(); + } +}; + +struct DsaMemoryToCorePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DsaMemoryToCorePass) + StringRef getArgument() const final { return "dsa-memory-to-core"; } + StringRef getDescription() const final { + return "Lower dsa.alloc/copy to memref"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add( + &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace mlir::dsa { +std::unique_ptr createDsaMemoryToCorePass() { + return std::make_unique(); +} + +void registerDsaMemoryToCorePass() { PassRegistration(); } +} // namespace mlir::dsa diff --git a/third_party/tle/lib/Dialect/IR/CMakeLists.txt b/third_party/tle/lib/Dialect/IR/CMakeLists.txt new file mode 100644 index 0000000000..870dc009af --- /dev/null +++ b/third_party/tle/lib/Dialect/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +add_triton_library(TleDsaIR + DsaDialect.cpp + + DEPENDS + TleDsaDialectIncGen + TleDsaTypesIncGen + TleDsaOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + TritonIR +) + +target_include_directories(TleDsaIR + PUBLIC + ${TLE_INCLUDE_SOURCE} + ${TLE_INCLUDE_BINARY} +) diff --git a/third_party/tle/lib/Dialect/IR/DsaDialect.cpp b/third_party/tle/lib/Dialect/IR/DsaDialect.cpp new file mode 100644 index 0000000000..fe80e1a1ce --- /dev/null +++ b/third_party/tle/lib/Dialect/IR/DsaDialect.cpp @@ -0,0 +1,39 @@ +//===- DsaDialect.cpp - TLE DSA dialect -------------------------*- C++ -*-===// +// +// Template dialect for TLE-Struct style DSA extensions. +// +//===----------------------------------------------------------------------===// + +#include "tle-dsa/Dialect/IR/DsaDialect.h" + +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +namespace mlir::dsa { + +void DsaDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "tle-dsa/Dialect/IR/DsaOps.cpp.inc" + >(); + registerTypes(); +} + +void DsaDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "tle-dsa/Dialect/IR/DsaOpsTypes.cpp.inc" + >(); +} + +} // namespace mlir::dsa + +#include "tle-dsa/Dialect/IR/DsaOpsDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "tle-dsa/Dialect/IR/DsaOps.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "tle-dsa/Dialect/IR/DsaOpsTypes.cpp.inc" diff --git a/third_party/tle/python/CMakeLists.txt b/third_party/tle/python/CMakeLists.txt new file mode 100644 index 0000000000..ea37f1851f --- /dev/null +++ b/third_party/tle/python/CMakeLists.txt @@ -0,0 +1,7 @@ +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonTleDsaTemplate + ${CMAKE_CURRENT_SOURCE_DIR}/triton_tle_dsa.cc + LINK_LIBS TleDsaIR + ) + target_link_libraries(TritonTleDsaTemplate PRIVATE Python3::Module pybind11::headers) +endif() diff --git a/third_party/tle/python/triton_tle_dsa.cc b/third_party/tle/python/triton_tle_dsa.cc new file mode 100644 index 0000000000..bf9385f69d --- /dev/null +++ b/third_party/tle/python/triton_tle_dsa.cc @@ -0,0 +1,123 @@ +//===- triton_tle_dsa.cc - TLE DSA builder injection -------------*- C++ +//-*-===// +// +// Template pybind that injects DSA dialect ops into TritonOpBuilder. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" + +#include "ir.h" +#include "tle-dsa/Dialect/IR/DsaDialect.h" + +namespace py = pybind11; +using namespace mlir; + +namespace dsa = mlir::dsa; + +static void init_triton_tle_ir(py::module m) { + (void)m; + auto core_ir = py::module::import("triton._C.libtriton.ir"); + auto builder_cls = + core_ir.attr("builder").cast>(); + + builder_cls + .def("create_dsa_alloc", + [](TritonOpBuilder &self, py::object shapeObj, + py::object elementTyObj) -> Value { + self.getContext()->getOrLoadDialect(); + auto &b = self.getBuilder(); + std::vector dims; + if (py::isinstance(shapeObj)) { + dims.push_back(py::cast(shapeObj)); + } else { + py::iterable shape = + py::reinterpret_borrow(shapeObj); + dims.reserve(py::len(shape)); + for (py::handle dim : shape) + dims.push_back(py::cast(dim)); + } + auto shapeAttr = DenseI64ArrayAttr::get(b.getContext(), dims); + Type elementTy = py::cast(elementTyObj); + auto bufTy = MemRefType::get(dims, elementTy); + auto op = self.getBuilder().create(self.getLastLoc(), + bufTy, shapeAttr); + return op.getResult(); + }) + .def("create_dsa_copy", + [](TritonOpBuilder &self, Value src, Value dst) -> void { + self.getContext()->getOrLoadDialect(); + self.getBuilder().create(self.getLastLoc(), src, dst); + }) + .def("create_dsa_local_pointers", + [](TritonOpBuilder &self, Type resultTy, Value src, + py::args args) -> OpState { + self.getContext()->getOrLoadDialect(); + llvm::SmallVector indices; + indices.reserve(args.size()); + for (const auto &arg : args) + indices.push_back(py::cast(arg)); + return self.create(resultTy, src, indices); + }) + .def("create_dsa_remote_pointers", + [](TritonOpBuilder &self, Type resultTy, Value src, + Value shardId) -> OpState { + self.getContext()->getOrLoadDialect(); + return self.create(resultTy, src, shardId); + }) + .def("create_dsa_distributed_barrier", + [](TritonOpBuilder &self, const std::string &groupKind, + const std::vector &groupShape, + const std::vector &groupAxes, + const std::vector &groupMask) -> void { + self.getContext()->getOrLoadDialect(); + auto &builder = self.getBuilder(); + auto *ctx = builder.getContext(); + StringAttr kindAttr; + IntegerAttr rankAttr; + DenseI32ArrayAttr shapeAttr; + DenseI32ArrayAttr axesAttr; + DenseI32ArrayAttr maskAttr; + + if (!groupKind.empty()) { + kindAttr = builder.getStringAttr(groupKind); + rankAttr = builder.getI32IntegerAttr( + static_cast(groupShape.size())); + shapeAttr = DenseI32ArrayAttr::get(ctx, groupShape); + axesAttr = DenseI32ArrayAttr::get(ctx, groupAxes); + if (!groupMask.empty()) + maskAttr = DenseI32ArrayAttr::get(ctx, groupMask); + } + + self.create( + kindAttr, rankAttr, shapeAttr, axesAttr, maskAttr); + }); +} + +// void init_triton_tle(py::module &&m, const char *submodule_name = nullptr) { +// if (submodule_name && *submodule_name != '\0') +// m = m.def_submodule(submodule_name); +// py::module local_m = std::move(m); +// local_m.def("load_dialects", [](mlir::MLIRContext &context) { +// context.getOrLoadDialect(); +// }); +// init_triton_tle_ir(std::move(local_m)); +// } + +void init_triton_tle(py::module &&m) { + py::module local_m = std::move(m); + + local_m.def("load_dialects", [](mlir::MLIRContext &context) { + context.getOrLoadDialect(); + }); + + init_triton_tle_ir(std::move(local_m)); +} diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py index b6b4aa22bc..8580641586 100644 --- a/third_party/tsingmicro/backend/compiler.py +++ b/third_party/tsingmicro/backend/compiler.py @@ -118,9 +118,9 @@ def _ttir_to_coreir(mod): coreir_to_mk_mode = "--core-dialects-to-mk=precision-priority" args = [ - triton_opt_path, src_path, "--triton-to-core-dialects", "--linalg-tiling", f"{coreir_to_mk_mode}", - "--linalg-fusion", "--legalize-tensor-form-loops", "--one-shot-bufferize", - "--convert-bufferization-to-memref", "--cse", "--canonicalize" + triton_opt_path, src_path, "--triton-to-core-dialects", "--tle-to-mk", "--dsa-memory-to-core", + "--linalg-tiling", f"{coreir_to_mk_mode}", "--linalg-fusion", "--legalize-tensor-form-loops", + "--one-shot-bufferize", "--convert-bufferization-to-memref", "--cse", "--canonicalize" ] if os.getenv("TRITON_DEBUG", "0") == "1": args.append("--mlir-print-debuginfo") diff --git a/third_party/tsingmicro/bin/RegisterTritonDialects.h b/third_party/tsingmicro/bin/RegisterTritonDialects.h index bc253cceed..fadfcea826 100644 --- a/third_party/tsingmicro/bin/RegisterTritonDialects.h +++ b/third_party/tsingmicro/bin/RegisterTritonDialects.h @@ -33,7 +33,10 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "magic-kernel/Conversion/TLEToMK/Passes.h" #include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "third_party/tle/include/tle-dsa/Conversion/DsaToCore/DsaToCore.h" +#include "third_party/tle/include/tle-dsa/Dialect/IR/DsaDialect.h" #include "triton-shared/Conversion/ConvertTritonPtr/Passes.h" #include "triton-shared/Conversion/ReconcilePtrCasts/Passes.h" #include "triton-shared/Conversion/StructuredToMemref/Passes.h" @@ -83,6 +86,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonPasses(); mlir::triton::gpu::registerTritonGPUPasses(); mlir::registerLinalgPasses(); + mlir::dsa::registerDsaMemoryToCorePass(); + mlir::triton::registerTLEToMKPass(); mlir::registerTritonNvidiaGPUPasses(); mlir::test::registerTestAliasPass(); mlir::test::registerTestAlignmentPass(); @@ -208,5 +213,5 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::tensor::TensorDialect, mlir::memref::MemRefDialect, mlir::affine::AffineDialect, mlir::bufferization::BufferizationDialect, mlir::mk::MagicKernelDialect, mlir::tx::Tx81Dialect, - mlir::addr::AddressDialect>(); + mlir::addr::AddressDialect, mlir::dsa::DsaDialect>(); } diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt index f879f691f1..0998f20a3f 100644 --- a/third_party/tsingmicro/crt/CMakeLists.txt +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -10,6 +10,14 @@ if(NOT DEFINED TARGET) endif() endif() +if(NOT DEFINED TX8_YOC_RT_THREAD_SMP) + if(DEFINED ENV{TX8_YOC_RT_THREAD_SMP}) + set(TX8_YOC_RT_THREAD_SMP $ENV{TX8_YOC_RT_THREAD_SMP}) + else() + message(FATAL_ERROR "TX8_YOC_RT_THREAD_SMP environment variable is not defined") + endif() +endif() + if(NOT DEFINED XUANTIE_NAME) if(DEFINED ENV{XUANTIE_NAME}) set(XUANTIE_NAME $ENV{XUANTIE_NAME}) @@ -53,6 +61,7 @@ include_directories(${TX8_DEPS_ROOT}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include/${TARGET}) include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${TX8_YOC_RT_THREAD_SMP}/interface/op_fw_sim_if/peripheral/include/) # Set build type default if(NOT CMAKE_BUILD_TYPE) diff --git a/third_party/tsingmicro/crt/lib/Tx81/recv.c b/third_party/tsingmicro/crt/lib/Tx81/recv.c new file mode 100644 index 0000000000..a45606e24a --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/recv.c @@ -0,0 +1,26 @@ +//===------------------------ recv.c --------------------------------------===// +// +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Recv, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +// #include "instr_adapter_plat.h" + +#include "direct_dte_and_fsm.h" +#include "tx81.h" +#include "tx81_spm.h" +#include +#include + +uint32_t __get_pid(uint32_t); + +// Blockingly receive data from a source tile into a destination buffer. +// Returns the destination buffer address. +void __Recv(int64_t chip_x, int64_t chip_y, int64_t die_id, int64_t tile_id, + void *dst, uint32_t elem_bytes, uint32_t data_size) { + // TODO + return; +} diff --git a/third_party/tsingmicro/crt/lib/Tx81/send.c b/third_party/tsingmicro/crt/lib/Tx81/send.c new file mode 100644 index 0000000000..31608408af --- /dev/null +++ b/third_party/tsingmicro/crt/lib/Tx81/send.c @@ -0,0 +1,166 @@ +//===------------------------ send.c --------------------------------------===// +// +// +//===----------------------------------------------------------------------===// +// +// Runtime API of MLIR operation tx::Send, see Tx81Ops.td for detail. +// +//===----------------------------------------------------------------------===// + +// #include "instr_adapter_plat.h" +#include "direct_dte_and_fsm.h" +#include "tx81.h" +#include "tx81_spm.h" +#include +#include +#include + +#define MAX_TILE_NUM 16 + +const uint32_t tile_physical_relation[MAX_TILE_NUM] = { + 0, 1, 2, 3, 7, 11, 15, 14, 13, 12, 8, 9, 10, 6, 5, 4}; + +// 获取物理连接最近的下一个tile id +int32_t getNextNearestTileId(uint32_t tileId) { + for (int i = 0; i < MAX_TILE_NUM; i++) { + if (tile_physical_relation[i] == tileId) { + if (i != (MAX_TILE_NUM - 1)) { + return tile_physical_relation[++i]; + } else { + return tile_physical_relation[0]; + } + } + } + + return -1; +} + +// 获取物理连接最近的上一个tile id +int32_t getPrevNearestTileId(uint32_t tileId) { + for (int i = 0; i < MAX_TILE_NUM; i++) { + if (tile_physical_relation[i] == tileId) { + if (i != 0) { + return tile_physical_relation[--i]; + } else { + return tile_physical_relation[MAX_TILE_NUM - 1]; + } + } + } + + return -1; +} + +void tile_sync_by_spm_single_direction(int32_t tile_this, int32_t tile_a, + int32_t tile_x, int32_t tile_y, + uint32_t this_sync_spm_index, + uint32_t other_sync_spm_index) { + // this_debug_spm_ptr 只用于板端记录调试信息 + volatile uint32_t *this_debug_spm_ptr = + (volatile uint32_t *)get_spm_memory_mapping(SINGLE_SPM_SYNC_DEBUG_ADDR); + this_debug_spm_ptr[0] = tile_this; + this_debug_spm_ptr[1] = tile_a; + this_debug_spm_ptr[2] = tile_x; + this_debug_spm_ptr[3] = tile_y; + this_debug_spm_ptr[4] = this_sync_spm_index; + this_debug_spm_ptr[5] = other_sync_spm_index; + + uint64_t tile_a_spm = + get_tile_spm_addr_base(tile_a, tile_x, tile_y) + SINGLE_SPM_SYNC_ADDR; + *(uint64_t *)(this_debug_spm_ptr + 6) = tile_a_spm; + volatile uint32_t *this_spm_ptr = + (volatile uint32_t *)get_spm_memory_mapping(SINGLE_SPM_SYNC_ADDR); + *(uint64_t *)(this_debug_spm_ptr + 8) = (uint64_t)this_spm_ptr; + volatile uint32_t *tile_a_spm_ptr = (volatile uint32_t *)(tile_a_spm); + + this_debug_spm_ptr[10] = 0; // 写对端开始 + tile_a_spm_ptr[other_sync_spm_index] = 1; // forward + + this_debug_spm_ptr[10] = 1; // 写对端结束 + + while (!this_spm_ptr[this_sync_spm_index]) { + } + + this_spm_ptr[this_sync_spm_index] = 0; // forward +} + +#define SCFG_TILE_ID_ADDR 0x6A0058 // KUIPER_ADDR_MAP_REG_BASE 0x6A0000 + +uint32_t __get_pid(uint32_t); + +int initTileId(uint32_t tileId, uint32_t rowLength) { + // init 1D tile-id + *(volatile uint32_t *)(get_spm_memory_mapping(TILE_ID_ADDR)) = tileId; + // init 2D logic id + *(volatile uint32_t *)(get_spm_memory_mapping(LOGIC_ID_ADDR)) = + *(volatile uint32_t *)(SCFG_TILE_ID_ADDR); + // init row-length + *(volatile uint32_t *)(get_spm_memory_mapping(ROW_LENGTH_ADDR)) = rowLength; + *(volatile uint32_t *)(get_spm_memory_mapping(INNER_CHIP_ERROR_CODE)) = 0; + + // __LOG__(KCORE_LOG_DEBUG, "logic_id:0x%x, tileId:%u, rowLength:%u\n", + // *(volatile uint32_t *)(get_spm_memory_mapping(LOGIC_ID_ADDR)), + // tileId, rowLength); + return 0; +} + +// Asynchronously send data to a destination tile. +// The operation reads from the given source buffer and sends it to the remote +// tile. The operation is non-blocking and returns immediately. +void __Send(int64_t chipX, int64_t chipY, int64_t dieId, int64_t tileId, + void *restrict dst, void *restrict src, uint32_t elem_bytes, + uint64_t data_size) { + uint32_t coreIndex = __get_pid(0); // 全局tile id + initTileId(coreIndex, 4); + // __EP_LOG__(0, "+++++++++ Send444 dst:%lx, src: %lx, cur_tileId:%d, + // nextTileId: %d, data_size: %d\n", dst, src, coreIndex, tileId, data_size); + (void)chipX; + (void)chipY; + (void)dieId; + int64_t nextTileId = getNextNearestTileId(coreIndex); + int64_t preTileId = getPrevNearestTileId(coreIndex); + // tile_sync_by_spm_single_direction(coreIndex, preTileId, 4, 4, 0, 0); + const TsmOperatorPointer *intrinsic = g_intrinsic(); + + int fringFsmId = DIRECT_DTE_FSM_ID_0; + int remottFringFsmId = DIRECT_DTE_FSM_ID_0; + + void *fdteNode = direct_dte_attach(0); + void *fringFsmHd = direct_fsm_monitor_init(fringFsmId, 0, data_size, 1); + + TsmStream *stream = (TsmStream *)(intrinsic->stream_pointer); + uint64_t nextTileBaseAddr = get_tile_spm_addr_base(nextTileId, 4, 4); + + stream->wait_finish(); + + // __EP_LOG__(0, "fdte info base: %ld, dst: %ld, src: %ld, remote_fsm_id: %d, + // data_size: %d, dst_tileId: %d, this_tileId: %d\n", + // nextTileBaseAddr, nextTileBaseAddr + (uint64_t)dst, (uint64_t)src, + // remottFringFsmId, data_size, tileId, coreIndex); + DirectDTESendInfo fdteInfo = {.src_addr = (uint64_t)src, + .dst_addr = nextTileBaseAddr + (uint64_t)dst, + .length = data_size, + .remote_fsm_id = remottFringFsmId, + .mode = 0, // unicast + .dst_tile = nextTileId, + .tile_this = coreIndex, + .dte_node = fdteNode}; + set_direct_fsm_monitor_dst_addr(fringFsmId, nextTileBaseAddr + (uint64_t)dst); + tile_sync_by_spm_single_direction(coreIndex, preTileId, 4, 4, 0, 0); + direct_dte_send_async(&fdteInfo); // 把当前数据异步发送给下一个tile + // __EP_LOG__(KCORE_LOG_DEBUG, "send data to next tile: %u, current + // tile:%u.\n", + // getNextNearestTileId(coreIndex), coreIndex); + + direct_fsm_monitor_receive(coreIndex, preTileId, + fringFsmHd); // 阻塞接收前一个tile发送的数据 + // __EP_LOG__(KCORE_LOG_DEBUG, + // "receive data from coreIndex: %u, current tile:%u.\n", + // getPrevNearestTileId(coreIndex), coreIndex); + direct_dte_wait_done(&fdteInfo); // 等待异步发送完成 + TsmWaitfinish(); + + tile_sync_by_spm_single_direction(coreIndex, preTileId, 4, 4, 0, 0); + direct_dte_release(fdteNode); + direct_fsm_monitor_deinit(fringFsmHd); + // __EP_LOG__(0, "-------- Send\n") +} diff --git a/third_party/tsingmicro/examples/tle/test_tle_dsa_noc_gemm_4096.py b/third_party/tsingmicro/examples/tle/test_tle_dsa_noc_gemm_4096.py new file mode 100644 index 0000000000..b7ed891c2f --- /dev/null +++ b/third_party/tsingmicro/examples/tle/test_tle_dsa_noc_gemm_4096.py @@ -0,0 +1,176 @@ +import imp +import torch +import triton +import triton.language as tl +from triton.experimental import tle + +TILE_NUM = 16 +M = 4096 +K = 1024 +N = 4096 +BLOCK_M = M // TILE_NUM +BLOCK_K = K +SUB_N = N // TILE_NUM + +TILE_PHYSICAL_RELATION = [0, 1, 2, 3, 7, 11, 15, 14, 13, 12, 8, 9, 10, 6, 5, 4] + +MESH = tle.device_mesh( + None, + _shape=(TILE_NUM, ), + _dim_names=("tile", ), + _physical_ids=tuple(TILE_PHYSICAL_RELATION), +) + + +@triton.jit +def dsa_shift_n_gemm_kernel( + A_ptr, + B_ptr, + C_ptr, + send_next_tile_lut_ptr, + ring_index_lut_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + SUB_N: tl.constexpr, + TILE_NUM: tl.constexpr, +): + pid = tl.program_id(0) + send_next_tile = tl.load(send_next_tile_lut_ptr + pid) + ring_index = tl.load(ring_index_lut_ptr + pid) + + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_K) + + a_ptrs = A_ptr + offs_m[:, None] * K + offs_k[None, :] + a = tl.load(a_ptrs) + + shard_idx = ring_index + offs_sub_n = shard_idx * SUB_N + tl.arange(0, SUB_N) + b_ptrs = B_ptr + offs_k[:, None] * N + offs_sub_n[None, :] + b_init = tl.load(b_ptrs) + + send_buf = tle.language.dsa.alloc((BLOCK_K, SUB_N), tl.float16) + recv_buf = tle.language.dsa.alloc((BLOCK_K, SUB_N), tl.float16) + + offs_buf_k = tl.arange(0, BLOCK_K)[:, None] + tl.zeros((1, SUB_N), dtype=tl.int32) + offs_buf_n = tl.arange(0, SUB_N)[None, :] + tl.zeros((BLOCK_K, 1), dtype=tl.int32) + + send_ptr = tle.language.dsa.local_ptr(send_buf, [offs_buf_k, offs_buf_n]) + recv_ptr = tle.language.dsa.local_ptr(recv_buf, [offs_buf_k, offs_buf_n]) + + remote_recv_buf = tle.remote(recv_buf, send_next_tile) + remote_recv_ptr = tle.language.dsa.local_ptr(remote_recv_buf, [offs_buf_k, offs_buf_n]) + + tl.store(send_ptr, b_init) + + for step in range(TILE_NUM): + b_cur = tl.load(send_ptr) + c_part = tl.dot(a, b_cur, out_dtype=tl.float32) + + offs_n = shard_idx * SUB_N + tl.arange(0, SUB_N) + c_ptrs = C_ptr + offs_m[:, None] * N + offs_n[None, :] + tl.store(c_ptrs, c_part.to(tl.float16)) + + if step < TILE_NUM - 1: + tl.store(remote_recv_ptr, tl.load(send_ptr)) + # tle.distributed_barrier(MESH) + tl.store(send_ptr, tl.load(recv_ptr)) + # tle.distributed_barrier(MESH) + + shard_idx = tl.where(shard_idx == 0, TILE_NUM - 1, shard_idx - 1) + + +def build_ring_luts(mesh, device): + phys = mesh.physical_ids + n = mesh.size + send_next = torch.empty(n, dtype=torch.int32) + ring_index = torch.empty(n, dtype=torch.int32) + for i in range(n): + cur = phys[i] + nxt = phys[(i + 1) % n] + send_next[cur] = nxt + ring_index[cur] = i + return send_next.to(device), ring_index.to(device) + + +def run(): + device = triton.runtime.driver.active.get_active_torch_device() + a = torch.randn((M, K), device=device, dtype=torch.float16) + b = torch.randn((K, N), device=device, dtype=torch.float16) + c = torch.empty((M, N), device=device, dtype=torch.float16) + + send_next_lut, ring_index_lut = build_ring_luts(MESH, device) + + grid = (TILE_NUM, ) + dsa_shift_n_gemm_kernel[grid]( + a, + b, + c, + send_next_lut, + ring_index_lut, + M=M, + N=N, + K=K, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + SUB_N=SUB_N, + TILE_NUM=TILE_NUM, + ) + a_f32 = a.cpu().float() + b_f32 = b.cpu().float() + c_f32 = c.cpu().float() + ref = torch.matmul(a_f32, b_f32) + + max_diff = (c_f32 - ref).abs().max().item() + passed = torch.allclose(c_f32, ref, atol=1e-1, rtol=1e-1) + + print(f"Shift-N Ring-GEMM: M={M}, N={N}, K={K}, TILE_NUM={TILE_NUM}") + print(f"BLOCK_M={BLOCK_M}, BLOCK_K={BLOCK_K}, SUB_N={SUB_N}") + print(f"Physical ring: {TILE_PHYSICAL_RELATION}") + print(f"max_abs_diff = {max_diff:.6f}") + + if passed: + print("PASS") + else: + print("FAIL") + diff = (c_f32 - ref).abs() + idx = diff.argmax().item() + r, col = idx // N, idx % N + print(f" worst @ ({r},{col}): got={c_f32[r,col]:.4f} ref={ref[r,col]:.4f}") + + # import flag_gems + # with flag_gems.use_gems(): + # ref_out = torch.mm(a, b) + # # Compare on CPU to avoid unsupported torch.testing ops on TXDA backend. + # res_out = c.detach().cpu().to(torch.float32) + # golden_cpu = ref_out.detach().cpu().to(torch.float32) + # # ref = torch.matmul(a_cpu, b_cpu) + # max_abs = (res_out - golden_cpu).abs().max().item() + + # # diff = (c_cpu - ref).abs() + # # flat_idx = diff.argmax().item() + # # row = flat_idx // diff.shape[1] + # # col = flat_idx % diff.shape[1] + # # print(f"[DEBUG] split-k max_abs_diff={max_abs}") + # # print(f"[DEBUG] split-k worst_idx=({row}, {col})") + # # print(f"[DEBUG] c_cpu[{row},{col}]={c_cpu[row, col].item()}") + # # print(f"[DEBUG] ref [{row},{col}]={ref[row, col].item()}") + # # print("[DEBUG] c_cpu[0:4, 0:8]=") + # # print(c_cpu[0:4, 0:8]) + # # print("[DEBUG] ref[0:4, 0:8]=") + # # print(ref[0:4, 0:8]) + + # if not torch.allclose(res_out, golden_cpu, atol=1e-3, rtol=1e-2): + # raise AssertionError(f"Mismatch: max_abs_diff={max_abs}") + # print( + # f"PASS: M={M}, N={N}, K={K}, BLOCK_M={BLOCK_M}, " + # f"BLOCK_K={BLOCK_K}, TILE_NUM={TILE_NUM}, " + # f"mode=ring, max_abs_diff={max_abs}" + # ) + + +if __name__ == "__main__": + run() diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt index 316cac76a0..ad2c8e27d4 100644 --- a/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(LinalgToMK) add_subdirectory(CoreDialectsToMK) add_subdirectory(LegalizeTensorFormLoops) +add_subdirectory(TLEToMK) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/CMakeLists.txt b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/CMakeLists.txt new file mode 100644 index 0000000000..883e421e8e --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TLEToMK) +add_public_tablegen_target(TLEToMKConversionPassIncGen) diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.h b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.h new file mode 100644 index 0000000000..3e95c8b18d --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.h @@ -0,0 +1,22 @@ +//===------------------- Passes.h -----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TLE_TO_MK_CONVERSION_PASSES_H +#define TLE_TO_MK_CONVERSION_PASSES_H + +#include "magic-kernel/Conversion/TLEToMK/TLEToMK.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "magic-kernel/Conversion/TLEToMK/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif // TLE_TO_MK_CONVERSION_PASSES_H diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.td b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.td new file mode 100644 index 0000000000..d2b3022870 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/Passes.td @@ -0,0 +1,19 @@ +//===------------------- Passes.td ----------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#ifndef TLE_TO_MK_CONVERSION_PASSES +#define TLE_TO_MK_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TLEToMK : Pass<"tle-to-mk", "mlir::ModuleOp"> { + let summary = "Convert TLE communication operations into magic kernel operations"; + + let options = []; +} + +#endif diff --git a/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/TLEToMK.h b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/TLEToMK.h new file mode 100644 index 0000000000..5218545976 --- /dev/null +++ b/third_party/tsingmicro/include/magic-kernel/Conversion/TLEToMK/TLEToMK.h @@ -0,0 +1,32 @@ +//===------------------- TLEToMK.h -------------------------*- C++ -*---===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering TLE communication ops into mk ops. +// +//===----------------------------------------------------------------------===// + +#ifndef ZTC_CONVERSION_TLE_TO_MK_H +#define ZTC_CONVERSION_TLE_TO_MK_H + +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_DECL +#include "magic-kernel/Conversion/TLEToMK/Passes.h.inc" + +void populateTLEToMKConversionPatterns(RewritePatternSet &patterns); + +// std::unique_ptr> createTLEToMKPass(); + +} // namespace triton +} // namespace mlir + +#endif // ZTC_CONVERSION_TLE_TO_MK_H diff --git a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td index bfb0b7260e..c6b371e6b6 100644 --- a/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td +++ b/third_party/tsingmicro/include/magic-kernel/Dialect/IR/MagicKernelOps.td @@ -49,6 +49,9 @@ def IntTensorOrMemref : def FPTensorOrMemref : AnyTypeOf<[MemRefOf<[AnyFloat]>, RankedTensorOf<[AnyFloat]>], "", "::mlir::ShapedType">; +def MKCommAddrLike : + AnyTypeOf<[I64, TensorOrMemref], "comm-addr-like">; + // // Interfaces // @@ -505,6 +508,59 @@ def BarrierOp : MKOp<"barrier"> { let assemblyFormat = "attr-dict"; } +// ============================================================================= +// Communication Ops (TLE) +// ============================================================================= + +def RemoteStoreOp : MKOp<"remote_store", [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "Store local tensor/memref to a remote tile address"; + + let description = [{ + Asynchronously store data from the current tile to a remote destination + address. The operation reads from the given source buffer (tensor/memref) + and writes it to the remote tile. + }]; + + let arguments = ( + ins + I64:$remote_chip_id_x, // X-coordinate of the remote chip ID + I64:$remote_chip_id_y, // Y-coordinate of the remote chip ID + I64:$remote_die_id, // ID of the remote die + I64:$remote_tile_id, // ID of the remote tile + MKCommAddrLike:$dst_addr, // Remote destination base address (or placeholder) + Arg:$src // Source buffer (read from here) + ); + + let results = (outs); +} + +def RemoteLoadOp : MKOp<"remote_load", [DestinationStyleOpInterface, MemoryEffects<[MemWrite]>]> { + let summary = "Load data from remote tile into a destination buffer"; + + let description = [{ + Destination-style remote load: receive data from a source tile and write it + into the given destination buffer. Returns a tensor view of the received + data (same shape and element type as the destination). + }]; + + let arguments = ( + ins + I64:$remote_chip_id_x, // X-coordinate of the remote chip ID + I64:$remote_chip_id_y, // Y-coordinate of the remote chip ID + I64:$remote_die_id, // ID of the remote die + I64:$remote_tile_id, // ID of the remote tile + Arg:$dst // Destination buffer (receive into here) + ); + + let results = (outs Variadic:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { + return getDstMutable(); + } + }]; +} + def PrintOp : MKOp<"print", [MemoryEffects<[MemWrite]>, DestinationStyleOpInterface]> { let summary = "Print at most a single scalar or 1D TensorOrMemref on each line"; diff --git a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td index 16cd6940f4..289351086a 100644 --- a/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td +++ b/third_party/tsingmicro/include/tsingmicro-tx81/Dialect/IR/Tx81Ops.td @@ -1192,4 +1192,87 @@ def AtomicBarrierOutOp : Tx81Op<"atomic_barrier_out", [MemoryEffects<[MemAlloc, let assemblyFormat = "attr-dict"; } +// ============================================================================= +// 4.22. TsmTile Communication (Recv/Send) +// ============================================================================= + +def RemoteBufferOp : Tx81Op<"remote_buffer", [Pure]> { + let summary = "Build a remote buffer descriptor"; + + let description = [{ + Build a lightweight remote buffer descriptor from remote coordinates and a + destination base address. This op is used to carry `remote(buffer)` semantics + through Tx81 lowering before it is consumed by remote_store. + }]; + + let arguments = ( + ins + I64:$remote_chip_id_x, + I64:$remote_chip_id_y, + I64:$remote_die_id, + I64:$remote_tile_id, + MemRefOrInt:$dst + ); + + let results = (outs I64:$remote_dst); + + let assemblyFormat = [{ + $remote_chip_id_x `,` $remote_chip_id_y `,` $remote_die_id `,` $remote_tile_id `,` $dst attr-dict `:` type($remote_chip_id_x) `,` type($remote_chip_id_y) `,` type($remote_die_id) `,` type($remote_tile_id) `,` type($dst) + }]; +} + +def RemoteLoadOp : Tx81Op<"remote_load", [MemoryEffects<[MemWrite]>]> { + let summary = "Load data from a source tile into a destination buffer"; + + let description = [{ + Receive data from a source tile and write it into the given destination + buffer. + }]; + + let arguments = ( + ins + I64:$remote_chip_id_x, // X-coordinate of the remote chip ID + I64:$remote_chip_id_y, // Y-coordinate of the remote chip ID + I64:$remote_die_id, // ID of the remote die + I64:$remote_tile_id, // ID of the remote tile + MemRefOrInt:$dst, // Destination buffer address (receive into here) + I32:$elem_bytes, // Element size in bytes + I64:$data_size // Total data size in bytes + ); + + let results = (outs); + + let assemblyFormat = [{ + $remote_chip_id_x `,` $remote_chip_id_y `,` $remote_die_id `,` $remote_tile_id `,` $dst `,` $elem_bytes `,` $data_size attr-dict `:` type($remote_chip_id_x) `,` type($remote_chip_id_y) `,` type($remote_die_id) `,` type($remote_tile_id) `,` type($dst) `,` type($elem_bytes) `,` type($data_size) + }]; +} + +def RemoteStoreOp : Tx81Op<"remote_store", [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "Store data to a destination tile"; + + let description = [{ + Store data from the current tile to a destination tile. The operation reads + from the given source buffer and stores it into the remote destination + address. + }]; + + let arguments = ( + ins + I64:$remote_chip_id_x, // X-coordinate of the remote chip ID + I64:$remote_chip_id_y, // Y-coordinate of the remote chip ID + I64:$remote_die_id, // ID of the remote die + I64:$remote_tile_id, // ID of the remote tile + MemRefOrInt:$dst, // Remote destination base address + MemRefOrInt:$src, // Source buffer address (read from here) + I32:$elem_bytes, // Element size in bytes + I64:$data_size // Total data size in bytes + ); + + let results = (outs); + + let assemblyFormat = [{ + $remote_chip_id_x `,` $remote_chip_id_y `,` $remote_die_id `,` $remote_tile_id `,` $dst `,` $src `,` $elem_bytes `,` $data_size attr-dict `:` type($remote_chip_id_x) `,` type($remote_chip_id_y) `,` type($remote_die_id) `,` type($remote_tile_id) `,` type($dst) `,` type($src) `,` type($elem_bytes) `,` type($data_size) + }]; +} + #endif // TSINGMICRO_TX81_OPS diff --git a/third_party/tsingmicro/lib/Conversion/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/CMakeLists.txt index fc3811b1c3..88d5903a03 100644 --- a/third_party/tsingmicro/lib/Conversion/CMakeLists.txt +++ b/third_party/tsingmicro/lib/Conversion/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(LinalgTiling) add_subdirectory(LinalgFusion) add_subdirectory(AllocateSharedMemory) add_subdirectory(ExportKernelSymbols) +add_subdirectory(TLEToMK) add_subdirectory(UnstructuredToMK) add_subdirectory(ReconcilePtrCasts) diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp index 0d762a8bca..72c4323036 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp @@ -3730,6 +3730,41 @@ struct AssertOpConverter : public OpConversionPattern { bool assertToCf = false; }; + +/// Convert a dense tensor arith.constant to linalg.fill(scalar, tensor.empty). +/// This is the missing pattern referenced by the comment in LinalgToMKPass: +/// "Lower dense constant to linalg.fill" +struct DenseConstantToFillPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter &rewriter) const override { + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return failure(); + auto denseAttr = dyn_cast(op.getValue()); + if (!denseAttr) + return failure(); + if (!isa(denseAttr.getElementType())) + return failure(); + if (!denseAttr.isSplat()) + return failure(); + + auto loc = op.getLoc(); + auto elemType = resultType.getElementType(); + auto splatValue = denseAttr.getSplatValue(); + Value scalar = rewriter.create( + loc, elemType, cast(splatValue)); + Value empty = + rewriter.create(loc, resultType.getShape(), elemType); + Value fill = + rewriter.create(loc, scalar, empty).getResult(0); + rewriter.replaceOp(op, fill); + return success(); + } +}; + } // namespace void mlir::triton::populateLinalgToMKPreProcessPatterns( @@ -3795,4 +3830,5 @@ void mlir::triton::populateLinalgToMKConversionPatterns( patterns.add(patterns.getContext()); // After NormalizeReduceInitToIdentityPattern and si-to-fp patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp index f882a16f46..afd2ae6d17 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp @@ -1784,6 +1784,169 @@ struct BarrierConversion : public OpConversionPattern { } }; +struct RemoteLoadConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mk::RemoteLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // mk.remote_load has 5 operands: + // 4 I64 coords (indices 0-3) + 1 dst (index 4). + Value dstVal = adaptor.getOperands()[4]; + + // Compute elem_bytes and data_size (in bytes) from the dst shaped type. + // Prefer compile-time constants for static shapes; fall back to runtime + // computation using extracted sizes if needed. + Type dstOrigTy = op.getDst().getType(); + ShapedType shapedTy = dyn_cast(dstOrigTy); + if (!shapedTy) + return rewriter.notifyMatchFailure( + op, "mk.remote_load dst must be shaped type"); + + int64_t elemBytesConst = + static_cast(getElemByte(shapedTy.getElementType())); + Value elemBytesI32 = rewriter.create( + loc, elemBytesConst, rewriter.getI32Type()); + + Value dataSizeI64; + if (shapedTy.hasStaticShape()) { + int64_t numElems = shapedTy.getNumElements(); + int64_t totalBytes = numElems * elemBytesConst; + dataSizeI64 = rewriter.create( + loc, totalBytes, rewriter.getI64Type()); + } else { + // Dynamic shape: compute element count from runtime sizes. + // Requires memref operand to extract metadata. + if (!isa(dstVal.getType())) + return rewriter.notifyMatchFailure( + op, "dynamic-shaped remote_load requires memref dst"); + auto [basePtr, sizes, strides] = createMetadata(rewriter, loc, dstVal); + (void)basePtr; + (void)strides; + Value elemCount = calculateElemCount(rewriter, loc, sizes); + Value elemCountI64 = rewriter.create( + loc, rewriter.getI64Type(), elemCount); + Value elemBytesI64 = rewriter.create( + loc, rewriter.getI64Type(), elemBytesI32); + dataSizeI64 = rewriter.create(loc, elemCountI64.getType(), + elemCountI64, elemBytesI64); + } + + // Convert dst memref to address (I64) + Value dstAddr = createAddressFromMemref(rewriter, loc, dstVal); + + // Create tx.remote_load operation. + rewriter.create( + loc, + adaptor.getOperands()[0], // remote_chip_id_x + adaptor.getOperands()[1], // remote_chip_id_y + adaptor.getOperands()[2], // remote_die_id + adaptor.getOperands()[3], // remote_tile_id + dstAddr, // dst (I64 address) + elemBytesI32, // elem_bytes (I32) + dataSizeI64 // data_size (I64) + ); + + // mk.remote_load has results; tx.remote_load is void. Replace results with + // dst. (converted) dst operand value, which represents the destination + // buffer. + if (op->getNumResults() > 0) { + SmallVector repl; + repl.reserve(op->getNumResults()); + for (unsigned i = 0; i < op->getNumResults(); ++i) + repl.push_back(dstVal); + rewriter.replaceOp(op, repl); + } else { + rewriter.eraseOp(op); + } + + return success(); + } +}; + +struct RemoteStoreConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mk::RemoteStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // mk.remote_store has 6 operands: + // 4 I64 coords (indices 0-3) + 1 dst_addr (index 4) + 1 src (index 5) + Value dstAddrVal = adaptor.getOperands()[4]; + Value srcVal = adaptor.getOperands()[5]; + + // Compute elem_bytes and data_size (in bytes) from the src shaped type. + Type srcOrigTy = op.getSrc().getType(); + ShapedType shapedTy = dyn_cast(srcOrigTy); + if (!shapedTy) + return rewriter.notifyMatchFailure( + op, "mk.remote_store src must be shaped type"); + + int64_t elemBytesConst = + static_cast(getElemByte(shapedTy.getElementType())); + Value elemBytesI32 = rewriter.create( + loc, elemBytesConst, rewriter.getI32Type()); + + Value dataSizeI64; + if (shapedTy.hasStaticShape()) { + int64_t numElems = shapedTy.getNumElements(); + int64_t totalBytes = numElems * elemBytesConst; + dataSizeI64 = rewriter.create( + loc, totalBytes, rewriter.getI64Type()); + } else { + if (!isa(srcVal.getType())) + return rewriter.notifyMatchFailure( + op, "dynamic-shaped remote_store requires memref src"); + auto [basePtr, sizes, strides] = createMetadata(rewriter, loc, srcVal); + (void)basePtr; + (void)strides; + Value elemCount = calculateElemCount(rewriter, loc, sizes); + Value elemCountI64 = rewriter.create( + loc, rewriter.getI64Type(), elemCount); + Value elemBytesI64 = rewriter.create( + loc, rewriter.getI64Type(), elemBytesI32); + dataSizeI64 = rewriter.create(loc, elemCountI64.getType(), + elemCountI64, elemBytesI64); + } + + // Convert src memref to address (I64) + Value srcAddr = createAddressFromMemref(rewriter, loc, srcVal); + + // Convert dst "addr-like" to I64 address. + Value dstAddr; + if (dstAddrVal.getType().isInteger(64)) { + dstAddr = dstAddrVal; + } else if (isa(dstAddrVal.getType())) { + dstAddr = createAddressFromMemref(rewriter, loc, dstAddrVal); + } else { + return rewriter.notifyMatchFailure( + op, "mk.remote_store dst_addr must be i64 or memref at MKToTx81"); + } + + // Create tx.remote_store operation directly with the destination address. + rewriter.create( + loc, + adaptor.getOperands()[0], // remote_chip_id_x + adaptor.getOperands()[1], // remote_chip_id_y + adaptor.getOperands()[2], // remote_die_id + adaptor.getOperands()[3], // remote_tile_id + dstAddr, // dst (I64 address) + srcAddr, // src (I64 address) + elemBytesI32, // elem_bytes (I32) + dataSizeI64 // data_size (I64) + ); + + // mk.remote_store has no results, just erase it + rewriter.eraseOp(op); + + return success(); + } +}; + struct PrintConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -2203,6 +2366,8 @@ void mlir::triton::populateMKToTx81ConversionPatterns( BarrierConversion, BarrierConversion, PrintConversion, + RemoteStoreConversion, + RemoteLoadConversion, AtomicRMWOpConversion, AtomicCASOpConversion>( patterns.getContext()); diff --git a/third_party/tsingmicro/lib/Conversion/TLEToMK/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TLEToMK/CMakeLists.txt new file mode 100644 index 0000000000..23187c92a4 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TLEToMK/CMakeLists.txt @@ -0,0 +1,21 @@ +add_triton_library(TLEToMagicKernel + TLEToMK.cpp + TLEToMKPass.cpp + + DEPENDS + MagicKernelTableGen + TLEToMKConversionPassIncGen + TleDsaDialectIncGen + TleDsaTypesIncGen + TleDsaOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + MLIRArithDialect + TritonIR + TritonStructuredIR + TleDsaIR +) diff --git a/third_party/tsingmicro/lib/Conversion/TLEToMK/MKCommonBufferPlanningPass.cpp b/third_party/tsingmicro/lib/Conversion/TLEToMK/MKCommonBufferPlanningPass.cpp new file mode 100644 index 0000000000..76916a29e8 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TLEToMK/MKCommonBufferPlanningPass.cpp @@ -0,0 +1,183 @@ +//===---------------- MKCommBufferPlanningPass.cpp ---------------------===// +// +// Insert local SPM buffers for paired mk.recv/mk.send that share the same +// placeholder base address. The placeholder base is represented by the i64 +// src_addr/dst_addr operands. +// +// Strategy: +// - Find pairs where recv.src_addr and send.dst_addr are the same SSA value. +// - Replace the shared placeholder "addr" with two distinct buffers: +// one buffer for send's remote dst, one buffer for recv's remote src. +// - The buffers are created as tensor.empty (or memref.alloc if already +// bufferized). +// +//===--------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/TLEToMK/TLEToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "magic-kernel/Conversion/TLEToMK/Passes.h.inc" + +namespace { + +static int64_t alignUp(int64_t v, int64_t a) { return (v + a - 1) / a * a; } + +static std::optional getStaticBytes(ShapedType ty) { + if (!ty || !ty.hasStaticShape()) + return std::nullopt; + auto elemTy = ty.getElementType(); + if (!elemTy.isIntOrFloat()) + return std::nullopt; + int64_t elemBytes = elemTy.getIntOrFloatBitWidth() / 8; + if (elemBytes <= 0) + return std::nullopt; + return ty.getNumElements() * elemBytes; +} + +/// Return a canonical "placeholder root" for an addr-like value. +/// This lets us match send/recv pairs even if the i64 addr was computed by +/// distinct ptr_to_int/extract ops. +static Value getPlaceholderRoot(Value addrLike) { + Value v = addrLike; + // Peel trivial casts (best-effort). + if (auto cast = v.getDefiningOp()) { + if (!cast.getOperands().empty()) + v = cast.getOperands().front(); + } + + // If it's an integer address derived from a triton ptr, use the ptr source. + if (v.getType().isInteger(64)) { + if (auto p2i = v.getDefiningOp()) + v = p2i.getSrc(); + } + + // If it comes from extracting element [0,0,...] from a tensor of ptrs, use + // the tensor-of-ptrs as the root. + if (auto ex = v.getDefiningOp()) { + // Only treat it as placeholder root if indices are all constants. + // (We expect [0,0,...] here.) + v = ex.getTensor(); + } + + return v; +} + +static Value createEmptyLikeShaped(OpBuilder &b, Location loc, ShapedType ty) { + if (auto t = dyn_cast(ty)) { + return b.create(loc, t.getShape(), t.getElementType()); + } + if (auto m = dyn_cast(ty)) { + return b.create(loc, m); + } + return Value(); +} + +struct MKCommBufferPlanningPass + : public MKCommBufferPlanningBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + + module.walk([&](triton::FuncOp func) { + // Map base -> first send/recv. + DenseMap rootToSend; + DenseMap rootToRecv; + + func.walk([&](Operation *op) { + if (auto send = dyn_cast(op)) { + Value root = getPlaceholderRoot(send.getDstAddr()); + rootToSend.try_emplace(root, send); + } else if (auto recv = dyn_cast(op)) { + Value root = getPlaceholderRoot(recv.getDst()); + rootToRecv.try_emplace(root, recv); + } + }); + + OpBuilder b(func.getContext()); + for (auto &it : rootToRecv) { + Value root = it.first; + auto recv = it.second; + auto sendIt = rootToSend.find(root); + if (sendIt == rootToSend.end()) + continue; + auto send = sendIt->second; + + Location loc = recv.getLoc(); + + // Create a single shared buffer as close as possible while still + // dominating both send and recv. + DominanceInfo dom(func); + Block *sendBlock = send->getBlock(); + Block *recvBlock = recv->getBlock(); + Block *insBlock = dom.findNearestCommonDominator(sendBlock, recvBlock); + if (!insBlock) + insBlock = &func.getBody().front(); + + if (insBlock == sendBlock && insBlock == recvBlock) { + // Same block: insert before the earlier op. + Operation *insPt = send->isBeforeInBlock(recv) ? send.getOperation() + : recv.getOperation(); + b.setInsertionPoint(insPt); + } else { + // Different blocks: insert at end of common dominator block + // (before terminator if any). + b.setInsertionPointToEnd(insBlock); + if (!insBlock->empty() && + insBlock->back().hasTrait()) + b.setInsertionPoint(&insBlock->back()); + } + + auto sendSrcTy = dyn_cast(send.getSrc().getType()); + auto recvTy = recv.getNumResults() > 0 + ? dyn_cast(recv->getResult(0).getType()) + : dyn_cast(recv.getDst().getType()); + if (!sendSrcTy || !recvTy) + continue; + + // For scheme C we expect send/recv to communicate same shape/type. + if (sendSrcTy.getElementType() != recvTy.getElementType() || + sendSrcTy.getShape() != recvTy.getShape()) + continue; + + Value sharedBuf = createEmptyLikeShaped(b, loc, sendSrcTy); + if (!sharedBuf) + continue; + + // Replace the shared placeholder root with the shared buffer. + // Operand layout: + // mk.send: 4 coords + dst_addr + src + // mk.recv: 4 coords + dst_key + dst + send->setOperand(4, sharedBuf); + recv->setOperand(4, sharedBuf); // dst_key + } + }); + } +}; + +} // namespace + +std::unique_ptr triton::createMKCommBufferPlanningPass() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMK.cpp b/third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMK.cpp new file mode 100644 index 0000000000..96da375611 --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMK.cpp @@ -0,0 +1,425 @@ +//===------------------- TLEToMK.cpp -----------------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/TLEToMK/TLEToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "tle/include/tle-dsa/Dialect/IR/DsaDialect.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" + +#define DEBUG_TYPE "tle-to-mk" + +using namespace mlir; +using namespace triton; +using namespace mk; +using namespace tts; + +namespace { + +static constexpr llvm::StringLiteral kRemoteShardCarrierAttr = + "tle.remote_shard_id_carrier"; + +static bool isConstantZeroIndex(Value v) { + if (auto cst = v.getDefiningOp()) + return cst.value() == 0; + if (auto cst = v.getDefiningOp()) { + if (!isa(cst.getType())) + return false; + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getValue().isZero(); + } + return false; +} + +static bool areAllZeroIndices(ValueRange indices) { + return llvm::all_of(indices, isConstantZeroIndex); +} + +static bool isBeforeOrAtInSameBlock(Operation *a, Operation *b) { + return a && b && a->getBlock() == b->getBlock() && + (a == b || a->isBeforeInBlock(b)); +} + +static Value getOrCreateScalarPtr(PatternRewriter &rewriter, Location loc, + Value ptrLike, Operation *useAnchor) { + if (!isa(ptrLike.getType())) + return ptrLike; + + for (Operation *user : ptrLike.getUsers()) { + auto ex = dyn_cast(user); + if (!ex) + continue; + if (ex.getTensor() != ptrLike) + continue; + if (!areAllZeroIndices(ex.getIndices())) + continue; + if (!useAnchor || isBeforeOrAtInSameBlock(ex.getOperation(), useAnchor)) + return ex.getResult(); + } + + auto ranked = cast(ptrLike.getType()); + SmallVector idxs; + idxs.reserve(ranked.getRank()); + for (int i = 0; i < ranked.getRank(); ++i) + idxs.push_back(rewriter.create(loc, 0)); + return rewriter.create(loc, ptrLike, idxs); +} + +static Value getOrCreatePtrToIntI64(PatternRewriter &rewriter, Location loc, + Value scalarPtr, Operation *useAnchor) { + for (Operation *user : scalarPtr.getUsers()) { + auto p2i = dyn_cast(user); + if (!p2i) + continue; + if (p2i.getSrc() != scalarPtr) + continue; + if (p2i.getType() != rewriter.getI64Type()) + continue; + if (!useAnchor || isBeforeOrAtInSameBlock(p2i.getOperation(), useAnchor)) + return p2i.getResult(); + } + + return rewriter.create(loc, rewriter.getI64Type(), + scalarPtr); +} + +/// Extract a flat i64 base-address from a pointer-like value. +/// +/// When \p ptrLike is the result of a \c dsa.local_pointers op we go straight +/// to the underlying memref, avoiding the creation of any \c !tt.ptr typed +/// intermediate values. +static Value getOrCreatePtrLikeAddrI64(PatternRewriter &rewriter, Location loc, + Value ptrLike, Operation *useAnchor) { + // --- Fast path: dsa.local_pointers → extract base from memref directly --- + if (auto localPtrOp = ptrLike.getDefiningOp()) { + OpBuilder::InsertionGuard g(rewriter); + // Insert right after the local_pointers op so that the new ops dominate + // all users. + if (localPtrOp->getNextNode()) + rewriter.setInsertionPoint(localPtrOp->getNextNode()); + else + rewriter.setInsertionPointAfter(localPtrOp); + auto idxTy = rewriter.getIndexType(); + auto i64Ty = rewriter.getI64Type(); + Value baseIndex = rewriter.create( + loc, idxTy, localPtrOp.getSrc()); + return rewriter.create(loc, i64Ty, baseIndex); + } + + // --- Original path: Triton pointer value --- + OpBuilder::InsertionGuard g(rewriter); + Block *block = rewriter.getInsertionBlock(); + if (auto def = ptrLike.getDefiningOp()) { + if (block && def->getBlock() == block) + rewriter.setInsertionPointAfter(def); + } else if (block) { + rewriter.setInsertionPointToStart(block); + } + + Value scalarPtr = getOrCreateScalarPtr(rewriter, loc, ptrLike, useAnchor); + return getOrCreatePtrToIntI64(rewriter, loc, scalarPtr, useAnchor); +} + +static Value castIntegerLikeToI64(PatternRewriter &rewriter, Location loc, + Value v) { + auto i64Ty = rewriter.getI64Type(); + Type ty = v.getType(); + if (ty == i64Ty) + return v; + if (isa(ty)) + return rewriter.create(loc, i64Ty, v); + if (auto intTy = dyn_cast(ty)) { + if (intTy.getWidth() < 64) + return rewriter.create(loc, i64Ty, v); + if (intTy.getWidth() > 64) + return rewriter.create(loc, i64Ty, v); + return v; + } + return Value(); +} + +static Value peelShardScalar(Value shardLike) { + if (auto splat = shardLike.getDefiningOp()) + return splat.getSrc(); + return shardLike; +} + +static LogicalResult getCoordsFromShardIdValue(PatternRewriter &rewriter, + Location loc, Value shardIdLike, + SmallVector &coords) { + Value shardId = peelShardScalar(shardIdLike); + Value tileId = castIntegerLikeToI64(rewriter, loc, shardId); + if (!tileId) + return failure(); + Value four = + rewriter.create(loc, rewriter.getI64IntegerAttr(4)); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value chipX = rewriter.create(loc, tileId, four); + Value chipY = rewriter.create(loc, tileId, four); + coords = {chipX, chipY, zero, tileId}; + return success(); +} + +static LogicalResult extractRemoteInfoFromPtr(PatternRewriter &rewriter, + Location loc, Value ptrLike, + SmallVector &coords, + Value &basePtrLike) { + if (auto remotePtrOp = ptrLike.getDefiningOp()) { + if (failed(getCoordsFromShardIdValue(rewriter, loc, + remotePtrOp.getShardId(), coords))) + return failure(); + basePtrLike = remotePtrOp.getSrc(); + return success(); + } + if (auto addPtr = ptrLike.getDefiningOp(); + addPtr && addPtr->hasAttr(kRemoteShardCarrierAttr)) { + if (failed(getCoordsFromShardIdValue(rewriter, loc, addPtr.getOffset(), + coords))) + return failure(); + basePtrLike = addPtr.getPtr(); + return success(); + } + return failure(); +} + +// ===----------------------------------------------------------------------===// +// Barrier +// ===----------------------------------------------------------------------===// + +struct DsaDistributedBarrierToMkPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::dsa::DistributedBarrierOp op, + PatternRewriter &rewriter) const override { + rewriter.create(op.getLoc()); + rewriter.eraseOp(op); + return success(); + } +}; + +// ===----------------------------------------------------------------------===// +// Remote load / store (dsa.remote_pointers → mk.remote_load/store) +// ===----------------------------------------------------------------------===// + +struct DsaRemoteLoadToMkPattern : public OpRewritePattern { + explicit DsaRemoteLoadToMkPattern(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/2) {} + + LogicalResult matchAndRewrite(triton::LoadOp loadOp, + PatternRewriter &rewriter) const override { + Location loc = loadOp.getLoc(); + SmallVector recvCoords; + Value basePtrLike = loadOp.getPtr(); + if (failed(extractRemoteInfoFromPtr(rewriter, loc, loadOp.getPtr(), + recvCoords, basePtrLike))) + return failure(); + + auto resultType = dyn_cast(loadOp.getResult().getType()); + if (!resultType) + return loadOp->emitRemark( + "remote load currently expects ranked tensor result"); + for (int64_t s : resultType.getShape()) { + if (ShapedType::isDynamic(s)) + return loadOp->emitRemark( + "remote load with dynamic shape not supported"); + } + + Value dstBuffer = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + auto recvOp = rewriter.create( + loc, resultType, recvCoords[0], recvCoords[1], recvCoords[2], + recvCoords[3], dstBuffer); + rewriter.replaceOp(loadOp, recvOp.getResults().front()); + return success(); + } +}; + +struct DsaRemoteStoreToMkPattern : public OpRewritePattern { + explicit DsaRemoteStoreToMkPattern(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/2) {} + + LogicalResult matchAndRewrite(triton::StoreOp storeOp, + PatternRewriter &rewriter) const override { + Location loc = storeOp.getLoc(); + SmallVector sendCoords; + Value basePtrLike = storeOp.getPtr(); + if (failed(extractRemoteInfoFromPtr(rewriter, loc, storeOp.getPtr(), + sendCoords, basePtrLike))) + return failure(); + + if (storeOp.getMask()) + return storeOp->emitRemark("masked remote store not supported"); + + Value dstAddrI64 = getOrCreatePtrLikeAddrI64(rewriter, loc, basePtrLike, + storeOp.getOperation()); + rewriter.create(loc, sendCoords[0], sendCoords[1], + sendCoords[2], sendCoords[3], dstAddrI64, + storeOp.getValue()); + rewriter.eraseOp(storeOp); + return success(); + } +}; + +// ===----------------------------------------------------------------------===// +// Local load / store (dsa.local_pointers + tt.load/store → memref ops) +// +// Instead of lowering dsa.local_pointers to Triton pointer arithmetic +// (tt.splat/tt.addptr with tensor>), we directly convert the +// load/store users to memref-level operations. This avoids producing +// !tt.ptr element types that downstream triton-to-core-dialects cannot +// convert to valid memref types. +// ===----------------------------------------------------------------------===// + +/// tt.load whose pointer comes from dsa.local_pointers → +/// bufferization.to_tensor of the underlying memref. +struct DsaLocalLoadToMemrefPattern : public OpRewritePattern { + explicit DsaLocalLoadToMemrefPattern(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/3) {} + + LogicalResult matchAndRewrite(triton::LoadOp loadOp, + PatternRewriter &rewriter) const override { + // Only match loads whose pointer is produced by dsa.local_pointers. + auto localPtrOp = + loadOp.getPtr().getDefiningOp(); + if (!localPtrOp) + return failure(); + + auto memrefTy = dyn_cast(localPtrOp.getSrc().getType()); + if (!memrefTy) + return failure(); + + auto resultTy = dyn_cast(loadOp.getResult().getType()); + if (!resultTy) + return failure(); + + // Build a tensor type from the memref shape + element type. + auto tensorTy = + RankedTensorType::get(memrefTy.getShape(), memrefTy.getElementType()); + + // Shapes must agree (the common DSA pattern uses identity indices). + if (tensorTy.getShape() != resultTy.getShape()) + return loadOp->emitRemark( + "local load shape mismatch between memref and result tensor"); + + // Element type may differ if an implicit cast is present (e.g. f32→f16). + // For now we require them to match. + if (memrefTy.getElementType() != resultTy.getElementType()) + return loadOp->emitRemark( + "local load element type mismatch between memref and result tensor"); + + // Replace with: bufferization.to_tensor %memref + // writable=true because the SPM buffer is mutable. + auto toTensor = rewriter.create( + loadOp.getLoc(), localPtrOp.getSrc(), + /*restrict=*/true, /*writable=*/true); + rewriter.replaceOp(loadOp, toTensor.getResult()); + return success(); + } +}; + +/// tt.store whose pointer comes from dsa.local_pointers → +/// bufferization.to_memref + memref.copy into the underlying SPM buffer. +struct DsaLocalStoreToMemrefPattern : public OpRewritePattern { + explicit DsaLocalStoreToMemrefPattern(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/3) {} + + LogicalResult matchAndRewrite(triton::StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto localPtrOp = + storeOp.getPtr().getDefiningOp(); + if (!localPtrOp) + return failure(); + + auto destMemrefTy = dyn_cast(localPtrOp.getSrc().getType()); + if (!destMemrefTy) + return failure(); + + Value val = storeOp.getValue(); + auto valTy = dyn_cast(val.getType()); + if (!valTy) + return failure(); + + // Shapes must match. + if (valTy.getShape() != destMemrefTy.getShape()) + return storeOp->emitRemark( + "local store shape mismatch between value tensor and SPM memref"); + + // Element types must match (no implicit cast support yet). + if (valTy.getElementType() != destMemrefTy.getElementType()) + return storeOp->emitRemark( + "local store element type mismatch between value and SPM memref"); + + Location loc = storeOp.getLoc(); + + // Materialise the tensor value as a memref, then copy into the SPM buffer. + // Use a contiguous memref type for the intermediate to_memref result. + auto srcMemrefTy = + MemRefType::get(valTy.getShape(), valTy.getElementType()); + auto srcMemref = + rewriter.create(loc, srcMemrefTy, val); + rewriter.create(loc, srcMemref, localPtrOp.getSrc()); + rewriter.eraseOp(storeOp); + return success(); + } +}; + +// ===----------------------------------------------------------------------===// +// Remote pointers fallback (kept for edge cases) +// ===----------------------------------------------------------------------===// + +struct DsaRemotePointersToTritonPattern + : public OpRewritePattern { + explicit DsaRemotePointersToTritonPattern(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(mlir::dsa::RemotePointersOp op, + PatternRewriter &rewriter) const override { + Value offset = op.getShardId(); + if (auto srcTy = dyn_cast(op.getSrc().getType())) { + auto shardTy = dyn_cast(offset.getType()); + if (!shardTy || shardTy.getShape() != srcTy.getShape()) { + auto offsetTy = + RankedTensorType::get(srcTy.getShape(), offset.getType()); + offset = + rewriter.create(op.getLoc(), offsetTy, offset); + } + } + auto addPtr = rewriter.create(op.getLoc(), op.getType(), + op.getSrc(), offset); + addPtr->setAttr(kRemoteShardCarrierAttr, rewriter.getUnitAttr()); + rewriter.replaceOp(op, addPtr.getResult()); + return success(); + } +}; + +} // namespace + +void mlir::triton::populateTLEToMKConversionPatterns( + RewritePatternSet &patterns) { + // Highest benefit (3): local load/store → memref ops. + // These MUST fire before any pattern that would produce !tt.ptr types. + patterns.add( + patterns.getContext()); + + // Benefit 2: remote load/store → mk ops. + patterns.add( + patterns.getContext()); + + // Benefit 1: remaining remote_pointers / barrier. + patterns + .add( + patterns.getContext()); +} diff --git a/third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMKPass.cpp new file mode 100644 index 0000000000..0ce9fc439b --- /dev/null +++ b/third_party/tsingmicro/lib/Conversion/TLEToMK/TLEToMKPass.cpp @@ -0,0 +1,56 @@ +//===------------------- TLEToMKPass.cpp -------------------------===// +// +// Copyright (C) 2020-2025 Terapines Technology (Wuhan) Co., Ltd +// All rights reserved. +// +//===----------------------------------------------------------------------===// +// +// Lowering TLE communication ops to backend dialects +// +//===----------------------------------------------------------------------===// + +#include "magic-kernel/Conversion/TLEToMK/TLEToMK.h" +#include "magic-kernel/Dialect/IR/MagicKernelDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "magic-kernel/Conversion/TLEToMK/Passes.h.inc" + +namespace { + +class TLEToMKPass : public TLEToMKBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + RewritePatternSet patterns(&getContext()); + populateTLEToMKConversionPatterns(patterns); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr triton::createTLEToMK() { + return std::make_unique(); +} diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt index 600562994e..cf37f358ac 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/CMakeLists.txt @@ -30,6 +30,7 @@ add_triton_library(TritonToCoreDialects UnstructuredToMemref TritonToStructured TritonToUnstructured + TLEToMagicKernel ConvertTritonPtr ReconcilePtrCasts TritonPtrToMemref diff --git a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp index 2fc52eca17..bbe7401af8 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToCoreDialects/TritonToCoreDialectsPass.cpp @@ -19,6 +19,8 @@ #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" +#include "magic-kernel/Conversion/TLEToMK/TLEToMK.h" + #include "triton-shared/Conversion/UnstructuredToMK/UnstructuredToMK.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" diff --git a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp index 0bd1f7106b..667dc9eac7 100644 --- a/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp +++ b/third_party/tsingmicro/lib/Conversion/Tx81ToLLVM/Tx81ToLLVM.cpp @@ -65,6 +65,8 @@ const char wdma1dFuncName[] = "__Wdma1d"; const char rdmaFuncName[] = "__Rdma"; const char wdmaFuncName[] = "__Wdma"; const char memcpyFuncName[] = "__Memcpy"; +const char recvFuncName[] = "__Recv"; +const char sendFuncName[] = "__Send"; const char addVVFuncName[] = "__AddVV"; const char subVVFuncName[] = "__SubVV"; const char mulVVFuncName[] = "__MulVV"; @@ -596,6 +598,144 @@ struct Rdma1dOpConversion : public OpConversionPattern { } }; +// Resolve tx.remote_buffer to its destination address. +struct RemoteBufferOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::RemoteBufferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getOperands()[4]); + return success(); + } +}; + +// Convert tx.remote_load to LLVM call to __Recv function +struct RemoteLoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::RemoteLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Recv runtime function if not already declared + // Signature: + // void __Recv(int64_t chip_x, int64_t chip_y, int64_t die_id, + // int64_t tile_id, void* dst, + // uint32_t elem_bytes, uint64_t data_size) + auto i8PtrTy = LLVM::LLVMPointerType::get(ctx); + auto i64Ty = rewriter.getI64Type(); + auto i32Ty = rewriter.getI32Type(); + auto voidTy = LLVM::LLVMVoidType::get(ctx); + + // Types for function declaration + SmallVector argTypes = { + i64Ty, // remote_chip_id_x + i64Ty, // remote_chip_id_y + i64Ty, // remote_die_id + i64Ty, // remote_tile_id + i8PtrTy, // dst + i32Ty, // elem_bytes + i64Ty // data_size + }; + + // Declare the function with void return type + Value funcPtr = triton::declareTx81Function(module, rewriter, loc, + recvFuncName, voidTy, argTypes); + + // Get the operands and convert dst to i8* + Value chipX = adaptor.getOperands()[0]; + Value chipY = adaptor.getOperands()[1]; + Value dieId = adaptor.getOperands()[2]; + Value tileId = adaptor.getOperands()[3]; + Value dstAddr = adaptor.getOperands()[4]; + Value elemBytes = adaptor.getOperands()[5]; + Value dataSize = adaptor.getOperands()[6]; + + // Convert destination address (i64) directly to pointer. + Value dst = rewriter.create(loc, i8PtrTy, dstAddr); + + // Create the call to __Recv (void function, so empty TypeRange) + rewriter.create( + loc, TypeRange{}, recvFuncName, + ValueRange{chipX, chipY, dieId, tileId, dst, elemBytes, dataSize}); + + // tx.remote_load has no results, just erase it + rewriter.eraseOp(op); + + return success(); + } +}; + +// Convert tx.remote_store to LLVM call to __Send function +struct RemoteStoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tx::RemoteStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + // Get the module for function declarations + auto module = op->getParentOfType(); + + // Declare the __Send runtime function if not already declared + // Signature: + // void __Send(int64_t chip_x, int64_t chip_y, int64_t die_id, + // int64_t tile_id, void* dst, void* src, + // uint32_t elem_bytes, uint64_t data_size) + auto i8PtrTy = LLVM::LLVMPointerType::get(ctx); + auto i64Ty = rewriter.getI64Type(); + auto i32Ty = rewriter.getI32Type(); + auto voidTy = LLVM::LLVMVoidType::get(ctx); + + // Types for function declaration + SmallVector argTypes = { + i64Ty, // remote_chip_id_x + i64Ty, // remote_chip_id_y + i64Ty, // remote_die_id + i64Ty, // remote_tile_id + i8PtrTy, // dst + i8PtrTy, // src + i32Ty, // elem_bytes + i64Ty // data_size + }; + + // Declare the function with void return type + Value funcPtr = triton::declareTx81Function(module, rewriter, loc, + sendFuncName, voidTy, argTypes); + + // Get the operands and convert dst/src to i8* + Value chipX = adaptor.getOperands()[0]; + Value chipY = adaptor.getOperands()[1]; + Value dieId = adaptor.getOperands()[2]; + Value tileId = adaptor.getOperands()[3]; + Value dstAddr = adaptor.getOperands()[4]; + Value src = adaptor.getOperands()[5]; + Value elemBytes = adaptor.getOperands()[6]; + Value dataSize = adaptor.getOperands()[7]; + + // Convert destination and source addresses (i64) directly to pointers. + Value dst = rewriter.create(loc, i8PtrTy, dstAddr); + src = rewriter.create(loc, i8PtrTy, src); + + // Create the call to __Send (void function, so empty TypeRange) + rewriter.create( + loc, TypeRange{}, sendFuncName, + ValueRange{chipX, chipY, dieId, tileId, dst, src, elemBytes, dataSize}); + + // tx.remote_store has no results, just erase it + rewriter.eraseOp(op); + + return success(); + } +}; + template struct RdmaWdmaOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -2504,6 +2644,8 @@ class Tx81ToLLVMPass : public Tx81ToLLVMBase { MemsetOpConversion, GetProgramIDConversion, BarrierConversion, + RemoteStoreOpConversion, + RemoteLoadOpConversion, AssertConversion>( context); // clang-format on diff --git a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp index d51e5debbc..d47106654f 100644 --- a/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/third_party/tsingmicro/lib/Dialect/MagicKernel/Transforms/BufferizableOpInterfaceImpl.cpp @@ -225,6 +225,67 @@ struct BitCastOpInterface } }; +struct SendOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + auto sendOp = cast(op); + // mk.send reads the local src buffer. The dst_addr is "addr-like" and + // should not be considered a memory read. + return &opOperand == &sendOp.getSrcMutable(); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto sendOp = cast(op); + + // Nothing to do. This op is already bufferized. + if (!isa(sendOp.getDstAddr().getType()) && + !isa(sendOp.getSrc().getType())) + return success(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(sendOp); + + SmallVector newOperands(sendOp->getOperands().begin(), + sendOp->getOperands().end()); + + if (isa(sendOp.getDstAddr().getType())) { + FailureOr dstBuffer = + getBuffer(rewriter, sendOp.getDstAddr(), options); + if (failed(dstBuffer)) + return failure(); + newOperands[sendOp.getDstAddrMutable().getOperandNumber()] = *dstBuffer; + } + + if (isa(sendOp.getSrc().getType())) { + FailureOr srcBuffer = + getBuffer(rewriter, sendOp.getSrc(), options); + if (failed(srcBuffer)) + return failure(); + newOperands[sendOp.getSrcMutable().getOperandNumber()] = *srcBuffer; + } + + OperationState state(sendOp->getLoc(), sendOp->getName(), newOperands, + TypeRange{}, sendOp->getAttrs()); + Operation *newOp = Operation::create(state); + rewriter.insert(newOp); + rewriter.eraseOp(sendOp); + return success(); + } +}; + /// Helper structure that iterates over all mkOps in `OpTys` and registers /// the `BufferizableOpInterface` with each of them. template struct MKOpInterfaceHelper { @@ -263,5 +324,6 @@ void mlir::mk::registerBufferizableOpInterfaceExternalModels( MKOpInterfaceHelper::registerOpInterface(ctx); MKOpInterfaceHelper::registerOpInterface(ctx); mk::BitcastOp::attachInterface(*ctx); + mk::RemoteStoreOp::attachInterface(*ctx); }); } diff --git a/third_party/tsingmicro/scripts/build_tsingmicro.sh b/third_party/tsingmicro/scripts/build_tsingmicro.sh index 2910246318..96c0238bfb 100755 --- a/third_party/tsingmicro/scripts/build_tsingmicro.sh +++ b/third_party/tsingmicro/scripts/build_tsingmicro.sh @@ -121,6 +121,7 @@ fi export LLVM_SYSPATH=$LLVM export TX8_DEPS_ROOT=$TX8_DEPS_ROOT +export TX8_YOC_RT_THREAD_SMP=$TX8_DEPS_ROOT/tx8-yoc-rt-thread-smp export FLAGTREE_BACKEND=tsingmicro # debug diff --git a/third_party/tsingmicro/scripts/build_tx8_deps.sh b/third_party/tsingmicro/scripts/build_tx8_deps.sh index c4fd780f80..d9e82d43cc 100755 --- a/third_party/tsingmicro/scripts/build_tx8_deps.sh +++ b/third_party/tsingmicro/scripts/build_tx8_deps.sh @@ -151,7 +151,7 @@ CONFIG_FILE="$script_dir/copy_config.conf" # 显示帮助信息 show_help() { - echo "usage: $0 mode [tx81fw url]" + echo "usage: $0 mode [tx81fw url] [tx_profiler url] [tx8-yoc-rt-thread-smp url]" echo "support mode:" echo " build_flagtree_tx8_deps" echo " build_tx8_deps" @@ -167,7 +167,12 @@ show_help() { echo " default: http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81-profiling/master/profiling_tool_v5.5.0_release_2025-1124_.tar.gz" echo " ..." echo "" - echo "example: $0 build_tx8_deps" + echo "tx8-yoc-rt-thread-smp url:" + echo " eg: http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx8-yoc-rt-thread-smp-202603031631-88bfb9.tar.gz" + echo " default: http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx8-yoc-rt-thread-smp-202603031631-88bfb9.tar.gz" + echo " ..." + echo "" + echo "example: bash triton/third_party/tsingmicro/scripts/build_tx8_deps.sh build_dev http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx81fw_202602261758_b72af3.tar.gz http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81-profiling/master/profiling_tool_v5.6.0_release_2026-0228_.tar.gz http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx8-yoc-rt-thread-smp-202603031631-88bfb9.tar.gz" } # 检查参数数量 diff --git a/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh b/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh index 5fc7272b53..05113112fa 100755 --- a/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh +++ b/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh @@ -48,7 +48,7 @@ quick_mode=0 skip_device= precision_priority=1 -tx8_depends_name=tx8_depends_dev_20260112_201902 +tx8_depends_name=tx8_depends_dev_20260309_173649 torch_txda_name=torch_txda+txops-20251230-03541ed8+71a1e5a txda_skip_ops="repeat_interleave.self_int,pad,to.dtype,uniform_,sort.values_stable,contiguous,resolve_conj" txda_fallback_cpu_ops="random_,quantile,_local_scalar_dense,arange,unfold,index,le,all,ge,pad,to,gather_backward,zero_,view_as_real,resolve_neg,embedding_backward,sort,repeat_interleave,rsub,hstack,vstack,min,uniform_,abs,ne,eq,mul,bitwise_and,masked_select,max,ceil,div,gt,lt,sum,scatter,where,resolve_conj,isclose,isfinite,tile,equal,gather,contiguous" diff --git a/third_party/tsingmicro/scripts/copy_config.conf b/third_party/tsingmicro/scripts/copy_config.conf index ddc96d2499..c192d5043b 100644 --- a/third_party/tsingmicro/scripts/copy_config.conf +++ b/third_party/tsingmicro/scripts/copy_config.conf @@ -7,6 +7,7 @@ file:download/triton-tx8fw/**/include/components/oplib_tx81/riscv/riscv/include/ dir:download/triton-tx8fw/**/tx81-intrisic/instr_tx81/include dir:download/tx8fw-xuantie-sdk/Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2 file:download/*.pdf +dir:download/tx8-yoc-rt-thread-smp/tx8-yoc-rt-thread-smp*/*,tx8-yoc-rt-thread-smp [block3:build_dev] dir:download/tx_profiler/profiling_tool*/*,profiling_tool diff --git a/third_party/tsingmicro/scripts/publish/build_wheel.sh b/third_party/tsingmicro/scripts/publish/build_wheel.sh index 1169e4a9a5..6d3b375c57 100755 --- a/third_party/tsingmicro/scripts/publish/build_wheel.sh +++ b/third_party/tsingmicro/scripts/publish/build_wheel.sh @@ -118,6 +118,7 @@ export LLVM_SYSPATH=$LLVM export TX8_DEPS_ROOT=$TX8_DEPS_ROOT # synchronous temporary solution: add waitfinish after every cintrinsic exec export ENABLE_SYNCHRONOUS_INTRINSIC=1 +export TX8_YOC_RT_THREAD_SMP=$TX8_DEPS_ROOT/tx8-yoc-rt-thread-smp cd python python3 -m pip wheel . --no-build-isolation -v --verbos