Skip to content

feat: support mori IR JIT compilation with shmem#418

Open
yanboshao wants to merge 5 commits intomainfrom
yanbo/shmem
Open

feat: support mori IR JIT compilation with shmem#418
yanboshao wants to merge 5 commits intomainfrom
yanbo/shmem

Conversation

@yanboshao
Copy link
Copy Markdown
Contributor

Motivation

Support mori IR JIT compilation with shmem, enabling FlyDSL kernels to call mori shmem device functions (e.g. my_pe, int32_p, quiet_thread) via ExternFunction, with automatic bitcode linking and runtime hook injection.

Technical Details

-Add ExternFunction for calling external LLVM bitcode from @flyc.kernel, with optional bitcode_path for explicit link-lib declaration
-Add shmem module-load hook in jit_executor for mori shmem_module_init injection after hipModuleLoadData
-Support link_libs in MlirCompiler — inject l= into rocdl-attach-target pass fragment
-Track extern_symbols and link_libs in CompilationContext; fallback to mori_shmem_* prefix auto-detection when bitcode_path is not set
-Add mgpuSetModuleLoadHook callback in FlyRocmRuntimeWrappers.cpp
-Remove targets from create_gpu_module to avoid rocdl-attach-target appending a duplicate target (which causes hipErrorNoBinaryForGpu when link_libs is used)

Test Plan

-torchrun --nproc_per_node=2 tests/kernels/test_flydsl_shmem.py — cross-PE shmem integration test:
--test_basic: verify mori_shmem_my_pe / mori_shmem_n_pes return correct PE identity
--test_put: verify cross-PE int32_p + quiet_thread delivers data correctly

Test Result

[PE 0/2] initialized on GPU 0
[PE 1/2] initialized on GPU 1
[PE 0] my_pe=0, n_pes=2 — basic PASS
[PE 1] my_pe=1, n_pes=2 — basic PASS
[PE 0] buf=142, expected=142 (from PE 1) — put PASS
[PE 1] buf=42, expected=42 (from PE 0) — put PASS
All tests PASSED on 2 PEs (FlyDSL + mori shmem)

Submission Checklist

@yanboshao yanboshao force-pushed the yanbo/shmem branch 3 times, most recently from c45ba00 to 0c768da Compare April 21, 2026 05:18
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this file be treated as some kind of expr? I think it would be better placed elsewhere.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExternFunction is essentially part of the compiler infrastructure: it declares llvm.func, generates llvm.call, and registers symbols into the CompilationContext.
However, ExternFunction.call generates IR at the kernel body’s insertion point, which is fundamentally the same level of operation as addi() in expr/arith.py.
Since users do not directly import ExternFunction—it is consumed internally by Mori—it makes more sense to place it under the compiler/ directory.
I have already moved this file to the compiler/ directory.

