diff --git a/.github/workflows/hopper-build-and-test.yml b/.github/workflows/hopper-build-and-test.yml index 8e73e52b6..59d436e91 100644 --- a/.github/workflows/hopper-build-and-test.yml +++ b/.github/workflows/hopper-build-and-test.yml @@ -158,9 +158,9 @@ jobs: python3 python/tutorials/hints/08/08-grouped-gemm.py --only_unit_test python3 python/tutorials/hints/11/11-programmatic-dependent-launch.py --only_unit_test # flagtree tle raw - python3 python/tutorials/tle/raw/01-vector-add.py - python3 python/tutorials/tle/raw/02-fused-softmax.py - python3 python/tutorials/tle/raw/03-matrix-multiplication.py - python3 python/tutorials/tle/raw/04-hello-world.py - python3 python/tutorials/tle/raw/05-topk.py - python3 python/tutorials/tle/raw/06-test-vassert.py + python3 python/tutorials/tle/raw/mlir/01-vector-add.py + python3 python/tutorials/tle/raw/mlir/02-fused-softmax.py + python3 python/tutorials/tle/raw/mlir/03-matrix-multiplication.py + python3 python/tutorials/tle/raw/mlir/04-hello-world.py + python3 python/tutorials/tle/raw/mlir/05-topk.py + python3 python/tutorials/tle/raw/mlir/06-test-vassert.py diff --git a/python/src/ir.cc b/python/src/ir.cc index 369113c07..391dcfbca 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -9,6 +9,7 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" @@ -363,7 +364,7 @@ void init_triton_ir(py::module &&m) { math::MathDialect, arith::ArithDialect, scf::SCFDialect, ::mlir::gpu::GPUDialect, cf::ControlFlowDialect, LLVM::LLVMDialect, mlir::ub::UBDialect, - mlir::triton::gluon::GluonDialect, + mlir::triton::gluon::GluonDialect, DLTIDialect, mlir::triton::tle::TleDialect // flagtree tle raw >(); mlir::LLVM::registerInlinerInterface(registry); diff --git a/python/triton/experimental/tle/language/raw/core.py b/python/triton/experimental/tle/language/raw/core.py index 171f554e2..fe5a2c85b 100644 --- a/python/triton/experimental/tle/language/raw/core.py +++ b/python/triton/experimental/tle/language/raw/core.py @@ -4,9 +4,10 @@ @builtin def call(func, outputs, inputs, _semantic=None): - results = _semantic.builder.create_tle_raw_region_by_llvm_func(f"{func.llvm}", func.fnname, - [output.handle for output in outputs], - [input.handle for input in inputs]) + context = _semantic.builder.get_context() + llvm = func.make_llvm(context) + results = _semantic.builder.create_tle_raw_call(llvm, [output.handle for output in outputs], + [input.handle for input in inputs]) tensors = [tensor(result, output.type) for result, output in zip(results, outputs)] if len(tensors) == 1: return tensors[0] diff --git a/python/triton/experimental/tle/language/raw/semantic.py b/python/triton/experimental/tle/language/raw/semantic.py deleted file mode 100644 index 8f28a67f2..000000000 --- a/python/triton/experimental/tle/language/raw/semantic.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import TypeVar - -from triton._C.libtriton import ir -from triton.language.semantic import TritonSemantic - -TensorTy = TypeVar("TensorTy") - - -class TLERawSemantic(TritonSemantic[TensorTy]): - - def __init__(self, builder: ir.builder, *args, **kwargs) -> None: - super().__init__(builder, *args, **kwargs) diff --git a/python/triton/experimental/tle/raw/cuda/__init__.py b/python/triton/experimental/tle/raw/cuda/__init__.py new file mode 100644 index 000000000..3c84a6dfb --- /dev/null +++ b/python/triton/experimental/tle/raw/cuda/__init__.py @@ -0,0 +1,3 @@ +from .runtime import CUDAJITFunction + +__all__ = ["CUDAJITFunction"] diff --git a/python/triton/experimental/tle/raw/cuda/runtime.py b/python/triton/experimental/tle/raw/cuda/runtime.py new file mode 100644 index 000000000..b912e49ef --- /dev/null +++ b/python/triton/experimental/tle/raw/cuda/runtime.py @@ -0,0 +1,44 @@ +from __future__ import annotations +import copy +import os +from pathlib import Path +import subprocess +from typing import Any, Dict, Final + +from triton._C.libtriton import llvm # pyright: ignore[reportMissingImports] +from triton._C.libtriton.tle.llvm import parse_llvm_ir # pyright: ignore[reportMissingImports] + +# TODO: We use cli tools to compile CUDA code temporarily, and plan to replace it with LLVM components Python bindings in the future. +CLANG = os.getenv("CLANG", "clang") + + +class CUDAJITFunction(object): + + def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.fn: Final[Any] = fn + self.code: Final[str] = file.read_text() + self.__triton_builtin__: Final[bool] = True + + def __deepcopy__(self, memo: Dict[int, Any]) -> CUDAJITFunction: + return self.__class__(copy.deepcopy(self.fn, memo), copy.deepcopy(self.pipeline, memo), self.context) + + def make_llvm(self, mlir_context) -> str: + build = subprocess.run( + [ + CLANG, + "-x", + "cuda", + "--cuda-device-only", + "-emit-llvm", + "-S", + "-", + "-o", + "-", + ], + input=self.code.encode(), + capture_output=True, + ) + llvm_context = llvm.context() + module = parse_llvm_ir(build.stdout.decode(), llvm_context, mlir_context) + return f"{module}" diff --git a/python/triton/experimental/tle/raw/mlir/__init__.py b/python/triton/experimental/tle/raw/mlir/__init__.py index 447446151..5431db683 100644 --- a/python/triton/experimental/tle/raw/mlir/__init__.py +++ b/python/triton/experimental/tle/raw/mlir/__init__.py @@ -1,4 +1,4 @@ -from .runtime import EdslMLIRJITFunction +from .runtime import MLIRJITFunction from .utils import vprintf, vassert -__all__ = ["EdslMLIRJITFunction", "vprintf", "vassert"] +__all__ = ["MLIRJITFunction", "vprintf", "vassert"] diff --git a/python/triton/experimental/tle/raw/mlir/codegen.py b/python/triton/experimental/tle/raw/mlir/codegen.py index 46c88faa7..874628c02 100644 --- a/python/triton/experimental/tle/raw/mlir/codegen.py +++ b/python/triton/experimental/tle/raw/mlir/codegen.py @@ -6,15 +6,10 @@ from mlir.dialects import func from .utils import ExternalCall +from ..utils import UnknownSymbolError -class UnknownSymbolError(Exception): - - def __init__(self, name: str, *args, **kwargs) -> None: - super().__init__(f"unknown symbol {name}", *args, **kwargs) - - -class EdslMLIRCodeGenerator(ast.NodeVisitor): +class MLIRCodeGenerator(ast.NodeVisitor): def __init__( self, diff --git a/python/triton/experimental/tle/raw/mlir/runtime.py b/python/triton/experimental/tle/raw/mlir/runtime.py index 190947045..e60c2841d 100644 --- a/python/triton/experimental/tle/raw/mlir/runtime.py +++ b/python/triton/experimental/tle/raw/mlir/runtime.py @@ -8,19 +8,29 @@ from mlir import ir from mlir.passmanager import PassManager -from .codegen import EdslMLIRCodeGenerator +from .codegen import MLIRCodeGenerator -class EdslMLIRJITFunction(object): +class MLIRJITFunction(object): - def __init__(self, fn: Any, pipeline: List[str], context: Optional[ir.Context] = None, *args, **kwargs) -> None: + def __init__(self, fn: Any, pipeline: Optional[List[str]] = None, context: Optional[ir.Context] = None, *args, + **kwargs) -> None: super().__init__(*args, **kwargs) self.fn: Final[Any] = fn - self.pipeline: Final[List[str]] = [*pipeline] + self.pipeline: Final[List[str]] = ([*pipeline] if pipeline is not None else [ + "convert-scf-to-cf", + "finalize-memref-to-llvm", + "convert-arith-to-llvm", + "convert-cf-to-llvm", + "convert-func-to-llvm", + "convert-index-to-llvm", + "convert-nvvm-to-llvm", + "cse", + ]) self.context: Final[ir.Context] = ir.Context() if context is None else context self.__triton_builtin__: Final[bool] = True - def __deepcopy__(self, memo: Dict[int, Any]) -> EdslMLIRJITFunction: + def __deepcopy__(self, memo: Dict[int, Any]) -> MLIRJITFunction: return self.__class__(copy.deepcopy(self.fn, memo), copy.deepcopy(self.pipeline, memo), self.context) @cached_property @@ -40,16 +50,16 @@ def globals(self) -> Dict[str, Any]: return {k: v for k, v in self.fn.__globals__.items() if not k.startswith("__")} @cached_property - def codegen(self) -> EdslMLIRCodeGenerator: - return EdslMLIRCodeGenerator(self.absfilename, {}, self.globals, self.context) + def codegen(self) -> MLIRCodeGenerator: + return MLIRCodeGenerator(self.absfilename, {}, self.globals, self.context) @property def ir(self) -> ir.Module: mod: ir.Module = self.codegen.visit(self.ast) return mod - @cached_property - def llvm(self) -> ir.Module: + @property + def ll(self) -> ir.Module: mod: ir.Module = self.ir with self.context: pm: PassManager = PassManager() @@ -59,6 +69,10 @@ def llvm(self) -> ir.Module: pm.run(mod.operation) return mod + @cached_property + def llvm(self) -> str: + return f"{self.ll}" + @cached_property def src(self) -> str: return inspect.getsource(self.fn) diff --git a/python/triton/experimental/tle/raw/mlir/utils.py b/python/triton/experimental/tle/raw/mlir/utils.py index e77ed7048..0be9a5514 100644 --- a/python/triton/experimental/tle/raw/mlir/utils.py +++ b/python/triton/experimental/tle/raw/mlir/utils.py @@ -11,7 +11,7 @@ from mlir.dialects import arith, func, llvm, scf if TYPE_CHECKING: - from .codegen import EdslMLIRCodeGenerator + from .codegen import MLIRCodeGenerator class ExternalCall(object): @@ -26,16 +26,16 @@ def build(self) -> func.FuncOp: ... @abstractmethod - def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp: + def call(self, codegen: MLIRCodeGenerator) -> func.CallOp: ... - def decl(self, codegen: EdslMLIRCodeGenerator) -> func.FuncOp: + def decl(self, codegen: MLIRCodeGenerator) -> func.FuncOp: with ir.InsertionPoint.at_block_begin(codegen.module.body): funcop: func.FuncOp = codegen.decls.get(self.keyword) or self.build() codegen.decls[self.keyword] = funcop return funcop - def global_string(self, val: str, codegen: EdslMLIRCodeGenerator) -> llvm.GlobalOp: + def global_string(self, val: str, codegen: MLIRCodeGenerator) -> llvm.GlobalOp: hdigest = blake2s(val.encode('utf-8'), digest_size=16) key: str = f"globalstr{base64.urlsafe_b64encode(hdigest.digest()).decode('ascii').rstrip('=')}" with ir.InsertionPoint.at_block_begin(codegen.module.body): @@ -59,7 +59,7 @@ def build(self) -> func.FuncOp: [ir.IntegerType.get_signless(32)]), visibility="private") @override - def call(self, codegen: EdslMLIRCodeGenerator) -> func.CallOp: + def call(self, codegen: MLIRCodeGenerator) -> func.CallOp: [format, *args] = self.args funcop: func.FuncOp = self.decl(codegen) format: llvm.GlobalOp = self.global_string(format, codegen) @@ -99,7 +99,7 @@ def build(self) -> func.FuncOp: visibility="private") @override - def call(self, codegen: EdslMLIRCodeGenerator) -> Any: + def call(self, codegen: MLIRCodeGenerator) -> Any: func_op = self.decl(codegen) true_const = arith.constant(ir.IntegerType.get_signless(1), 1) diff --git a/python/triton/experimental/tle/raw/runtime.py b/python/triton/experimental/tle/raw/runtime.py index cc95f0159..62e1bb5df 100644 --- a/python/triton/experimental/tle/raw/runtime.py +++ b/python/triton/experimental/tle/raw/runtime.py @@ -1,16 +1,17 @@ -from .mlir import EdslMLIRJITFunction -from typing import List +from .cuda import CUDAJITFunction +from .mlir import MLIRJITFunction -registry = {"mlir": EdslMLIRJITFunction} +registry = {"cuda": CUDAJITFunction, "mlir": MLIRJITFunction} -def dialect(*, name: str, pipeline: List[str] = [ - "convert-scf-to-cf", "finalize-memref-to-llvm", "convert-arith-to-llvm", "convert-cf-to-llvm", - "convert-func-to-llvm", "convert-index-to-llvm", "convert-nvvm-to-llvm", "cse" -]): +def dialect( + *, + name: str, + **kwargs, +): def decorator(fn): - edsl = registry[name](fn, pipeline=pipeline) + edsl = registry[name](fn, **kwargs) return edsl return decorator diff --git a/python/triton/experimental/tle/raw/utils/__init__.py b/python/triton/experimental/tle/raw/utils/__init__.py new file mode 100644 index 000000000..056356707 --- /dev/null +++ b/python/triton/experimental/tle/raw/utils/__init__.py @@ -0,0 +1,3 @@ +from .exception import UnknownSymbolError + +__all__ = ["UnknownSymbolError"] diff --git a/python/triton/experimental/tle/raw/utils/exception.py b/python/triton/experimental/tle/raw/utils/exception.py new file mode 100644 index 000000000..c7548e870 --- /dev/null +++ b/python/triton/experimental/tle/raw/utils/exception.py @@ -0,0 +1,4 @@ +class UnknownSymbolError(Exception): + + def __init__(self, name: str, *args, **kwargs) -> None: + super().__init__(f"unknown symbol {name}", *args, **kwargs) diff --git a/python/tutorials/tle/raw/cuda/01-vector-add.cu b/python/tutorials/tle/raw/cuda/01-vector-add.cu new file mode 100644 index 000000000..b6f6d00ec --- /dev/null +++ b/python/tutorials/tle/raw/cuda/01-vector-add.cu @@ -0,0 +1,9 @@ +__device__ void VectorAdd(__attribute__((address_space(1))) float *C, + __attribute__((address_space(1))) const float *A, + __attribute__((address_space(1))) const float *B, + const int N) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = idx; i < N; i += blockDim.x) { + C[i] = A[i] + B[i]; + } +} diff --git a/python/tutorials/tle/raw/cuda/01-vector-add.py b/python/tutorials/tle/raw/cuda/01-vector-add.py new file mode 100644 index 000000000..69835307d --- /dev/null +++ b/python/tutorials/tle/raw/cuda/01-vector-add.py @@ -0,0 +1,41 @@ +from pathlib import Path + +import torch +import triton +import triton.language as tl +from triton.experimental.tle.raw import dialect +import triton.experimental.tle.language.raw as tle_raw + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@dialect(name="cuda", file=Path(__file__).parent / "01-vector-add.cu") +def edsl(*args, **kwargs): + ... + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + tle_raw.call(edsl, [], [output_ptr, x_ptr, y_ptr, n_elements]) + + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + + +if __name__ == "__main__": + x = torch.randn(2048, device=DEVICE) + y = torch.randn(2048, device=DEVICE) + z = add(x, y) + assert torch.allclose(x + y, z), (x + y, z) diff --git a/python/tutorials/tle/raw/01-vector-add.py b/python/tutorials/tle/raw/mlir/01-vector-add.py similarity index 100% rename from python/tutorials/tle/raw/01-vector-add.py rename to python/tutorials/tle/raw/mlir/01-vector-add.py diff --git a/python/tutorials/tle/raw/02-fused-softmax.py b/python/tutorials/tle/raw/mlir/02-fused-softmax.py similarity index 100% rename from python/tutorials/tle/raw/02-fused-softmax.py rename to python/tutorials/tle/raw/mlir/02-fused-softmax.py diff --git a/python/tutorials/tle/raw/03-matrix-multiplication.py b/python/tutorials/tle/raw/mlir/03-matrix-multiplication.py similarity index 100% rename from python/tutorials/tle/raw/03-matrix-multiplication.py rename to python/tutorials/tle/raw/mlir/03-matrix-multiplication.py diff --git a/python/tutorials/tle/raw/04-hello-world.py b/python/tutorials/tle/raw/mlir/04-hello-world.py similarity index 100% rename from python/tutorials/tle/raw/04-hello-world.py rename to python/tutorials/tle/raw/mlir/04-hello-world.py diff --git a/python/tutorials/tle/raw/05-topk.py b/python/tutorials/tle/raw/mlir/05-topk.py similarity index 100% rename from python/tutorials/tle/raw/05-topk.py rename to python/tutorials/tle/raw/mlir/05-topk.py diff --git a/python/tutorials/tle/raw/06-test-vassert.py b/python/tutorials/tle/raw/mlir/06-test-vassert.py similarity index 100% rename from python/tutorials/tle/raw/06-test-vassert.py rename to python/tutorials/tle/raw/mlir/06-test-vassert.py diff --git a/third_party/tle/CMakeLists.txt b/third_party/tle/CMakeLists.txt index 847cb7f08..165966399 100644 --- a/third_party/tle/CMakeLists.txt +++ b/third_party/tle/CMakeLists.txt @@ -9,9 +9,10 @@ if(TRITON_BUILD_PYTHON_MODULE) ${CMAKE_CURRENT_SOURCE_DIR}/triton_tle_raw.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/lib/Protocol.cpp LINK_LIBS + MLIRTargetLLVMIRImport TritonTLETransforms TritonIR) - target_link_libraries(TritonTLE PRIVATE Python3::Module pybind11::headers) + target_link_libraries(TritonTLE PRIVATE Python3::Module pybind11::headers MLIRTargetLLVMIRImport) endif() add_subdirectory(test) diff --git a/third_party/tle/triton_tle.cc b/third_party/tle/triton_tle.cc index 75792c203..bbf33ef07 100644 --- a/third_party/tle/triton_tle.cc +++ b/third_party/tle/triton_tle.cc @@ -30,12 +30,15 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" +#include "mlir/Target/LLVMIR/Import.h" #include "passes.h" #include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "tle/dialect/include/IR/Dialect.h" #include "tle/dialect/include/Transforms/Passes.h" @@ -46,7 +49,11 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/IRReader/IRReader.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" namespace py = pybind11; using namespace mlir; @@ -54,9 +61,10 @@ namespace ttg = triton::gpu; namespace ttng = triton::nvidia_gpu; namespace tle = triton::tle; -extern SmallVector createTLERawRegionByLLVMFunc( - TritonOpBuilder &self, std::string_view text, std::string_view fnname, - const std::vector &outputs, const std::vector &inputs); +extern SmallVector createTLERawCall(TritonOpBuilder &self, + std::string_view text, + const std::vector &outputs, + const std::vector &inputs); void init_triton_tle_ir(py::module &&m) { using ret = py::return_value_policy; @@ -187,14 +195,14 @@ void init_tle_raw_ir(py::module &&m) { using ret = py::return_value_policy; auto *builder_cls = ir::getBuilderClass(); - builder_cls->def( - "create_tle_raw_region_by_llvm_func", - [](TritonOpBuilder &self, std::string_view text, std::string_view fnname, - const std::vector &outputs, const std::vector &inputs) { - SmallVector results = - createTLERawRegionByLLVMFunc(self, text, fnname, outputs, inputs); - return std::vector(results.begin(), results.end()); - }); + builder_cls->def("create_tle_raw_call", [](TritonOpBuilder &self, + std::string_view text, + const std::vector &outputs, + const std::vector &inputs) { + SmallVector results = createTLERawCall(self, text, outputs, inputs); + return std::vector(results.begin(), results.end()); + }); + builder_cls->def("get_context", &TritonOpBuilder::getContext); } void init_tle_raw_passes(py::module &&m) { @@ -202,6 +210,27 @@ void init_tle_raw_passes(py::module &&m) { mlir::triton::tle::createTleConvertArgToMemDesc); } +void init_llvm(py::module &&m) { + using ret = py::return_value_policy; + m.def("parse_llvm_ir", + [](std::string_view text, llvm::LLVMContext &llvmContext, + mlir::MLIRContext &mlirContext) -> mlir::ModuleOp { + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(text); + llvm::SMDiagnostic error; + std::unique_ptr llvmModule = + llvm::parseIR(buffer->getMemBufferRef(), error, llvmContext); + if (!llvmModule) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + return mlir::translateLLVMIRToModule(std::move(llvmModule), + &mlirContext) + ->clone(); + }); +} + void init_triton_tle(py::module &&m) { // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { @@ -216,4 +245,5 @@ void init_triton_tle(py::module &&m) { init_triton_tle_passes(m.def_submodule("passes")); init_tle_raw_ir(m.def_submodule("raw_ir")); init_tle_raw_passes(m.def_submodule("raw_passes")); + init_llvm(m.def_submodule("llvm")); } diff --git a/third_party/tle/triton_tle_raw.cc b/third_party/tle/triton_tle_raw.cc index 2a2e1cb69..54edf80e0 100644 --- a/third_party/tle/triton_tle_raw.cc +++ b/third_party/tle/triton_tle_raw.cc @@ -53,12 +53,23 @@ SmallVector flatten(TritonOpBuilder &builder, // - TT IR: i32 (IntegerType) // - LLVM func: 1 arg = i32 // - Conversion: SignaturePattern::apply directly passes the scalar value -SmallVector createTLERawRegionByLLVMFunc( - TritonOpBuilder &self, std::string_view text, std::string_view fnname, - const std::vector &outputs, const std::vector &inputs) { +SmallVector createTLERawCall(TritonOpBuilder &self, + std::string_view text, + const std::vector &outputs, + const std::vector &inputs) { ParserConfig config(self.getContext()); OwningOpRef module = parseSourceString(text, config); - LLVM::LLVMFuncOp func = module->lookupSymbol(fnname); + LLVM::LLVMFuncOp func = nullptr; + for (auto op : module->getOps()) { + if (!op.empty()) { + if (func) { + llvm_unreachable("Multiple functions found in LLVM IR text"); + } else { + func = op; + } + } + } + assert(func && "No function found in LLVM IR text"); OpBuilder &builder = self.getBuilder(); Operation *curOp = builder.getInsertionBlock()->getParentOp(); while (curOp && curOp->getParentOp() && !isa(curOp)) { @@ -71,12 +82,14 @@ SmallVector createTLERawRegionByLLVMFunc( for (Operation &op : module->getOps()) { if ((!isa(op) || (isa(op) && - !curModule.lookupSymbol(cast(op).getName())))) { + !curModule.lookupSymbol(cast(op).getName()))) && + !isa(op)) { builder.clone(op); } } }; - LLVM::LLVMFuncOp funcOp = curModule.lookupSymbol(fnname); + LLVM::LLVMFuncOp funcOp = + curModule.lookupSymbol(func.getSymName()); SmallVector operands = {}; TypeRange tgts = func.getArgumentTypes(); SmallVector outs = SmallVector(outputs.begin(), outputs.end()),