diff --git a/csrc/base.h b/csrc/base.h index 9255dc42c0d..9bceeac845c 100644 --- a/csrc/base.h +++ b/csrc/base.h @@ -7,7 +7,6 @@ // clang-format on #pragma once -#include #include #include #include diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index f472d41f061..c397bd6af40 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1732,9 +1732,10 @@ std::ostream& operator<<( } void SegmentedFusion::print() const { - debug() << "Segmented_Fusion Dump: -- Re-written complete fusion:{\n"; - completeFusion()->printMath(); - debug() << "} // {Re-written complete fusion}\n"; + debug() << "Segmented_Fusion Dump: -- Re-written complete fusion:{" + << std::endl; + completeFusion()->print(); + debug() << "} // {Re-written complete fusion}" << std::endl << std::endl; debug() << this << "\n"; } diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index d91fd4eda60..44f5a93510d 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -8,17 +8,14 @@ #include "host_ir/lower_to_communication.h" -#include "host_ir/container.h" -#include "ir/all_nodes.h" -#include "ir/allocation_utils.h" #include "ir/builder.h" +#include "ir/interface_nodes.h" #include "ir/internal_base_nodes.h" #include "ir/iostream.h" -#include "kernel_ir.h" +#include "logical_domain_map.h" #include "multidevice/communication.h" #include "multidevice/resharding.h" #include "multidevice/utils.h" -#include "ops/all_ops.h" namespace nvfuser { @@ -56,10 +53,11 @@ void lowerToScatter( const CommunicatorBackend backend, std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - receiver_mesh.rank() == 1, + NVF_ERROR_EQ( + receiver_mesh.rank(), + 1, "Gather only supported on a 1D mesh. Given ", - receiver_mesh); + output_tv->toString()); // Find a common device between input and receiver meshes to be the root std::vector input_devices = input_tv->getDeviceMesh().vector(); @@ -348,13 +346,16 @@ CommunicationInfo getCommunicationInfo(Expr* e) { "getCommunicationInfo should only be called when `e` is known to be a " "communication. Given: ", e); - + NVF_ERROR_EQ( + e->inputs().size(), 1, "Expected 1 input, but got ", e->toString()); auto* producer = e->inputs().at(0)->as(); + NVF_ERROR_EQ( + e->outputs().size(), 1, "Expected 1 output, but got ", e->toString()); auto* consumer = e->outputs().at(0)->as(); - std::optional communication_info = std::nullopt; - // Fill `communication_info` instead of returning the result, so we can catch - // errors when more than one DIDs have sharding changes. + std::optional communication_info = std::nullopt; + // Fill `communication_info` instead of returning the result, so we can + // catch errors when more than one DIDs have sharding changes. auto fill_communication_info = [&](CommunicationType type, IterDomain* p_sharded_id, IterDomain* c_sharded_id) { @@ -375,19 +376,23 @@ CommunicationInfo getCommunicationInfo(Expr* e) { auto consumer_pt_to_did = mapDeviceAndStreamParallelTypeToId(consumer->getLoopDomain()); + const DeviceMesh& producer_mesh = producer->getDeviceMesh(); + const DeviceMesh& consumer_mesh = consumer->getDeviceMesh(); + const bool same_mesh = producer_mesh == consumer_mesh; + for (ParallelType pt : kParallelTypeDIDs) { + if (!haveDifferentShardings(producer, consumer, {pt})) { + continue; + } + IterDomain* p_loop_did = getOrDefault(producer_pt_to_did, pt); IterDomain* c_loop_did = getOrDefault(consumer_pt_to_did, pt); if (p_loop_did == nullptr && c_loop_did == nullptr) { // Not sharded on this parallel type - continue; + NVF_THROW("Not sharded on this parallel type: ", pt); } - const DeviceMesh& producer_mesh = producer->getDeviceMesh(); - const DeviceMesh& consumer_mesh = consumer->getDeviceMesh(); - const bool same_mesh = producer_mesh == consumer_mesh; - if (e->isA()) { if (p_loop_did && !c_loop_did) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); @@ -435,7 +440,8 @@ CommunicationInfo getCommunicationInfo(Expr* e) { auto c_it = p2c_map.find(p_logical_id); NVF_ERROR( c_it != p2c_map.end(), - "Cannot find the mapped consumer logical ID for the producer logical " + "Cannot find the mapped consumer logical ID for the producer " + "logical " "ID ", p_logical_id->toString()); if (!c_it->second->isReduction()) { diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index b6867d238c9..459b2910cbe 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -186,6 +186,13 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { PropagateDirection::kBackward); } } + + if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { + debug() << std::endl + << "Fusion Transforms after " << name() << ":" << std::endl; + fusion->printTransforms(); + debug() << std::endl; + } } } // namespace nvfuser::preseg_passes diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 241c6b47e55..e0fb5efb69a 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -829,8 +829,13 @@ Val* ContiguousInnerDimensionsMapper::getContigMergeOfInnerSize( break; } - auto sharded_extent = SimplifyingIrBuilder::divExpr( - getProjectedExtent(logical_id), num_devices); + Val* sharded_extent; + if (logical_id->isDeviceDim()) { + sharded_extent = of_tv->container()->oneVal(); + } else { + sharded_extent = SimplifyingIrBuilder::divExpr( + getProjectedExtent(logical_id), num_devices); + } product_of_inner_extents = SimplifyingIrBuilder::mulExpr(product_of_inner_extents, sharded_extent); } diff --git a/tests/python/multidevice/test_alphafold3.py b/tests/python/multidevice/test_alphafold3.py new file mode 100644 index 00000000000..a18b3682808 --- /dev/null +++ b/tests/python/multidevice/test_alphafold3.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +# This file contains certain building blocks of the AlphaFold3 model. + +import pytest +import torch +from dataclasses import dataclass +from enum import Enum, auto + +import nvfuser_direct as nvfuser +from nvfuser_direct import FusionDefinition, DataType, TensorView + + +@dataclass +class ModelConfig: + c_z: int = 128 + c_hidden: int = 32 + n_heads: int = 4 + + +_DEFAULT_CONFIG = ModelConfig() + + +class Direction(Enum): + INCOMING = auto() # aka ending node + OUTGOING = auto() # aka starting node + + +def layer_norm( + fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView +) -> TensorView: + io_dtype = x.dtype() + x = fd.ops.cast(x, dtype=DataType.Float) + var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True) + y = fd.ops.sub(x, mean) + var = fd.ops.add(var, fd.define_scalar(1e-5)) + y = fd.ops.mul(y, fd.ops.rsqrt(var)) + shape = fd.ops.shape(x) + w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1]) + y = fd.ops.mul(y, w) + b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1]) + y = fd.ops.add(y, b) + y = fd.ops.cast(y, dtype=io_dtype) + return y + + +def gating( + fd: FusionDefinition, + z: TensorView, + w_p: TensorView, + z_in: TensorView, + w_g: TensorView, +) -> TensorView: + io_dtype = z.dtype() + p = fd.ops.linear(z, w_p) + g = fd.ops.linear(z_in, w_g) + g = fd.ops.sigmoid(g) + z = fd.ops.mul(p, g) + return fd.ops.cast(z, dtype=io_dtype) + + +# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates +# +# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure +# prediction with AlphaFold. Nature 596, 583–589 (2021). +# https://doi.org/10.1038/s41586-021-03819-2 +# (see Supplementary Methods 1.6.5 for details) +@pytest.mark.mpi +@pytest.mark.parametrize( + "direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower() +) +def test_triangle_updates(direction, multidevice_test): + d = multidevice_test.size + cp_size = 2 + if d % (cp_size * cp_size) != 0: + pytest.skip( + f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}." + ) + dp_size = d // (cp_size * cp_size) + + c_z = _DEFAULT_CONFIG.c_z + + with FusionDefinition() as fd: + z_in_tv = fd.define_tensor( + shape=[-1, -1, -1, c_z], + dtype=DataType.BFloat16, + contiguity=True, + ) # [b, i, j, c_z] + w_norm_in = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + b_norm_in = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_p_in = fd.define_tensor( + shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_g_in = fd.define_tensor( + shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_norm_out = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + b_norm_out = fd.define_tensor( + shape=[c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_p_out = fd.define_tensor( + shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True + ) + w_g_out = fd.define_tensor( + shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True + ) + # Masking is used in an internal implementation: http://nv/e-4 + mask_tv = fd.define_tensor( + shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True + ) # [b, i, j] + + batch_size = fd.ops.size(z_in_tv, 0) + n_tokens = fd.ops.size(z_in_tv, 1) + + z_in = layer_norm(fd, z_in_tv, w_norm_in, b_norm_in) + z = gating(fd, z_in_tv, w_p_in, z_in, w_g_in) + mask = fd.ops.broadcast_in_dim( + mask_tv, + shape=[batch_size, n_tokens, n_tokens, c_z], + broadcast_dims=[0, 1, 2], + ) + z = fd.ops.where(mask, z, 0.0) + a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z]) + b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2]) + + match direction: + case Direction.OUTGOING: + # z_out = einsum("bikc,bjkc->bijc", a, b) + a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k] + b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j] + case Direction.INCOMING: + # z_out = einsum("bkic,bkjc->bijc", a, b) + a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k] + b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j] + z = fd.ops.matmul(a, b) # [b, c, i, j] + z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c] + + einsum_out = z + + z = layer_norm(fd, z, w_norm_out, b_norm_out) + z = gating(fd, z, w_p_out, z_in, w_g_out) + fd.add_output(z) + + mesh = nvfuser.multidevice.DeviceMesh( + torch.arange(d).reshape(dp_size, cp_size, cp_size) + ) + for tv in [ + z_in_tv, + w_norm_in, + b_norm_in, + w_p_in, + w_g_in, + w_norm_out, + b_norm_out, + w_p_out, + w_g_out, + mask_tv, + einsum_out, + ]: + tv.set_device_mesh(mesh) + + for tv in [z_in_tv, mask_tv, einsum_out]: + tv.outer_split(2, cp_size) + tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x) + tv.outer_split(1, cp_size) + tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y) + tv.outer_split(0, dp_size) + tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z) + + batch_per_rank = 3 + n_tokens_per_rank = 5 + z_in_ref = torch.testing.make_tensor( + batch_per_rank * dp_size, + n_tokens_per_rank * cp_size, + n_tokens_per_rank * cp_size, + c_z, + dtype=torch.bfloat16, + device="cpu", + ) + mask_ref = torch.testing.make_tensor( + batch_per_rank * dp_size, + n_tokens_per_rank * cp_size, + n_tokens_per_rank * cp_size, + dtype=torch.bool, + device="cpu", + ) + + z_in = multidevice_test.shard_tensor(z_in_ref, z_in_tv) + w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + w_p_in = torch.testing.make_tensor( + c_z * 2, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_g_in = torch.testing.make_tensor( + c_z * 2, c_z, dtype=torch.bfloat16, device="cuda" + ) + w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda") + w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") + w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda") + mask = multidevice_test.shard_tensor(mask_ref, mask_tv) + (z_out,) = fd.execute( + [ + z_in, + w_norm_in, + b_norm_in, + w_p_in, + w_g_in, + w_norm_out, + b_norm_out, + w_p_out, + w_g_out, + mask, + ] + ) + assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z) diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index 99858ca2e41..965b3a7799f 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -29,7 +29,6 @@ def test_sizes_and_ranks(multidevice_test): @pytest.mark.mpi def test_pointwise(multidevice_test): num_devices = multidevice_test.size - mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices)) with FusionDefinition() as fd: inp_tv = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float) @@ -37,6 +36,7 @@ def test_pointwise(multidevice_test): tv2 = fd.ops.add(tv1, tv1) fd.add_output(tv2) + mesh = nvfuser.multidevice.DeviceMesh(torch.arange(num_devices)) for tv in [inp_tv, tv1, tv2]: tv.set_device_mesh(mesh) @@ -50,6 +50,63 @@ def test_pointwise(multidevice_test): torch.testing.assert_close(out.cpu(), out_ref) +@pytest.mark.mpi +def test_transpose(multidevice_test): + d = multidevice_test.size + cp_size = 2 + if d % (cp_size * cp_size) != 0: + pytest.skip( + f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}." + ) + dp_size = d // (cp_size * cp_size) + + c = 128 + with FusionDefinition() as fd: + inp_tv = fd.define_tensor( + (-1, c, -1, -1, cp_size), contiguity=True, dtype=DataType.BFloat16 + ) + out_tv = fd.ops.set(inp_tv) + fd.add_output(out_tv) + + mesh = nvfuser.multidevice.DeviceMesh( + torch.arange(d).reshape(dp_size, cp_size, cp_size) + ) + for tv in [inp_tv, out_tv]: + tv.set_device_mesh(mesh) + + inp_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y) + inp_tv.outer_split(3, cp_size) + inp_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x) + inp_tv.outer_split(0, dp_size) + inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z) + + out_tv.axis(4).parallelize(nvfuser.ParallelType.mesh_y) + out_tv.outer_split(3, cp_size) + out_tv.axis(3).parallelize(nvfuser.ParallelType.mesh_x) + out_tv.outer_split(0, dp_size) + out_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z) + out_tv.set_allocation_domain( + ( + out_tv.axis(3), + out_tv.axis(0), + out_tv.axis(1), + out_tv.axis(2), + out_tv.axis(4), + out_tv.axis(5), + out_tv.axis(6), + ), + True, + ) + + b = dp_size * 3 + s = cp_size * 5 + inp_ref = torch.randn(b, c, s, s, cp_size, dtype=torch.bfloat16) + out_ref = inp_ref + + inp = multidevice_test.shard_tensor(inp_ref, inp_tv) + fd.execute([inp]) + + class QkvFormat(Enum): BHSE = auto() BSHE = auto()