@yanboshao yanboshao force-pushed the yanbo/shmem branch 3 times, most recently from cf91f76 to d63acb1 Compare April 21, 2026 06:44
Comment thread lib/Runtime/FlyRocmRuntimeWrappers.cpp Outdated
hipModule_t module = nullptr;
HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
if (module && s_moduleLoadHook) {
s_moduleLoadHook(module);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hook all kernels?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Replaced with an opt-in thread-local callback (mgpuSetModuleLoadCallback). Only engines whose kernels
declare extern bitcode install it, and it is cleared immediately after ExecutionEngine.initialize(), so kernels without an ExternFunction see no hook at all.

HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
if (module)
s_loadedModules.push_back(module);
return module;
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module is returned. Can we maintain it in python class private var instead of a global var in cpp?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment thread python/flydsl/compiler/jit_function.py Outdated
pm.enable_verifier(env.debug.enable_verifier)
_t0 = time.perf_counter()
pm.run(module.operation)
_dt = time.perf_counter() - _t0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean debug codes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment thread python/flydsl/compiler/jit_function.py Outdated

if comp_ctx.extern_symbols and not link_libs:
try:
from mori.ir.flydsl.compile_helper import prepare_compile
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a good idea to import mori here. We need a clean jit func here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Extern integration now flows through CompilationContext.link_libs + post_load_processors, which each ExternFunction populates at declaration time; the JIT path just reads those two lists.

- Add ExternFunction for calling external LLVM bitcode from @flyc.kernel
- Add shmem module-load hook in jit_executor for mori shmem injection
- Support link_libs and mori shmem bitcode linking in MlirCompiler
- Track extern_symbols in CompilationContext, use gpu_func_op for known_block_size
- Add mgpuSetModuleLoadHook callback in FlyRocmRuntimeWrappers
- C++ runtime: replace mgpuSetModuleLoadHook with module tracker APIs
  (mgpuResetModuleTracker/GetLoadedModuleCount/GetLoadedModule)
- CompiledArtifact: accept post_load_processors instead of needs_shmem,
  apply processors to tracked modules after engine initialization
- Restore owned reference keeping in __call__ to prevent UAF
- Delegate shmem detection to mori compile_helper.prepare_compile(),
  removing hardcoded mori_shmem_* prefix and direct find_bitcode import
- Improve rocdl-attach-target fragment string splicing robustness
- Add clarifying comment for void-return Operation.create in extern.py
- Remove hardcoded absolute paths from test_flydsl_shmem.py

Made-with: Cursor
- Raise ImportError (not silent pass) when mori compile helper is
  missing for kernels using extern symbols
- Always obtain module_init_fn even with explicit link_libs
- Fix rstrip to only remove trailing brace, not all braces
- Replace assert with raise RuntimeError for missing rocdl-attach-target
- Add bounds check in mgpuGetLoadedModule (C++)
- Use Optional type annotations in CompiledArtifact
- Rename test_ functions to run_ to avoid pytest collection conflicts

Made-with: Cursor
Post-rebase cleanup round covering correctness, concurrency, pickling
and documentation gaps spotted during deep review.

Correctness & safety:
- extern.py: replace id()-keyed `_declared_in` cache with an idempotent
  scan of `gpu_module_body.operations`; the old cache could false-hit on
  recycled Python ids across compilations and grew unboundedly.
- jit_executor.py: `_ensure_engine` now raises RuntimeError if
  post_load_processors are registered but the engine tracker reports
  zero loaded modules -- surfaces any future ExecutionEngine behaviour
  change (e.g. async load) instead of silently skipping module init.
- jit_executor.py: `__getstate__` now raises pickle.PicklingError when
  any post_load_processor cannot be serialised as "module:qualname",
  instead of silently dropping processors and producing a cache entry
  that would segfault on reload.
- jit_executor.py: `cb_ref` typed as Optional[Any] (private
  ctypes._CFuncPtr was never a stable annotation).

Dead code / structure:
- kernel_function.py: drop unused `CompilationContext.extern_symbols`
  (was written but never read anywhere).
- jit_function.py: shorten the long comment above `create_gpu_module` to
  just explain *why* targets are not passed here (duplicate-target
  attribute -> hipErrorNoBinaryForGpu), no behaviour change.

Concurrency contract documentation:
- FlyRocmRuntimeWrappers.cpp: document the thread_local `s_moduleLoadCb`
  design in detail -- why a mutex-protected global is *not* a fix, and
  the explicit set-before / clear-after contract that callers must
  honour around `mgpuModuleLoad`.

Docs:
- docs/extern_integration_guide.md: new FlyDSL-side integration guide
  covering the ExternFunction surface, the post-load callback contract,
  and the on-disk JIT cache / pickling contract.

Made-with: Cursor
…trip

The previous check only inspected whether ``__module__`` and
``__qualname__`` exist — both are present on lambdas, nested functions,
and bound methods, so those slipped through and ``__getstate__`` happily
wrote strings like ``__main__:<lambda>`` into the pickle stream.  On
unpickle ``_resolve_qualname`` would then silently return None, the
processor list would come back short, and the next kernel launch would
GPU-fault on an uninitialised device-side global with no usable
stacktrace — defeating the whole point of raising eagerly at cache-write
time.

Fixes:
- reject ``__qualname__`` containing ``<`` (lambdas, ``<locals>``,
  comprehensions, …)
- reject bound methods (``__self__`` is set; resolving the name gives
  the unbound function, silently dropping ``self``)
- final round-trip check: ``_resolve_qualname(ref) is fn`` must hold,
  otherwise we refuse

All three rules documented in the ``_qualname`` docstring.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants