Conversation
|
@coderfeli @sjfeng1999 Could you please review this PR when you have time? Thanks! |
|
I think a better general approach is to verify the op before lowering |
Freezes current (buggy) lowering: a non-unit-stride !fly.memref on one side of fly.copy_atom_call is lowered to a single contiguous llvm.load / llvm.store against the memory-side pointer, silently ignoring the stride. Next commit will fix emitAtomCallSSA and update these CHECKs.
`fly.copy_atom_call` / `fly.copy_atom_call_ssa` with `!fly.universal_copy<...>` used to accept non-contiguous memrefs and lower them as if the atom operated on a contiguous slice. Fix this by verifying universal copy operands before lowering: * the memref layout must coalesce to a single static leaf * the coalesced bit count must match the copy atom bit width * the contiguous bit count must not be smaller than the copy granularity This turns strided/otherwise incompatible memrefs into a clear verification error instead of silently generating code that deviates from the original atom semantics.
b4130b0 to
cd1ac59
Compare
Makes sense, thanks. I updated the fix to validate the atom-call operands up front and emit a clear diagnostic for incompatible layouts instead. |
Motivation
fly.copy_atom_callusing auniversal_copyatom with a strided!fly.memrefon one side (e.g.!fly.memref<f16, global, 4:8>) was silently lowered to a single contiguousllvm.load/llvm.store/llvm.memcpyagainst the memory-side pointer, ignoring the memref layout's stride. The result is that adjacent lanes are read or written instead of stride-spaced elements — a silent correctness bug whenever a single atom moves a non-unit-stride slice of memory.Technical Details
Bug is in
CopyOpUniversalCopyType::emitAtomCallSSAandemitAtomCall(lib/Dialect/Fly/IR/FlyUniversalOps.cpp): both paths assumed contiguous memory and emitted a singlevector<N × elemTy>load/store (or a singlellvm.memcpyfor the non-SSA memref-to-memref path).Fix:
LayoutAttrto obtain static(count, stride).count <= 1 || stride == 1: keep the existing fast path (single vector load/store, orllvm.memcpyfor the memref-to-memref path).llvm.getelementptrwith offseti * stride(in elements),applySwizzleOnPtr, thenllvm.load+llvm.insertelement(gather side) orllvm.extractelement+llvm.store(scatter side).Helpers added in the same file:
getCoalescedLeafCountAndStride,emitStridedLoadAsVector,emitStridedStoreFromVector.Behavior change worth noting: memrefs whose layout does not coalesce to a single static leaf previously went through the silently-wrong contiguous path; they now make the lowering fail explicitly. Register-promoted memrefs are unaffected because
ConvertAtomCallToSSAForm::isEligibleToPromotealready restricts them to single-leaf layouts with stride 1 or shape 1.Changes are split into two review-friendly commits:
[Test] Add reproducer for strided universal_copy in convert-fly-to-rocdl— addstests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlirwith FileCheck lines frozen to the currently-buggy IR.[Fix] Honor stride in universal copy atom lowering— fixesemitAtomCallSSA/emitAtomCalland updates the CHECK lines to the corrected IR. The diff on the test file in commit 2 is exactly the IR before/after the fix.Test Plan
tests/mlir/Transforms/convert_fly_to_rocdl_universal_copy_strided.mlirruns--fly-convert-atom-call-to-ssa-form --convert-fly-to-rocdlagainst two kernels:load_strided_global_into_register:!fly.memref<f16, global, 4:8>->!fly.memref<f16, register, 4:1>.store_register_into_strided_global: the reverse direction.llvm.getelementptr %arg0[0]/[8]/[16]/[24]withf16loads/stores and the matchingllvm.insertelement/llvm.extractelementchain, proving the stride-8 layout is honored.tests/mlir/Transforms/*.mlir(includingcanonicalize,convert-atom-call-to-ssa-form,layout_lowering,promote_regmem_to_ssa*,rewrite_func_signature) to ensure the strided path did not regress existing contiguous/register lowerings.Test Result
Submission Checklist