Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/hopper-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ jobs:
python3 python/tutorials/hints/08/08-grouped-gemm.py --only_unit_test
python3 python/tutorials/hints/11/11-programmatic-dependent-launch.py --only_unit_test
# flagtree tle raw
python3 python/tutorials/tle/raw/01-vector-add.py
python3 python/tutorials/tle/raw/02-fused-softmax.py
python3 python/tutorials/tle/raw/03-matrix-multiplication.py
python3 python/tutorials/tle/raw/04-hello-world.py
python3 python/tutorials/tle/raw/05-topk.py
python3 python/tutorials/tle/raw/06-test-vassert.py
python3 python/tutorials/tle/raw/mlir/01-vector-add.py
python3 python/tutorials/tle/raw/mlir/02-fused-softmax.py
python3 python/tutorials/tle/raw/mlir/03-matrix-multiplication.py
python3 python/tutorials/tle/raw/mlir/04-hello-world.py
python3 python/tutorials/tle/raw/mlir/05-topk.py
python3 python/tutorials/tle/raw/mlir/06-test-vassert.py
3 changes: 2 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
Expand Down Expand Up @@ -363,7 +364,7 @@ void init_triton_ir(py::module &&m) {
math::MathDialect, arith::ArithDialect, scf::SCFDialect,
::mlir::gpu::GPUDialect, cf::ControlFlowDialect,
LLVM::LLVMDialect, mlir::ub::UBDialect,
mlir::triton::gluon::GluonDialect,
mlir::triton::gluon::GluonDialect, DLTIDialect,
mlir::triton::tle::TleDialect // flagtree tle raw
>();
mlir::LLVM::registerInlinerInterface(registry);
Expand Down
7 changes: 4 additions & 3 deletions python/triton/experimental/tle/language/raw/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

@builtin
def call(func, outputs, inputs, _semantic=None):
results = _semantic.builder.create_tle_raw_region_by_llvm_func(f"{func.llvm}", func.fnname,
[output.handle for output in outputs],
[input.handle for input in inputs])
context = _semantic.builder.get_context()
llvm = func.make_llvm(context)
results = _semantic.builder.create_tle_raw_call(llvm, [output.handle for output in outputs],
[input.handle for input in inputs])
tensors = [tensor(result, output.type) for result, output in zip(results, outputs)]
if len(tensors) == 1:
return tensors[0]
Expand Down
12 changes: 0 additions & 12 deletions python/triton/experimental/tle/language/raw/semantic.py

This file was deleted.

3 changes: 3 additions & 0 deletions python/triton/experimental/tle/raw/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .runtime import CUDAJITFunction

__all__ = ["CUDAJITFunction"]
44 changes: 44 additions & 0 deletions python/triton/experimental/tle/raw/cuda/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations
import copy
import os
from pathlib import Path
import subprocess
from typing import Any, Dict, Final

from triton._C.libtriton import llvm # pyright: ignore[reportMissingImports]
from triton._C.libtriton.tle.llvm import parse_llvm_ir # pyright: ignore[reportMissingImports]

# TODO: We use cli tools to compile CUDA code temporarily, and plan to replace it with LLVM components Python bindings in the future.
CLANG = os.getenv("CLANG", "clang")


class CUDAJITFunction(object):

def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fn: Final[Any] = fn
self.code: Final[str] = file.read_text()
self.__triton_builtin__: Final[bool] = True

def __deepcopy__(self, memo: Dict[int, Any]) -> CUDAJITFunction:
return self.__class__(copy.deepcopy(self.fn, memo), copy.deepcopy(self.pipeline, memo), self.context)

def make_llvm(self, mlir_context) -> str:
build = subprocess.run(
[
CLANG,
"-x",
"cuda",
"--cuda-device-only",
"-emit-llvm",
"-S",
"-",
"-o",
"-",
],
input=self.code.encode(),
capture_output=True,
)
llvm_context = llvm.context()
module = parse_llvm_ir(build.stdout.decode(), llvm_context, mlir_context)
return f"{module}"
4 changes: 2 additions & 2 deletions python/triton/experimental/tle/raw/mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .runtime import EdslMLIRJITFunction
from .runtime import MLIRJITFunction
from .utils import vprintf, vassert

__all__ = ["EdslMLIRJITFunction", "vprintf", "vassert"]
__all__ = ["MLIRJITFunction", "vprintf", "vassert"]
9 changes: 2 additions & 7 deletions python/triton/experimental/tle/raw/mlir/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@
from mlir.dialects import func

from .utils import ExternalCall
from ..utils import UnknownSymbolError


class UnknownSymbolError(Exception):

def __init__(self, name: str, *args, **kwargs) -> None:
super().__init__(f"unknown symbol {name}", *args, **kwargs)


class EdslMLIRCodeGenerator(ast.NodeVisitor):
class MLIRCodeGenerator(ast.NodeVisitor):

def __init__(
self,
Expand Down
32 changes: 23 additions & 9 deletions python/triton/experimental/tle/raw/mlir/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,29 @@
from mlir import ir
from mlir.passmanager import PassManager

from .codegen import EdslMLIRCodeGenerator
from .codegen import MLIRCodeGenerator


class EdslMLIRJITFunction(object):
class MLIRJITFunction(object):

def __init__(self, fn: Any, pipeline: List[str], context: Optional[ir.Context] = None, *args, **kwargs) -> None:
def __init__(self, fn: Any, pipeline: Optional[List[str]] = None, context: Optional[ir.Context] = None, *args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.fn: Final[Any] = fn
self.pipeline: Final[List[str]] = [*pipeline]
self.pipeline: Final[List[str]] = ([*pipeline] if pipeline is not None else [
"convert-scf-to-cf",
"finalize-memref-to-llvm",
"convert-arith-to-llvm",
"convert-cf-to-llvm",
"convert-func-to-llvm",
"convert-index-to-llvm",
"convert-nvvm-to-llvm",
"cse",
])
self.context: Final[ir.Context] = ir.Context() if context is None else context
self.__triton_builtin__: Final[bool] = True

def __deepcopy__(self, memo: Dict[int, Any]) -> EdslMLIRJITFunction:
def __deepcopy__(self, memo: Dict[int, Any]) -> MLIRJITFunction:
return self.__class__(copy.deepcopy(self.fn, memo), copy.deepcopy(self.pipeline, memo), self.context)

@cached_property
Expand All @@ -40,16 +50,16 @@ def globals(self) -> Dict[str, Any]:
return {k: v for k, v in self.fn.__globals__.items() if not k.startswith("__")}

@cached_property
def codegen(self) -> EdslMLIRCodeGenerator:
return EdslMLIRCodeGenerator(self.absfilename, {}, self.globals, self.context)
def codegen(self) -> MLIRCodeGenerator:
return MLIRCodeGenerator(self.absfilename, {}, self.globals, self.context)

@property
def ir(self) -> ir.Module:
mod: ir.Module = self.codegen.visit(self.ast)
return mod

@cached_property
def llvm(self) -> ir.Module:
@property
def ll(self) -> ir.Module:
mod: ir.Module = self.ir
with self.context:
pm: PassManager = PassManager()
Expand All @@ -59,6 +69,10 @@ def llvm(self) -> ir.Module:
pm.run(mod.operation)
return mod

@cached_property
def llvm(self) -> str:
return f"{self.ll}"

@cached_property
def src(self) -> str:
return inspect.getsource(self.fn)
12 changes: 6 additions & 6 deletions python/triton/experimental/tle/raw/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mlir.dialects import arith, func, llvm, scf

if TYPE_CHECKING:
from .codegen import EdslMLIRCodeGenerator
from .codegen import MLIRCodeGenerator


class ExternalCall(object):
Expand All @@ -26,16 +26,16 @@ def build(self) -> func.FuncOp:
...

@abstractmethod
def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp:
def call(self, codegen: MLIRCodeGenerator) -> func.CallOp:
...

def decl(self, codegen: EdslMLIRCodeGenerator) -> func.FuncOp:
def decl(self, codegen: MLIRCodeGenerator) -> func.FuncOp:
with ir.InsertionPoint.at_block_begin(codegen.module.body):
funcop: func.FuncOp = codegen.decls.get(self.keyword) or self.build()
codegen.decls[self.keyword] = funcop
return funcop

def global_string(self, val: str, codegen: EdslMLIRCodeGenerator) -> llvm.GlobalOp:
def global_string(self, val: str, codegen: MLIRCodeGenerator) -> llvm.GlobalOp:
hdigest = blake2s(val.encode('utf-8'), digest_size=16)
key: str = f"globalstr{base64.urlsafe_b64encode(hdigest.digest()).decode('ascii').rstrip('=')}"
with ir.InsertionPoint.at_block_begin(codegen.module.body):
Expand All @@ -59,7 +59,7 @@ def build(self) -> func.FuncOp:
[ir.IntegerType.get_signless(32)]), visibility="private")

@override
def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp:
def call(self, codegen: MLIRCodeGenerator) -> func.CallOp:
[format, *args] = self.args
funcop: func.FuncOp = self.decl(codegen)
format: llvm.GlobalOp = self.global_string(format, codegen)
Expand Down Expand Up @@ -99,7 +99,7 @@ def build(self) -> func.FuncOp:
visibility="private")

@override
def call(self, codegen: EdslMLIRCodeGenerator) -> Any:
def call(self, codegen: MLIRCodeGenerator) -> Any:
func_op = self.decl(codegen)

true_const = arith.constant(ir.IntegerType.get_signless(1), 1)
Expand Down
17 changes: 9 additions & 8 deletions python/triton/experimental/tle/raw/runtime.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from .mlir import EdslMLIRJITFunction
from typing import List
from .cuda import CUDAJITFunction
from .mlir import MLIRJITFunction

registry = {"mlir": EdslMLIRJITFunction}
registry = {"cuda": CUDAJITFunction, "mlir": MLIRJITFunction}


def dialect(*, name: str, pipeline: List[str] = [
"convert-scf-to-cf", "finalize-memref-to-llvm", "convert-arith-to-llvm", "convert-cf-to-llvm",
"convert-func-to-llvm", "convert-index-to-llvm", "convert-nvvm-to-llvm", "cse"
]):
def dialect(
*,
name: str,
**kwargs,
):

def decorator(fn):
edsl = registry[name](fn, pipeline=pipeline)
edsl = registry[name](fn, **kwargs)
return edsl

return decorator
3 changes: 3 additions & 0 deletions python/triton/experimental/tle/raw/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .exception import UnknownSymbolError

__all__ = ["UnknownSymbolError"]
4 changes: 4 additions & 0 deletions python/triton/experimental/tle/raw/utils/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class UnknownSymbolError(Exception):

def __init__(self, name: str, *args, **kwargs) -> None:
super().__init__(f"unknown symbol {name}", *args, **kwargs)
9 changes: 9 additions & 0 deletions python/tutorials/tle/raw/cuda/01-vector-add.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__device__ void VectorAdd(__attribute__((address_space(1))) float *C,
__attribute__((address_space(1))) const float *A,
__attribute__((address_space(1))) const float *B,
const int N) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < N; i += blockDim.x) {
C[i] = A[i] + B[i];
}
}
41 changes: 41 additions & 0 deletions python/tutorials/tle/raw/cuda/01-vector-add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pathlib import Path

import torch
import triton
import triton.language as tl
from triton.experimental.tle.raw import dialect
import triton.experimental.tle.language.raw as tle_raw

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@dialect(name="cuda", file=Path(__file__).parent / "01-vector-add.cu")
def edsl(*args, **kwargs):
...


@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
tle_raw.call(edsl, [], [output_ptr, x_ptr, y_ptr, n_elements])


def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output


if __name__ == "__main__":
x = torch.randn(2048, device=DEVICE)
y = torch.randn(2048, device=DEVICE)
z = add(x, y)
assert torch.allclose(x + y, z), (x + y, z)
3 changes: 2 additions & 1 deletion third_party/tle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ if(TRITON_BUILD_PYTHON_MODULE)
${CMAKE_CURRENT_SOURCE_DIR}/triton_tle_raw.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/lib/Protocol.cpp
LINK_LIBS
MLIRTargetLLVMIRImport
TritonTLETransforms
TritonIR)
target_link_libraries(TritonTLE PRIVATE Python3::Module pybind11::headers)
target_link_libraries(TritonTLE PRIVATE Python3::Module pybind11::headers MLIRTargetLLVMIRImport)
endif()

add_subdirectory(test)
Loading
Loading