-
Notifications
You must be signed in to change notification settings - Fork 45
[HINT] Add Triton v3.2.x hint manager #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
starrryz
wants to merge
15
commits into
triton_v3.2.x
Choose a base branch
from
triton_v3.2.x_hint_manager
base: triton_v3.2.x
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
b75b621
the initial design for unified hint framework
starrryz 0078272
update the logic of how to call backend method in basehinthandler
starrryz bfe1966
update hintmanager, wrap additional code into hintmanager, back no-hi…
starrryz 2c64367
remove redundant code
starrryz a430a9e
fix import and python bugs
starrryz 06a032a
fix import and python bugs_2
starrryz 854b504
apply code-format change
starrryz 9e2ef64
apply code-format change_2
starrryz 51756b2
fix bug : circular import
starrryz 19f80c5
fix bug : hintmanager name into hint_manager
starrryz 45c93b4
fix bug : massive useless print
starrryz 869c357
update spec hint-related codegen && jit
starrryz cc97432
remove redundant code in python triton src
starrryz d484e12
update hintmanager, Align with triton_v3.5.x branch.
starrryz 5f7a336
fix hint manager import error
starrryz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| import sys | ||
| import importlib | ||
|
|
||
|
|
||
| class BaseHintHandler: | ||
| # dynamicly find method | ||
| def trigger(self, hook_name, *args, **kwargs): | ||
| if hasattr(self, hook_name): | ||
| method = getattr(self, hook_name) | ||
| if callable(method): | ||
| try: | ||
| return method(*args, **kwargs) | ||
|
|
||
| except TypeError as e: | ||
| import inspect | ||
|
|
||
| try: | ||
| sig = inspect.signature(method) | ||
| expected = str(sig) | ||
| except Exception: | ||
| expected = "(unknown)" | ||
|
|
||
| actual_args = f"{len(args)} positional" | ||
| actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords" | ||
|
|
||
| print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}") | ||
| print(f" > Expect : {expected}") | ||
| print(f" > Actual : {actual_args}, {actual_kwargs}") | ||
| print(f" > Reason : {e}\n") | ||
|
|
||
| raise e | ||
| return None | ||
|
|
||
|
|
||
| class HintManager: | ||
|
|
||
| def __init__(self, backend_name): | ||
| self.backend_name = backend_name | ||
| # load Handler with backend name | ||
| self.handler = self._load_handler(backend_name) | ||
|
|
||
| def _load_handler(self, backend): | ||
| if backend == 'npu': | ||
| try: | ||
| module = importlib.import_module("triton.backends.ascend.ascend_hint_handler") | ||
| return module.AscendHintHandler() | ||
| except ImportError as e: | ||
| print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr) | ||
| return BaseHintHandler() | ||
| elif backend == 'aipu': | ||
| try: | ||
| module = importlib.import_module("triton.backends.aipu.aipu_hint_handler") | ||
| return module.AipuHintHandler() | ||
| except ImportError as e: | ||
| print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr) | ||
| return BaseHintHandler() | ||
| elif backend == 'cuda': | ||
| try: | ||
| module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler") | ||
| return module.NvidiaHintHandler() | ||
| except ImportError as e: | ||
| print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr) | ||
| return BaseHintHandler() | ||
| else: | ||
| return BaseHintHandler() | ||
|
|
||
|
|
||
| # supported backend with matched version | ||
| SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"] | ||
|
|
||
| # TODO : npu will have conflicts if more backend involved | ||
| # mapping name | ||
| BACKEND_ALIASES = { | ||
| "ascend": "npu", | ||
| "huawei": "npu", | ||
| "nvidia": "cuda", | ||
| } | ||
|
|
||
|
|
||
| def normalize_backend_name(name: str) -> str: | ||
| if not name: | ||
| return "" | ||
| name = name.lower() | ||
| return BACKEND_ALIASES.get(name, name) | ||
|
|
||
|
|
||
| def hint_get_flagtree_backend() -> str: | ||
| detected_backend = "" | ||
|
|
||
| import torch | ||
|
|
||
| # Priority 1: Triton Driver | ||
| try: | ||
| from triton.runtime import driver | ||
| if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'): | ||
| device = driver.active.get_active_torch_device() | ||
| if isinstance(device, torch.device): | ||
| detected_backend = device.type | ||
| # unimplemented support | ||
| elif isinstance(device, str): | ||
| detected_backend = device | ||
| except ImportError: | ||
| pass | ||
|
|
||
| # TODO : some backend may not support priority 1, so keep priority 2 is necessary | ||
| # Priority 2: Torch Global State | ||
| if not detected_backend: | ||
| check_priority = ["aipu", "npu", "cuda"] | ||
|
|
||
| # 3. parse according to benefit | ||
| for candidate in check_priority: | ||
| module = getattr(torch, candidate, None) | ||
| if module and hasattr(module, "is_available") and module.is_available(): | ||
| detected_backend = candidate | ||
| break | ||
|
|
||
| # (Normalization and Validation) | ||
| canonical_backend = normalize_backend_name(detected_backend) | ||
|
|
||
| if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS: | ||
| return "" | ||
|
|
||
| return canonical_backend | ||
|
|
||
|
|
||
| # lazy load after first call hint trigger | ||
| _global_hint_manager = None | ||
|
|
||
|
|
||
| def hint_trigger(hook_name, *args, **kwargs): | ||
| global _global_hint_manager | ||
|
|
||
| if _global_hint_manager is None: | ||
| _global_hint_manager = HintManager(hint_get_flagtree_backend()) | ||
| return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # should store at thrid_party/???/backend/ | ||
| from triton.compiler.hint_manager import BaseHintHandler | ||
| import triton.language as language | ||
| import ast | ||
| from triton.compiler.code_generator import _is_triton_value | ||
|
|
||
|
|
||
| class AscendHintHandler(BaseHintHandler): | ||
|
|
||
| @staticmethod | ||
| def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): | ||
| import ast | ||
| from triton.compiler.code_generator import _is_triton_value | ||
| # flagtree: After normal processing, check if we need to add hint annotation | ||
| if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): | ||
| line_num = node.lineno | ||
| # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later | ||
| function_def = code_generator.jit_fn.parse() | ||
| line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) | ||
| flagtree_hints = line_flagtree_hints.get(line_num) | ||
|
|
||
| # Check if this is a tl.load call with dot_pad_only_k hint | ||
| if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) | ||
| and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) | ||
| and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): | ||
|
|
||
| # Add hint annotation to the loaded tensor(s) | ||
| for name, value in zip(names, values): | ||
| if _is_triton_value(value): | ||
| # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") | ||
| # Create hint annotation | ||
| hint_val = code_generator.builder.get_unit_attr() | ||
| code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) | ||
|
|
||
| @staticmethod | ||
| def check_override_bind_sub_block(code_generator, node, bind_sub_block): | ||
| # flagtree: After normal processing, check if we need to override bind_sub_block | ||
| if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): | ||
| line_num = node.lineno | ||
| # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later | ||
| function_def = code_generator.jit_fn.parse() | ||
| line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) | ||
| flagtree_hints = line_flagtree_hints.get(line_num) | ||
|
|
||
| # Check if this is a range/for loop with bind_sub_block hint | ||
| if flagtree_hints and 'bind_sub_block' in flagtree_hints: | ||
| return True | ||
| # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") | ||
starrryz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return bind_sub_block | ||
|
|
||
| @staticmethod | ||
| def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): | ||
| for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) | ||
|
|
||
| @staticmethod | ||
| def maps_line_numbers_to_comment_hints(jit_fn): | ||
| import tokenize | ||
| from io import StringIO | ||
| # Maps line numbers to comment hints | ||
| line_flagtree_hints = {} | ||
| code_str = jit_fn.src | ||
| g = tokenize.generate_tokens(StringIO(code_str).readline) | ||
| for tok_type, tok_text, start, end, _ in g: | ||
| if tok_type == tokenize.COMMENT: | ||
| comment = tok_text.replace(" ", "").strip() | ||
| if comment.startswith('#@hint:'): | ||
| flagtree_hints = comment[len('#@hint:'):].strip() | ||
| # Record the line number of the comment | ||
| line_num = start[0] | ||
| line_flagtree_hints[line_num] = flagtree_hints | ||
|
|
||
| # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") | ||
|
|
||
| return line_flagtree_hints | ||
|
|
||
| @staticmethod | ||
| def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): | ||
| # Attach the line number to comment mapping to the function definition node | ||
| tree.body[0].line_flagtree_hints = line_flagtree_hints | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.