Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion csrc/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
// clang-format on
#pragma once

#include <concepts>
#include <coroutine>
#include <deque>
#include <iterator>
Expand Down
7 changes: 4 additions & 3 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1732,9 +1732,10 @@ std::ostream& operator<<(
}

void SegmentedFusion::print() const {
debug() << "Segmented_Fusion Dump: -- Re-written complete fusion:{\n";
completeFusion()->printMath();
debug() << "} // {Re-written complete fusion}\n";
debug() << "Segmented_Fusion Dump: -- Re-written complete fusion:{"
<< std::endl;
completeFusion()->print();
debug() << "} // {Re-written complete fusion}" << std::endl << std::endl;
debug() << this << "\n";
}

Expand Down
42 changes: 24 additions & 18 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@

#include "host_ir/lower_to_communication.h"

#include "host_ir/container.h"
#include "ir/all_nodes.h"
#include "ir/allocation_utils.h"
#include "ir/builder.h"
#include "ir/interface_nodes.h"
#include "ir/internal_base_nodes.h"
#include "ir/iostream.h"
#include "kernel_ir.h"
#include "logical_domain_map.h"
#include "multidevice/communication.h"
#include "multidevice/resharding.h"
#include "multidevice/utils.h"
#include "ops/all_ops.h"

namespace nvfuser {

Expand Down Expand Up @@ -56,10 +53,11 @@ void lowerToScatter(
const CommunicatorBackend backend,
std::vector<Expr*>& comms) {
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();
NVF_ERROR(
receiver_mesh.rank() == 1,
NVF_ERROR_EQ(
receiver_mesh.rank(),
1,
"Gather only supported on a 1D mesh. Given ",
receiver_mesh);
output_tv->toString());

// Find a common device between input and receiver meshes to be the root
std::vector<DeviceIdxType> input_devices = input_tv->getDeviceMesh().vector();
Expand Down Expand Up @@ -348,13 +346,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
"getCommunicationInfo should only be called when `e` is known to be a "
"communication. Given: ",
e);

NVF_ERROR_EQ(
e->inputs().size(), 1, "Expected 1 input, but got ", e->toString());
auto* producer = e->inputs().at(0)->as<TensorView>();
NVF_ERROR_EQ(
e->outputs().size(), 1, "Expected 1 output, but got ", e->toString());
auto* consumer = e->outputs().at(0)->as<TensorView>();
std::optional<CommunicationInfo> communication_info = std::nullopt;

// Fill `communication_info` instead of returning the result, so we can catch
// errors when more than one DIDs have sharding changes.
std::optional<CommunicationInfo> communication_info = std::nullopt;
// Fill `communication_info` instead of returning the result, so we can
// catch errors when more than one DIDs have sharding changes.
auto fill_communication_info = [&](CommunicationType type,
IterDomain* p_sharded_id,
IterDomain* c_sharded_id) {
Expand All @@ -375,19 +376,23 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
auto consumer_pt_to_did =
mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain());

const DeviceMesh& producer_mesh = producer->getDeviceMesh();
const DeviceMesh& consumer_mesh = consumer->getDeviceMesh();
const bool same_mesh = producer_mesh == consumer_mesh;

for (ParallelType pt : kParallelTypeDIDs) {
if (!haveDifferentShardings(producer, consumer, {pt})) {
continue;
}

IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt);
IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt);

if (p_loop_did == nullptr && c_loop_did == nullptr) {
// Not sharded on this parallel type
continue;
NVF_THROW("Not sharded on this parallel type: ", pt);
}

const DeviceMesh& producer_mesh = producer->getDeviceMesh();
const DeviceMesh& consumer_mesh = consumer->getDeviceMesh();
const bool same_mesh = producer_mesh == consumer_mesh;

if (e->isA<LoadStoreOp>()) {
if (p_loop_did && !c_loop_did) {
IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did);
Expand Down Expand Up @@ -435,7 +440,8 @@ CommunicationInfo getCommunicationInfo(Expr* e) {
auto c_it = p2c_map.find(p_logical_id);
NVF_ERROR(
c_it != p2c_map.end(),
"Cannot find the mapped consumer logical ID for the producer logical "
"Cannot find the mapped consumer logical ID for the producer "
"logical "
"ID ",
p_logical_id->toString());
if (!c_it->second->isReduction()) {
Expand Down
7 changes: 7 additions & 0 deletions csrc/preseg_passes/propagate_shardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ void PropagateShardingsPass::runPass(Fusion* fusion) {
PropagateDirection::kBackward);
}
}

if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {
debug() << std::endl
<< "Fusion Transforms after " << name() << ":" << std::endl;
fusion->printTransforms();
debug() << std::endl;
}
}

} // namespace nvfuser::preseg_passes
9 changes: 7 additions & 2 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,13 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize(
break;
}

auto sharded_extent = SimplifyingIrBuilder::divExpr(
getProjectedExtent(logical_id), num_devices);
Val* sharded_extent;
if (logical_id->isDeviceDim()) {
sharded_extent = of_tv->container()->oneVal();
} else {
sharded_extent = SimplifyingIrBuilder::divExpr(
getProjectedExtent(logical_id), num_devices);
}
product_of_inner_extents =
SimplifyingIrBuilder::mulExpr(product_of_inner_extents, sharded_extent);
}
Expand Down
225 changes: 225 additions & 0 deletions tests/python/multidevice/test_alphafold3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause


