diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 77354d59320..1a202c35af5 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -235,6 +235,25 @@ def _calculate_cuda_graph_token_counts( (tp_size=2, num_cuda_graphs=4, cuda_graph_max_tokens=1000) [1000, 752, 504, 256] """ + if num_cuda_graphs == -1: + # automatically determine the number of CUDA graphs to + # capture based on the `max_requests` value + cuda_graph_token_counts = ( + [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, cuda_graph_max_tokens + 1, 16)) + ) + # Align each entry to TP size + cuda_graph_token_counts = list( + dict.fromkeys(math.ceil(s / tp_size) * tp_size for s in cuda_graph_token_counts) + ) + # Clamp to max tokens + cuda_graph_token_counts = [ + s for s in cuda_graph_token_counts if s <= cuda_graph_max_tokens + ] + if not cuda_graph_token_counts or cuda_graph_token_counts[-1] != cuda_graph_max_tokens: + cuda_graph_token_counts.append(cuda_graph_max_tokens) + cuda_graph_token_counts.reverse() + return cuda_graph_token_counts + assert num_cuda_graphs >= 1, f"num_cuda_graphs must be >= 1, got {num_cuda_graphs}" assert ( cuda_graph_max_tokens > 0 @@ -340,7 +359,12 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int or cuda_graph_max_tokens <= 0 ): cuda_graph_max_tokens = max_tokens - num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens) + + if num_cuda_graphs != -1: + # if -1, no need to adjust. This will be taken care of in + # the _calculate_cuda_graph_token_counts function where we will generate + # the token counts based on the max_tokens value and the step size. + num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens) # Calculate token counts for prefill and mixed graphs. # These need the full cuda_graph_max_tokens to handle variable-length sequences. diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bd6143409e7..b1525396647 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -857,6 +857,11 @@ def validate_args(args, defaults={}): assert args.fp8 is None, \ "fp8 is not supported with inference dynamic batching and full_iteration_inference CUDA graph" + if args.cuda_graph_impl == 'local': + assert args.inference_dynamic_batching_num_cuda_graphs > 0 or args.inference_dynamic_batching_num_cuda_graphs == -1, \ + 'inference_dynamic_batching_num_cuda_graphs should be a positive integer or -1' \ + '-1 means that we will automatically determine the number of CUDA graphs to capture based on the `max_requests` value.' + print_rank_0('using {} for parameters ...'.format(args.params_dtype)) if args.dataloader_type is None: @@ -1655,7 +1660,9 @@ def _add_inference_args(parser): 'cuda graph batch sizes range from 1 to `max_requests`. ' '(See `dynamic_context.py` for details on how ' '`max_requests` is computed). Due to rounding, the actual ' - 'number of cuda graphs may not equal this argument.') + 'number of cuda graphs may not equal this argument.' + 'The user can also pass -1, in which case we automatically determine the number of graphs ' \ + 'to capture based on the `max_requests`.') group.add_argument('--inference-dynamic-batching-track-paused-request-events', action='store_true', help='Track paused request ids by adding \'paused\' events ' diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 02be5c136fd..e679b5d7c64 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -543,7 +543,7 @@ def teardown_method(self, method): not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" ) @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) - @pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4]) + @pytest.mark.parametrize("num_cuda_graphs", [None, 1, 4, -1]) @pytest.mark.parametrize("cuda_graph_scope", [[], [CudaGraphScope.full_iteration_inference]]) def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None: """Simple test that runs without errors, and validates output.""" @@ -557,6 +557,7 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None num_cuda_graphs=num_cuda_graphs, cuda_graph_scope=cuda_graph_scope, force_build_cuda_graphs=True, + context_max_requests=128, ) # Validate max_requests, max_tokens.