diff --git a/examples/rl/model_configs/common.sh b/examples/rl/model_configs/common.sh index 389880d6289..e40e4be2c9c 100644 --- a/examples/rl/model_configs/common.sh +++ b/examples/rl/model_configs/common.sh @@ -15,6 +15,7 @@ COMMON_OPTIONS="\ --use-mcore-models \ --transformer-impl transformer_engine \ --${PRECISION:-bf16} \ + --inference-logits-dtype bf16 \ --te-rng-tracker \ --inference-dynamic-batching-buffer-size-gb 20 \ --data-parallel-random-init \ diff --git a/examples/rl/model_configs/nemotron5p5_12b_H.sh b/examples/rl/model_configs/nemotron5p5_12b_H.sh index bfb4c7e4727..cea2eb6d874 100644 --- a/examples/rl/model_configs/nemotron5p5_12b_H.sh +++ b/examples/rl/model_configs/nemotron5p5_12b_H.sh @@ -121,6 +121,7 @@ MODEL_OPTIONS="\ --log-num-zeros-in-grad \ --log-throughput \ --bf16 \ + --inference-logits-dtype bf16 \ --adam-beta1 0.9 \ --adam-beta2 0.95 \ --use-distributed-optimizer \ diff --git a/examples/rl/model_configs/qwen3_30b_a3b_moe.sh b/examples/rl/model_configs/qwen3_30b_a3b_moe.sh index 637b431280f..c55d678f577 100644 --- a/examples/rl/model_configs/qwen3_30b_a3b_moe.sh +++ b/examples/rl/model_configs/qwen3_30b_a3b_moe.sh @@ -43,6 +43,7 @@ MODEL_OPTIONS=" --seq-length 8192 \ --inference-max-seq-length 8192 \ --bf16 \ +--inference-logits-dtype bf16 \ --tensor-model-parallel-size $TP \ --pipeline-model-parallel-size $PP \ --expert-model-parallel-size $EP \ diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 5fbbcc376f3..59a5e0418bd 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -15,6 +15,7 @@ from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.module import Float16Module from megatron.core.utils import deprecate_args, get_attr_wrapped_model, get_model_config DEPRECATED_ARGS = ["inference_wrapper_config", "pg_collection"] @@ -118,13 +119,14 @@ def _forward(self, inference_input): tokens = inference_input["tokens"] position_ids = inference_input["position_ids"] attention_mask = inference_input["attention_mask"] - return self.model( - tokens, - position_ids, - attention_mask, + kwargs = dict( inference_context=self.inference_context, runtime_gather_output=True, # Inference should always gather the logits ) + if isinstance(self.model, Float16Module): + kwargs["fp32_output"] = self.config.inference_logits_dtype == torch.float32 + logits = self.model(tokens, position_ids, attention_mask, **kwargs) + return logits.to(self.config.inference_logits_dtype) def _get_batch_size_and_seq_len( self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None @@ -218,9 +220,9 @@ def forward_pass_with_pipeline_parallel( logits = None if is_pipeline_last_stage(self.pp_group): logits = output_tensor - - # Explicitly cast logits to expected dtype - logits = logits.to(self.config.params_dtype) + assert ( + logits.dtype == self.config.inference_logits_dtype + ), f"Expected logits dtype {self.config.inference_logits_dtype}, got {logits.dtype}" return logits diff --git a/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py index 922fba13f4c..1210c30dd7a 100644 --- a/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py @@ -12,6 +12,7 @@ from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) +from megatron.core.transformer.module import Float16Module # pylint: disable=line-too-long @@ -137,20 +138,21 @@ def _forward(self, inference_input: Dict[str, Any]): position_ids = inference_input["position_ids"] num_image_tiles = inference_input["num_tiles"] - output = self.model( - images, - tokens, + kwargs = dict( position_ids=position_ids, attention_mask=None, inference_context=self.inference_context, num_image_tiles=num_image_tiles, runtime_gather_output=True, ) + if isinstance(self.model, Float16Module): + kwargs["fp32_output"] = self.config.inference_logits_dtype == torch.float32 + output = self.model(images, tokens, **kwargs) if isinstance(output, tuple): logits, _ = output else: logits = output - return logits + return logits.to(self.config.inference_logits_dtype) def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor: """The forward pass of the model for inference diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 0bdc5853aaf..4789f71f553 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -151,7 +151,7 @@ def _init_dynamic_sampling_tensors(self): self._get_stop_word_finished_ids_callback = None device = torch.cuda.current_device() - logits_dtype = self.inference_wrapped_model.config.params_dtype + logits_dtype = self.model_config.inference_logits_dtype self._sampling_backend = "torch" self._enable_cuda_graph = self.model_config.cuda_graph_impl == "local" @@ -711,7 +711,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor): logits = broadcast_from_last_pipeline_stage( logits_shape, - dtype=self.model_config.params_dtype, + dtype=self.model_config.inference_logits_dtype, tensor=logits, pp_group=self.pp_group, ) @@ -955,7 +955,7 @@ def _compute_serial_mtp_and_sample(self): nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/pp-broadcast") mtp_logits_2d = broadcast_from_last_pipeline_stage( [active_request_count, self.vocab_size], - dtype=self.model_config.params_dtype, + dtype=self.model_config.inference_logits_dtype, tensor=mtp_logits_2d, pp_group=self.pp_group, ) @@ -2210,7 +2210,7 @@ def generate_all_output_tokens_static_batch( # and then broadcast the sampled tokens rather than broadcasting the raw logits. logits = broadcast_from_last_pipeline_stage( [batch_size, logits_seq_len, self.vocab_size], - dtype=self.model_config.params_dtype, + dtype=self.model_config.inference_logits_dtype, tensor=logits, pp_group=self.pp_group, ) diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 84b0ca2fea3..9edb9d56e72 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -384,6 +384,7 @@ def compute_mtp_single_step( logits, _ = self.output_layer(mtp_hidden, weight=output_weight, runtime_gather_output=True) logits = self._scale_logits(logits) + logits = logits.to(self.config.inference_logits_dtype) return mtp_hidden, logits diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index ebad45b3de0..ecb500ffd22 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -920,6 +920,10 @@ class TransformerConfig(ModelParallelConfig): inference_sampling_seed: int = 42 """ Random seed to use for sampling during inference. """ + inference_logits_dtype: torch.dtype = torch.float32 + """ Dtype for logits during inference. Float32 improves sampling stability + by reducing tie-breaking non-determinism in argmax/multinomial. """ + symmetric_ar_type: Optional[Literal['two_shot', "one_shot", "multimem_all_reduce"]] = None """What type of symmetric all reduce to use. The default is None which is no use of symmetric memory. diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 4d018217cec..98127389445 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -797,7 +797,7 @@ def get_logprobs(model, tokens, position_ids, no_grad=False, sequence_packing=Fa # This is a hack to fix megatron's behaviour when flash-decode affects the training code flow. flash_decode = model.config.flash_decode model.config.flash_decode = False - fp32_output = not (args.fp16 or args.bf16) + fp32_output = model.config.inference_logits_dtype == torch.float32 with torch.no_grad() if no_grad else nullcontext(): logits_or_hidden_states = model( tokens, diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index f28eb8733ee..ba4da7ecbc1 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -973,6 +973,7 @@ def validate_args(args, defaults={}): args.main_params_dtype = map_dtype(args.main_params_dtype) args.exp_avg_dtype = map_dtype(args.exp_avg_dtype) args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype) + args.inference_logits_dtype = map_dtype(args.inference_logits_dtype) args.mamba_inference_conv_states_dtype = map_dtype(args.mamba_inference_conv_states_dtype) args.mamba_inference_ssm_states_dtype = map_dtype(args.mamba_inference_ssm_states_dtype) @@ -1788,6 +1789,7 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['experimental_attention_variant'] = 'dsa' kw_args['inference_sampling_seed'] = args.seed + kw_args['inference_logits_dtype'] = args.inference_logits_dtype # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when @@ -1949,6 +1951,12 @@ def _add_inference_args(parser): help="Enable chunked prefill (disabled by default)") group.add_argument('--num-speculative-tokens', type=int, default=0, help='Number of speculative tokens generated during decode') + group.add_argument('--inference-logits-dtype', type=str, default='fp32', + choices=['fp32', 'fp16', 'bf16'], + help='Dtype for logits during inference. fp32 (default) ' + 'improves sampling determinism by reducing tie-breaking ' + 'non-determinism in argmax/multinomial.', + dest='inference_logits_dtype') group.add_argument('--inference-dynamic-batching-prefix-caching', dest='inference_dynamic_batching_enable_prefix_caching', action=argparse.BooleanOptionalAction, @@ -2073,6 +2081,7 @@ def _add_network_size_args(parser): "cuda_graph_retain_backward_graph", "disable_parameter_transpose_cache", "inference_sampling_seed", + "inference_logits_dtype", "use_inference_optimized_layers", "heterogeneous_block_specs", "hetereogenous_dist_checkpoint", diff --git a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py index f49a67790e7..8c3ddeb2cad 100644 --- a/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py +++ b/tests/unit_tests/inference/model_inference_wrappers/gpt/test_gpt_inference_wrapper.py @@ -17,7 +17,9 @@ class TestGPTInferenceWrapper: - def setup_model(self, tensor_parallel_size, pipeline_parallel_size): + def setup_model( + self, tensor_parallel_size, pipeline_parallel_size, inference_logits_dtype=torch.float32 + ): Utils.initialize_model_parallel( tensor_model_parallel_size=tensor_parallel_size, pipeline_model_parallel_size=pipeline_parallel_size, @@ -27,12 +29,14 @@ def setup_model(self, tensor_parallel_size, pipeline_parallel_size): self.batch_size = 4 self.sequence_length = 32 hidden_size = 32 + self.inference_logits_dtype = inference_logits_dtype transformer_config = TransformerConfig( num_layers=4, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True, + inference_logits_dtype=inference_logits_dtype, ) gpt_model = GPTModel( @@ -86,6 +90,9 @@ def test_inference_pipeline_parallel(self, materialize_only_last_token_logits): logits_seq_len, self.vocab_size, ), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}" + assert ( + logits.dtype == self.inference_logits_dtype + ), f"Expected logits dtype {self.inference_logits_dtype}, got {logits.dtype}" @pytest.mark.parametrize("materialize_only_last_token_logits", [True, False]) def test_inference_only_tensor_parallel(self, materialize_only_last_token_logits): @@ -120,3 +127,43 @@ def test_inference_only_tensor_parallel(self, materialize_only_last_token_logits logits_seq_len, self.vocab_size, ), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}" + assert ( + logits.dtype == self.inference_logits_dtype + ), f"Expected logits dtype {self.inference_logits_dtype}, got {logits.dtype}" + + @pytest.mark.parametrize("inference_logits_dtype", [torch.float32, torch.bfloat16]) + @pytest.mark.parametrize( + "tp_pp", [pytest.param((2, 2), id="pp"), pytest.param((4, 1), id="tp")] + ) + def test_inference_logits_dtype(self, tp_pp, inference_logits_dtype): + tp, pp = tp_pp + self.setup_model( + tensor_parallel_size=tp, + pipeline_parallel_size=pp, + inference_logits_dtype=inference_logits_dtype, + ) + + batch_prompt_tokens = ( + torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length)) + .int() + .cuda() + ) + self.inference_wrapped_model.prep_model_for_inference() + + inference_input = self.inference_wrapped_model.prep_inference_input( + prompts_tokens=batch_prompt_tokens + ) + inference_input_for_context_window = ( + self.inference_wrapped_model.get_batch_for_context_window(inference_input, 0, 5) + ) + + logits = self.inference_wrapped_model.run_one_forward_step( + inference_input_for_context_window + ) + + if pp > 1 and not parallel_state.is_pipeline_last_stage(): + assert logits is None + else: + assert ( + logits.dtype == inference_logits_dtype + ), f"Expected {inference_logits_dtype}, got {logits.dtype}" diff --git a/tests/unit_tests/rl/test_rl_utils.py b/tests/unit_tests/rl/test_rl_utils.py index 6bf6e994ffb..36c6e7db0ba 100644 --- a/tests/unit_tests/rl/test_rl_utils.py +++ b/tests/unit_tests/rl/test_rl_utils.py @@ -48,7 +48,10 @@ def __init__(self, batch=BATCH, seq=SEQ, vocab=VOCAB): self.vocab = vocab self.pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.config = TransformerConfig( - num_attention_heads=8, num_layers=8, pipeline_dtype=torch.bfloat16 + num_attention_heads=8, + num_layers=8, + pipeline_dtype=torch.bfloat16, + inference_logits_dtype=torch.bfloat16, ) self.model_type = ModelType.encoder_or_decoder @@ -796,6 +799,7 @@ def test_get_logprobs_cuda_graphs(self, initialize_model_parallel): use_cpu_initialization=True, cuda_graph_impl="local", bf16=True, + inference_logits_dtype=torch.bfloat16, ) model = GPTModel( config=transformer_config,