diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py new file mode 100644 index 000000000..e7860cf64 --- /dev/null +++ b/python/triton/compiler/hint_manager.py @@ -0,0 +1,135 @@ +import sys +import importlib + + +class BaseHintHandler: + # dynamicly find method + def trigger(self, hook_name, *args, **kwargs): + if hasattr(self, hook_name): + method = getattr(self, hook_name) + if callable(method): + try: + return method(*args, **kwargs) + + except TypeError as e: + import inspect + + try: + sig = inspect.signature(method) + expected = str(sig) + except Exception: + expected = "(unknown)" + + actual_args = f"{len(args)} positional" + actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords" + + print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}") + print(f" > Expect : {expected}") + print(f" > Actual : {actual_args}, {actual_kwargs}") + print(f" > Reason : {e}\n") + + raise e + return None + + +class HintManager: + + def __init__(self, backend_name): + self.backend_name = backend_name + # load Handler with backend name + self.handler = self._load_handler(backend_name) + + def _load_handler(self, backend): + if backend == 'npu': + try: + module = importlib.import_module("triton.backends.ascend.ascend_hint_handler") + return module.AscendHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + elif backend == 'aipu': + try: + module = importlib.import_module("triton.backends.aipu.aipu_hint_handler") + return module.AipuHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + elif backend == 'cuda': + try: + module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler") + return module.NvidiaHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + else: + return BaseHintHandler() + + +# supported backend with matched version +SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"] + +# TODO : npu will have conflicts if more backend involved +# mapping name +BACKEND_ALIASES = { + "ascend": "npu", + "huawei": "npu", + "nvidia": "cuda", +} + + +def normalize_backend_name(name: str) -> str: + if not name: + return "" + name = name.lower() + return BACKEND_ALIASES.get(name, name) + + +def hint_get_flagtree_backend() -> str: + detected_backend = "" + + import torch + + # Priority 1: Triton Driver + try: + from triton.runtime import driver + if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): + device = driver.active.get_active_torch_device() + if isinstance(device, torch.device): + detected_backend = device.type + # unimplemented support + elif isinstance(device, str): + detected_backend = device + except ImportError: + pass + + # TODO : some backend may not support priority 1, so keep priority 2 is necessary + # Priority 2: Torch Global State + if not detected_backend: + check_priority = ["aipu", "npu", "cuda"] + + # 3. parse according to benefit + for candidate in check_priority: + module = getattr(torch, candidate, None) + if module and hasattr(module, "is_available") and module.is_available(): + detected_backend = candidate + break + + # (Normalization and Validation) + canonical_backend = normalize_backend_name(detected_backend) + + if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS: + return "" + + return canonical_backend + + +# lazy load after first call hint trigger +_global_hint_manager = None + + +def hint_trigger(hook_name, *args, **kwargs): + global _global_hint_manager + + if _global_hint_manager is None: + _global_hint_manager = HintManager(hint_get_flagtree_backend()) + return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py new file mode 100644 index 000000000..65e492c6c --- /dev/null +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -0,0 +1,79 @@ +# should store at thrid_party/???/backend/ +from triton.compiler.hint_manager import BaseHintHandler +import triton.language as language +import ast +from triton.compiler.code_generator import _is_triton_value + + +class AscendHintHandler(BaseHintHandler): + + @staticmethod + def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) + and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + @staticmethod + def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + + @staticmethod + def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + + @staticmethod + def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + + return line_flagtree_hints + + @staticmethod + def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 172ba90b4..a20fe0e9a 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -22,6 +22,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +from .hint_manager import hint_trigger # Central registry for all 'with' statement handlers WITH_DISPATCH = {} @@ -547,6 +548,9 @@ def visit_Assign(self, node): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) + # switch into hintmanager + hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) + def visit_AugAssign(self, node): name = node.target.id lhs = ast.Name(id=name, ctx=ast.Load()) @@ -992,6 +996,11 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # hint manager + new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block) + if new_bind_sub_block is not None: + bind_sub_block = new_bind_sub_block + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -1065,6 +1074,9 @@ def visit_For(self, node): for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) if (IteratorClass is extension.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) + # hint manager + if bind_sub_block: + hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) diff --git a/third_party/ascend/backend/spec/triton/runtime/jit.py b/third_party/ascend/backend/spec/triton/runtime/jit.py index 45178a40b..da8ba230e 100644 --- a/third_party/ascend/backend/spec/triton/runtime/jit.py +++ b/third_party/ascend/backend/spec/triton/runtime/jit.py @@ -756,10 +756,20 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): + # hint manager + # after removing flagtree backend specialization, hiding the implementation into hintmanager + from ..compiler.hint_manager import hint_trigger + line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self) + tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) + + # hint manager + # Attach the line number to comment mapping to the function definition node + hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + return tree def __call__(self, *args, **kwargs):