From b75b621ff10072861ee7891acbc9890ad396b490 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Mon, 12 Jan 2026 09:19:06 +0000 Subject: [PATCH 01/15] the initial design for unified hint framework --- python/triton/compiler/code_generator.py | 5 + python/triton/compiler/hint_manager.py | 160 ++++++++++++++++++ python/triton/runtime/jit.py | 30 ++++ .../ascend/backend/ascend_hint_handler.py | 64 +++++++ 4 files changed, 259 insertions(+) create mode 100644 python/triton/compiler/hint_manager.py create mode 100644 third_party/ascend/backend/ascend_hint_handler.py diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index d8ca58d8d..0f0aa1a68 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -241,6 +241,11 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # special handling. self.visiting_arg_default_value = False + # adding unified hint manager init + from .hint_manager import HintManager + from .hint_manager import hint_get_flagtree_backend + self.hint_manager = HintManager(hint_get_flagtree_backend()) + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} builtin_namespace.update(( ('print', language.core.device_print), diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py new file mode 100644 index 000000000..33a52782c --- /dev/null +++ b/python/triton/compiler/hint_manager.py @@ -0,0 +1,160 @@ +import os +import torch +import triton +from typing import Optional + +class BaseHintHandler: + # 这里是不是该变成动态的,所有都注册,或者不注册的就不解析 + # --- Assign 相关 --- + def ext_CodeGenerator_visit_Assign_hint_anno(self, code_generator, node, names, values): + """默认为空,不做任何标注""" + pass + + # --- For Loop 相关 (完全沿用原名) --- + + def visit_For_ext_support(self): + """默认只支持 range,不增加额外 Iterator 支持""" + return [] + + def set_bind_sub_block_when_parallel(self, IteratorClass, iterator, bind_sub_block): + """默认不修改,直接把传进来的 bind_sub_block 返回回去""" + return bind_sub_block + + def check_override_bind_sub_block(self, code_generator, node, bind_sub_block): + """默认不覆盖,直接返回原值""" + return bind_sub_block + + def forop_setattr_for_bind_sub_block(self, code_generator, for_op, bind_sub_block): + """默认不设置属性""" + pass + + def need_repr_in_CodeGenerator_CompilationError(self): + """默认不需要额外报错信息""" + return False + + + +class HintManager: + def __init__(self, backend_name): + self.backend_name = backend_name + self.hints_cache = {} # { lineno: { key: value } } + # 根据后端名称加载对应的 Handler + self.handler = self._load_handler(backend_name) + + def _load_handler(self, backend): + # 简单的工厂模式 + if backend == 'npu': + try: + # 假设 ascend 的代码在 python path 中可见 + # 这里根据你项目的实际 import 路径修改 + # 假如是在 third_party.ascend... 下 + # need to be optimized + module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler") + return module.AscendHintHandler() + except ImportError as e: + logging.warning(f"Failed to load Ascend Hint Handler: {e}") + return BaseHintHandler() + elif backend == 'aipu': + from .backends.aipu import AipuHintHandler + return AipuHintHandler() + else: + return BaseHintHandler() + + def parse_hints_once(self, jit_fn): + """只解析一次,缓存结果""" + if not self.hints_cache and jit_fn: + import ast + # 假设你的前端 parse 逻辑能提取出 {lineno: hints} + # 这里优化了 3.2 中重复 parse 的问题 + tree = jit_fn.parse() + # 递归或遍历 tree 获取所有 hints,存入 self.hints_cache + self.hints_cache = self._extract_hints_from_tree(tree) + + def apply_hints(self, builder, node, instruction_handle, ...): + """CodeGenerator 调用的唯一入口""" + if not hasattr(node, 'lineno'): + return + + hints = self.hints_cache.get(node.lineno) + if hints: + # 委托给具体后端的 Handler 处理 + self.handler.process(builder, instruction_handle, hints) + + +# supported backend with matched version +SUPPORTED_CONFIG = { + "cuda": {"3.5"}, + "npu": {"3.2"}, + "aipu": {"3.3"}, +} + +# mapping name +BACKEND_ALIASES = { + "ascend": "npu", + "huawei": "npu", + "nv": "cuda", +} + + +def normalize_backend_name(name: str) -> str: + # convert name + if not name: + return "" + name = name.lower() + return BACKEND_ALIASES.get(name, name) + +def hint_get_flagtree_backend() -> str: + detected_backend = "" + + # --- 阶段一:多源探测 (Chain of Detection) --- + + # Priority 1: Triton Driver + try: + from triton.runtime import driver + if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): + device = driver.active.get_active_torch_device() + if isinstance(device, torch.device): + detected_backend = device.type + # unimplemented support + elif isinstance(device, str): + detected_backend = device + except ImportError: + pass + + # Priority 2: Torch Global State + if not detected_backend: + candidates = list(SUPPORTED_CONFIG.keys()) + # cuda priority least + candidates.sort(key=lambda x: 1 if x == "cuda" else 0) + + # 3. 按优先级顺序遍历 + for candidate in candidates: + module_name = candidate + module = getattr(torch, module_name, None) + if module and hasattr(module, "is_available") and module.is_available(): + detected_backend = candidate + break + + # Priority 3: Environment Variable (need to remove!!!) + if not detected_backend: + detected_backend = os.environ.get("FLAGTREE_BACKEND", "") + + # (Normalization and Validation) + canonical_backend = normalize_backend_name(detected_backend) + + if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG: + return "" + + # verify name and version match + current_triton_version = ".".join(triton.__version__.split(".")[:2]) + supported_versions = SUPPORTED_CONFIG[canonical_backend] + + if current_triton_version in supported_versions: + return canonical_backend + else: + # version and backend mismatch + logging.warning( + f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " + f"'{current_triton_version}' matches no supported versions {supported_versions}." + ) + return "" \ No newline at end of file diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 45178a40b..4909a3441 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -565,6 +565,8 @@ def run(self, *args, grid, warmup, **kwargs): # parse options from ..compiler import make_backend + # tip_for_runtime_device_get + # torch.device = device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() @@ -752,6 +754,34 @@ def preload(self, specialization_data): self.cache[device][key] = kernel return kernel + # need to remove and place right + def get_flagtree_backend(): + from triton.runtime.driver import driver + + # non-driver : proton f2reduce + # GPUdriver : self.get_current_device = torch.cuda.current_device + # NPUDriver(DriverBase) : get_current_device(self) return torch.npu.current_device() + # AIPUDriver(DriverBase) : get_active_torch_device(self): torch.device("aipu", 0) 但是3.3的jit.run是nv的get_device方式 + # _GCUDriver(DriverBase) : get_active_torch_device(self): torch.device("gcu", self.get_current_device()) + # BangDriver(DriverBase) : get_device_interface(self): return torch.mlu + # CudaDriver(GPUDriver) : get_active_torch_device(self): return "iluvatar" !to implemet; + # MusaDriver(GPUDriver) : get_active_torch_device(self): return "musa" !to implemet + # TXDADriver(GPUDriver) : get_active_torch_device(self): return torch.device("txda", self.get_current_device()) + # HIPDriver(GPUDriver): get_active_torch_device(self): return torch.device("cuda", self.get_current_device()) + # CudaDriver(GPUDriver): get_active_torch_device(self): return torch.device("cuda", self.get_current_device()) + # XPUDriver(GPUDriver): get_active_torch_device(self): return "xpu" + # return torch.npu.current_device() 本质貌似还是torch + device = driver.active.get_current_device() + + # 稳定得到str + name = getattr(device, "name", "").lower() + + # 可能不叫ascend,有可能是device编号 + if "ascend" in name: + return "ascend" + return "default" + + # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py new file mode 100644 index 000000000..1a9079f54 --- /dev/null +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -0,0 +1,64 @@ +from triton.compiler.hint_manager import BaseHintHandler +import triton.language as language +import ast +from triton.compiler.code_generator import _is_triton_value + +class AscendHintHandler(BaseHintHandler): + + def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + + def visit_For_ext_support(): + import triton.language as language + return [language.parallel] + + def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): + import triton.language as language + if (IteratorClass is language.parallel): + return iterator.bind_sub_block + return bind_sub_block + + def check_override_bind_sub_block(code_generator, node, bind_sub_block): + # flagtree: After normal processing, check if we need to override bind_sub_block + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a range/for loop with bind_sub_block hint + if flagtree_hints and 'bind_sub_block' in flagtree_hints: + return True + # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") + return bind_sub_block + + def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + + def need_repr_in_CodeGenerator_CompilationError(): + return True \ No newline at end of file From 0078272074c11be9307721e0703bc2f21dccfb41 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 13 Jan 2026 07:37:11 +0000 Subject: [PATCH 02/15] update the logic of how to call backend method in basehinthandler --- python/triton/compiler/hint_manager.py | 57 +++++++++---------- .../ascend/backend/ascend_hint_handler.py | 13 +++-- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index 33a52782c..ac2f50519 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -4,35 +4,34 @@ from typing import Optional class BaseHintHandler: - # 这里是不是该变成动态的,所有都注册,或者不注册的就不解析 - # --- Assign 相关 --- - def ext_CodeGenerator_visit_Assign_hint_anno(self, code_generator, node, names, values): - """默认为空,不做任何标注""" - pass - - # --- For Loop 相关 (完全沿用原名) --- - - def visit_For_ext_support(self): - """默认只支持 range,不增加额外 Iterator 支持""" - return [] - - def set_bind_sub_block_when_parallel(self, IteratorClass, iterator, bind_sub_block): - """默认不修改,直接把传进来的 bind_sub_block 返回回去""" - return bind_sub_block - - def check_override_bind_sub_block(self, code_generator, node, bind_sub_block): - """默认不覆盖,直接返回原值""" - return bind_sub_block - - def forop_setattr_for_bind_sub_block(self, code_generator, for_op, bind_sub_block): - """默认不设置属性""" - pass - - def need_repr_in_CodeGenerator_CompilationError(self): - """默认不需要额外报错信息""" - return False - - + # dynamicly find method + def trigger(self, hook_name, *args, **kwargs): + if hasattr(self, hook_name): + method = getattr(self, hook_name) + if callable(method): + try: + return method(*args, **kwargs) + + except TypeError as e: + import inspect + + try: + sig = inspect.signature(method) + expected = str(sig) + except: + expected = "(unknown)" + + actual_args = f"{len(args)} positional" + actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords" + + print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}") + print(f" > Expect : {expected}") + print(f" > Actual : {actual_args}, {actual_kwargs}") + print(f" > Reason : {e}\n") + + raise e + print(f"no capable method in backend handler") + return None class HintManager: def __init__(self, backend_name): diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py index 1a9079f54..5d240a354 100644 --- a/third_party/ascend/backend/ascend_hint_handler.py +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -1,3 +1,4 @@ +# should store at thrid_party/???/backend/ from triton.compiler.hint_manager import BaseHintHandler import triton.language as language import ast @@ -5,7 +6,7 @@ class AscendHintHandler(BaseHintHandler): - def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + def ext_CodeGenerator_visit_Assign_hint_anno(self, code_generator, node, names, values): import ast from triton.compiler.code_generator import _is_triton_value # flagtree: After normal processing, check if we need to add hint annotation @@ -32,17 +33,17 @@ def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values hint_val = code_generator.builder.get_unit_attr() code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) - def visit_For_ext_support(): + def visit_For_ext_support(self): import triton.language as language return [language.parallel] - def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): + def set_bind_sub_block_when_parallel(self, IteratorClass, iterator, bind_sub_block): import triton.language as language if (IteratorClass is language.parallel): return iterator.bind_sub_block return bind_sub_block - def check_override_bind_sub_block(code_generator, node, bind_sub_block): + def check_override_bind_sub_block(self, code_generator, node, bind_sub_block): # flagtree: After normal processing, check if we need to override bind_sub_block if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): line_num = node.lineno @@ -57,8 +58,8 @@ def check_override_bind_sub_block(code_generator, node, bind_sub_block): # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") return bind_sub_block - def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): + def forop_setattr_for_bind_sub_block(self, code_generator, for_op, bind_sub_block): for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) - def need_repr_in_CodeGenerator_CompilationError(): + def need_repr_in_CodeGenerator_CompilationError(self): return True \ No newline at end of file From bfe196609b1c8b9587ead5ab799e945fb2793acd Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Wed, 21 Jan 2026 10:05:38 +0000 Subject: [PATCH 03/15] update hintmanager, wrap additional code into hintmanager, back no-hint-related handler func into spec, update import, change jit implement into hintmanager, simplify trigger call --- python/triton/compiler/code_generator.py | 6 +- python/triton/compiler/hint_manager.py | 12 ++- python/triton/runtime/jit.py | 1 + .../ascend/backend/ascend_hint_handler.py | 93 +++++++++++-------- 4 files changed, 67 insertions(+), 45 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 0f0aa1a68..1fbddd895 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,6 +15,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +from .hintmanager import hint_trigger def mangle_ty(ty): @@ -241,11 +242,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n # special handling. self.visiting_arg_default_value = False - # adding unified hint manager init - from .hint_manager import HintManager - from .hint_manager import hint_get_flagtree_backend - self.hint_manager = HintManager(hint_get_flagtree_backend()) - builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} builtin_namespace.update(( ('print', language.core.device_print), diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index ac2f50519..df3fc38f0 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -152,8 +152,16 @@ def hint_get_flagtree_backend() -> str: return canonical_backend else: # version and backend mismatch - logging.warning( + msg = ( f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " f"'{current_triton_version}' matches no supported versions {supported_versions}." ) - return "" \ No newline at end of file + print(msg, file=sys.stderr) + return "" +# lazy load after first call hint trigger +_global_hint_manager = None + +def hint_trigger(hook_name, *args, **kwargs): + if _global_hint_manager is None: + _global_hint_manager = HintManager(hint_get_flagtree_backend()) + return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) \ No newline at end of file diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4909a3441..366fca93e 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -11,6 +11,7 @@ from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver from types import ModuleType +from ..compiler.hintmanager import hint_trigger TRITON_MODULE = __name__[:-len(".runtime.jit")] diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py index 5d240a354..6d90fdf52 100644 --- a/third_party/ascend/backend/ascend_hint_handler.py +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -6,44 +6,36 @@ class AscendHintHandler(BaseHintHandler): - def ext_CodeGenerator_visit_Assign_hint_anno(self, code_generator, node, names, values): - import ast - from triton.compiler.code_generator import _is_triton_value - # flagtree: After normal processing, check if we need to add hint annotation - if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = code_generator.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a tl.load call with dot_pad_only_k hint - if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and - isinstance(node.value, ast.Call) and - isinstance(node.value.func, ast.Attribute) and - isinstance(node.value.func.value, ast.Name) and - node.value.func.value.id == 'tl' and - node.value.func.attr == 'load'): - - # Add hint annotation to the loaded tensor(s) - for name, value in zip(names, values): - if _is_triton_value(value): - # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") - # Create hint annotation - hint_val = code_generator.builder.get_unit_attr() - code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) + @staticmethod + def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): + import ast + from triton.compiler.code_generator import _is_triton_value + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = code_generator.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) - def visit_For_ext_support(self): - import triton.language as language - return [language.parallel] + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): - def set_bind_sub_block_when_parallel(self, IteratorClass, iterator, bind_sub_block): - import triton.language as language - if (IteratorClass is language.parallel): - return iterator.bind_sub_block - return bind_sub_block + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = code_generator.builder.get_unit_attr() + code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) - def check_override_bind_sub_block(self, code_generator, node, bind_sub_block): + @staticmethod + def check_override_bind_sub_block(code_generator, node, bind_sub_block): # flagtree: After normal processing, check if we need to override bind_sub_block if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): line_num = node.lineno @@ -58,8 +50,33 @@ def check_override_bind_sub_block(self, code_generator, node, bind_sub_block): # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") return bind_sub_block - def forop_setattr_for_bind_sub_block(self, code_generator, for_op, bind_sub_block): + @staticmethod + def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) - def need_repr_in_CodeGenerator_CompilationError(self): - return True \ No newline at end of file + + @staticmethod + def maps_line_numbers_to_comment_hints(jit_fn): + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + + return line_flagtree_hints + + @staticmethod + def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints \ No newline at end of file From 2c64367cebd8628c7b7572c33d4a3a6103b5f623 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Mon, 26 Jan 2026 09:15:04 +0000 Subject: [PATCH 04/15] remove redundant code --- python/triton/compiler/hint_manager.py | 68 +++++++++----------------- python/triton/runtime/jit.py | 30 ------------ 2 files changed, 23 insertions(+), 75 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index df3fc38f0..f46cc437f 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -1,4 +1,7 @@ import os +import sys +import logging +import importlib import torch import triton from typing import Optional @@ -36,22 +39,16 @@ def trigger(self, hook_name, *args, **kwargs): class HintManager: def __init__(self, backend_name): self.backend_name = backend_name - self.hints_cache = {} # { lineno: { key: value } } - # 根据后端名称加载对应的 Handler + # load Handler with backend name self.handler = self._load_handler(backend_name) def _load_handler(self, backend): - # 简单的工厂模式 if backend == 'npu': try: - # 假设 ascend 的代码在 python path 中可见 - # 这里根据你项目的实际 import 路径修改 - # 假如是在 third_party.ascend... 下 - # need to be optimized module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler") return module.AscendHintHandler() except ImportError as e: - logging.warning(f"Failed to load Ascend Hint Handler: {e}") + print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) return BaseHintHandler() elif backend == 'aipu': from .backends.aipu import AipuHintHandler @@ -59,26 +56,6 @@ def _load_handler(self, backend): else: return BaseHintHandler() - def parse_hints_once(self, jit_fn): - """只解析一次,缓存结果""" - if not self.hints_cache and jit_fn: - import ast - # 假设你的前端 parse 逻辑能提取出 {lineno: hints} - # 这里优化了 3.2 中重复 parse 的问题 - tree = jit_fn.parse() - # 递归或遍历 tree 获取所有 hints,存入 self.hints_cache - self.hints_cache = self._extract_hints_from_tree(tree) - - def apply_hints(self, builder, node, instruction_handle, ...): - """CodeGenerator 调用的唯一入口""" - if not hasattr(node, 'lineno'): - return - - hints = self.hints_cache.get(node.lineno) - if hints: - # 委托给具体后端的 Handler 处理 - self.handler.process(builder, instruction_handle, hints) - # supported backend with matched version SUPPORTED_CONFIG = { @@ -96,7 +73,6 @@ def apply_hints(self, builder, node, instruction_handle, ...): def normalize_backend_name(name: str) -> str: - # convert name if not name: return "" name = name.lower() @@ -105,8 +81,6 @@ def normalize_backend_name(name: str) -> str: def hint_get_flagtree_backend() -> str: detected_backend = "" - # --- 阶段一:多源探测 (Chain of Detection) --- - # Priority 1: Triton Driver try: from triton.runtime import driver @@ -126,7 +100,7 @@ def hint_get_flagtree_backend() -> str: # cuda priority least candidates.sort(key=lambda x: 1 if x == "cuda" else 0) - # 3. 按优先级顺序遍历 + # 3. parse according to benefit for candidate in candidates: module_name = candidate module = getattr(torch, module_name, None) @@ -145,23 +119,27 @@ def hint_get_flagtree_backend() -> str: return "" # verify name and version match - current_triton_version = ".".join(triton.__version__.split(".")[:2]) - supported_versions = SUPPORTED_CONFIG[canonical_backend] - - if current_triton_version in supported_versions: - return canonical_backend - else: - # version and backend mismatch - msg = ( - f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " - f"'{current_triton_version}' matches no supported versions {supported_versions}." - ) - print(msg, file=sys.stderr) - return "" + try: + current_triton_version = ".".join(triton.__version__.split(".")[:2]) + supported_versions = SUPPORTED_CONFIG[canonical_backend] + if current_triton_version not in supported_versions: + msg = ( + f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " + f"'{current_triton_version}' matches no supported versions {supported_versions}." + ) + print(msg, file=sys.stderr) + return "" + except Exception: + pass + + return canonical_backend + # lazy load after first call hint trigger _global_hint_manager = None def hint_trigger(hook_name, *args, **kwargs): + global _global_hint_manager + if _global_hint_manager is None: _global_hint_manager = HintManager(hint_get_flagtree_backend()) return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) \ No newline at end of file diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 366fca93e..33ee561d0 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -566,8 +566,6 @@ def run(self, *args, grid, warmup, **kwargs): # parse options from ..compiler import make_backend - # tip_for_runtime_device_get - # torch.device = device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) target = driver.active.get_current_target() @@ -755,34 +753,6 @@ def preload(self, specialization_data): self.cache[device][key] = kernel return kernel - # need to remove and place right - def get_flagtree_backend(): - from triton.runtime.driver import driver - - # non-driver : proton f2reduce - # GPUdriver : self.get_current_device = torch.cuda.current_device - # NPUDriver(DriverBase) : get_current_device(self) return torch.npu.current_device() - # AIPUDriver(DriverBase) : get_active_torch_device(self): torch.device("aipu", 0) 但是3.3的jit.run是nv的get_device方式 - # _GCUDriver(DriverBase) : get_active_torch_device(self): torch.device("gcu", self.get_current_device()) - # BangDriver(DriverBase) : get_device_interface(self): return torch.mlu - # CudaDriver(GPUDriver) : get_active_torch_device(self): return "iluvatar" !to implemet; - # MusaDriver(GPUDriver) : get_active_torch_device(self): return "musa" !to implemet - # TXDADriver(GPUDriver) : get_active_torch_device(self): return torch.device("txda", self.get_current_device()) - # HIPDriver(GPUDriver): get_active_torch_device(self): return torch.device("cuda", self.get_current_device()) - # CudaDriver(GPUDriver): get_active_torch_device(self): return torch.device("cuda", self.get_current_device()) - # XPUDriver(GPUDriver): get_active_torch_device(self): return "xpu" - # return torch.npu.current_device() 本质貌似还是torch - device = driver.active.get_current_device() - - # 稳定得到str - name = getattr(device, "name", "").lower() - - # 可能不叫ascend,有可能是device编号 - if "ascend" in name: - return "ascend" - return "default" - - # we do not parse `src` in the constructor because # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. From a430a9e214340635092dbda1656fcf21660408c4 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Mon, 26 Jan 2026 11:12:08 +0000 Subject: [PATCH 05/15] fix import and python bugs --- python/triton/compiler/hint_manager.py | 9 +++--- .../ascend/backend/ascend_hint_handler.py | 32 +++++++++---------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index f46cc437f..351a1fc9e 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -2,8 +2,6 @@ import sys import logging import importlib -import torch -import triton from typing import Optional class BaseHintHandler: @@ -33,8 +31,8 @@ def trigger(self, hook_name, *args, **kwargs): print(f" > Reason : {e}\n") raise e - print(f"no capable method in backend handler") - return None + print(f"no capable method in backend handler") + return None class HintManager: def __init__(self, backend_name): @@ -80,6 +78,9 @@ def normalize_backend_name(name: str) -> str: def hint_get_flagtree_backend() -> str: detected_backend = "" + + import torch + import triton # Priority 1: Triton Driver try: diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py index 6d90fdf52..cd48f9361 100644 --- a/third_party/ascend/backend/ascend_hint_handler.py +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -57,24 +57,24 @@ def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): @staticmethod def maps_line_numbers_to_comment_hints(jit_fn): - import tokenize - from io import StringIO - # Maps line numbers to comment hints - line_flagtree_hints = {} - code_str = jit_fn.src - g = tokenize.generate_tokens(StringIO(code_str).readline) - for tok_type, tok_text, start, end, _ in g: - if tok_type == tokenize.COMMENT: - comment = tok_text.replace(" ", "").strip() - if comment.startswith('#@hint:'): - flagtree_hints = comment[len('#@hint:'):].strip() - # Record the line number of the comment - line_num = start[0] - line_flagtree_hints[line_num] = flagtree_hints + import tokenize + from io import StringIO + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = jit_fn.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints - # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") - return line_flagtree_hints + return line_flagtree_hints @staticmethod def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): From 06a032a2483662b3111c8eb270f3123fd07f4128 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Mon, 26 Jan 2026 11:16:30 +0000 Subject: [PATCH 06/15] fix import and python bugs_2 --- python/triton/compiler/hint_manager.py | 2 +- third_party/ascend/backend/ascend_hint_handler.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index 351a1fc9e..c0e6284d1 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -19,7 +19,7 @@ def trigger(self, hook_name, *args, **kwargs): try: sig = inspect.signature(method) expected = str(sig) - except: + except Exception: expected = "(unknown)" actual_args = f"{len(args)} positional" diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py index cd48f9361..0b7834330 100644 --- a/third_party/ascend/backend/ascend_hint_handler.py +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -78,5 +78,5 @@ def maps_line_numbers_to_comment_hints(jit_fn): @staticmethod def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): - # Attach the line number to comment mapping to the function definition node - tree.body[0].line_flagtree_hints = line_flagtree_hints \ No newline at end of file + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints \ No newline at end of file From 854b504da518947d3330130a5ff217ca4feb35ec Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Mon, 26 Jan 2026 11:31:08 +0000 Subject: [PATCH 07/15] apply code-format change --- python/triton/compiler/hint_manager.py | 24 +++++++++---------- .../ascend/backend/ascend_hint_handler.py | 15 +++++------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index c0e6284d1..4161078b5 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -1,8 +1,6 @@ import os import sys -import logging import importlib -from typing import Optional class BaseHintHandler: # dynamicly find method @@ -31,7 +29,7 @@ def trigger(self, hook_name, *args, **kwargs): print(f" > Reason : {e}\n") raise e - print(f"no capable method in backend handler") + print("no capable method in backend handler") return None class HintManager: @@ -58,7 +56,7 @@ def _load_handler(self, backend): # supported backend with matched version SUPPORTED_CONFIG = { "cuda": {"3.5"}, - "npu": {"3.2"}, + "npu": {"3.2"}, "aipu": {"3.3"}, } @@ -82,7 +80,7 @@ def hint_get_flagtree_backend() -> str: import torch import triton - # Priority 1: Triton Driver + # Priority 1: Triton Driver try: from triton.runtime import driver if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): @@ -103,19 +101,19 @@ def hint_get_flagtree_backend() -> str: # 3. parse according to benefit for candidate in candidates: - module_name = candidate + module_name = candidate module = getattr(torch, module_name, None) if module and hasattr(module, "is_available") and module.is_available(): detected_backend = candidate break - + # Priority 3: Environment Variable (need to remove!!!) if not detected_backend: detected_backend = os.environ.get("FLAGTREE_BACKEND", "") # (Normalization and Validation) canonical_backend = normalize_backend_name(detected_backend) - + if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG: return "" @@ -124,10 +122,8 @@ def hint_get_flagtree_backend() -> str: current_triton_version = ".".join(triton.__version__.split(".")[:2]) supported_versions = SUPPORTED_CONFIG[canonical_backend] if current_triton_version not in supported_versions: - msg = ( - f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " - f"'{current_triton_version}' matches no supported versions {supported_versions}." - ) + msg = (f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " + f"'{current_triton_version}' matches no supported versions {supported_versions}.") print(msg, file=sys.stderr) return "" except Exception: @@ -135,12 +131,14 @@ def hint_get_flagtree_backend() -> str: return canonical_backend + # lazy load after first call hint trigger _global_hint_manager = None + def hint_trigger(hook_name, *args, **kwargs): global _global_hint_manager if _global_hint_manager is None: _global_hint_manager = HintManager(hint_get_flagtree_backend()) - return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) \ No newline at end of file + return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) diff --git a/third_party/ascend/backend/ascend_hint_handler.py b/third_party/ascend/backend/ascend_hint_handler.py index 0b7834330..65e492c6c 100644 --- a/third_party/ascend/backend/ascend_hint_handler.py +++ b/third_party/ascend/backend/ascend_hint_handler.py @@ -1,9 +1,10 @@ # should store at thrid_party/???/backend/ from triton.compiler.hint_manager import BaseHintHandler -import triton.language as language +import triton.language as language import ast from triton.compiler.code_generator import _is_triton_value + class AscendHintHandler(BaseHintHandler): @staticmethod @@ -19,12 +20,9 @@ def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values flagtree_hints = line_flagtree_hints.get(line_num) # Check if this is a tl.load call with dot_pad_only_k hint - if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and - isinstance(node.value, ast.Call) and - isinstance(node.value.func, ast.Attribute) and - isinstance(node.value.func.value, ast.Name) and - node.value.func.value.id == 'tl' and - node.value.func.attr == 'load'): + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) + and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): # Add hint annotation to the loaded tensor(s) for name, value in zip(names, values): @@ -54,7 +52,6 @@ def check_override_bind_sub_block(code_generator, node, bind_sub_block): def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) - @staticmethod def maps_line_numbers_to_comment_hints(jit_fn): import tokenize @@ -79,4 +76,4 @@ def maps_line_numbers_to_comment_hints(jit_fn): @staticmethod def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): # Attach the line number to comment mapping to the function definition node - tree.body[0].line_flagtree_hints = line_flagtree_hints \ No newline at end of file + tree.body[0].line_flagtree_hints = line_flagtree_hints From 9e2ef64e3e1ae411d1a793723e6648610993a6bb Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Mon, 26 Jan 2026 11:35:19 +0000 Subject: [PATCH 08/15] apply code-format change_2 --- python/triton/compiler/hint_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index 4161078b5..f4a6e8f69 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -2,6 +2,7 @@ import sys import importlib + class BaseHintHandler: # dynamicly find method def trigger(self, hook_name, *args, **kwargs): @@ -32,7 +33,9 @@ def trigger(self, hook_name, *args, **kwargs): print("no capable method in backend handler") return None + class HintManager: + def __init__(self, backend_name): self.backend_name = backend_name # load Handler with backend name @@ -56,7 +59,7 @@ def _load_handler(self, backend): # supported backend with matched version SUPPORTED_CONFIG = { "cuda": {"3.5"}, - "npu": {"3.2"}, + "npu": {"3.2"}, "aipu": {"3.3"}, } @@ -74,12 +77,13 @@ def normalize_backend_name(name: str) -> str: name = name.lower() return BACKEND_ALIASES.get(name, name) + def hint_get_flagtree_backend() -> str: detected_backend = "" import torch import triton - + # Priority 1: Triton Driver try: from triton.runtime import driver From 51756b26fb426e0778a0c1098ea3621add1860a2 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 27 Jan 2026 02:39:15 +0000 Subject: [PATCH 09/15] fix bug : circular import --- python/triton/runtime/jit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 33ee561d0..45178a40b 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -11,7 +11,6 @@ from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver from types import ModuleType -from ..compiler.hintmanager import hint_trigger TRITON_MODULE = __name__[:-len(".runtime.jit")] From 19f80c51301a2a9ae5846a107db1d82c534a5b67 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 27 Jan 2026 02:52:37 +0000 Subject: [PATCH 10/15] fix bug : hintmanager name into hint_manager --- python/triton/compiler/code_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 1fbddd895..9c614658c 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,7 +15,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType -from .hintmanager import hint_trigger +from .hint_manager import hint_trigger def mangle_ty(ty): From 45c93b4bb9a287629b9ef98b67b4cff2a8cc2b9d Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 27 Jan 2026 03:20:24 +0000 Subject: [PATCH 11/15] fix bug : massive useless print --- python/triton/compiler/hint_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index f4a6e8f69..719605175 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -30,7 +30,6 @@ def trigger(self, hook_name, *args, **kwargs): print(f" > Reason : {e}\n") raise e - print("no capable method in backend handler") return None From 869c357b988b76e4f0bc93603f77b4533f06b90c Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 10 Mar 2026 08:41:52 +0000 Subject: [PATCH 12/15] update spec hint-related codegen && jit --- .../backend/spec/triton/compiler/code_generator.py | 12 ++++++++++++ .../ascend/backend/spec/triton/runtime/jit.py | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 172ba90b4..a20fe0e9a 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -22,6 +22,7 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType +from .hint_manager import hint_trigger # Central registry for all 'with' statement handlers WITH_DISPATCH = {} @@ -547,6 +548,9 @@ def visit_Assign(self, node): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) + # switch into hintmanager + hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values) + def visit_AugAssign(self, node): name = node.target.id lhs = ast.Name(id=name, ctx=ast.Load()) @@ -992,6 +996,11 @@ def visit_For(self, node): step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) else: raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # hint manager + new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block) + if new_bind_sub_block is not None: + bind_sub_block = new_bind_sub_block + # handle negative constant step (not supported by scf.for in MLIR) negative_step = False if _is_constexpr(step) and step.value < 0: @@ -1065,6 +1074,9 @@ def visit_For(self, node): for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) if (IteratorClass is extension.parallel): for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) + # hint manager + if bind_sub_block: + hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) diff --git a/third_party/ascend/backend/spec/triton/runtime/jit.py b/third_party/ascend/backend/spec/triton/runtime/jit.py index 45178a40b..da8ba230e 100644 --- a/third_party/ascend/backend/spec/triton/runtime/jit.py +++ b/third_party/ascend/backend/spec/triton/runtime/jit.py @@ -756,10 +756,20 @@ def preload(self, specialization_data): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): + # hint manager + # after removing flagtree backend specialization, hiding the implementation into hintmanager + from ..compiler.hint_manager import hint_trigger + line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self) + tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) + + # hint manager + # Attach the line number to comment mapping to the function definition node + hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints) + return tree def __call__(self, *args, **kwargs): From cc97432a139c98dd769234b8d2991e6d5accedb2 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 10 Mar 2026 08:47:02 +0000 Subject: [PATCH 13/15] remove redundant code in python triton src --- python/triton/compiler/code_generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 9c614658c..d8ca58d8d 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,7 +15,6 @@ from ..runtime import JITFunction from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType -from .hint_manager import hint_trigger def mangle_ty(ty): From d484e12b5f5563fa19ff3bc01c472e3a881cfd7a Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Tue, 10 Mar 2026 08:52:59 +0000 Subject: [PATCH 14/15] update hintmanager, Align with triton_v3.5.x branch. --- python/triton/compiler/hint_manager.py | 54 ++++++++++---------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index 719605175..e3eb7afc5 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -1,4 +1,3 @@ -import os import sys import importlib @@ -49,24 +48,32 @@ def _load_handler(self, backend): print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) return BaseHintHandler() elif backend == 'aipu': - from .backends.aipu import AipuHintHandler - return AipuHintHandler() + try: + module = importlib.import_module("third_party.aipu.backend.aipu_hint_handler") + return module.AipuHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() + elif backend == 'cuda': + try: + module = importlib.import_module("third_party.nvidia.backend.nvidia_hint_handler") + return module.NvidiaHintHandler() + except ImportError as e: + print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr) + return BaseHintHandler() else: return BaseHintHandler() # supported backend with matched version -SUPPORTED_CONFIG = { - "cuda": {"3.5"}, - "npu": {"3.2"}, - "aipu": {"3.3"}, -} +SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"] +# TODO : npu will have conflicts if more backend involved # mapping name BACKEND_ALIASES = { "ascend": "npu", "huawei": "npu", - "nv": "cuda", + "nvidia": "cuda", } @@ -81,7 +88,6 @@ def hint_get_flagtree_backend() -> str: detected_backend = "" import torch - import triton # Priority 1: Triton Driver try: @@ -96,42 +102,24 @@ def hint_get_flagtree_backend() -> str: except ImportError: pass + # TODO : some backend may not support priority 1, so keep priority 2 is necessary # Priority 2: Torch Global State if not detected_backend: - candidates = list(SUPPORTED_CONFIG.keys()) - # cuda priority least - candidates.sort(key=lambda x: 1 if x == "cuda" else 0) + check_priority = ["aipu", "npu", "cuda"] # 3. parse according to benefit - for candidate in candidates: - module_name = candidate - module = getattr(torch, module_name, None) + for candidate in check_priority: + module = getattr(torch, candidate, None) if module and hasattr(module, "is_available") and module.is_available(): detected_backend = candidate break - # Priority 3: Environment Variable (need to remove!!!) - if not detected_backend: - detected_backend = os.environ.get("FLAGTREE_BACKEND", "") - # (Normalization and Validation) canonical_backend = normalize_backend_name(detected_backend) - if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG: + if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS: return "" - # verify name and version match - try: - current_triton_version = ".".join(triton.__version__.split(".")[:2]) - supported_versions = SUPPORTED_CONFIG[canonical_backend] - if current_triton_version not in supported_versions: - msg = (f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version " - f"'{current_triton_version}' matches no supported versions {supported_versions}.") - print(msg, file=sys.stderr) - return "" - except Exception: - pass - return canonical_backend From 5f7a336d5a40787e10e177add4d7e34bc06e0654 Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Wed, 11 Mar 2026 02:40:09 +0000 Subject: [PATCH 15/15] fix hint manager import error --- python/triton/compiler/hint_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/triton/compiler/hint_manager.py b/python/triton/compiler/hint_manager.py index e3eb7afc5..e7860cf64 100644 --- a/python/triton/compiler/hint_manager.py +++ b/python/triton/compiler/hint_manager.py @@ -42,21 +42,21 @@ def __init__(self, backend_name): def _load_handler(self, backend): if backend == 'npu': try: - module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler") + module = importlib.import_module("triton.backends.ascend.ascend_hint_handler") return module.AscendHintHandler() except ImportError as e: print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) return BaseHintHandler() elif backend == 'aipu': try: - module = importlib.import_module("third_party.aipu.backend.aipu_hint_handler") + module = importlib.import_module("triton.backends.aipu.aipu_hint_handler") return module.AipuHintHandler() except ImportError as e: print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr) return BaseHintHandler() elif backend == 'cuda': try: - module = importlib.import_module("third_party.nvidia.backend.nvidia_hint_handler") + module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler") return module.NvidiaHintHandler() except ImportError as e: print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr)