Skip to content

Reduce-based MM+RS in MultiDeviceExecutor#5923

Open
nsarka wants to merge 10 commits intoNVIDIA:mainfrom
nsarka:nsarka/mm-rs-nvls
Open

Reduce-based MM+RS in MultiDeviceExecutor#5923
nsarka wants to merge 10 commits intoNVIDIA:mainfrom
nsarka:nsarka/mm-rs-nvls

Conversation

@nsarka
Copy link
Member

@nsarka nsarka commented Feb 5, 2026

This PR is a follow-up to Sam's Broadcast based pipeline PR. Instead of broadcasting for AG+MM, this PR handles the same flow but for reduce in the following MM+RS fusion:

class MatmulRsCollectiveBasedPipelineFusion(FusionDefinition):
    def __init__(self, dtype, m, k, n, num_devices, communication_backend):
        super().__init__()
        self.m = m
        self.k = k
        self.n = n
        self._num_devices = num_devices
        self.dtype = dtype
        self.communication_backend = communication_backend

    def definition(self) -> None:
        m, k, n, d = (
            self.m,
            self.k,
            self.n,
            self._num_devices,
        )
        self.A = self.define_tensor(
            shape=[d, d, m // d, k // d], contiguity=True, dtype=torch_dtype_to_nvfuser_dtype(self.dtype) # [didx(d), stream(d), m/d, k/d]
        )
        self.B = self.define_tensor(
            shape=[d, 1, k // d, n], contiguity=True, dtype=torch_dtype_to_nvfuser_dtype(self.dtype) # [didx(d), 1, k/d, n]
        )

        self.C_unreduced = self.ops.matmul(
            self.A, self.B # [did(d), stream(d), m/d, n]
        )

        self.C = self.ops.sum(self.C_unreduced, 0) # [Stream(r(d)), didx(d), m/d, n]

        self.add_output(self.C)

    def multidevice_schedule(self):
        mesh = nvfuser.multidevice.DeviceMesh(range(self._num_devices))
        for tv in [
            self.A,
            self.B,
            self.C_unreduced,
            self.C,
        ]:
            tv.set_device_mesh(mesh)

        self.A.axis(0).parallelize(ParallelType.mesh_x)
        self.A.axis(1).parallelize(ParallelType.stream)
        self.B.axis(0).parallelize(ParallelType.mesh_x)
        self.C_unreduced.axis(1).parallelize(ParallelType.stream)
        self.C_unreduced.axis(0).parallelize(ParallelType.mesh_x)
        self.C.axis(1).parallelize(ParallelType.mesh_x)
        self.C.axis(0).parallelize(ParallelType.stream)

The fusion gets lowered into this host_ir. The key idea is that there is no "swizzle", and the root of each reduce communication is the streamIdx.

%HostIrContainer { (T0_g___half[ideviceIdx.x0{4}, istreamIdx1{4}, iS2{128}, iS3{128}] (DeviceMesh{0 1 2 3}), T1_g___half[ideviceIdx.x4{4}, bS5{1}, iS6{128}, iS7{512}] (DeviceMesh{0 1 2 3})) -> (T3_g___half[rS13{4}, ideviceIdx.x14{4}, iS15{128}, iS16{512}] (DeviceMesh{0 1 2 3})) :

  T2_g___half[ideviceIdx.x8{4}, istreamIdx9{4}, iS10{128}, iS11{512}, rS12{128}] (DeviceMesh{0 1 2 3}) = ALLOCATE(buffer=T2_g___half[ideviceIdx.x8{4}, istreamIdx9{4}, iS10{128}, iS11{512}, rS12{128}] (DeviceMesh{0 1 2 3}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  T3_g___half[rS13{4}, ideviceIdx.x14{4}, iS15{128}, iS16{512}] (DeviceMesh{0 1 2 3}) = ALLOCATE(buffer=T3_g___half[rS13{4}, ideviceIdx.x14{4}, iS15{128}, iS16{512}] (DeviceMesh{0 1 2 3}), mem_type=global, size=262144, zero_init=false, resets_to_zero=false)
  Stream 0x1c1e0b80 = GetCurrentStream()
  FOR streamIdx in istreamIdx9{4}:
    SetCurrentStream(Stream ( streamIdx % numberOfStreams ))
    Synchronize(Stream 0x1c1e0b80)
  FOR streamIdx in istreamIdx9{4}:
    SetCurrentStream(Stream ( streamIdx % numberOfStreams ))
    T4_l___half[ideviceIdx.x17{4}, iS18{128}, iS19{128}] (DeviceMesh{0 1 2 3})
       = HirAliasSelect( T0_g___half[ideviceIdx.x0{4}, istreamIdx1{4}, iS2{128}, iS3{128}] (DeviceMesh{0 1 2 3}), axis = istreamIdx1{4}, index = streamIdx )
    T5_l___half[ideviceIdx.x20{4}, iS21{128}, iS22{512}] (DeviceMesh{0 1 2 3})
       = HirAliasSelect( T1_g___half[ideviceIdx.x4{4}, bS5{1}, iS6{128}, iS7{512}] (DeviceMesh{0 1 2 3}), axis = bS5{1}, index = streamIdx )
    T6_l___half[ideviceIdx.x23{4}, iS24{128}, iS25{512}, rS26{128}] (DeviceMesh{0 1 2 3})
       = HirAliasSelect( T2_g___half[ideviceIdx.x8{4}, istreamIdx9{4}, iS10{128}, iS11{512}, rS12{128}] (DeviceMesh{0 1 2 3}), axis = istreamIdx9{4}, index = streamIdx )
    T6_l___half[ideviceIdx.x23{4}, iS24{128}, iS25{512}, rS26{128}] (DeviceMesh{0 1 2 3})
       = matmul(T4_l___half[ideviceIdx.x17{4}, iS18{128}, iS19{128}] (DeviceMesh{0 1 2 3}),
                T5_l___half[ideviceIdx.x20{4}, iS21{128}, iS22{512}] (DeviceMesh{0 1 2 3}))
    T6_l___half[ideviceIdx.x23{4}, iS24{128}, iS25{512}, rS26{128}] (DeviceMesh{0 1 2 3})
       = HirAliasSelect( T2_g___half[ideviceIdx.x8{4}, istreamIdx9{4}, iS10{128}, iS11{512}, rS12{128}] (DeviceMesh{0 1 2 3}), axis = istreamIdx9{4}, index = streamIdx )
    T7_l___half[ideviceIdx.x27{4}, iS28{128}, iS29{512}] (DeviceMesh{0 1 2 3})
       = HirAliasSelect( T3_g___half[rS13{4}, ideviceIdx.x14{4}, iS15{128}, iS16{512}] (DeviceMesh{0 1 2 3}), axis = rS13{4}, index = streamIdx )
    Communication 48 (type=Reduce, team=(0 1 2 3), root=streamIdx, input=T6_l___half[ideviceIdx.x23{4}, iS24{128}, iS25{512}, rS26{128}] (DeviceMesh{0 1 2 3}), output=T7_l___half[ideviceIdx.x27{4}, iS28{128}, iS29{512}] (DeviceMesh{0 1 2 3}), backend=NCCL)
    Wait(Communication 48)
    SetCurrentStream(Stream 0x1c1e0b80)
    Synchronize(Stream ( streamIdx % numberOfStreams ))
} // %HostIrContainer

@nsarka nsarka self-assigned this Feb 5, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 5, 2026

Greptile Overview

Greptile Summary

Implements reduce-collective-based MM+RS (matmul + reduce-scatter) lowering for the MultiDeviceExecutor, complementing the existing P2P-based approach. The key change enables stream parallelization on reduction axes, where each stream performs a matmul slice and then reduces results using NCCL collectives with the stream index as the root.

Key changes:

  • Added new lowering path in stream_parallel_type.cpp for offset_stream_indexing_by_rank=false that lowers ReductionOp to Communication(Reduce) collective
  • Updated HirAliasSelect evaluator to handle reduction axes as no-ops (since reduced dimensions don't exist in allocated tensors)
  • Removed validation preventing stream parallelization on reduction axes
  • Added comprehensive test ReduceScatterReduceBased demonstrating the new flow

Critical issue:

  • Line 482 in stream_parallel_type.cpp contains an unsafe cast from Val* to TensorView* that will crash if the reduction input is not a tensor

Confidence Score: 2/5

  • This PR contains a critical unsafe cast that will cause runtime crashes
  • The unsafe cast at line 482 (reduction_op->in()->as<TensorView>()) is a blocking issue that will crash if the input is not a TensorView. The rest of the implementation appears sound, with proper handling of reduction axes in HirAliasSelect and appropriate test coverage.
  • Pay close attention to csrc/host_ir/pass/stream_parallel_type.cpp - the unsafe cast must be fixed before merge

Important Files Changed

Filename Overview
csrc/host_ir/pass/stream_parallel_type.cpp Adds reduce-collective-based MM+RS lowering path; contains unsafe cast on line 482 that will crash
csrc/host_ir/evaluator.cpp Handles reduction axes in HirAliasSelect by treating them as no-ops; logic appears sound
tests/cpp/test_host_ir_stream_lowering.cpp Removes Matmul_K test that checked stream parallelization on reduction axis is now allowed
tests/cpp/test_multidevice_stream_parallel_type.cpp Adds new ReduceScatterReduceBased test for MM+RS with reduce collective; test setup looks correct

Sequence Diagram

sequenceDiagram
    participant Stream0
    participant Stream1
    participant StreamN
    participant Device0
    participant Device1
    participant DeviceN
    
    Note over Stream0,StreamN: FOR streamIdx in range(D)
    Stream0->>Device0: SetCurrentStream(Stream0)
    Stream0->>Device0: HirAliasSelect(A, axis=stream, idx=0)
    Stream0->>Device0: HirAliasSelect(B, axis=stream, idx=0)
    Stream0->>Device0: matmul(A_slice, B_slice)
    Stream0->>DeviceN: Communication(Reduce, root=streamIdx=0)
    Stream0->>Stream0: Wait(Reduce)
    
    par Stream1 execution
        Stream1->>Device1: SetCurrentStream(Stream1)
        Stream1->>Device1: HirAliasSelect(A, axis=stream, idx=1)
        Stream1->>Device1: HirAliasSelect(B, axis=stream, idx=1)
        Stream1->>Device1: matmul(A_slice, B_slice)
        Stream1->>DeviceN: Communication(Reduce, root=streamIdx=1)
        Stream1->>Stream1: Wait(Reduce)
    and StreamN execution
        StreamN->>DeviceN: SetCurrentStream(StreamN)
        StreamN->>DeviceN: HirAliasSelect(A, axis=stream, idx=N)
        StreamN->>DeviceN: HirAliasSelect(B, axis=stream, idx=N)
        StreamN->>DeviceN: matmul(A_slice, B_slice)
        StreamN->>DeviceN: Communication(Reduce, root=streamIdx=N)
        StreamN->>StreamN: Wait(Reduce)
    end
    
    Note over Stream0,StreamN: All streams synchronize with main stream
Loading

@nsarka
Copy link
Member Author

nsarka commented Feb 5, 2026

My plan after this PR is to implement an NVLS-based reduce for this algorithm, as well as an NVLS-based ReduceScatter for the default no-overlap algorithm.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 220 to 224
// If no stream axis found, check the input tensor for a stream axis
// If found, use that. If not, keep the expression as is
if (stream_axis == nullptr) {
new_top_level_exprs.push_back(expr);
continue;
auto* input = expr->input(0)->as<TensorView>();
IterDomain* input_stream_axis = getStreamAxis(input->getLoopDomain());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe cast on input(0)

When stream_axis == nullptr, the new logic unconditionally does expr->input(0)->as<TensorView>() (line ~223). This will crash if the first input is not a TensorView (e.g., scalar params/indices/predicates in some host ops). Previously this path didn’t assume anything about inputs. This needs a type check (or use ir_utils::filterByType<TensorView>(expr->inputs())) before casting.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 5, 2026

Additional Comments (1)

csrc/host_ir/pass/stream_parallel_type.cpp
Dropped expr when using input stream axis

In groupStreamParallelRegions, when stream_axis == nullptr and you fall back to using the input tensor’s stream axis, you set is_input=true and then skip both the merge path and the for_loop creation via else if (!is_input) (line ~265). In the case where canMergeWithPreviousForLoop(...) is false (e.g., first such expr or previous expr not mergeable), the expression is not pushed into new_top_level_exprs at all, so it gets dropped from the IR. This will break any fusion where this fallback triggers. Consider either creating a new loop even for the input-stream-axis case or pushing expr as-is when you can’t merge.

Also appears in: same block controls all exprs that hit the fallback path in this function.

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

csrc/host_ir/pass/stream_parallel_type.cpp
Reduce path never triggers
In processForLoopBodies, the new reduce-collective-based branch is else if (!did_to_stream && stream_to_did), but stream_to_did is computed only by scanning outputs for a Stream→DID transition. A ReductionOp typically outputs a tensor without a DID-parallelized axis (the DID axis is reduced away), so stream_to_did will be false and this branch will never run; the reduction will fall through to the non-communication path, producing incorrect host IR for the intended MM+RS reduce-based pattern.

This also interacts with groupStreamParallelRegions: when the reduction output has no stream axis, you set tv=input/stream_axis=input_stream_axis and is_input=true, but then suppress loop creation (else if (!is_input)), which relies on the reduction being merged into a prior loop. If the prior loop doesn’t exist/doesn’t match, the reduction is silently dropped from new_top_level_exprs.

Also appears in the detection logic at csrc/host_ir/pass/stream_parallel_type.cpp:370-399 where stream_to_did is derived from outputs only.

body_expr);
NVF_ERROR(
body_expr->as<ReductionOp>()->getReductionOpType() == BinaryOpType::Add,
"expected a reduce operation but got ",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change error message to a more informative one.
Btw it shouldn't be a problem to support the other red ops.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (2)

csrc/host_ir/pass/stream_parallel_type.cpp
Reduction stream axis allowed

validateStreamAxis no longer rejects stream axes with IterType::Reduction, so a stream-parallel loop can now be created over a reduction ID. That breaks the pass’s own assumptions later (e.g., TensorSlicingCache::get clones IDs with cloneWithoutRFactor() and HostIrEvaluator::handle(HirAliasSelect) special-cases reduction IDs by making selects no-ops), which can silently produce incorrect slicing/binding rather than a real per-stream slice. Please restore the check (or otherwise ensure stream axes are never reduction IDs) so the pass fails fast instead of mis-lowering.


csrc/host_ir/evaluator.cpp
Reduction-axis select no-op

HostIrEvaluator::handle(HirAliasSelect) now treats selecting a reduction axis as a no-op and binds out directly to in (when indexed_id->isReduction()). But the code still later relies on a well-defined output shape (it would normally drop the selected dimension). Binding the full input tensor to an output TensorView with one fewer logical dimension can cause downstream shape/stride mismatches when later ops consume the alias-select output. This needs to ensure the bound tensor matches the expected output rank/strides, not just reuse the input tensor.

stream_axis->getIterType() == IterType::Broadcast,
"Stream axis ",
stream_axis,
" should be an iteration or broadcast axis.");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We stream parallelize the reduced axis in the sum op

@nsarka nsarka requested a review from samnordmann February 10, 2026 18:52
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@nsarka
Copy link
Member Author

nsarka commented Feb 10, 2026

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@nsarka
Copy link
Member Author

nsarka commented Feb 10, 2026

!test

@nsarka nsarka requested a review from wujingyue February 10, 2026 21:05

EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get()));
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the reduced axis can be stream parallelized, this test fails. So here I removed it.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

if_sending_to_self->elseBody().pushBack(send);
break;
if (params.offset_stream_indexing_by_rank) {
// Lower to MM + RS p2p based algorithm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is the same as before. The only difference is the indentation level

NVF_THROW(
"Unsupported communicator backend for lowering stream parallel "
"type into p2p: ",
// Lower to the MM+RS reduce-collective-based algorithm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is the core change

auto index = (indexed_id->isBroadcast() || input.size(axis) == 1)
? 0
: expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>();

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the broadcast op from the fusion. It failed here because it was trying to select on the D axis which is 1 locally

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the broadcast op from the fusion.

so why adding here the case of a broadcasted axis (which btw looks good to me) ?

It failed here because it was trying to select on the D axis which is 1 locally

so why not detect if the axis is sharded ? I think that checking that the axis is of size 1 is not correct. Firstly, if the dimension is DIDx then the symbolic size will be D and not 1. Secondly, if the axis is neither broadcast nor sharded but just happens to be of size 1, then we want to error out.

Does it make sense ?

auto index = (indexed_id->isBroadcast() || input.size(axis) == 1)
? 0
: expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the broadcast op from the fusion.

so why adding here the case of a broadcasted axis (which btw looks good to me) ?

It failed here because it was trying to select on the D axis which is 1 locally

so why not detect if the axis is sharded ? I think that checking that the axis is of size 1 is not correct. Firstly, if the dimension is DIDx then the symbolic size will be D and not 1. Secondly, if the axis is neither broadcast nor sharded but just happens to be of size 1, then we want to error out.

Does it make sense ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants