Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion megatron/core/inference/batch_dimensions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,25 @@ def _calculate_cuda_graph_token_counts(
(tp_size=2, num_cuda_graphs=4, cuda_graph_max_tokens=1000)
[1000, 752, 504, 256]
"""
if num_cuda_graphs == -1:
# automatically determine the number of CUDA graphs to
# capture based on the `max_requests` value
cuda_graph_token_counts = (
[1, 2, 4] + list(range(8, 256, 8)) + list(range(256, cuda_graph_max_tokens + 1, 16))
)
# Align each entry to TP size
cuda_graph_token_counts = list(
dict.fromkeys(math.ceil(s / tp_size) * tp_size for s in cuda_graph_token_counts)
)
# Clamp to max tokens
cuda_graph_token_counts = [
s for s in cuda_graph_token_counts if s <= cuda_graph_max_tokens
]
if not cuda_graph_token_counts or cuda_graph_token_counts[-1] != cuda_graph_max_tokens:
cuda_graph_token_counts.append(cuda_graph_max_tokens)
cuda_graph_token_counts.reverse()
return cuda_graph_token_counts

assert num_cuda_graphs >= 1, f"num_cuda_graphs must be >= 1, got {num_cuda_graphs}"
assert (
cuda_graph_max_tokens > 0
Expand Down Expand Up @@ -340,7 +359,12 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
or cuda_graph_max_tokens <= 0
):
cuda_graph_max_tokens = max_tokens
num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens)

if num_cuda_graphs != -1:
# if -1, no need to adjust. This will be taken care of in
# the _calculate_cuda_graph_token_counts function where we will generate
# the token counts based on the max_tokens value and the step size.
num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens)

# Calculate token counts for prefill and mixed graphs.
# These need the full cuda_graph_max_tokens to handle variable-length sequences.
Expand Down
9 changes: 8 additions & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,11 @@ def validate_args(args, defaults={}):
assert args.fp8 is None, \
"fp8 is not supported with inference dynamic batching and full_iteration_inference CUDA graph"

if args.cuda_graph_impl == 'local':
assert args.inference_dynamic_batching_num_cuda_graphs > 0 or args.inference_dynamic_batching_num_cuda_graphs == -1, \
'inference_dynamic_batching_num_cuda_graphs should be a positive integer or -1' \
'-1 means that we will automatically determine the number of CUDA graphs to capture based on the `max_requests` value.'

print_rank_0('using {} for parameters ...'.format(args.params_dtype))

if args.dataloader_type is None:
Expand Down Expand Up @@ -1655,7 +1660,9 @@ def _add_inference_args(parser):
'cuda graph batch sizes range from 1 to `max_requests`. '
'(See `dynamic_context.py` for details on how '
'`max_requests` is computed). Due to rounding, the actual '
'number of cuda graphs may not equal this argument.')
'number of cuda graphs may not equal this argument.'
'The user can also pass -1, in which case we automatically determine the number of graphs ' \
'to capture based on the `max_requests`.')
group.add_argument('--inference-dynamic-batching-track-paused-request-events',
action='store_true',
help='Track paused request ids by adding \'paused\' events '
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/inference/engines/test_dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def teardown_method(self, method):
not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching"
)
@pytest.mark.parametrize("model_provider", ["gpt", "mamba"])
@pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4])
@pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4, -1])
@pytest.mark.parametrize("cuda_graph_scope", [[], [CudaGraphScope.full_iteration_inference]])
def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None:
"""Simple test that runs without errors, and validates output."""
Expand All @@ -557,6 +557,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None
num_cuda_graphs=num_cuda_graphs,
cuda_graph_scope=cuda_graph_scope,
force_build_cuda_graphs=True,
context_max_requests=128,
)

# Validate max_requests, max_tokens.
Expand Down
Loading