diff --git a/benchmarks/python/test_transpose.py b/benchmarks/python/test_transpose.py index 11b66774708..9e91fc0b2f2 100644 --- a/benchmarks/python/test_transpose.py +++ b/benchmarks/python/test_transpose.py @@ -15,17 +15,26 @@ def transpose_fusion( is_copy_transpose: bool, axes: list, rank: int, + num_inputs: int = 2, ): shape = [-1] * rank contiguity = [True] * rank T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) - T1 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False) + + if num_inputs == 2: + T1 = fd.define_tensor( + shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False + ) if dtype in PROMOTE_DTYPES: T0 = fd.ops.cast(T0, dtype=DataType.Float) - T1 = fd.ops.cast(T1, dtype=DataType.Float) + if num_inputs == 2: + T1 = fd.ops.cast(T1, dtype=DataType.Float) - T4 = fd.ops.add(T0, T1) + if num_inputs == 2: + T4 = fd.ops.add(T0, T1) + else: + T4 = T0 T5 = fd.ops.permute(T4, dims=axes) if dtype in PROMOTE_DTYPES: @@ -46,11 +55,18 @@ def transpose_fusion( # Without contiguous, transpose returns a view with swapped strides. # contiguous() materializes a contiguous copy of the result. # When compiled with thunder, contiguous version will use nvFuser's transpose scheduler, otherwise it will use the pointwise scheduler. -def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1, is_copy_transpose] - relu_transpose_result = torch.nn.functional.relu( - torch.transpose(inputs[0] + inputs[1], inputs[2], inputs[3]) - ) - is_copy_transpose = inputs[4] +def transpose_fwd_fn( + inputs: list, +): # [input1, input2 (optional), dim0, dim1, is_copy_transpose, num_inputs] + num_inputs = inputs[-1] + is_copy_transpose = inputs[-2] + if num_inputs == 2: + data = inputs[0] + inputs[1] + dim0, dim1 = inputs[2], inputs[3] + else: + data = inputs[0] + dim0, dim1 = inputs[1], inputs[2] + relu_transpose_result = torch.nn.functional.relu(torch.transpose(data, dim0, dim1)) if is_copy_transpose: return relu_transpose_result.contiguous() else: @@ -75,6 +91,11 @@ def _generate_transpose_params(): [True, False], ids=["copy_transpose", "view_transpose"], ) +@pytest.mark.parametrize( + "num_inputs", + [1, 2], + ids=["1_input", "2_inputs"], +) @pytest.mark.pointwise def test_transpose_nvf_benchmark( benchmark, @@ -83,11 +104,16 @@ def test_transpose_nvf_benchmark( dtype: torch.dtype, axes: tuple, dims: int, + num_inputs: int, disable_validation: bool, disable_benchmarking: bool, ): input1 = torch.randn(size, device="cuda", dtype=dtype) - input2 = torch.randn(size, device="cuda", dtype=dtype) + inputs = [input1] + if num_inputs == 2: + input2 = torch.randn(size, device="cuda", dtype=dtype) + inputs.append(input2) + permute_axes = list(range(len(size))) permute_axes[axes[0]], permute_axes[axes[1]] = ( permute_axes[axes[1]], @@ -101,16 +127,22 @@ def test_transpose_nvf_benchmark( is_copy_transpose, permute_axes, rank=dims, + num_inputs=num_inputs, ) if not disable_validation: - eager_output = transpose_fwd_fn( - [input1, input2, axes[0], axes[1], is_copy_transpose] - ) - fd.validate([input1, input2], [eager_output]) + if num_inputs == 2: + eager_output = transpose_fwd_fn( + [input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs] + ) + else: + eager_output = transpose_fwd_fn( + [input1, axes[0], axes[1], is_copy_transpose, num_inputs] + ) + fd.validate(inputs, [eager_output]) if not disable_benchmarking: - run_benchmark(benchmark, fd.execute, [input1, input2]) + run_benchmark(benchmark, fd.execute, inputs) @pytest.mark.parametrize("executor", DEFAULT_EXECUTORS) @@ -121,6 +153,11 @@ def test_transpose_nvf_benchmark( [True, False], ids=["copy_transpose", "view_transpose"], ) +@pytest.mark.parametrize( + "num_inputs", + [1, 2], + ids=["1_input", "2_inputs"], +) def test_transpose_baseline_benchmark( benchmark, size: tuple, @@ -128,18 +165,26 @@ def test_transpose_baseline_benchmark( is_copy_transpose: bool, axes: tuple, dims: int, + num_inputs: int, executor: str, ): if executor == "torchcompile": clear_dynamo_cache() input1 = torch.randn(size, device="cuda", dtype=dtype) - input2 = torch.randn(size, device="cuda", dtype=dtype) benchmark_fn = with_executor(executor, transpose_fwd_fn) # Inputs and outputs are same as nvFuser, no need for manual IOByte computation - run_benchmark( - benchmark, - benchmark_fn, - [input1, input2, axes[0], axes[1], is_copy_transpose], - ) + if num_inputs == 2: + input2 = torch.randn(size, device="cuda", dtype=dtype) + run_benchmark( + benchmark, + benchmark_fn, + [input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs], + ) + else: + run_benchmark( + benchmark, + benchmark_fn, + [input1, axes[0], axes[1], is_copy_transpose, num_inputs], + ) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index c2511ec5c2f..7f4e5da0634 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -682,6 +682,29 @@ std::unique_ptr getTransposeHeuristics( "combination of view op with small transpose dimensions are not " "supported by transpose scheduler"); + // Double tile_size2 if the default configuration doesn't provide enough + // bytes in flight to saturate memory bandwidth. This is based on Little's + // law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight + // per SM as: (sum of input tensor element sizes) * elements_per_tile * + // blocks_per_sm. If this is less than the required bits in flight (derived + // from hardware bandwidth and memory latency), we double tile_size2 to + // increase the data in flight. + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t max_blocks_per_sm = dev_prop->maxBlocksPerMultiProcessor; + const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2; + const int64_t required_bits_per_sm = + scheduler_utils::getRequiredBitsInFlight(); + int64_t total_input_bits_per_elem = 0; + for (auto tv : ir_utils::filterByType(fusion->inputs())) { + total_input_bits_per_elem += + dataTypeSizeBit(tv->getDataType().value(), index_type); + } + const int64_t bits_in_flight_per_sm = + total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm; + if (bits_in_flight_per_sm < required_bits_per_sm) { + tparams->tile_size2 *= 2; + } + // Note [vectorization and unroll of input and output] // // The choice of vectorization size, block size and tile sizes needs to be