Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions benchmarks/python/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@ def transpose_fusion(
is_copy_transpose: bool,
axes: list,
rank: int,
num_inputs: int = 2,
):
shape = [-1] * rank
contiguity = [True] * rank
T0 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False)
T1 = fd.define_tensor(shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False)

if num_inputs == 2:
T1 = fd.define_tensor(
shape=shape, contiguity=contiguity, dtype=dtype, is_cpu=False
)

if dtype in PROMOTE_DTYPES:
T0 = fd.ops.cast(T0, dtype=DataType.Float)
T1 = fd.ops.cast(T1, dtype=DataType.Float)
if num_inputs == 2:
T1 = fd.ops.cast(T1, dtype=DataType.Float)

T4 = fd.ops.add(T0, T1)
if num_inputs == 2:
T4 = fd.ops.add(T0, T1)
else:
T4 = T0
T5 = fd.ops.permute(T4, dims=axes)

if dtype in PROMOTE_DTYPES:
Expand All @@ -46,11 +55,18 @@ def transpose_fusion(
# Without contiguous, transpose returns a view with swapped strides.
# contiguous() materializes a contiguous copy of the result.
# When compiled with thunder, contiguous version will use nvFuser's transpose scheduler, otherwise it will use the pointwise scheduler.
def transpose_fwd_fn(inputs: list): # [input1, input2, dim0, dim1, is_copy_transpose]
relu_transpose_result = torch.nn.functional.relu(
torch.transpose(inputs[0] + inputs[1], inputs[2], inputs[3])
)
is_copy_transpose = inputs[4]
def transpose_fwd_fn(
inputs: list,
): # [input1, input2 (optional), dim0, dim1, is_copy_transpose, num_inputs]
num_inputs = inputs[-1]
is_copy_transpose = inputs[-2]
if num_inputs == 2:
data = inputs[0] + inputs[1]
dim0, dim1 = inputs[2], inputs[3]
else:
data = inputs[0]
dim0, dim1 = inputs[1], inputs[2]
relu_transpose_result = torch.nn.functional.relu(torch.transpose(data, dim0, dim1))
if is_copy_transpose:
return relu_transpose_result.contiguous()
else:
Expand All @@ -75,6 +91,11 @@ def _generate_transpose_params():
[True, False],
ids=["copy_transpose", "view_transpose"],
)
@pytest.mark.parametrize(
"num_inputs",
[1, 2],
ids=["1_input", "2_inputs"],
)
@pytest.mark.pointwise
def test_transpose_nvf_benchmark(
benchmark,
Expand All @@ -83,11 +104,16 @@ def test_transpose_nvf_benchmark(
dtype: torch.dtype,
axes: tuple,
dims: int,
num_inputs: int,
disable_validation: bool,
disable_benchmarking: bool,
):
input1 = torch.randn(size, device="cuda", dtype=dtype)
input2 = torch.randn(size, device="cuda", dtype=dtype)
inputs = [input1]
if num_inputs == 2:
input2 = torch.randn(size, device="cuda", dtype=dtype)
inputs.append(input2)

permute_axes = list(range(len(size)))
permute_axes[axes[0]], permute_axes[axes[1]] = (
permute_axes[axes[1]],
Expand All @@ -101,16 +127,22 @@ def test_transpose_nvf_benchmark(
is_copy_transpose,
permute_axes,
rank=dims,
num_inputs=num_inputs,
)

if not disable_validation:
eager_output = transpose_fwd_fn(
[input1, input2, axes[0], axes[1], is_copy_transpose]
)
fd.validate([input1, input2], [eager_output])
if num_inputs == 2:
eager_output = transpose_fwd_fn(
[input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs]
)
else:
eager_output = transpose_fwd_fn(
[input1, axes[0], axes[1], is_copy_transpose, num_inputs]
)
fd.validate(inputs, [eager_output])

if not disable_benchmarking:
run_benchmark(benchmark, fd.execute, [input1, input2])
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("executor", DEFAULT_EXECUTORS)
Expand All @@ -121,25 +153,38 @@ def test_transpose_nvf_benchmark(
[True, False],
ids=["copy_transpose", "view_transpose"],
)
@pytest.mark.parametrize(
"num_inputs",
[1, 2],
ids=["1_input", "2_inputs"],
)
def test_transpose_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
is_copy_transpose: bool,
axes: tuple,
dims: int,
num_inputs: int,
executor: str,
):
if executor == "torchcompile":
clear_dynamo_cache()
input1 = torch.randn(size, device="cuda", dtype=dtype)
input2 = torch.randn(size, device="cuda", dtype=dtype)

benchmark_fn = with_executor(executor, transpose_fwd_fn)

# Inputs and outputs are same as nvFuser, no need for manual IOByte computation
run_benchmark(
benchmark,
benchmark_fn,
[input1, input2, axes[0], axes[1], is_copy_transpose],
)
if num_inputs == 2:
input2 = torch.randn(size, device="cuda", dtype=dtype)
run_benchmark(
benchmark,
benchmark_fn,
[input1, input2, axes[0], axes[1], is_copy_transpose, num_inputs],
)
else:
run_benchmark(
benchmark,
benchmark_fn,
[input1, axes[0], axes[1], is_copy_transpose, num_inputs],
)
23 changes: 23 additions & 0 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,29 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
"combination of view op with small transpose dimensions are not "
"supported by transpose scheduler");

// Double tile_size2 if the default configuration doesn't provide enough
// bytes in flight to saturate memory bandwidth. This is based on Little's
// law: bytes_in_flight = bandwidth * latency. We estimate the bits in flight
// per SM as: (sum of input tensor element sizes) * elements_per_tile *
// blocks_per_sm. If this is less than the required bits in flight (derived
// from hardware bandwidth and memory latency), we double tile_size2 to
// increase the data in flight.
const auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t max_blocks_per_sm = dev_prop->maxBlocksPerMultiProcessor;
const int64_t num_elems_per_tile = tparams->tile_size1 * tparams->tile_size2;
const int64_t required_bits_per_sm =
scheduler_utils::getRequiredBitsInFlight();
int64_t total_input_bits_per_elem = 0;
for (auto tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
total_input_bits_per_elem +=
dataTypeSizeBit(tv->getDataType().value(), index_type);
}
const int64_t bits_in_flight_per_sm =
total_input_bits_per_elem * num_elems_per_tile * max_blocks_per_sm;
if (bits_in_flight_per_sm < required_bits_per_sm) {
tparams->tile_size2 *= 2;
}

// Note [vectorization and unroll of input and output]
//
// The choice of vectorization size, block size and tile sizes needs to be
Expand Down