From 595423d6c69e701c690880fe592126bc904e2a61 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Tue, 3 Feb 2026 17:14:03 +0000 Subject: [PATCH 1/5] Add benchmark capabilities for ops. --- benchmark/ops/all_gather_matmul/benchmark.py | 376 ++++++++++++++++ benchmark/ops/matmul_all_gather/benchmark.py | 367 +++++++++++++++ benchmark/ops/matmul_all_reduce/benchmark.py | 378 ++++++++++++++++ .../ops/matmul_reduce_scatter/benchmark.py | 421 ++++++++++++++++++ iris/ops/__init__.py | 12 +- 5 files changed, 1547 insertions(+), 7 deletions(-) create mode 100644 benchmark/ops/all_gather_matmul/benchmark.py create mode 100644 benchmark/ops/matmul_all_gather/benchmark.py create mode 100644 benchmark/ops/matmul_all_reduce/benchmark.py create mode 100644 benchmark/ops/matmul_reduce_scatter/benchmark.py diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py new file mode 100644 index 000000000..3bc45579e --- /dev/null +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops all_gather_matmul fused operation. + +This benchmark showcases the fused All-Gather + GEMM operation where each rank +has a sharded A matrix that gets gathered, then multiplied with B. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all_gather_matmul fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension total (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="all_gather_matmul.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (all_gather_into_tensor + matmul) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + K_local = K // world_size # Sharded K dimension + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "all_gather_matmul") + json_writer.add_field("k_local", K_local) + json_writer.add_field("k_total", K) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_sharded is M x K_local, B is K x N, output is M x N + A_sharded = shmem.zeros((M, K_local), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A_sharded, same B + torch.manual_seed(123 + rank) + A_sharded_data = torch.randn((M, K_local), dtype=datatype, device=f"cuda:{rank}") + A_sharded.copy_(A_sharded_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_sharded matrices and compute expected result + A_sharded_list = [torch.zeros((M, K_local), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_sharded_list, A_sharded_data) + + # Concatenate along K dimension: A_gathered = [A_0 | A_1 | ... | A_n] + A_gathered = torch.cat(A_sharded_list, dim=1) # (M, K) + + # Expected: A_gathered @ B + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_result = torch.matmul(A_gathered, B_data) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather_matmul": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "all_gather_matmul_preamble"): + workspace = shmem.ops.all_gather_matmul_preamble( + C, + A_sharded, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather-Matmul") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather_matmul"]["start_event"].record() + shmem.ops.all_gather_matmul( + C, + A_sharded, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["all_gather_matmul"]["end_event"].record() + kernel_timing["all_gather_matmul"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather_matmul"]["start_event"].elapsed_time( + kernel_timing["all_gather_matmul"]["end_event"] + ) + kernel_timing["all_gather_matmul"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather-matmul validation passed!") + else: + shmem.error("All-gather-matmul validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["all_gather_matmul"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M * K_local * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + input_bytes = M * K_local * element_size + total_bytes = input_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather-matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_matmul_ms", + kernel_timing["all_gather_matmul"]["ms"] / kernel_timing["all_gather_matmul"]["experiments"], + ) + json_writer.add_field("all_gather_matmul_experiments", kernel_timing["all_gather_matmul"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (all_gather_into_tensor + matmul) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (all_gather_into_tensor + matmul)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_sharded = torch.randn(M, K_local, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_A_gathered = torch.zeros(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + dist.all_gather_into_tensor(pytorch_A_gathered, pytorch_A_sharded) + pytorch_C = torch.matmul(pytorch_A_gathered, pytorch_B) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch all_gather_into_tensor+matmul (M={M}, K_local={K_local}, K_total={K}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_gather/benchmark.py b/benchmark/ops/matmul_all_gather/benchmark.py new file mode 100644 index 000000000..22c914e8d --- /dev/null +++ b/benchmark/ops/matmul_all_gather/benchmark.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_all_gather fused operation. + +This benchmark showcases the fused GEMM + All-Gather operation where each rank +computes a local matmul and then gathers results along M dimension. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_all_gather fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows per rank in matrix A (M_local)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_all_gather.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_gather_into_tensor) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29529", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M_local = args["m"] # Local M dimension + M = M_local * world_size # Total M after gather + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_all_gather") + json_writer.add_field("m_local", M_local) + json_writer.add_field("m_total", M) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Create input and output tensors + # A_local is M_local x K, output is M x N (gathered) + A_local = shmem.zeros((M_local, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A_local, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M_local, K), dtype=datatype, device=f"cuda:{rank}") + A_local.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + if args["validate"]: + # Gather all A_local matrices and compute expected result + A_local_list = [torch.zeros((M_local, K), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_local_list, A_local_data) + + # Expected: [A_0 @ B; A_1 @ B; ...; A_n @ B] stacked along M + expected_tensor = shmem.zeros((M, N), dtype=datatype) + expected_parts = [] + for i, A_rank_local in enumerate(A_local_list): + C_rank_local = torch.matmul(A_rank_local, B_data) + expected_parts.append(C_rank_local) + expected_result = torch.cat(expected_parts, dim=0) + expected_tensor.copy_(expected_result) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_all_gather": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-All-Gather") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_all_gather"]["start_event"].record() + shmem.ops.matmul_all_gather( + C, + A_local, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["matmul_all_gather"]["end_event"].record() + kernel_timing["matmul_all_gather"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_all_gather"]["start_event"].elapsed_time( + kernel_timing["matmul_all_gather"]["end_event"] + ) + kernel_timing["matmul_all_gather"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-1 if datatype == torch.float16 else 1e-3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("Matmul-all-gather validation passed!") + else: + shmem.error("Matmul-all-gather validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M_local*N*K flops per rank (but total is same across all ranks) + total_flops = 2 * M_local * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-gather part + # All-gather moves (world_size - 1) * M_local * N * element_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M_local * N * element_size + total_bytes = output_bytes * (world_size - 1) + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Matmul-all-gather (M_local={M_local}, M_total={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_all_gather_ms", + kernel_timing["matmul_all_gather"]["ms"] / kernel_timing["matmul_all_gather"]["experiments"], + ) + json_writer.add_field("matmul_all_gather_experiments", kernel_timing["matmul_all_gather"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_gather_into_tensor) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_gather_into_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A_local = torch.randn(M_local, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C_local = torch.zeros(M_local, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C_local = torch.matmul(pytorch_A_local, pytorch_B) + dist.all_gather_into_tensor(pytorch_C, pytorch_C_local) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C_local = torch.matmul(pytorch_A_local, pytorch_B) + dist.all_gather_into_tensor(pytorch_C, pytorch_C_local) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_gather_into_tensor (M_local={M_local}, M_total={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_all_reduce/benchmark.py b/benchmark/ops/matmul_all_reduce/benchmark.py new file mode 100644 index 000000000..fd923e051 --- /dev/null +++ b/benchmark/ops/matmul_all_reduce/benchmark.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_all_reduce fused operation. + +This benchmark showcases the fused GEMM + All-Reduce operation and reports +achieved TFLOPS and communication bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_all_reduce fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_all_reduce.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--all_reduce_variant", + type=str, + default="two_shot", + choices=["atomic", "ring", "two_shot", "one_shot", "spinlock"], + help="All-reduce variant to use", + ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29528", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + "all_reduce_variant": args["all_reduce_variant"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_all_reduce") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("all_reduce_variant", config.all_reduce_variant) + + # Create input and output tensors + # Must use shmem.zeros() to allocate on Iris symmetric heap + A = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tensor = None + + # Fill inputs with deterministic values + # Each rank has different A, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result + # Reference: each rank computes local C = A @ B, then all_reduce + if args["validate"]: + expected_tensor = shmem.zeros((M, N), dtype=datatype) + C_local_ref = torch.matmul(A_local_data, B_data) + pytorch_output = C_local_ref.clone() + shmem.barrier() + dist.all_reduce(pytorch_output, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + expected_tensor.copy_(pytorch_output) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_all_reduce": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "matmul_all_reduce_preamble"): + workspace = shmem.ops.matmul_all_reduce_preamble( + C, + A, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-All-Reduce") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_all_reduce"]["start_event"].record() + shmem.ops.matmul_all_reduce( + C, + A, + B, + config=config, + async_op=False, + workspace=workspace, + ) + kernel_timing["matmul_all_reduce"]["end_event"].record() + kernel_timing["matmul_all_reduce"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_all_reduce"]["start_event"].elapsed_time( + kernel_timing["matmul_all_reduce"]["end_event"] + ) + kernel_timing["matmul_all_reduce"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 0.2 if datatype == torch.float16 else 0.3 + success = torch.allclose(C, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(C - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("Matmul-all-reduce validation passed!") + else: + shmem.error("Matmul-all-reduce validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_all_reduce"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_all_reduce"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"]) * 1e-3 + ) + + # Calculate bandwidth for all-reduce part + # All-reduce moves 2 * (world_size - 1) / world_size * data_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M * N * element_size + total_bytes = output_bytes * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Matmul-all-reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}, variant={args['all_reduce_variant']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_all_reduce_ms", + kernel_timing["matmul_all_reduce"]["ms"] / kernel_timing["matmul_all_reduce"]["experiments"], + ) + json_writer.add_field("matmul_all_reduce_experiments", kernel_timing["matmul_all_reduce"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_reduce) for comparison + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ops/matmul_reduce_scatter/benchmark.py b/benchmark/ops/matmul_reduce_scatter/benchmark.py new file mode 100644 index 000000000..301444f25 --- /dev/null +++ b/benchmark/ops/matmul_reduce_scatter/benchmark.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris.ops matmul_reduce_scatter fused operation. + +This benchmark showcases the fused GEMM + Reduce-Scatter operation where each rank +computes a local matmul, reduces across all ranks, and scatters tiles to ranks. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ops import FusedConfig + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark matmul_reduce_scatter fused operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in matrix A (M)") + parser.add_argument("-n", type=int, default=2048, help="Number of columns in matrix B (N)") + parser.add_argument("-k", type=int, default=131072, help="Common dimension (K)") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="matmul_reduce_scatter.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=None, help="Number of SMs for operation (auto-detect if None)") + parser.add_argument( + "--benchmark_pytorch", + action="store_true", + help="Also benchmark PyTorch (matmul + all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=256, help="Block size for M dimension") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension") + parser.add_argument("--block_size_k", type=int, default=64, help="Block size for K dimension") + parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29531", help="Initialization URL for distributed setup" + ) + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + shmem = iris.iris(args["heap_size"]) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + K = args["k"] + + # Create config with parameters + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "block_size_k": args["block_size_k"], + "group_size_m": args["group_size_m"], + } + if args["comm_sms"] is not None: + config_kwargs["num_sms"] = args["comm_sms"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + + config = FusedConfig(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("operation", "matmul_reduce_scatter") + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export actual config values to JSON (including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("block_size_k", config.block_size_k) + json_writer.add_field("group_size_m", config.group_size_m) + json_writer.add_field("num_sms", config.num_sms) + json_writer.add_field("num_xcds", config.num_xcds) + + # Calculate tile distribution + num_pid_m = (M + config.block_size_m - 1) // config.block_size_m + num_pid_n = (N + config.block_size_n - 1) // config.block_size_n + total_tiles = num_pid_m * num_pid_n + tiles_per_rank = total_tiles // world_size + start_tile = rank * tiles_per_rank + if rank == world_size - 1: + tiles_per_rank = total_tiles - start_tile + + json_writer.add_field("total_tiles", total_tiles) + json_writer.add_field("tiles_per_rank", tiles_per_rank) + + # Create input and output tensors + # Each rank computes full A @ B, but only keeps its assigned tiles + A = shmem.zeros((M, K), dtype=datatype) + B = shmem.zeros((K, N), dtype=datatype) + C = shmem.zeros((M, N), dtype=datatype) + expected_tiles = [] + + # Fill inputs with deterministic values + # Each rank has different A, same B + torch.manual_seed(123 + rank) + A_local_data = torch.randn((M, K), dtype=datatype, device=f"cuda:{rank}") + A.copy_(A_local_data) + + torch.manual_seed(456) # Same B for all ranks + B_data = torch.randn((K, N), dtype=datatype, device=f"cuda:{rank}") + B.copy_(B_data) + + # For validation: compute expected result for this rank's tiles + if args["validate"]: + # Gather all A matrices to compute expected result + A_list = [torch.zeros((M, K), dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + dist.all_gather(A_list, A_local_data) + + # Expected: sum of all (A_i @ B) for each rank i, but only for this rank's tiles + expected_full = torch.zeros((M, N), dtype=datatype, device=f"cuda:{rank}") + for A_rank in A_list: + expected_full += torch.matmul(A_rank, B_data) + + # Extract only this rank's tiles + for local_tile_idx in range(tiles_per_rank): + tile_id = start_tile + local_tile_idx + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + m_start = pid_m * config.block_size_m + m_end = min(m_start + config.block_size_m, M) + n_start = pid_n * config.block_size_n + n_end = min(n_start + config.block_size_n, N) + + expected_tiles.append( + { + "tile_id": tile_id, + "pid_m": pid_m, + "pid_n": pid_n, + "m_start": m_start, + "m_end": m_end, + "n_start": n_start, + "n_end": n_end, + "data": expected_full[m_start:m_end, n_start:n_end].clone(), + } + ) + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "matmul_reduce_scatter": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + workspace = None + + def run_experiment(): + nonlocal kernel_timing, workspace + + # Preamble if available + if hasattr(shmem.ops, "matmul_reduce_scatter_preamble"): + workspace = shmem.ops.matmul_reduce_scatter_preamble( + C, + A, + B, + config=config, + workspace=workspace, + ) + + shmem.barrier() + + torch.cuda.nvtx.range_push("Matmul-Reduce-Scatter") + with torch.cuda.stream(comm_stream): + kernel_timing["matmul_reduce_scatter"]["start_event"].record() + shmem.ops.matmul_reduce_scatter( + C, + A, + B, + async_op=False, + config=config, + workspace=workspace, + ) + kernel_timing["matmul_reduce_scatter"]["end_event"].record() + kernel_timing["matmul_reduce_scatter"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["matmul_reduce_scatter"]["start_event"].elapsed_time( + kernel_timing["matmul_reduce_scatter"]["end_event"] + ) + kernel_timing["matmul_reduce_scatter"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + C.zero_() + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 2e-1 if datatype == torch.float16 else 1e-1 + success = True + + # Validate each tile assigned to this rank + for tile_info in expected_tiles: + C_tile = C[tile_info["m_start"] : tile_info["m_end"], tile_info["n_start"] : tile_info["n_end"]] + expected_tile = tile_info["data"] + + tile_match = torch.allclose(C_tile, expected_tile, atol=atol) + if not tile_match: + max_diff = torch.abs(C_tile - expected_tile).max().item() + shmem.error( + f"Rank {rank}, tile {tile_info['tile_id']} ({tile_info['pid_m']},{tile_info['pid_n']}): " + f"Validation failed, max diff: {max_diff}" + ) + success = False + + if success: + shmem.info("Matmul-reduce-scatter validation passed!") + else: + shmem.error("Matmul-reduce-scatter validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + for k in ["matmul_reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + iris.do_bench(run_experiment, shmem.barrier, n_warmup=25, n_repeat=1) + + for k in ["matmul_reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + C.zero_() + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate TFLOPS: 2*M*N*K flops + total_flops = 2 * M * N * K + total_tflops_unit = total_flops * 1e-12 + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + tflops = total_tflops_unit / ( + (kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"]) + * 1e-3 + ) + + # Calculate bandwidth for reduce-scatter part + # Similar to all-reduce: 2 * (world_size - 1) / world_size * data_size bytes + element_size = torch.tensor([], dtype=datatype).element_size() + output_bytes = M * N * element_size + total_bytes = output_bytes * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"]) + * 1e-3 + ) + + shmem.info( + f"Matmul-reduce-scatter (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {tflops:.3f} TFLOPS, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("tflops", tflops) + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_flops", total_flops) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "matmul_reduce_scatter_ms", + kernel_timing["matmul_reduce_scatter"]["ms"] / kernel_timing["matmul_reduce_scatter"]["experiments"], + ) + json_writer.add_field( + "matmul_reduce_scatter_experiments", kernel_timing["matmul_reduce_scatter"]["experiments"] + ) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark PyTorch (matmul + all_reduce) for comparison + # Note: We use all_reduce since PyTorch's reduce_scatter has different semantics + if args["benchmark_pytorch"]: + shmem.info("Benchmarking PyTorch (matmul + all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_A = torch.randn(M, K, dtype=datatype, device=f"cuda:{rank}") + pytorch_B = torch.randn(K, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_C = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + dist.barrier() + + def run_pytorch_experiment(): + pytorch_C = torch.matmul(pytorch_A, pytorch_B) + dist.all_reduce(pytorch_C, op=dist.ReduceOp.SUM) + + pytorch_ms = iris.do_bench(run_pytorch_experiment, dist.barrier) + + # Calculate TFLOPS and bandwidth + pytorch_tflops = total_tflops_unit / (pytorch_ms * 1e-3) + pytorch_bandwidth_gbps = total_bytes_gb / (pytorch_ms * 1e-3) + + shmem.info( + f"PyTorch matmul+all_reduce (M={M}, N={N}, K={K}, world_size={world_size}, dtype={args['datatype']}): " + f"{pytorch_ms:.3f} ms, {pytorch_tflops:.3f} TFLOPS, {pytorch_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_tflops = tflops + speedup = (iris_tflops / pytorch_tflops) if pytorch_tflops > 0 else 0 + shmem.info(f"Speedup (Iris/PyTorch): {speedup:.2f}x") + + json_writer.add_field("pytorch_tflops", pytorch_tflops) + json_writer.add_field("pytorch_bandwidth_gbps", pytorch_bandwidth_gbps) + json_writer.add_field("pytorch_ms", pytorch_ms) + json_writer.add_field("iris_speedup", speedup) + + # Wait for all to finish PyTorch benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = args["init_url"] + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba51..a6ed4a659 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -141,17 +141,16 @@ def matmul_all_gather(self, output_tensor, A, B, bias=None, async_op=False, conf """ return matmul_all_gather(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) - def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None): + def matmul_reduce_scatter(self, output_tensor, A, B, async_op=False, config=None, workspace=None): """ Fused matrix multiplication and reduce-scatter. - Computes: output = reduce_scatter(A @ B + bias) along N dimension + Computes: output = reduce_scatter(A @ B) where each rank keeps assigned tiles Args: - output_tensor: Output tensor (M, N_local) where N_local = N / world_size + output_tensor: Output tensor (M, N) - will contain reduced tiles for this rank A: Input matrix A (M, K) B: Input matrix B (K, N) - bias: Optional bias vector (M,) or (N,) async_op: If False, performs barrier at end config: Optional FusedConfig for tuning workspace: Optional pre-allocated workspace @@ -160,11 +159,10 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, workspace: Updated workspace object Example: - >>> N_local = N // world_size - >>> output = shmem.zeros((M, N_local), dtype=torch.float16) + >>> output = shmem.zeros((M, N), dtype=torch.float16) >>> shmem.ops.matmul_reduce_scatter(output, A, B) """ - return matmul_reduce_scatter(self._shmem, output_tensor, A, B, bias, async_op, config, workspace) + return matmul_reduce_scatter(self._shmem, output_tensor, A, B, async_op, config, workspace) # Export public API From ef227b08acacc7534f96349e3845064db09589ea Mon Sep 17 00:00:00 2001 From: neoblizz Date: Sat, 7 Feb 2026 19:14:58 +0000 Subject: [PATCH 2/5] Merge conflicts. --- benchmark/ops/all_gather_matmul/benchmark.py | 8 + iris/iris.py | 15 +- iris/iris.py.backup | 2255 ++++++++++++++++++ iris/ops/all_gather_matmul.py.with_chunked | 521 ++++ iris/ops/config.py | 26 +- iris/ops/workspace.py | 4 + iris/x/gather.py | 2 +- tests/ops/test_all_gather_matmul.py | 21 +- 8 files changed, 2831 insertions(+), 21 deletions(-) create mode 100644 iris/iris.py.backup create mode 100644 iris/ops/all_gather_matmul.py.with_chunked diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py index 3bc45579e..20ff0c536 100644 --- a/benchmark/ops/all_gather_matmul/benchmark.py +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -61,6 +61,13 @@ def parse_args(): parser.add_argument("--group_size_m", type=int, default=1, help="Group size for M dimension tiling") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument( + "--variant", + type=str, + default="pull", + choices=["pull", "chunked"], + help="All-gather matmul variant (pull or chunked)", + ) parser.add_argument( "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" ) @@ -106,6 +113,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): "block_size_n": args["block_size_n"], "block_size_k": args["block_size_k"], "group_size_m": args["group_size_m"], + "all_gather_matmul_variant": args["variant"], } if args["comm_sms"] is not None: config_kwargs["num_sms"] = args["comm_sms"] diff --git a/iris/iris.py b/iris/iris.py index 5032a640e..9b8a3d35a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1793,17 +1793,12 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + # Vectorization hints: must be <= minimum block size used by any caller. + # (32, 32) is safe since all supported block sizes are multiples of 32. + # Largest vectorized load instruction is dwordx4 (128-bits = 8 x fp16). + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) - - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) return translated_ptr diff --git a/iris/iris.py.backup b/iris/iris.py.backup new file mode 100644 index 000000000..e8932c3c8 --- /dev/null +++ b/iris/iris.py.backup @@ -0,0 +1,2255 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Iris: Multi-GPU Communication and Memory Management Framework + +Iris is a high-performance framework that enables seamless multi-GPU programming in Triton, +enabling fine-grained communication and compute overlap natively in Triton +across multiple GPUs with SHMEM-like Remote Memory Access (RMA) capabilities. + +Key Features: +- Symmetric heap management across multiple GPUs +- High-performance atomic operations (add, cas, xchg, xor, and, or, min, max) +- Efficient load/store operations with rank-to-rank communication +- Memory allocation and deallocation utilities +- Built-in logging with rank information +- PyTorch distributed integration for distributed computing + +Example: + >>> import iris + >>> ctx = iris.iris(heap_size=2**30) # 1GB heap + >>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32) +""" + +import triton +import triton.language as tl + +from iris._distributed_helpers import ( + init_distributed, + distributed_barrier, + distributed_broadcast_scalar, + distributed_broadcast_tensor, +) +from iris.hip import ( + set_device, + get_cu_count, + count_devices, +) +from iris.symmetric_heap import SymmetricHeap +import numpy as np +import math +import torch +import logging + +# Import logging functionality from the separate logging module +from .logging import logger + + +class Iris: + """ + Main Iris class for multi-GPU communication and memory management. + + This class provides a unified interface for distributed GPU operations including + memory allocation, atomic operations, and inter-rank communication. + + Args: + heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) + + Example: + >>> ctx = iris.iris(heap_size=2**31) # 2GB heap + >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") # Rank 0 of 1 + >>> tensor = ctx.zeros(1000, 1000, dtype=torch.float32) + """ + + def __init__(self, heap_size=1 << 30): + # Initialize distributed environment + comm, cur_rank, num_ranks = init_distributed() + num_gpus = count_devices() + + gpu_id = cur_rank % num_gpus + set_device(gpu_id) + + self.comm = comm + self.num_ranks = num_ranks + self.cur_rank = cur_rank + self.gpu_id = gpu_id + self.heap_size = heap_size + + # Initialize symmetric heap + self.heap = SymmetricHeap(heap_size, gpu_id, cur_rank, num_ranks) + self.device = f"cuda:{gpu_id}" + self.heap_bases = self.heap.get_heap_bases() + + for i in range(num_ranks): + self.debug(f"GPU {i}: Heap base {hex(int(self.heap_bases[i].item()))}") + + distributed_barrier() + + # Initialize CCL interface + self.ccl = self.CCL(self) + + # Lazy initialization for ops interface + self._ops = None + + def _log_with_rank(self, level, message): + """Helper method to log with rank information injected into the record.""" + if logger.isEnabledFor(level): + record = logging.LogRecord( + name=logger.name, level=level, pathname="", lineno=0, msg=message, args=(), exc_info=None + ) + # Inject rank information into the record + record.iris_rank = self.cur_rank + record.iris_num_ranks = self.num_ranks + logger.handle(record) + + def debug(self, message): + """ + Log a debug message with rank information. + + Args: + message (str): Human-readable message to log at debug level. + + Notes: + The log record is enriched with ``iris_rank`` and ``iris_num_ranks`` so + formatters can display the originating rank and world size. + + Example: + >>> ctx = iris.iris() + >>> iris.set_logger_level(iris.DEBUG) + >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers + """ + self._log_with_rank(logging.DEBUG, message) + + def info(self, message): + """ + Log an info message with rank information. + + Args: + message (str): Human-readable message to log at info level. + + Example: + >>> ctx = iris.iris() + >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0 + """ + self._log_with_rank(logging.INFO, message) + + def warning(self, message): + """ + Log a warning message with rank information. + + Args: + message (str): Human-readable message to log at warning level. + + Example: + >>> ctx = iris.iris() + >>> ctx.warning("Memory usage is high") # [Iris] [0/1] Memory usage is high + """ + self._log_with_rank(logging.WARNING, message) + + def error(self, message): + """ + Log an error message with rank information. + + Args: + message (str): Human-readable message to log at error level. + + Example: + >>> ctx = iris.iris() + >>> ctx.error("Failed to allocate memory") # [Iris] [0/1] Failed to allocate memory + """ + self._log_with_rank(logging.ERROR, message) + + @property + def ops(self): + """ + Access fused GEMM+CCL operations. + + This property provides a namespace for high-level fused operations that combine + matrix multiplication with collective communication. Operations automatically infer + dimensions, strides, and hardware parameters from input tensors. + + Available operations: + - matmul_all_reduce: GEMM + All-Reduce + - all_gather_matmul: All-Gather + GEMM + - matmul_all_gather: GEMM + All-Gather + - matmul_reduce_scatter: GEMM + Reduce-Scatter + + Returns: + OpsNamespace: Namespace with fused operation methods + + Raises: + ImportError: If tritonBLAS is not available + + Example: + >>> ctx = iris.iris() + >>> A = ctx.randn((1024, 512), dtype=torch.float16) + >>> B = ctx.randn((512, 2048), dtype=torch.float16) + >>> output = ctx.zeros((1024, 2048), dtype=torch.float16) + >>> ctx.ops.matmul_all_reduce(output, A, B, ctx) + """ + if self._ops is None: + from iris.ops import OpsNamespace + + self._ops = OpsNamespace(self) + return self._ops + + def broadcast(self, value, source_rank=0): + """ + Broadcast a value from one rank to all ranks. + + This method automatically detects the type of value and uses the appropriate + broadcast mechanism: + - For tensors and arrays: uses efficient PyTorch distributed tensor collectives + - For scalars and other objects: uses object broadcast + + Args: + value (Any): The value to broadcast. Can be a scalar, tensor, numpy array, + or any picklable object. Only the ``source_rank`` value is used; + other ranks should pass a placeholder (e.g., ``None``). + source_rank (int): Rank id that holds the authoritative value. + + Returns: + Any: The value broadcast to all ranks. Tensors and arrays are returned as + numpy arrays; scalars and objects are returned in their original type. + + Examples: + >>> ctx = iris.iris() + >>> # Broadcasting a scalar + >>> value = 42 if ctx.cur_rank == 0 else None + >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 + >>> + >>> # Broadcasting a tensor + >>> if ctx.cur_rank == 0: + >>> data = torch.randn(10, 10) + >>> else: + >>> data = None + >>> data = ctx.broadcast(data, source_rank=0) # All ranks get the same array + """ + # Check if the value on source_rank is a tensor or array-like + if self.cur_rank == source_rank and value is not None: + # Explicitly exclude strings and non-numeric types + if isinstance(value, (str, dict, bool)): + is_tensor = False + elif isinstance(value, torch.Tensor): + is_tensor = True + elif isinstance(value, np.ndarray): + is_tensor = True + elif isinstance(value, (list, tuple)): + # Try to convert list/tuple to tensor to check if it's numeric + try: + torch.as_tensor(value) + is_tensor = True + except (TypeError, ValueError): + is_tensor = False + else: + # For other types, try to convert and check + try: + test_array = np.asarray(value) + # Check if it's a numeric dtype that torch can handle + if np.issubdtype(test_array.dtype, np.number): + torch.as_tensor(test_array) + is_tensor = True + else: + is_tensor = False + except (TypeError, ValueError): + is_tensor = False + else: + is_tensor = False + + # Broadcast the type decision to all ranks + is_tensor = distributed_broadcast_scalar(is_tensor, source_rank) + + if is_tensor: + return distributed_broadcast_tensor(value, root=source_rank) + else: + return distributed_broadcast_scalar(value, source_rank) + + def __allocate(self, num_elements, dtype): + """Allocate memory using the symmetric heap.""" + self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") + return self.heap.allocate(num_elements, dtype) + + def __parse_size(self, size): + # Handle nested tuples/lists by flattening them recursively + while len(size) == 1 and isinstance(size[0], (tuple, list)): + size = size[0] + num_elements = math.prod(size) + return size, num_elements + + def zeros_like( + self, input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format + ): + """ + Returns a tensor filled with the scalar value 0, with the same size as input, allocated on the Iris symmetric heap. + + Args: + input (Tensor): the size of input will determine size of the output tensor. + + Keyword Arguments: + dtype (torch.dtype, optional): the desired data type of returned Tensor. + Default: if None, defaults to the dtype of input. + layout (torch.layout, optional): the desired layout of returned tensor. + Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). + device (torch.device, optional): the desired device of returned tensor. + Default: if None, defaults to the device of input. Must be compatible with this Iris instance. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. + Default: torch.preserve_format. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> input_tensor = ctx.ones(2, 3) + >>> zeros_tensor = ctx.zeros_like(input_tensor) + >>> print(zeros_tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Use input's properties as defaults if not specified + if dtype is None: + dtype = input.dtype + if layout is None: + layout = input.layout + if device is None: + device = input.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Get the size from input tensor + size = input.size() + num_elements = input.numel() + + # Allocate new tensor with the same size + new_tensor = self.__allocate(num_elements, dtype) + new_tensor.zero_() + + # Reshape to match input size + new_tensor = new_tensor.reshape(size) + + # Apply the requested memory format + new_tensor = self.__apply_memory_format(new_tensor, size, memory_format, input) + + # Apply the requested layout + new_tensor = self.__apply_layout(new_tensor, layout) + + # Set requires_grad if specified + if requires_grad: + new_tensor.requires_grad_() + + return new_tensor + + def arange( + self, start=0, end=None, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False + ): + """ + Returns a 1-D tensor of size ⌈(end - start) / step⌉ with values from the interval [start, end) + taken with common difference step beginning from start. The tensor is allocated on the symmetric heap. + + Note: When using floating-point dtypes (especially reduced precision types like bfloat16), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer step is subject to floating point rounding errors when comparing + against end; to avoid inconsistency, we advise subtracting a small epsilon from end in such cases. + + Args: + start (Number, optional): the starting value for the set of points. Default: 0. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: 1. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.get_default_dtype()). + If dtype is not given, infer the data type from the other input arguments. + If any of start, end, or step are floating-point, the dtype is inferred + be the default dtype, see get_default_dtype(). Otherwise, the dtype is inferred + to be torch.int64. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] + >>> print(tensor.shape) # torch.Size([5]) + """ + self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") + + # Handle the case where only one argument is provided (end) + if end is None: + end = start + start = 0 + + # Validate inputs + if step == 0: + raise ValueError("step must be non-zero") + + # Validate step direction consistency + if step > 0 and start >= end: + raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") + elif step < 0 and start <= end: + raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") + + # Calculate the number of elements + num_elements = math.ceil((end - start) / step) + + # Infer dtype if not provided + if dtype is None: + if any(isinstance(x, float) for x in [start, end, step]): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + + target_device = tensor.device + arange_tensor = torch.arange(start, end, step, dtype=dtype, device=target_device) + + tensor[:] = arange_tensor + + tensor = self.__apply_layout(tensor, layout) + + if requires_grad: + tensor.requires_grad_() + + return tensor + + def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.zeros(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') + """ + self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with zeros + out.zero_() + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Fill with zeros + tensor.zero_() + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def randn( + self, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + ): + """ + Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 + (also called the standard normal distribution). The tensor is allocated on the Iris symmetric heap. + + .. math:: + \\text{out}_i \\sim \\mathcal{N}(0, 1) + + For complex dtypes, the tensor is i.i.d. sampled from a complex normal distribution with zero mean + and unit variance as + + .. math:: + \\text{out}_i \\sim \\mathcal{CN}(0, 1) + + This is equivalent to separately sampling the real :math:`(\\text{Re})` and imaginary :math:`(\\text{Im})` + part of :math:`\\text{out}_i` as + + .. math:: + \\text{Re}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}), \\quad \\text{Im}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}) + + The shape of the tensor is defined by the variable argument size. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type (see torch.set_default_device()). + device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. + Works only for CPU tensors. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randn(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') + """ + self.debug( + f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Generate random data and copy to out tensor + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + out.copy_(random_data) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Generate random data and copy to tensor + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + tensor.copy_(random_data) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.ones(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') + """ + self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with ones + out.fill_(1) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Fill with ones + tensor.fill_(1) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. + The tensor is allocated on the Iris symmetric heap. + + Args: + size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.full((2, 3), 3.14) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') + """ + self.debug( + f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Infer dtype from fill_value if not provided + if dtype is None: + if isinstance(fill_value, (int, float)): + if isinstance(fill_value, float): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + else: + # For other types (like tensors), use their dtype + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with the specified value + out.fill_(fill_value) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Fill with the specified value + tensor.fill_(fill_value) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): + """ + Returns a tensor filled with random numbers from a uniform distribution, allocated on the Iris symmetric heap. + + Args: + size (int or tuple of ints): the size of the output tensor. + low (float, optional): the lower bound of the uniform distribution. Default: 0.0. + high (float, optional): the upper bound of the uniform distribution. Default: 1.0. + dtype (torch.dtype, optional): the desired data type of returned tensor. Default: torch.float. + + Returns: + Tensor: A tensor filled with random numbers from a uniform distribution. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') + """ + self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") + size, num_elements = self.__parse_size(size) + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + tensor.uniform_(low, high) + return tensor.reshape(size) + + def empty( + self, + *size, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ): + """ + Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Note: + If torch.use_deterministic_algorithms() and torch.utils.deterministic.fill_uninitialized_memory are both set to True, + the output tensor is initialized to prevent any possible nondeterministic behavior from using the data as an input to an operation. + Floating point and complex tensors are filled with NaN, and integer tensors are filled with the maximum value. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. + Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. + Default: torch.contiguous_format. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.empty(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested memory format + tensor = self.__apply_memory_format(tensor, size, memory_format) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def randint( + self, *args, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False + ): + """ + Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). + The shape of the tensor is defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Note: + With the global dtype default (torch.float32), this function returns a tensor with dtype torch.int64. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): if None, this function returns a tensor with dtype torch.int64. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') + """ + self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Parse arguments to determine low, high, and size + # PyTorch randint signatures: + # randint(high, size) - where high is the upper bound and size is the shape + # randint(low, high, size) - where low and high are bounds, size is the shape + if len(args) == 2: + # randint(high, size) + high, size = args + low = 0 + elif len(args) == 3: + # randint(low, high, size) + low, high, size = args + else: + raise ValueError(f"randint expects 2 or 3 positional arguments, got {len(args)}") + + # Use default dtype if None is provided + if dtype is None: + dtype = torch.int64 + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Generate random integers using PyTorch's randint + # Use specified device or fall back to current device + target_device = device if device is not None else self.device + + # Handle generator parameter + if generator is not None: + torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=target_device) + else: + torch.randint(low, high, size, out=tensor, dtype=dtype, device=target_device) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive. + The tensor is allocated on the Iris symmetric heap. + + The values are: + (start, start + (end-start)/(steps-1), ..., start + (steps-2)*(end-start)/(steps-1), end) + + Args: + start (float or Tensor): the starting value for the set of points. If Tensor, it must be 0-dimensional. + end (float or Tensor): the ending value for the set of points. If Tensor, it must be 0-dimensional. + steps (int): size of the constructed tensor. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + Default: if None, uses the global default dtype when both start and end are real, + and corresponding complex dtype when either is complex. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] + >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') + """ + self.debug( + f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Use global default dtype if None is provided + if dtype is None: + # Check if start or end are complex numbers + start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) + end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) + + if start_is_complex or end_is_complex: + # Infer complex dtype based on default dtype + dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 + else: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse steps and extract the integer value + if isinstance(steps, (tuple, list)): + if len(steps) == 1: + # Single-element tuple/list like (5,) or [5] + steps_int = steps[0] + # Handle nested tuples like ((5,),) + if isinstance(steps_int, (tuple, list)): + steps_int = steps_int[0] + else: + # Multi-element tuple/list - use __parse_size for compatibility + size, num_elements = self.__parse_size(steps) + steps_int = num_elements + else: + # steps is a single integer + steps_int = steps + + # Ensure steps_int is an integer + steps_int = int(steps_int) + size = (steps_int,) + num_elements = steps_int + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Generate linspace using PyTorch's linspace + # Use specified device or fall back to current device + target_device = device if device is not None else self.device + torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=target_device) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def rand( + self, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + ): + """ + Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1). + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. + Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') + """ + self.debug( + f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self.__throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self.__parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self.__throw_if_invalid_output_tensor(out, num_elements, dtype) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = self.__allocate(num_elements=num_elements, dtype=dtype) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Generate random numbers using PyTorch's rand + # Use specified device (already validated and set above) + + # Handle generator parameter + if generator is not None: + torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) + else: + torch.rand(size, out=tensor, dtype=dtype, device=device) + + # Apply the requested layout + tensor = self.__apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def __deallocate(self, pointer): + pass + + def get_heap_bases(self): + """ + Return the tensor of symmetric heap base addresses for all ranks. + + Returns: + torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` + on the Iris device. Pass this to device-side Triton kernels that require + heap translation. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> heap_bases = ctx.get_heap_bases() + >>> print(heap_bases.shape) # torch.Size([num_ranks]) + """ + return self.heap_bases + + def barrier(self, stream=None, group=None): + """ + Synchronize ranks within the specified group and their CUDA devices. + + This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has + finished all queued work, then performs a distributed barrier so that all + ranks in the group reach the same point before proceeding. + + Args: + stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). + group (ProcessGroup, optional): The process group to synchronize. + If None, uses the default process group (all ranks). + + Example: + >>> ctx = iris.iris(1 << 20) + >>> ctx.barrier() # Synchronize all ranks + >>> ctx.barrier(group=my_group) # Synchronize only ranks in my_group + """ + # Wait for all GPUs to finish work + if stream is None: + torch.cuda.synchronize() + else: + stream.synchronize() + + # Distributed barrier + distributed_barrier(group=group) + + def get_device(self): + """ + Get the underlying device where the Iris symmetric heap resides. + + Returns: + torch.device: The CUDA device of Iris-managed memory. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> device = ctx.get_device() + >>> print(device) # cuda:0 + """ + return self.heap.get_device() + + def get_cu_count(self): + """ + Get the number of compute units (CUs) for the current GPU. + + Returns: + int: Number of compute units on this rank's GPU. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> cu_count = ctx.get_cu_count() + >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs + """ + return get_cu_count(self.gpu_id) + + def get_rank(self): + """ + Get this process's rank id in the distributed communicator. + + Returns: + int: Zero-based rank id of the current process. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> rank = ctx.get_rank() + >>> print(f"This is rank {rank}") # This is rank 0 + """ + return self.cur_rank + + def get_num_ranks(self): + """ + Get the total number of ranks in the distributed communicator. + + Returns: + int: World size (number of ranks). + + Example: + >>> ctx = iris.iris(1 << 20) + >>> num_ranks = ctx.get_num_ranks() + >>> print(f"Total ranks: {num_ranks}") # Total ranks: 1 + """ + return self.num_ranks + + def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): + if not self.__tensor_on_device(tensor): + raise RuntimeError( + f"The output tensor is not on the same device as the Iris instance. The Iris instance is on device {self.device} but the output tensor is on device {tensor.device}" + ) + if not self.__on_symmetric_heap(tensor): + raise RuntimeError( + f"The output tensor is not on the symmetric heap. The Iris instance is on heap base {self.heap_bases[self.cur_rank]} but the output tensor is on heap base {tensor.data_ptr()}" + ) + if tensor.numel() != num_elements: + raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") + if tensor.dtype != dtype: + raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") + + def __throw_if_invalid_device(self, device): + """ + Throw a RuntimeError if the requested device is not compatible with this Iris instance. + + Args: + device: The requested device (can be string, torch.device, or None) + + Raises: + RuntimeError: If the device is not compatible + """ + if not self.__is_valid_device(device): + raise RuntimeError( + f"Device mismatch: requested device {device} but Iris instance is on device {self.device}. " + f"Iris only supports tensors on its own device." + ) + + def __apply_memory_format( + self, tensor: torch.Tensor, size: tuple, memory_format: torch.memory_format, input_tensor: torch.Tensor = None + ): + """ + Apply the requested memory format to a tensor by setting appropriate strides. + This keeps the tensor on the symmetric heap while changing how PyTorch interprets the memory layout. + + Args: + tensor: The tensor to modify + size: The tensor's size/dimensions + memory_format: The desired memory format + input_tensor: The original input tensor (needed for preserve_format detection) + """ + if memory_format == torch.contiguous_format: + # Default format, no changes needed + return tensor + elif memory_format == torch.channels_last and len(size) == 4: + # For channels_last format: preserve shape (N, C, H, W) but change strides + # channels_last strides: [C*H*W, 1, C*W, C] for shape (N, C, H, W) + N, C, H, W = size[0], size[1], size[2], size[3] + # Keep the original shape (N, C, H, W) but use channels_last strides + tensor = self.__create_tensor_with_strides(tensor, size, (C * H * W, 1, C * W, C)) + return tensor + elif memory_format == torch.channels_last_3d and len(size) == 5: + # For channels_last_3d format: preserve shape (N, C, D, H, W) but change strides + # channels_last_3d strides: [C*D*H*W, 1, C*D*W, C*W, C] for shape (N, C, D, H, W) + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + # Keep the original shape (N, C, D, H, W) but use channels_last_3d strides + tensor = self.__create_tensor_with_strides(tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) + return tensor + elif memory_format == torch.preserve_format: + # For preserve_format, we need to detect the input tensor's memory format + # and apply the same format to the output + if input_tensor is not None: + # Check the actual memory format of the input tensor + if len(size) == 4: + # Check if input tensor is in channels_last format by examining strides + # channels_last format has strides[1] == 1 (channels dimension is contiguous) + input_strides = input_tensor.stride() + if len(input_strides) == 4 and input_strides[1] == 1: + # Input is in channels_last format, preserve it + # Use the input tensor's actual shape, not the size parameter + input_shape = input_tensor.shape + if len(input_shape) == 4: + # Input is already in channels_last format (N, H, W, C) + new_size = input_shape + # Use the input tensor's strides directly + tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) + return tensor + elif len(size) == 5: + # Check if input tensor is in channels_last_3d format + input_strides = input_tensor.stride() + if len(input_strides) == 5 and input_strides[1] == 1: + # Input is in channels_last_3d format, preserve it + # Use the input tensor's actual shape, not the size parameter + input_shape = input_tensor.shape + if len(input_shape) == 5: + # Input is already in channels_last_3d format (N, D, H, W, C) + new_size = input_shape + # Use the input tensor's strides directly + tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) + return tensor + # If no special format detected or no input tensor provided, use contiguous format + return tensor + else: + # Unsupported format or dimension combination + self.debug( + f"Warning: Memory format {memory_format} not supported for {len(size)}D tensor, using contiguous format" + ) + # For unsupported formats, return the tensor as-is (contiguous) + return tensor + + def __create_tensor_with_strides(self, original_tensor: torch.Tensor, size: tuple, strides: tuple) -> torch.Tensor: + """ + Create a new tensor with the specified strides while keeping the data on the symmetric heap. + + Args: + original_tensor: The original tensor (source of data and heap allocation) + size: The tensor's size/dimensions + strides: The desired strides for the new memory format + + Returns: + A new tensor with the specified strides, data copied from original, on the same heap + """ + + # First, create a temporary tensor with the correct strides using PyTorch + temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) + + # Handle different cases based on whether size changes and what the strides indicate + if size != original_tensor.shape: + # Size is different - this might be a format change that requires permutation + # Check if this is a channels_last format by comparing strides + if len(size) == 4: + # For channels_last: expected strides are [H*W*C, 1, W*C, C] for shape (N, H, W, C) + N, H, W, C = size[0], size[1], size[2], size[3] + expected_strides = (H * W * C, 1, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + else: + # If the size differs for other reasons, do not permute; just reshape if possible + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + elif len(size) == 5: + # For channels_last_3d: expected strides are [D*H*W*C, 1, H*W*C, W*C, C] for shape (N, D, H, W, C) + N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] + expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) + else: + # If the size differs for other reasons, do not permute; just reshape if possible + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + # For other dimensions, just try to reshape + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + # Size is the same - this is a stride-only change (like channels_last with preserved shape) + # We need to reorder the data to match the new stride pattern + if len(size) == 4: + # Check if this is channels_last format with preserved shape + N, C, H, W = size[0], size[1], size[2], size[3] + expected_strides = (C * H * W, 1, C * W, C) + if strides == expected_strides: + permuted = original_tensor + else: + permuted = original_tensor + elif len(size) == 5: + # Check if this is channels_last_3d format with preserved shape + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + expected_strides = (C * D * H * W, 1, C * D * W, C * W, C) + if strides == expected_strides: + permuted = original_tensor + else: + permuted = original_tensor + else: + permuted = original_tensor + + # Copy the permuted data to the temporary tensor + temp_tensor.copy_(permuted) + + # Now allocate a new tensor on our symmetric heap + num_elements = math.prod(size) + heap_tensor = self.__allocate(num_elements, original_tensor.dtype) + + # Reshape to the desired size + heap_tensor = heap_tensor.reshape(size) + + # Copy the data from the temporary tensor to our heap tensor + heap_tensor.copy_(temp_tensor) + + # Clean up the temporary tensor + del temp_tensor + + # Now we need to create a view with the correct strides + # We can't use as_strided directly on our heap tensor, but we can + # create a new tensor with the right strides and copy the data again + final_tensor = torch.as_strided(heap_tensor, size, strides) + + return final_tensor + + def __apply_layout(self, tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: + """ + Apply the requested layout to a tensor. + + Args: + tensor: The tensor to modify + layout: The desired layout + + Returns: + Tensor with the requested layout + """ + + if layout == torch.strided: + # Strided layout is the default - no changes needed + return tensor + else: + # Only support strided layout for now + raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") + + def __tensor_on_device(self, tensor: torch.Tensor): + # Get the Iris device from memory_pool.device + iris_device = self.get_device() + tensor_device = tensor.device + + # For CUDA devices, check if they're compatible + if tensor_device.type == "cuda" and iris_device.type == "cuda": + if iris_device.index is None: + return True + return tensor_device.index == iris_device.index + + # For non-CUDA devices, they must be exactly equal + return tensor_device == iris_device + + def __on_symmetric_heap(self, tensor: torch.Tensor): + """Check if a tensor is allocated on the symmetric heap.""" + return self.heap.on_symmetric_heap(tensor) + + def __is_valid_device(self, device) -> bool: + """ + Check if the requested device is compatible with this Iris instance. + + Args: + device: The requested device (can be string, torch.device, or None) + + Returns: + bool: True if the device is compatible, False otherwise + """ + if device is None: + return True # None means use default device + + # Convert device strings to torch.device objects for proper comparison + requested_device = torch.device(device) if isinstance(device, str) else device + iris_device = self.get_device() + + # Check if both are CUDA devices + if requested_device.type == "cuda" and iris_device.type == "cuda": + # Check if index matches or if requested is "cuda" (any index) + if requested_device.index is None: + return True + else: + return requested_device.index == iris_device.index + + # For non-CUDA devices, always return False + return False + + class CCL: + """ + Collective Communication Library (CCL) interface for Iris. + + Provides collective operations that can be called as methods on the Iris instance. + Example usage: + >>> shmem = iris.iris() + >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + """ + + def __init__(self, iris_instance): + """ + Initialize CCL with a reference to the parent Iris instance. + + Args: + iris_instance: The parent Iris instance + """ + self._iris = iris_instance + + def all_to_all(self, output_tensor, input_tensor, group=None, async_op=False, config=None): + """ + All-to-all collective operation. + + Each rank sends a tensor chunk to each other rank and receives + a tensor chunk from each other rank. Input/output tensors should have + shape (M, N * world_size) where each chunk of N columns corresponds to one rank. + + Args: + output_tensor: Output tensor of shape (M, N * world_size) + input_tensor: Input tensor of shape (M, N * world_size) + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_to_all(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_to_all import all_to_all as _all_to_all + + _all_to_all(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) + + def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, config=None): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + + Example: + >>> shmem = iris.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, group=group, async_op=async_op, config=config) + + def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): + """ + Prepare reusable workspace for all-reduce. + + Args: + output_tensor: Output tensor that will receive the reduced data. + input_tensor: Input tensor providing the local contribution. + config: Optional Config describing variant parameters. + workspace: Optional existing workspace to update/reuse. + + Returns: + Workspace object that can be passed to ``all_reduce``. + """ + from iris.ccl.all_reduce import all_reduce_preamble as _all_reduce_preamble + + return _all_reduce_preamble( + output_tensor, + input_tensor, + self._iris, + config=config, + workspace=workspace, + ) + + def all_reduce( + self, output_tensor, input_tensor, op=None, group=None, async_op=False, config=None, workspace=None + ): + """ + All-reduce collective operation. + + Each rank has a local input tensor, and all ranks compute the sum of all + input tensors. The result is written to output_tensor on all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain sum of all inputs + input_tensor: Input tensor of shape (M, N) - local rank's partial data + op: Reduction operation to apply. Currently only ReduceOp.SUM is supported. + Default: ReduceOp.SUM. + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Set config.all_reduce_variant to choose variant: "atomic", "ring", or "two_shot" + workspace: Optional workspace prepared by ``all_reduce_preamble`` to + reuse internal buffers across invocations. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.all_reduce(output_tensor, input_tensor) + + >>> # Custom configuration with ring variant + >>> from iris.ccl import Config + >>> config = Config(all_reduce_variant="ring") + >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + + >>> # Two-shot variant with block distribution + >>> config = Config(all_reduce_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_reduce(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_reduce import all_reduce as _all_reduce + from iris.ccl import ReduceOp + + # Default to SUM if not specified + if op is None: + op = ReduceOp.SUM + + return _all_reduce( + output_tensor, + input_tensor, + self._iris, + op=op, + group=group, + async_op=async_op, + config=config, + workspace=workspace, + ) + + def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async_op=False, config=None): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + op: Reduction operation to apply. Currently only ReduceOp.SUM is supported. + Default: ReduceOp.SUM. + group: ProcessGroup or None. If None, uses all ranks in shmem context. + Default: None. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + from iris.ccl import ReduceOp + + # Default to SUM if not specified + if op is None: + op = ReduceOp.SUM + + _reduce_scatter( + output_tensor, input_tensor, self._iris, op=op, group=group, async_op=async_op, config=config + ) + + +@triton.jit +def __translate(ptr, from_rank, to_rank, heap_bases): + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + # convert to int to compute difference + ptr_int = tl.cast(ptr, tl.uint64) + # Find the offset from from_rank heap + offset = ptr_int - from_base + # Byte cast for byte offset addition + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + # Find the offset into the to_rank heap + translated_ptr_byte = to_base_byte + offset + # Cast to_base back to pointer type + translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) + + # Optimization to vectorize the load/store + # We can't do this in general because we don't know the shape of the tensor or block sizes + # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + + # 0 You can use this if your block sizes are multiples of 32. + # Largest vectorized load instruction is dwordx4 (128-bits) + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) + + # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) + # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) + return translated_ptr + + +@triton.jit +def load(pointer, to_rank, from_rank, heap_bases, mask=None): + """ + Loads a value from the specified rank's memory location. + + This function performs a memory read operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and loading + data from the target memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local load operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. + from_rank (int): The rank ID from which to read the data. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + + Returns: + Block: The loaded value from the target memory location. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Load data from rank 1's memory into the current rank + >>> cur_rank = 0 # Current rank + >>> remote_rank = 1 # Remote rank to load from + >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) + >>> return data + """ + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + result = tl.load(translated_ptr, mask=mask) + return result + + +@triton.jit +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): + """ + Writes data to the specified rank's memory location. + + This function performs a memory write operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and storing + the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local store operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + value (Block): The tensor of elements to be stored. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the data will be written. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Store value 42 into rank 1's heap from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> value = 42 + >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + tl.store(translated_ptr, value, mask=mask) + + +@triton.jit +def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): + """ + Copies data from the specified rank's memory into the destination rank's memory. + This function performs the transfer by translating `src_ptr` from the `from_rank`'s address + space to the `to_rank`'s address space, performing a masked load from the translated + source, and storing the loaded data to `dst_ptr` in the `to_rank` memory location. + If `from_rank` and `to_rank` are the same, this function performs a local copy operation. + It is undefined behaviour if neither `from_rank` nor `to_rank` is the `cur_rank`. + + Args: + src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s local memory from which to read data. + dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `to_rank`'s local memory where the data will be written. + from_rank (int): The rank ID that owns `src_ptr` (source rank). + to_rank (int): The rank ID that will receive the data (destination rank). + cur_rank (int): The rank ID issuing the copy operation. Must be either `from_rank` or `to_rank`. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.copy(remote_ptr, local_ptr, from_rank, to_rank, to_rank, heap_bases) + """ + + cur_base = tl.load(heap_bases + cur_rank) + + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + + src_ptr_int = tl.cast(src_ptr, tl.uint64) + src_offset = src_ptr_int - cur_base + + dst_ptr_int = tl.cast(dst_ptr, tl.uint64) + dst_offset = dst_ptr_int - cur_base + + from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8)) + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + + translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) + translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + + data = tl.load(translated_src, mask=mask) + tl.store(translated_dst, data, mask=mask) + + +@triton.jit +def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): + """ + Copies data from the specified rank's memory to the current rank's local memory. + + This function performs a memory read operation by translating the `from_ptr` + from the current rank's address space to the `from_rank`'s address space, loading data + from the `from_rank` memory location, and storing it to the local `to_ptr`. + If the `from_rank` is the same as the current rank, this function performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. Must be the current rank where the pointer is local. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory where the data will be stored. + from_rank (int): The `from_rank` ID from which to read the data. + to_rank (int): The current rank ID where the data will be stored. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(remote_ptr, local_ptr, heap_bases): + >>> from_rank = 1 + >>> to_rank = 0 + >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) + """ + translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) + + data = tl.load(translated_from_ptr, mask=mask) + + tl.store(to_ptr, data, mask=mask) + + +@triton.jit +def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): + """ + Copies data from the current rank's local memory to the specified rank's memory. + This function performs a memory write operation by loading data from the current + rank's `from_ptr`, translating the `to_ptr` from the current rank's address + space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. + If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + from_rank (int): The current rank ID from which to read the data. + to_rank (int): The `to_rank` ID to which the data will be written. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> @triton.jit + >>> def kernel(local_ptr, remote_ptr, heap_bases): + >>> from_rank = 0 + >>> to_rank = 1 + >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) + """ + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + + data = tl.load(from_ptr, mask=mask) + + tl.store(translated_to_ptr, data, mask=mask) + + +@triton.jit +def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic add at the specified rank's memory location. + + This function performs an atomic addition operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + adding the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic addition operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically add 5 to rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> increment = 5 + >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Atomically subtracts data from the specified rank's memory location. + + This function performs an atomic subtraction operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + subtracting the provided data from the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic subtraction operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block): The tensor of elements to be subtracted atomically. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + + Returns: + Block: The value at the memory location before the atomic subtraction. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically subtract 3 from rank 2's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 2 # Remote rank (destination) + >>> decrement = 3 + >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): + """ + Atomically compares and exchanges the specified rank's memory location. + + This function performs an atomic compare-and-swap operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + comparing the current value with the expected value, then writing the new value if they match. + If the `from_rank` and `to_rank` are the same, this function performs a local atomic compare-and-swap operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + cmp (Block): The expected value to be compared with the current value at the memory location. + val (Block): The new value to be written if the compare succeeds. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + + Returns: + Block: The value contained at the memory location before the atomic operation attempt. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Compare-and-swap on rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> expected = 0 + >>> new_val = 42 + >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) + + +@triton.jit +def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic exchange at the specified rank's memory location. + + This function performs an atomic exchange operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + exchanging the current value with the provided new value. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic exchange operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Exchange value with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_value = 99 + >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic xor at the specified rank's memory location. + + This function performs an atomic xor operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + xoring the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic xor operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically XOR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xFF + >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic and at the specified rank's memory location. + + This function performs an atomic and operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + anding the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic and operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically AND with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0x0F + >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic or at the specified rank's memory location. + + This function performs an atomic or operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + oring the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic or operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically OR with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> mask_val = 0xF0 + >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic min at the specified rank's memory location. + + This function performs an atomic min operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + performing the min on the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic min operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find minimum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 10 + >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +@triton.jit +def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): + """ + Performs an atomic max at the specified rank's memory location. + + This function performs an atomic max operation by translating the pointer + from the `from_rank`'s address space to the `to_rank`'s address space and atomically + performing the max on the provided data to the `to_rank` memory location. If the `from_rank` and `to_rank` are the same, + this function performs a local atomic max operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local. + to_rank (int): The rank ID to which the atomic operation will be performed. + heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> @triton.jit + >>> def kernel(ptr, heap_bases): + >>> # Atomically find maximum with rank 1's memory from rank 0 + >>> cur_rank = 0 # Current rank (source) + >>> remote_rank = 1 # Remote rank (destination) + >>> new_val = 100 + >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) + """ + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + +def iris(heap_size=1 << 30): + """ + Create and return an Iris instance with the specified heap size. + + Args: + heap_size (int): Size of the heap in bytes. Defaults to 1GB. + + Returns: + Iris: An initialized Iris instance. + + Example: + >>> import iris + >>> iris_ctx = iris.iris(2**30) # 1GB heap + >>> tensor = iris_ctx.zeros(1024, 1024) + """ + return Iris(heap_size) diff --git a/iris/ops/all_gather_matmul.py.with_chunked b/iris/ops/all_gather_matmul.py.with_chunked new file mode 100644 index 000000000..ddc03d027 --- /dev/null +++ b/iris/ops/all_gather_matmul.py.with_chunked @@ -0,0 +1,521 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM operation using pull pattern. + +Each rank has a column-sharded input A_sharded (M x K_local). +This operation computes C = all_gather(A_sharded) @ B by pulling +tiles from remote ranks on-demand during GEMM computation. +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from tritonblas.kernels.stages.algorithms.binary import add_vector +from tritonblas.kernels.stages.algorithms.unary import convert_dtype + +from .config import FusedConfig +from .workspace import FusedWorkspace + + +@triton.jit() +def _fused_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + K_local: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_bias: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """Fused all-gather + GEMM kernel using pull pattern.""" + pid = tl.program_id(0) + + # Handle multi-XCD devices + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + + # Persistent loop over output tiles + for tile_id in range(pid, total_tiles, NUM_SMS): + # Compute tile coordinates with swizzling + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Compute row and column indices + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + # Create DeviceContext and TensorView for gather operations + ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) + src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) + + # Loop over all ranks to pull and accumulate + for source_rank_id in range(world_size): + loop_k_local = tl.cdiv(K_local, BLOCK_SIZE_K) + if not EVEN_K: + loop_k_local -= 1 + + # Loop over K dimension for this rank's shard + for k_block_idx in range(0, loop_k_local): + k_offset = k_block_idx * BLOCK_SIZE_K + + # Create tile view for this K block + tile_k = k_offset // BLOCK_SIZE_K + k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) + + # Pull A tile from source_rank_id using gather primitive + a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) + + # Load B tile + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_global = (source_rank_id * K_local) + rk_local + B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + + # Accumulate + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Handle remaining K elements if not evenly divisible + if not EVEN_K: + k_offset = loop_k_local * BLOCK_SIZE_K + tile_k = k_offset // BLOCK_SIZE_K + k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) + + # Pull A tile from source_rank_id using gather primitive + a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) + + rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_global = (source_rank_id * K_local) + rk_local + rk_global_mask = rk_global < K + B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_global_mask[:, None], other=0.0) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Add bias if provided using tritonBLAS + if BIAS: + bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = add_vector(acc, bias_vector, QUANTIZED=False) + + # Convert to output dtype using tritonBLAS + c = convert_dtype(acc, C.type.element_ty) + + # Store result (manual for now, tritonBLAS store has issues with our indices) + C_ptr = ( + C + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm + + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn + ) + mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N + ) + tl.store(C_ptr, c, mask=mask) + + +@triton.jit() +def _fused_chunked_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + temp_buffer, # Temporary buffer: BLOCK_M x K x num_tiles + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + K_local: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_bias: tl.constexpr, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """ + Fused all-gather + GEMM kernel using chunked/buffered pattern. + + This variant pre-gathers all of A into a temporary buffer before computing GEMM. + Eliminates the world_size loop by using iris.x.all_gather upfront. + + Memory layout: + - temp_buffer: BLOCK_M x K x num_tiles (stores gathered A for each tile) + - Each program gathers its M-tile of A, then does GEMM + """ + pid = tl.program_id(0) + + # Handle multi-XCD devices + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + + # Persistent loop over output tiles + for tile_id in range(pid, total_tiles, NUM_SMS): + # Compute tile coordinates with swizzling + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + # Compute row and column indices + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Buffer pointer for this tile: BLOCK_M x K for this pid_m + buffer_ptr = temp_buffer + tile_id * BLOCK_SIZE_M * K + + # Step 1: Pre-gather entire M-tile of A (BLOCK_M x K) + # Create DeviceContext and TensorView for gather operations + ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) + src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) + + # Gather K-tiles from all ranks + for source_rank_id in range(world_size): + k_start = source_rank_id * K_local + # Loop over K dimension in blocks + for k_local_idx in range(0, K_local, BLOCK_SIZE_K): + k_global = k_start + k_local_idx + rk = k_global + tl.arange(0, BLOCK_SIZE_K) + rk_mask = rk < K + + tile_k = k_local_idx // BLOCK_SIZE_K + k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) + + # Pull A tile from source_rank_id + a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) + + # Store in buffer + buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] + tl.store(buffer_A_ptr, a, mask=rk_mask[None, :]) + + # Step 2: Standard GEMM from buffer + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + # Loop over K dimension + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if EVEN_K: + for k_block_idx in range(loop_k): + k_offset = k_block_idx * BLOCK_SIZE_K + + # Load A from temp buffer + rk = k_offset + tl.arange(0, BLOCK_SIZE_K) + buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] + a = tl.load(buffer_A_ptr) + + # Load B tile + B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + + # Accumulate + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + else: + # Handle case where K is not evenly divisible by BLOCK_SIZE_K + for k_block_idx in range(loop_k): + k_offset = k_block_idx * BLOCK_SIZE_K + + # Load A from temp buffer + rk = k_offset + tl.arange(0, BLOCK_SIZE_K) + rk_mask = rk < K + buffer_A_ptr = buffer_ptr + rm[:, None] * K + rk[None, :] + a = tl.load(buffer_A_ptr, mask=rk_mask[None, :], other=0.0) + + # Load B tile + B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_mask[:, None], other=0.0) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + # Convert accumulator and add bias + c = convert_dtype(acc, C.type.element_ty) + if BIAS: + bias_offset = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) * stride_bias + bias_val = tl.load(bias_ptr + bias_offset) + c = add_vector(c, bias_val, 0) + + # Store result + C_ptr = ( + C + + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm + + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn + ) + mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N + ) + tl.store(C_ptr, c, mask=mask) + + +def all_gather_matmul_preamble( + shmem, + A_sharded: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, +) -> FusedWorkspace: + """Allocate workspace for all_gather_matmul (buffer needed for chunked variant).""" + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + + expected_K = world_size * K_local + assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" + + # Detect hardware configuration + device = A_sharded.device + if config.num_sms is None: + import iris.hip + num_sms = iris.hip.get_cu_count(device.index) + else: + num_sms = config.num_sms + + if config.num_xcds == 1: + # Auto-detect XCDs if default value is used + import iris.hip + num_xcds = iris.hip.get_num_xcc(device.index) + else: + num_xcds = config.num_xcds + + # Allocate temporary buffer for chunked variant + aux_buffer = None + if config.all_gather_matmul_variant == "chunked": + # Calculate grid size to determine buffer size + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + num_tiles = num_tiles_m * num_tiles_n + + # Allocate buffer: BLOCK_M x K x num_tiles + buffer_size = config.block_size_m * K * num_tiles + aux_buffer = torch.empty(buffer_size, dtype=A_sharded.dtype, device=device) + + return FusedWorkspace( + operation="all_gather_matmul", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + num_sms=num_sms, + num_xcds=num_xcds, + variant=config.all_gather_matmul_variant, + aux_buffer=aux_buffer, + prepared=True, + ) + + +def all_gather_matmul( + shmem, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, +) -> FusedWorkspace: + """Fused all-gather and matrix multiplication using pull pattern.""" + if config is None: + config = FusedConfig() + + M, K_local = A_sharded.shape + K, N = B.shape + world_size = shmem.get_num_ranks() + rank = shmem.get_rank() + + expected_K = world_size * K_local + assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" + assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" + + # Validate problem size against block sizes + assert M >= config.block_size_m, ( + f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." + ) + assert K_local >= config.block_size_k, ( + f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " + f"Use smaller block sizes for small problems." + ) + assert N >= config.block_size_n, ( + f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." + ) + + if workspace is None: + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) + + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + # Get hardware configuration from workspace + num_sms = workspace.num_sms + num_xcds = workspace.num_xcds + + even_k = K_local % config.block_size_k == 0 + + # Use SM-based grid (persistent kernels) + grid = (num_sms,) + + # Select kernel variant based on config + if config.all_gather_matmul_variant == "chunked": + # Chunked variant: pre-gather into buffer, then GEMM + assert workspace.aux_buffer is not None, "Chunked variant requires aux_buffer in workspace" + _fused_chunked_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, # Temporary buffer + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + num_xcds, + use_bias, + even_k, + config.allow_tf32, + ) + else: + # Pull variant (default): on-demand pull from remote ranks + _fused_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + num_xcds, + use_bias, + even_k, + config.allow_tf32, + ) + + if not async_op: + shmem.barrier() + + return workspace diff --git a/iris/ops/config.py b/iris/ops/config.py index 3ca085c31..77c0b5ab9 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -19,10 +19,10 @@ class FusedConfig: but users can override specific settings for performance tuning. GEMM Parameters: - block_size_m: Block size for M dimension (rows). Default: 256. - block_size_n: Block size for N dimension (columns). Default: 64. + block_size_m: Block size for M dimension (rows). Default: 128. + block_size_n: Block size for N dimension (columns). Default: 256. block_size_k: Block size for K dimension (reduction). Default: 64. - group_size_m: Group size for M dimension tiling. Default: 1. + group_size_m: Group size for M dimension tiling. Default: 4. num_sms: Number of SMs to use. If None, auto-detects from device. Default: None. num_xcds: Number of XCDs (chiplets). Default: 1. chunk_size: Chunk size for chiplet transform. Default: 1. @@ -32,8 +32,12 @@ class FusedConfig: CCL Parameters (for operations that need collective communication): all_reduce_variant: All-reduce algorithm variant. Options: "atomic", "ring", - "one_shot", "two_shot", "spinlock". Default: "one_shot". + "one_shot", "two_shot", "spinlock". Default: "two_shot". all_reduce_num_rings: Number of concurrent rings (for ring variant). Default: 1. + all_gather_matmul_variant: All-gather + matmul algorithm variant. Options: + "pull" (on-demand pull from remote ranks), + "chunked" (pre-gather into buffer then GEMM). + Default: "pull". Example: >>> # Use defaults @@ -47,10 +51,10 @@ class FusedConfig: """ # GEMM parameters - block_size_m: int = 256 - block_size_n: int = 64 + block_size_m: int = 128 + block_size_n: int = 256 block_size_k: int = 64 - group_size_m: int = 1 + group_size_m: int = 4 num_sms: Optional[int] = None # Auto-detect if None num_xcds: int = 1 chunk_size: int = 1 @@ -61,6 +65,7 @@ class FusedConfig: # CCL-specific parameters all_reduce_variant: str = "two_shot" # atomic, ring, one_shot, two_shot, spinlock all_reduce_num_rings: int = 1 + all_gather_matmul_variant: str = "pull" # pull, chunked def validate(self, world_size: Optional[int] = None): """ @@ -102,3 +107,10 @@ def validate(self, world_size: Optional[int] = None): if self.all_reduce_num_rings <= 0: raise ValueError(f"all_reduce_num_rings must be positive, got {self.all_reduce_num_rings}") + + # Validate all_gather_matmul_variant + valid_ag_variants = ["pull", "chunked"] + if self.all_gather_matmul_variant not in valid_ag_variants: + raise ValueError( + f"all_gather_matmul_variant must be one of {valid_ag_variants}, got {self.all_gather_matmul_variant}" + ) diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index a9c7cb616..9328e9f9e 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -38,6 +38,10 @@ class FusedWorkspace: world_size: int = 1 variant: str = "" + # Hardware configuration (detected in preamble) + num_sms: Optional[int] = None # Number of streaming multiprocessors + num_xcds: int = 1 # Number of XCDs/chiplets + # Temporary buffers (allocated as needed) aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9c..51f489a03 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -52,7 +52,7 @@ def gather( if source_rank == ctx.rank: # Local load - tile_data = tl.load(src_tile_ptr, mask=mask, other=0.0) + tile_data = tl.load(src_tile_ptr, mask=mask) else: # Remote load using RMA tile_data = iris.load( diff --git a/tests/ops/test_all_gather_matmul.py b/tests/ops/test_all_gather_matmul.py index 193505011..7dceea126 100644 --- a/tests/ops/test_all_gather_matmul.py +++ b/tests/ops/test_all_gather_matmul.py @@ -28,7 +28,14 @@ (256, 64, 128), ], ) -def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): +@pytest.mark.parametrize( + "variant", + [ + "pull", + "chunked", + ], +) +def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N, variant): """Test all_gather_matmul against torch all_gather + matmul.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") @@ -77,12 +84,20 @@ def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): # Run fused all_gather + matmul using shmem.ops API from iris.ops.config import FusedConfig + if rank == 0: + print(f"\n[Test] Testing variant={variant}, M={M}, K_local={K_local}, N={N}, dtype={dtype}") + # Use appropriate block sizes based on problem size # For small problems, use smaller blocks if M <= 256 or K_local <= 64 or N <= 128: - config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + config = FusedConfig( + block_size_m=64, + block_size_n=64, + block_size_k=32, + all_gather_matmul_variant=variant, + ) else: - config = FusedConfig() + config = FusedConfig(all_gather_matmul_variant=variant) # Validate config against problem size assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" From f132cebf3c4202d56da4e81e973f5811fb33d7c5 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Sat, 7 Feb 2026 20:13:20 +0000 Subject: [PATCH 3/5] Up the tritonBLAS commit. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 18e71badb..025337641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "numpy", "requests", "ruff", - "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@df58476a4520b72495a3f03f911368a184126568", + "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@cd119279f3df543a558aa6d2cd4a3daed0b1ec7a", ] From 1628a6192b72f5120d3ec78665c7f9f5430fd646 Mon Sep 17 00:00:00 2001 From: neoblizz Date: Tue, 10 Feb 2026 00:03:37 +0000 Subject: [PATCH 4/5] ... --- benchmark/ops/all_gather_matmul/benchmark.py | 20 ++++--------- iris/iris.py | 4 +-- iris/ops/all_gather_matmul.py | 31 ++++++++++++++++---- iris/ops/config.py | 6 ++-- iris/ops/workspace.py | 6 ++++ 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/benchmark/ops/all_gather_matmul/benchmark.py b/benchmark/ops/all_gather_matmul/benchmark.py index 20ff0c536..ae0443e6d 100644 --- a/benchmark/ops/all_gather_matmul/benchmark.py +++ b/benchmark/ops/all_gather_matmul/benchmark.py @@ -18,6 +18,7 @@ from examples.common.utils import JSONWriter import iris +from iris.ops.all_gather_matmul import all_gather_matmul_preamble from iris.ops import FusedConfig torch.manual_seed(123) @@ -65,8 +66,8 @@ def parse_args(): "--variant", type=str, default="pull", - choices=["pull", "chunked"], - help="All-gather matmul variant (pull or chunked)", + choices=["pull", "chunked", "push", "pipelined_pull"], + help="All-gather matmul variant", ) parser.add_argument( "--init_url", type=str, default="tcp://127.0.0.1:29530", help="Initialization URL for distributed setup" @@ -181,20 +182,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): }, } - workspace = None + # Pre-allocate workspace once (important for push variant which needs large buffers) + workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) def run_experiment(): - nonlocal kernel_timing, workspace - - # Preamble if available - if hasattr(shmem.ops, "all_gather_matmul_preamble"): - workspace = shmem.ops.all_gather_matmul_preamble( - C, - A_sharded, - B, - config=config, - workspace=workspace, - ) + nonlocal kernel_timing shmem.barrier() diff --git a/iris/iris.py b/iris/iris.py index 9b8a3d35a..21aaddd8a 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1796,8 +1796,8 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Vectorization hints: must be <= minimum block size used by any caller. # (32, 32) is safe since all supported block sizes are multiples of 32. # Largest vectorized load instruction is dwordx4 (128-bits = 8 x fp16). - translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) + # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + # translated_ptr = tl.max_contiguous(translated_ptr, (32, 32)) return translated_ptr diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 5d700206c..0dad98aee 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -17,6 +17,7 @@ import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext +from tritonblas.kernels.stages.indexing.pid_transforms import chiplet_transform_chunked from .config import FusedConfig from .workspace import FusedWorkspace @@ -164,7 +165,7 @@ def all_gather_matmul_preamble( B: torch.Tensor, config: Optional[FusedConfig] = None, ) -> FusedWorkspace: - """Allocate workspace for all_gather_matmul (none needed for pull pattern).""" + """Allocate workspace for all_gather_matmul.""" if config is None: config = FusedConfig() @@ -175,14 +176,27 @@ def all_gather_matmul_preamble( expected_K = world_size * K_local assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - return FusedWorkspace( + ws = FusedWorkspace( operation="all_gather_matmul", shape=(M, N, K), dtype=A_sharded.dtype, world_size=world_size, + variant=config.all_gather_matmul_variant, prepared=True, ) + # Allocate push variant workspace + if config.all_gather_matmul_variant == "push": + num_m_tiles = (M + config.block_size_m - 1) // config.block_size_m + num_k_tiles = (K_local + config.block_size_k - 1) // config.block_size_k + ws.a_inbox = shmem.zeros((world_size, M, K_local), dtype=A_sharded.dtype) + ws.signal_flags = shmem.zeros( + (world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32 + ) + shmem.barrier() + + return ws + def all_gather_matmul( shmem, @@ -245,10 +259,15 @@ def all_gather_matmul( even_k = K_local % config.block_size_k == 0 num_k_blocks_local = (K_local + config.block_size_k - 1) // config.block_size_k - # Launch single fused kernel - grid = (num_sms,) - _fused_all_gather_matmul_kernel[grid]( - A_sharded, + variant = config.all_gather_matmul_variant + + if variant == "pull": + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + num_tiles = num_tiles_m * num_tiles_n + # grid = (num_tiles,) + grid = (num_sms,) + _fused_all_gather_matmul_kernel[grid](A_sharded, B, output_tensor, bias_ptr, diff --git a/iris/ops/config.py b/iris/ops/config.py index 77c0b5ab9..a92925035 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -54,9 +54,9 @@ class FusedConfig: block_size_m: int = 128 block_size_n: int = 256 block_size_k: int = 64 - group_size_m: int = 4 + group_size_m: int = 1 num_sms: Optional[int] = None # Auto-detect if None - num_xcds: int = 1 + num_xcds: int = 8 chunk_size: int = 1 cache_modifier_a: str = ".ca" cache_modifier_b: str = ".ca" @@ -109,7 +109,7 @@ def validate(self, world_size: Optional[int] = None): raise ValueError(f"all_reduce_num_rings must be positive, got {self.all_reduce_num_rings}") # Validate all_gather_matmul_variant - valid_ag_variants = ["pull", "chunked"] + valid_ag_variants = ["pull"] if self.all_gather_matmul_variant not in valid_ag_variants: raise ValueError( f"all_gather_matmul_variant must be one of {valid_ag_variants}, got {self.all_gather_matmul_variant}" diff --git a/iris/ops/workspace.py b/iris/ops/workspace.py index 9328e9f9e..e519f0823 100644 --- a/iris/ops/workspace.py +++ b/iris/ops/workspace.py @@ -46,6 +46,10 @@ class FusedWorkspace: aux_buffer: Optional[torch.Tensor] = None # Generic buffer for intermediate results locks: Optional[torch.Tensor] = None # Synchronization primitives + # Push variant workspace + a_inbox: Optional[torch.Tensor] = None # (world_size, M, K_local) inbox buffer + signal_flags: Optional[torch.Tensor] = None # (world_size, world_size, m_tiles, k_tiles) + prepared: bool = False def matches( @@ -86,4 +90,6 @@ def clear(self): """Free all allocated buffers.""" self.aux_buffer = None self.locks = None + self.a_inbox = None + self.signal_flags = None self.prepared = False From c26e87275043e996c9dca78e44c60fc34d6d2eac Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 10 Feb 2026 00:04:25 +0000 Subject: [PATCH 5/5] Apply Ruff auto-fixes --- iris/ops/all_gather_matmul.py | 64 +++++++++++++++++------------------ 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 0dad98aee..6000f50ef 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -17,7 +17,6 @@ import iris.x from tritonblas.kernels.stages import GemmContext, ScheduleContext -from tritonblas.kernels.stages.indexing.pid_transforms import chiplet_transform_chunked from .config import FusedConfig from .workspace import FusedWorkspace @@ -190,9 +189,7 @@ def all_gather_matmul_preamble( num_m_tiles = (M + config.block_size_m - 1) // config.block_size_m num_k_tiles = (K_local + config.block_size_k - 1) // config.block_size_k ws.a_inbox = shmem.zeros((world_size, M, K_local), dtype=A_sharded.dtype) - ws.signal_flags = shmem.zeros( - (world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32 - ) + ws.signal_flags = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32) shmem.barrier() return ws @@ -267,35 +264,36 @@ def all_gather_matmul( num_tiles = num_tiles_m * num_tiles_n # grid = (num_tiles,) grid = (num_sms,) - _fused_all_gather_matmul_kernel[grid](A_sharded, - B, - output_tensor, - bias_ptr, - M, - N, - K, - K_local, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bias, - shmem.get_device_context(), - rank, - world_size, - config.block_size_m, - config.block_size_n, - config.block_size_k, - config.group_size_m, - num_sms, - config.num_xcds, - num_k_blocks_local, - use_bias, - even_k, - config.allow_tf32, - ) + _fused_all_gather_matmul_kernel[grid]( + A_sharded, + B, + output_tensor, + bias_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + shmem.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_sms, + config.num_xcds, + num_k_blocks_local, + use_bias, + even_k, + config.allow_tf32, + ) if not async_op: shmem.barrier()