Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/rl/model_configs/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ COMMON_OPTIONS="\
--use-mcore-models \
--transformer-impl transformer_engine \
--${PRECISION:-bf16} \
--inference-logits-dtype bf16 \
--te-rng-tracker \
--inference-dynamic-batching-buffer-size-gb 20 \
--data-parallel-random-init \
Expand Down
1 change: 1 addition & 0 deletions examples/rl/model_configs/nemotron5p5_12b_H.sh
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ MODEL_OPTIONS="\
--log-num-zeros-in-grad \
--log-throughput \
--bf16 \
--inference-logits-dtype bf16 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--use-distributed-optimizer \
Expand Down
1 change: 1 addition & 0 deletions examples/rl/model_configs/qwen3_30b_a3b_moe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ MODEL_OPTIONS="
--seq-length 8192 \
--inference-max-seq-length 8192 \
--bf16 \
--inference-logits-dtype bf16 \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--expert-model-parallel-size $EP \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import deprecate_args, get_attr_wrapped_model, get_model_config

DEPRECATED_ARGS = ["inference_wrapper_config", "pg_collection"]
Expand Down Expand Up @@ -118,13 +119,14 @@ def _forward(self, inference_input):
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
attention_mask = inference_input["attention_mask"]
return self.model(
tokens,
position_ids,
attention_mask,
kwargs = dict(
inference_context=self.inference_context,
runtime_gather_output=True, # Inference should always gather the logits
)
if isinstance(self.model, Float16Module):
kwargs["fp32_output"] = self.config.inference_logits_dtype == torch.float32
logits = self.model(tokens, position_ids, attention_mask, **kwargs)
return logits.to(self.config.inference_logits_dtype)

def _get_batch_size_and_seq_len(
self, tokens: torch.Tensor, recv_buffer_seq_len: Optional[int] = None
Expand Down Expand Up @@ -218,9 +220,9 @@ def forward_pass_with_pipeline_parallel(
logits = None
if is_pipeline_last_stage(self.pp_group):
logits = output_tensor

# Explicitly cast logits to expected dtype
logits = logits.to(self.config.params_dtype)
assert (
logits.dtype == self.config.inference_logits_dtype
), f"Expected logits dtype {self.config.inference_logits_dtype}, got {logits.dtype}"

return logits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.transformer.module import Float16Module


# pylint: disable=line-too-long
Expand Down Expand Up @@ -137,20 +138,21 @@ def _forward(self, inference_input: Dict[str, Any]):
position_ids = inference_input["position_ids"]
num_image_tiles = inference_input["num_tiles"]

output = self.model(
images,
tokens,
kwargs = dict(
position_ids=position_ids,
attention_mask=None,
inference_context=self.inference_context,
num_image_tiles=num_image_tiles,
runtime_gather_output=True,
)
if isinstance(self.model, Float16Module):
kwargs["fp32_output"] = self.config.inference_logits_dtype == torch.float32
output = self.model(images, tokens, **kwargs)
if isinstance(output, tuple):
logits, _ = output
else:
logits = output
return logits
return logits.to(self.config.inference_logits_dtype)

def run_one_forward_step(self, inference_input: Dict[str, Any]) -> torch.Tensor:
"""The forward pass of the model for inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _init_dynamic_sampling_tensors(self):
self._get_stop_word_finished_ids_callback = None

device = torch.cuda.current_device()
logits_dtype = self.inference_wrapped_model.config.params_dtype
logits_dtype = self.model_config.inference_logits_dtype

self._sampling_backend = "torch"
self._enable_cuda_graph = self.model_config.cuda_graph_impl == "local"
Expand Down Expand Up @@ -711,7 +711,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor):

logits = broadcast_from_last_pipeline_stage(
logits_shape,
dtype=self.model_config.params_dtype,
dtype=self.model_config.inference_logits_dtype,
tensor=logits,
pp_group=self.pp_group,
)
Expand Down Expand Up @@ -955,7 +955,7 @@ def _compute_serial_mtp_and_sample(self):
nvtx_range_push(f"mtp-spec-decoding/depth-{depth}/pp-broadcast")
mtp_logits_2d = broadcast_from_last_pipeline_stage(
[active_request_count, self.vocab_size],
dtype=self.model_config.params_dtype,
dtype=self.model_config.inference_logits_dtype,
tensor=mtp_logits_2d,
pp_group=self.pp_group,
)
Expand Down Expand Up @@ -2210,7 +2210,7 @@ def generate_all_output_tokens_static_batch(
# and then broadcast the sampled tokens rather than broadcasting the raw logits.
logits = broadcast_from_last_pipeline_stage(
[batch_size, logits_seq_len, self.vocab_size],
dtype=self.model_config.params_dtype,
dtype=self.model_config.inference_logits_dtype,
tensor=logits,
pp_group=self.pp_group,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def compute_mtp_single_step(

logits, _ = self.output_layer(mtp_hidden, weight=output_weight, runtime_gather_output=True)
logits = self._scale_logits(logits)
logits = logits.to(self.config.inference_logits_dtype)

return mtp_hidden, logits

Expand Down
4 changes: 4 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,10 @@ class TransformerConfig(ModelParallelConfig):
inference_sampling_seed: int = 42
""" Random seed to use for sampling during inference. """

inference_logits_dtype: torch.dtype = torch.float32
""" Dtype for logits during inference. Float32 improves sampling stability
by reducing tie-breaking non-determinism in argmax/multinomial. """

symmetric_ar_type: Optional[Literal['two_shot', "one_shot", "multimem_all_reduce"]] = None
"""What type of symmetric all reduce to use. The default is None
which is no use of symmetric memory.
Expand Down
2 changes: 1 addition & 1 deletion megatron/rl/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def get_logprobs(model, tokens, position_ids, no_grad=False, sequence_packing=Fa
# This is a hack to fix megatron's behaviour when flash-decode affects the training code flow.
flash_decode = model.config.flash_decode
model.config.flash_decode = False
fp32_output = not (args.fp16 or args.bf16)
fp32_output = model.config.inference_logits_dtype == torch.float32
with torch.no_grad() if no_grad else nullcontext():
logits_or_hidden_states = model(
tokens,
Expand Down
9 changes: 9 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ def validate_args(args, defaults={}):
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
args.inference_logits_dtype = map_dtype(args.inference_logits_dtype)
args.mamba_inference_conv_states_dtype = map_dtype(args.mamba_inference_conv_states_dtype)
args.mamba_inference_ssm_states_dtype = map_dtype(args.mamba_inference_ssm_states_dtype)

Expand Down Expand Up @@ -1788,6 +1789,7 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['experimental_attention_variant'] = 'dsa'

kw_args['inference_sampling_seed'] = args.seed
kw_args['inference_logits_dtype'] = args.inference_logits_dtype

# handle quantization config
# NOTE: Kitchen arguments are only added to the namespace when
Expand Down Expand Up @@ -1949,6 +1951,12 @@ def _add_inference_args(parser):
help="Enable chunked prefill (disabled by default)")
group.add_argument('--num-speculative-tokens', type=int, default=0,
help='Number of speculative tokens generated during decode')
group.add_argument('--inference-logits-dtype', type=str, default='fp32',
choices=['fp32', 'fp16', 'bf16'],
help='Dtype for logits during inference. fp32 (default) '
'improves sampling determinism by reducing tie-breaking '
'non-determinism in argmax/multinomial.',
dest='inference_logits_dtype')
group.add_argument('--inference-dynamic-batching-prefix-caching',
dest='inference_dynamic_batching_enable_prefix_caching',
action=argparse.BooleanOptionalAction,
Expand Down Expand Up @@ -2073,6 +2081,7 @@ def _add_network_size_args(parser):
"cuda_graph_retain_backward_graph",
"disable_parameter_transpose_cache",
"inference_sampling_seed",
"inference_logits_dtype",
"use_inference_optimized_layers",
"heterogeneous_block_specs",
"hetereogenous_dist_checkpoint",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

class TestGPTInferenceWrapper:

def setup_model(self, tensor_parallel_size, pipeline_parallel_size):
def setup_model(
self, tensor_parallel_size, pipeline_parallel_size, inference_logits_dtype=torch.float32
):
Utils.initialize_model_parallel(
tensor_model_parallel_size=tensor_parallel_size,
pipeline_model_parallel_size=pipeline_parallel_size,
Expand All @@ -27,12 +29,14 @@ def setup_model(self, tensor_parallel_size, pipeline_parallel_size):
self.batch_size = 4
self.sequence_length = 32
hidden_size = 32
self.inference_logits_dtype = inference_logits_dtype

transformer_config = TransformerConfig(
num_layers=4,
hidden_size=hidden_size,
num_attention_heads=4,
use_cpu_initialization=True,
inference_logits_dtype=inference_logits_dtype,
)

gpt_model = GPTModel(
Expand Down Expand Up @@ -86,6 +90,9 @@ def test_inference_pipeline_parallel(self, materialize_only_last_token_logits):
logits_seq_len,
self.vocab_size,
), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}"
assert (
logits.dtype == self.inference_logits_dtype
), f"Expected logits dtype {self.inference_logits_dtype}, got {logits.dtype}"

@pytest.mark.parametrize("materialize_only_last_token_logits", [True, False])
def test_inference_only_tensor_parallel(self, materialize_only_last_token_logits):
Expand Down Expand Up @@ -120,3 +127,43 @@ def test_inference_only_tensor_parallel(self, materialize_only_last_token_logits
logits_seq_len,
self.vocab_size,
), f"Shape mismatch . Expected {(self.batch_size, logits_seq_len, self.vocab_size)}, but got {logits.shape}"
assert (
logits.dtype == self.inference_logits_dtype
), f"Expected logits dtype {self.inference_logits_dtype}, got {logits.dtype}"

@pytest.mark.parametrize("inference_logits_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize(
"tp_pp", [pytest.param((2, 2), id="pp"), pytest.param((4, 1), id="tp")]
)
def test_inference_logits_dtype(self, tp_pp, inference_logits_dtype):
tp, pp = tp_pp
self.setup_model(
tensor_parallel_size=tp,
pipeline_parallel_size=pp,
inference_logits_dtype=inference_logits_dtype,
)

batch_prompt_tokens = (
torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.sequence_length))
.int()
.cuda()
)
self.inference_wrapped_model.prep_model_for_inference()

inference_input = self.inference_wrapped_model.prep_inference_input(
prompts_tokens=batch_prompt_tokens
)
inference_input_for_context_window = (
self.inference_wrapped_model.get_batch_for_context_window(inference_input, 0, 5)
)

logits = self.inference_wrapped_model.run_one_forward_step(
inference_input_for_context_window
)

if pp > 1 and not parallel_state.is_pipeline_last_stage():
assert logits is None
else:
assert (
logits.dtype == inference_logits_dtype
), f"Expected {inference_logits_dtype}, got {logits.dtype}"
6 changes: 5 additions & 1 deletion tests/unit_tests/rl/test_rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def __init__(self, batch=BATCH, seq=SEQ, vocab=VOCAB):
self.vocab = vocab
self.pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.config = TransformerConfig(
num_attention_heads=8, num_layers=8, pipeline_dtype=torch.bfloat16
num_attention_heads=8,
num_layers=8,
pipeline_dtype=torch.bfloat16,
inference_logits_dtype=torch.bfloat16,
)
self.model_type = ModelType.encoder_or_decoder

Expand Down Expand Up @@ -796,6 +799,7 @@ def test_get_logprobs_cuda_graphs(self, initialize_model_parallel):
use_cpu_initialization=True,
cuda_graph_impl="local",
bf16=True,
inference_logits_dtype=torch.bfloat16,
)
model = GPTModel(
config=transformer_config,
Expand Down
Loading