Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions python/triton/compiler/hint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import sys
import importlib


class BaseHintHandler:
# dynamicly find method
def trigger(self, hook_name, *args, **kwargs):
if hasattr(self, hook_name):
method = getattr(self, hook_name)
if callable(method):
try:
return method(*args, **kwargs)

except TypeError as e:
import inspect

try:
sig = inspect.signature(method)
expected = str(sig)
except Exception:
expected = "(unknown)"

actual_args = f"{len(args)} positional"
actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords"

print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}")
print(f" > Expect : {expected}")
print(f" > Actual : {actual_args}, {actual_kwargs}")
print(f" > Reason : {e}\n")

raise e
return None


class HintManager:

def __init__(self, backend_name):
self.backend_name = backend_name
# load Handler with backend name
self.handler = self._load_handler(backend_name)

def _load_handler(self, backend):
if backend == 'npu':
try:
module = importlib.import_module("triton.backends.ascend.ascend_hint_handler")
return module.AscendHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'aipu':
try:
module = importlib.import_module("triton.backends.aipu.aipu_hint_handler")
return module.AipuHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'cuda':
try:
module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler")
return module.NvidiaHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
else:
return BaseHintHandler()


# supported backend with matched version
SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"]

# TODO : npu will have conflicts if more backend involved
# mapping name
BACKEND_ALIASES = {
"ascend": "npu",
"huawei": "npu",
"nvidia": "cuda",
}


def normalize_backend_name(name: str) -> str:
if not name:
return ""
name = name.lower()
return BACKEND_ALIASES.get(name, name)


def hint_get_flagtree_backend() -> str:
detected_backend = ""

import torch

# Priority 1: Triton Driver
try:
from triton.runtime import driver
if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'):
device = driver.active.get_active_torch_device()
if isinstance(device, torch.device):
detected_backend = device.type
# unimplemented support
elif isinstance(device, str):
detected_backend = device
except ImportError:
pass

# TODO : some backend may not support priority 1, so keep priority 2 is necessary
# Priority 2: Torch Global State
if not detected_backend:
check_priority = ["aipu", "npu", "cuda"]

# 3. parse according to benefit
for candidate in check_priority:
module = getattr(torch, candidate, None)
if module and hasattr(module, "is_available") and module.is_available():
detected_backend = candidate
break

# (Normalization and Validation)
canonical_backend = normalize_backend_name(detected_backend)

if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS:
return ""

return canonical_backend


# lazy load after first call hint trigger
_global_hint_manager = None


def hint_trigger(hook_name, *args, **kwargs):
global _global_hint_manager

if _global_hint_manager is None:
_global_hint_manager = HintManager(hint_get_flagtree_backend())
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)
79 changes: 79 additions & 0 deletions third_party/ascend/backend/ascend_hint_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# should store at thrid_party/???/backend/
from triton.compiler.hint_manager import BaseHintHandler
import triton.language as language
import ast
from triton.compiler.code_generator import _is_triton_value


class AscendHintHandler(BaseHintHandler):

@staticmethod
def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values):
import ast
from triton.compiler.code_generator import _is_triton_value
# flagtree: After normal processing, check if we need to add hint annotation
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a tl.load call with dot_pad_only_k hint
if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name)
and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'):

# Add hint annotation to the loaded tensor(s)
for name, value in zip(names, values):
if _is_triton_value(value):
# print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}")
# Create hint annotation
hint_val = code_generator.builder.get_unit_attr()
code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val)

@staticmethod
def check_override_bind_sub_block(code_generator, node, bind_sub_block):
# flagtree: After normal processing, check if we need to override bind_sub_block
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a range/for loop with bind_sub_block hint
if flagtree_hints and 'bind_sub_block' in flagtree_hints:
return True
# print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}")
return bind_sub_block

@staticmethod
def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block):
for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block))

@staticmethod
def maps_line_numbers_to_comment_hints(jit_fn):
import tokenize
from io import StringIO
# Maps line numbers to comment hints
line_flagtree_hints = {}
code_str = jit_fn.src
g = tokenize.generate_tokens(StringIO(code_str).readline)
for tok_type, tok_text, start, end, _ in g:
if tok_type == tokenize.COMMENT:
comment = tok_text.replace(" ", "").strip()
if comment.startswith('#@hint:'):
flagtree_hints = comment[len('#@hint:'):].strip()
# Record the line number of the comment
line_num = start[0]
line_flagtree_hints[line_num] = flagtree_hints

# print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}")

return line_flagtree_hints

@staticmethod
def attach_line_number_to_comment_mapping(tree, line_flagtree_hints):
# Attach the line number to comment mapping to the function definition node
tree.body[0].line_flagtree_hints = line_flagtree_hints
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
from types import ModuleType
from .hint_manager import hint_trigger
# Central registry for all 'with' statement handlers
WITH_DISPATCH = {}

Expand Down Expand Up @@ -547,6 +548,9 @@ def visit_Assign(self, node):
value = language.semantic.to_tensor(value, self.builder)
self.set_value(name, value)

# switch into hintmanager
hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)

def visit_AugAssign(self, node):
name = node.target.id
lhs = ast.Name(id=name, ctx=ast.Load())
Expand Down Expand Up @@ -992,6 +996,11 @@ def visit_For(self, node):
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
else:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
# hint manager
new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block)
if new_bind_sub_block is not None:
bind_sub_block = new_bind_sub_block

# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if _is_constexpr(step) and step.value < 0:
Expand Down Expand Up @@ -1065,6 +1074,9 @@ def visit_For(self, node):
for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr())
if (IteratorClass is extension.parallel):
for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr())
# hint manager
if bind_sub_block:
hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)

self.scf_stack.append(node)
self.builder.set_insertion_point_to_start(for_op.get_body(0))
Expand Down
10 changes: 10 additions & 0 deletions third_party/ascend/backend/spec/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,20 @@ def preload(self, specialization_data):
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
# hint manager
# after removing flagtree backend specialization, hiding the implementation into hintmanager
from ..compiler.hint_manager import hint_trigger
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)

tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)

# hint manager
# Attach the line number to comment mapping to the function definition node
hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints)

return tree

def __call__(self, *args, **kwargs):
Expand Down