Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tsingmicro-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,4 @@ jobs:
python3 test_softmax.py >result-test_softmax.txt
python3 test_vec_add.py >result-test_vec_add.txt
python3 time1.py >result-time1.txt
python3 test_tle_dsa_noc_gemm_4096.py >result-noc_gemm.txt
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ elseif(FLAGTREE_BACKEND STREQUAL "aipu")
add_definitions(-D__NVIDIA__)
add_definitions(-D__AMD__)
elseif(FLAGTREE_BACKEND STREQUAL "tsingmicro")
set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}")
set(CMAKE_C_COMPILER clang-21)
set(CMAKE_CXX_COMPILER clang++-21)
set(CMAKE_LINKER lld-21)
Expand Down Expand Up @@ -285,6 +286,10 @@ if(TRITON_BUILD_PYTHON_MODULE)
list(APPEND TRITON_PLUGIN_NAMES "proton")
add_subdirectory(third_party/proton/dialect)

# Add TLE plugin
list(APPEND TRITON_PLUGIN_NAMES "tle")
add_subdirectory(third_party/tle)

get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
set(TRITON_LIBRARIES
Expand Down Expand Up @@ -460,6 +465,8 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
add_subdirectory(third_party/${CODEGEN_BACKEND})
endforeach()
add_subdirectory(third_party/proton/dialect)
# flagtree tle
add_subdirectory(third_party/tle)
endif()

find_package(Threads REQUIRED)
Expand Down
4 changes: 4 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,10 @@ def get_packages():
"triton/backends",
"triton/tools",
"triton/tools/extra",
"triton/experimental",
"triton/experimental/tle",
"triton/experimental/tle/language",
"triton/experimental/tle/language/dsa",
]
if helper.flagtree_backend == "xpu":
packages.append("triton/language/extra/xpu")
Expand Down
85 changes: 1 addition & 84 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/LocationSnapshot.h"

#include "ir.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
Expand Down Expand Up @@ -56,90 +57,6 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() {
}
}

// A custom op builder that keeps track of the last location
class TritonOpBuilder {
public:
TritonOpBuilder(MLIRContext *context) {
builder = std::make_unique<OpBuilder>(context);
lastLoc = std::make_unique<Location>(builder->getUnknownLoc());
}

OpBuilder &getBuilder() { return *builder; }
MLIRContext *getContext() { return builder->getContext(); }

bool isLineInfoEnabled() { return lineInfoEnabled; }

void setLastLoc(Location loc) {
if (lineInfoEnabled)
lastLoc = std::make_unique<Location>(loc);
}

void setLastLoc(const std::string &fileName, int line, int column) {
auto context = builder->getContext();
setLastLoc(FileLineColLoc::get(context, fileName, line, column));
}

Location getLastLoc() {
assert(lastLoc);
return *lastLoc;
}

void setInsertionPointToStart(Block &block) {
if (!block.empty())
setLastLoc(block.begin()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->setInsertionPointToStart(&block);
}

void setInsertionPointToEnd(Block &block) {
if (!block.empty())
setLastLoc(block.back().getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->setInsertionPointToEnd(&block);
}

void setInsertionPointAfter(Operation &op) {
setLastLoc(op.getLoc());
builder->setInsertionPointAfter(&op);
}

void restoreInsertionPoint(OpBuilder::InsertPoint pt) {
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
setLastLoc(pt.getPoint()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->restoreInsertionPoint(pt);
}

template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
auto loc = getLastLoc();
return builder->create<OpTy>(loc, std::forward<Args>(args)...);
}

// Overload to create or fold a single result operation.
template <typename OpTy, typename... Args>
std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value>
createOrFold(Args &&...args) {
auto loc = getLastLoc();
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
}

// Overload to create or fold a zero result operation.
template <typename OpTy, typename... Args>
std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy>
createOrFold(Args &&...args) {
auto loc = getLastLoc();
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
}

private:
std::unique_ptr<OpBuilder> builder;
std::unique_ptr<Location> lastLoc;
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
};

// Run the pass manager under a source manager diagnostic handler, which
// enables emitted MLIR diagnostics to directly reference Python source
// code. This diagnostic handler supports filtering diagnostic info by
Expand Down
91 changes: 91 additions & 0 deletions python/src/ir.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#pragma once

#include "mlir/IR/Builders.h"
#include "triton/Tools/Sys/GetEnv.hpp"

#include <cassert>
#include <memory>

// A custom op builder that keeps track of the last location.
class TritonOpBuilder {
public:
TritonOpBuilder(mlir::MLIRContext *context) {
builder = std::make_unique<mlir::OpBuilder>(context);
lastLoc = std::make_unique<mlir::Location>(builder->getUnknownLoc());
}

mlir::OpBuilder &getBuilder() { return *builder; }
mlir::MLIRContext *getContext() { return builder->getContext(); }

bool isLineInfoEnabled() { return lineInfoEnabled; }

void setLastLoc(mlir::Location loc) {
if (lineInfoEnabled)
lastLoc = std::make_unique<mlir::Location>(loc);
}

void setLastLoc(const std::string &fileName, int line, int column) {
auto context = builder->getContext();
setLastLoc(mlir::FileLineColLoc::get(context, fileName, line, column));
}

mlir::Location getLastLoc() {
assert(lastLoc);
return *lastLoc;
}

void setInsertionPointToStart(mlir::Block &block) {
if (!block.empty())
setLastLoc(block.begin()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->setInsertionPointToStart(&block);
}

void setInsertionPointToEnd(mlir::Block &block) {
if (!block.empty())
setLastLoc(block.back().getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->setInsertionPointToEnd(&block);
}

void setInsertionPointAfter(mlir::Operation &op) {
setLastLoc(op.getLoc());
builder->setInsertionPointAfter(&op);
}

void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) {
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
setLastLoc(pt.getPoint()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->restoreInsertionPoint(pt);
}

template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
auto loc = getLastLoc();
return builder->create<OpTy>(loc, std::forward<Args>(args)...);
}

template <typename OpTy, typename... Args>
std::enable_if_t<OpTy::template hasTrait<mlir::OpTrait::OneResult>(),
mlir::Value>
createOrFold(Args &&...args) {
auto loc = getLastLoc();
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
}

template <typename OpTy, typename... Args>
std::enable_if_t<OpTy::template hasTrait<mlir::OpTrait::ZeroResults>(), OpTy>
createOrFold(Args &&...args) {
auto loc = getLastLoc();
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
}

private:
std::unique_ptr<mlir::OpBuilder> builder;
std::unique_ptr<mlir::Location> lastLoc;
bool lineInfoEnabled =
!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
};
1 change: 1 addition & 0 deletions python/triton/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# flagtree tle
43 changes: 43 additions & 0 deletions python/triton/experimental/tle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# flagtree tle
from .distributed import (
B,
P,
S,
ShardedTensor,
ShardingSpec,
device_mesh,
distributed_barrier,
distributed_dot,
make_sharded_tensor,
remote,
reshard,
shard_id,
sharding,
)

from . import language

# try:
# from . import raw
# except ModuleNotFoundError:
# raw = None

__all__ = [
"device_mesh",
"S",
"P",
"B",
"sharding",
"ShardingSpec",
"ShardedTensor",
"make_sharded_tensor",
"reshard",
"remote",
"shard_id",
"distributed_barrier",
"distributed_dot",
"language",
]

# if raw is not None:
# __all__.append("raw")
Loading
Loading