# This file contains certain building blocks of the AlphaFold3 model.

import pytest
import torch
from dataclasses import dataclass
from enum import Enum, auto

import nvfuser_direct as nvfuser
from nvfuser_direct import FusionDefinition, DataType, TensorView


@dataclass
class ModelConfig:
c_z: int = 128
c_hidden: int = 32
n_heads: int = 4


_DEFAULT_CONFIG = ModelConfig()


class Direction(Enum):
INCOMING = auto() # aka ending node
OUTGOING = auto() # aka starting node


def layer_norm(
fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView
) -> TensorView:
io_dtype = x.dtype()
x = fd.ops.cast(x, dtype=DataType.Float)
var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True)
y = fd.ops.sub(x, mean)
var = fd.ops.add(var, fd.define_scalar(1e-5))
y = fd.ops.mul(y, fd.ops.rsqrt(var))
shape = fd.ops.shape(x)
w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1])
y = fd.ops.mul(y, w)
b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1])
y = fd.ops.add(y, b)
y = fd.ops.cast(y, dtype=io_dtype)
return y


def gating(
fd: FusionDefinition,
z: TensorView,
w_p: TensorView,
z_in: TensorView,
w_g: TensorView,
) -> TensorView:
io_dtype = z.dtype()
p = fd.ops.linear(z, w_p)
g = fd.ops.linear(z_in, w_g)
g = fd.ops.sigmoid(g)
z = fd.ops.mul(p, g)
return fd.ops.cast(z, dtype=io_dtype)


# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates
#
# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
# prediction with AlphaFold. Nature 596, 583–589 (2021).
# https://doi.org/10.1038/s41586-021-03819-2
# (see Supplementary Methods 1.6.5 for details)
@pytest.mark.mpi
@pytest.mark.parametrize(
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
)
def test_triangle_updates(direction, multidevice_test):
d = multidevice_test.size
cp_size = 2
if d % (cp_size * cp_size) != 0:
pytest.skip(
f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}."
)
dp_size = d // (cp_size * cp_size)

c_z = _DEFAULT_CONFIG.c_z

with FusionDefinition() as fd:
z_in_tv = fd.define_tensor(
shape=[-1, -1, -1, c_z],
dtype=DataType.BFloat16,
contiguity=True,
) # [b, i, j, c_z]
w_norm_in = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
b_norm_in = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
w_p_in = fd.define_tensor(
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g_in = fd.define_tensor(
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_norm_out = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
b_norm_out = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
w_p_out = fd.define_tensor(
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g_out = fd.define_tensor(
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
)
# Masking is used in an internal implementation: http://nv/e-4
mask_tv = fd.define_tensor(
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
) # [b, i, j]

batch_size = fd.ops.size(z_in_tv, 0)
n_tokens = fd.ops.size(z_in_tv, 1)

z_in = layer_norm(fd, z_in_tv, w_norm_in, b_norm_in)
z = gating(fd, z_in_tv, w_p_in, z_in, w_g_in)
mask = fd.ops.broadcast_in_dim(
mask_tv,
shape=[batch_size, n_tokens, n_tokens, c_z],
broadcast_dims=[0, 1, 2],
)
z = fd.ops.where(mask, z, 0.0)
a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z])
b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2])

match direction:
case Direction.OUTGOING:
# z_out = einsum("bikc,bjkc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j]
case Direction.INCOMING:
# z_out = einsum("bkic,bkjc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j]
z = fd.ops.matmul(a, b) # [b, c, i, j]
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c]

einsum_out = z

z = layer_norm(fd, z, w_norm_out, b_norm_out)
z = gating(fd, z, w_p_out, z_in, w_g_out)
fd.add_output(z)

mesh = nvfuser.multidevice.DeviceMesh(
torch.arange(d).reshape(dp_size, cp_size, cp_size)
)
for tv in [
z_in_tv,
w_norm_in,
b_norm_in,
w_p_in,
w_g_in,
w_norm_out,
b_norm_out,
w_p_out,
w_g_out,
mask_tv,
einsum_out,
]:
tv.set_device_mesh(mesh)

for tv in [z_in_tv, mask_tv, einsum_out]:
tv.outer_split(2, cp_size)
tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
tv.outer_split(1, cp_size)
tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y)
tv.outer_split(0, dp_size)
tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z)

batch_per_rank = 3
n_tokens_per_rank = 5
z_in_ref = torch.testing.make_tensor(
batch_per_rank * dp_size,
n_tokens_per_rank * cp_size,
n_tokens_per_rank * cp_size,
c_z,
dtype=torch.bfloat16,
device="cpu",
)
mask_ref = torch.testing.make_tensor(
batch_per_rank * dp_size,
n_tokens_per_rank * cp_size,
n_tokens_per_rank * cp_size,
dtype=torch.bool,
device="cpu",
)

z_in = multidevice_test.shard_tensor(z_in_ref, z_in_tv)
w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_p_in = torch.testing.make_tensor(
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
)
w_g_in = torch.testing.make_tensor(
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
)
w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
mask = multidevice_test.shard_tensor(mask_ref, mask_tv)
(z_out,) = fd.execute(
[
z_in,
w_norm_in,
b_norm_in,
w_p_in,
w_g_in,
w_norm_out,
b_norm_out,
w_p_out,
w_g_out,
mask,
]
)
assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z)
Loading