diff --git a/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh b/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh index d0da539..4d27dd3 100755 --- a/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh +++ b/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh @@ -92,9 +92,6 @@ elif [ -v OMPI_COMM_WORLD_RANK ]; then export CCOM_SOCKET_IFNAME=eth0 export FI_EFA_FORK_SAFE=1 - # Dataset is in shared location (Don't know where the s3 bucket for this artifact is to onboard to Kaizen) - #DATA_PATH="$SHARED_PATH_PREFIX/examples_datasets/databricks-dolly-15k" - # Store metrics in shared location METRICS_FILE=$ARTIFACT_PATH/results.json mkdir -p $ARTIFACT_PATH diff --git a/examples/training/llama/modeling_llama_nxd.py b/examples/training/llama/modeling_llama_nxd.py index cc6a230..f46fe50 100644 --- a/examples/training/llama/modeling_llama_nxd.py +++ b/examples/training/llama/modeling_llama_nxd.py @@ -822,9 +822,12 @@ def forward( shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + per_token_loss = loss_fct(shift_logits, shift_labels) + valid_mask = (shift_labels != -100).float() + per_token_loss = per_token_loss * valid_mask - loss = torch.mean(loss) + denom = valid_mask.sum().clamp_min(1.0) + loss = per_token_loss.sum() / denom if not return_dict: output = (logits,) + outputs[1:] diff --git a/examples/training/tp_dp_gpt_neox_hf_pretrain/common/get_dataset.py b/examples/training/tp_dp_gpt_neox_hf_pretrain/common/get_dataset.py index d05c6ba..72584d8 100644 --- a/examples/training/tp_dp_gpt_neox_hf_pretrain/common/get_dataset.py +++ b/examples/training/tp_dp_gpt_neox_hf_pretrain/common/get_dataset.py @@ -4,7 +4,7 @@ from datasets import load_dataset from transformers import AutoTokenizer -dataset_name = "wikicorpus" +dataset_name = "gboleda/wikicorpus" dataset_config_name = "raw_en" save_path = "~/examples_datasets/wikicorpus_gpt_neox_tokenized_2k" diff --git a/src/neuronx_distributed/_version.py b/src/neuronx_distributed/_version.py index df9340a..fb9229a 100644 --- a/src/neuronx_distributed/_version.py +++ b/src/neuronx_distributed/_version.py @@ -1,3 +1,3 @@ # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. # ============================================================================== -__version__ = "0.16.0" +__version__ = "0.17.0" diff --git a/src/neuronx_distributed/kernels/find_nonzero_indices.py b/src/neuronx_distributed/kernels/find_nonzero_indices.py index 5781881..7a652a6 100644 --- a/src/neuronx_distributed/kernels/find_nonzero_indices.py +++ b/src/neuronx_distributed/kernels/find_nonzero_indices.py @@ -101,9 +101,9 @@ def find_nonzero_indices(input_tensor: nt.tensor, n_partitions = nl.tile_size.pmax quadrant=32 n_quadrant = 4 - n_q7 = 8 - n_q7_per_quadrant = 2 - n_partition_per_q7 = 16 + n_gpsimd_cores = 8 + n_gpsimd_cores_per_quadrant = 2 + n_partition_per_gpsimd_core = 16 # Initialize (E,T) output on HBM to all -1s. indices = nl.ndarray((E, T), dtype=index_dtype, buffer=nl.shared_hbm) @@ -116,44 +116,44 @@ def find_nonzero_indices(input_tensor: nt.tensor, nonzero_counts = nl.ndarray((E, ), dtype=nl.int32, buffer=nl.shared_hbm) nonzero_counts_local = nl.zeros((1, E_per_shard), dtype=nl.int32, buffer=nl.sbuf) - # 2. Handle experts in groups of 8: 8 Q7 cores run in parallel. + # 2. Handle experts in groups of 8: 8 GPSIMD cores run in parallel. # Number of rounds to find index in groups of 8. - n_rounds = (E_per_shard + n_q7 - 1) // n_q7 + n_rounds = (E_per_shard + n_gpsimd_cores - 1) // n_gpsimd_cores # Number of chunks to chunk along the T dim. n_T_chunks = T // chunk_size for r in nl.static_range(n_rounds): # Get number of experts to process this round. - n_e = n_q7 + n_e = n_gpsimd_cores if r == n_rounds - 1: - n_e = E_per_shard - n_q7 * r + n_e = E_per_shard - n_gpsimd_cores * r # Counts of nonzeros so far. This is the offset at which the next chunk write should begin. # Keep offsets in int32. - offsets = nl.zeros((1, n_q7), dtype=nl.int32, buffer=nl.sbuf) + offsets = nl.zeros((1, n_gpsimd_cores), dtype=nl.int32, buffer=nl.sbuf) # Handle sequences in chunks with chunk_size for c in nl.static_range(n_T_chunks): # Load input: (T,E) layout on HBM - # Q7 kernel needs: (128, chunk_size) with partition 0, 16, 32, ... each filled with a different expert. - input_local_q7_aligned = nl.ndarray((n_partitions, 1, chunk_size), dtype=input_tensor.dtype, buffer=nl.sbuf) + # GPSIMD kernel needs: (128, chunk_size) with partition 0, 16, 32, ... each filled with a different expert. + input_local_gpsimd_core_aligned = nl.ndarray((n_partitions, 1, chunk_size), dtype=input_tensor.dtype, buffer=nl.sbuf) n_tiles = chunk_size // n_partitions for i in nl.affine_range(n_tiles): # Read (128, n_e) chunk i_packed_t = nl.arange(n_partitions)[:, None] i_packed_e = nl.arange(n_e)[None, :] - offset = r * n_q7 + row_start_id + E_offset + offset = r * n_gpsimd_cores + row_start_id + E_offset input_local_te = nl.load(input_tensor[i_packed_t + c * chunk_size + i * n_partitions, i_packed_e + offset]) # Copy to (128, 128) chunk, putting the n_e columns at column 0, 16, 32, ..., 112 input_local_aligned_te = nl.ndarray((n_partitions, n_partitions), dtype=input_tensor.dtype, buffer=nl.sbuf) for q in nl.affine_range(n_e): - input_local_aligned_te[:, nl.ds(q * n_partition_per_q7, 1)] = input_local_te[:, nl.ds(q, 1)] + input_local_aligned_te[:, nl.ds(q * n_partition_per_gpsimd_core, 1)] = input_local_te[:, nl.ds(q, 1)] # Transpose so we have expert data at row 0, 16, 32, ..., 112 - input_local_q7_aligned[:, 0, nl.ds(i*n_partitions, n_partitions)] = nisa.nc_transpose(input_local_aligned_te) + input_local_gpsimd_core_aligned[:, 0, nl.ds(i*n_partitions, n_partitions)] = nisa.nc_transpose(input_local_aligned_te) - # Run Q7 kernel for ISA NonzeroWithCount. - output_local = nki_asm_nonzero_with_count(input_local_q7_aligned, c*chunk_size) + # Run GPSIMD kernel for ISA NonzeroWithCount. + output_local = nki_asm_nonzero_with_count(input_local_gpsimd_core_aligned, c*chunk_size) # Write out rows 0, 32, 64, 96 i_0 = nl.arange(1)[:, None] @@ -161,35 +161,35 @@ def find_nonzero_indices(input_tensor: nt.tensor, i_indices = nl.arange(chunk_size)[None, :] for q in nl.affine_range(n_quadrant): nl.store( - indices[i_0 + E_offset + r * n_q7 + q * n_q7_per_quadrant, i_indices + offsets[i_0, i_1 + q * n_q7_per_quadrant]], + indices[i_0 + E_offset + r * n_gpsimd_cores + q * n_gpsimd_cores_per_quadrant, i_indices + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant]], value=output_local[nl.ds(q*quadrant, 1), 0, nl.ds(0, chunk_size)], - mask=i_0 + q * n_q7_per_quadrant < n_e + mask=i_0 + q * n_gpsimd_cores_per_quadrant < n_e ) - offsets[i_0, i_1 + q * n_q7_per_quadrant] = nl.add( - offsets[i_0, i_1 + q * n_q7_per_quadrant], + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant] = nl.add( + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant], output_local[nl.ds(q*quadrant, 1), 0, nl.ds(chunk_size, 1)], - mask=i_0 + q * n_q7_per_quadrant < n_e + mask=i_0 + q * n_gpsimd_cores_per_quadrant < n_e ) # Stream shuffle to move rows 16, 48, 80, 112 to rows 0, 32, 64, 96, and then write them out quad_mask = [255] * quadrant - quad_mask[0] = n_partition_per_q7 + quad_mask[0] = n_partition_per_gpsimd_core nisa.nc_stream_shuffle(src=output_local, dst=output_local, shuffle_mask = quad_mask) for q in nl.affine_range(n_quadrant): nl.store( - indices[i_0 + E_offset + r * n_q7 + q * n_q7_per_quadrant + 1, i_indices + offsets[i_0, i_1 + q * n_q7_per_quadrant + 1]], + indices[i_0 + E_offset + r * n_gpsimd_cores + q * n_gpsimd_cores_per_quadrant + 1, i_indices + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant + 1]], value=output_local[nl.ds(q*quadrant, 1), 0, nl.ds(0, chunk_size)], - mask=i_0 + q * n_q7_per_quadrant + 1 < n_e + mask=i_0 + q * n_gpsimd_cores_per_quadrant + 1 < n_e ) - offsets[i_0, i_1 + q * n_q7_per_quadrant + 1] = nl.add( - offsets[i_0, i_1 + q * n_q7_per_quadrant + 1], + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant + 1] = nl.add( + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant + 1], output_local[nl.ds(q*quadrant, 1), 0, nl.ds(chunk_size, 1)], - mask=i_0 + q * n_q7_per_quadrant + 1 < n_e + mask=i_0 + q * n_gpsimd_cores_per_quadrant + 1 < n_e ) # Final offsets are the nonzero counts per expert. i_n_e = nl.arange(n_e)[None, :] - nonzero_counts_local[i_0, i_n_e + r * n_q7] = offsets[i_0, i_n_e] + nonzero_counts_local[i_0, i_n_e + r * n_gpsimd_cores] = offsets[i_0, i_n_e] nonzero_counts_reshape = nonzero_counts.reshape((1, E)) nl.store(nonzero_counts_reshape[:, nl.ds(E_offset, E_per_shard)], nonzero_counts_local) diff --git a/src/neuronx_distributed/modules/moe/blockwise.py b/src/neuronx_distributed/modules/moe/blockwise.py index c4a271d..b81d4a3 100644 --- a/src/neuronx_distributed/modules/moe/blockwise.py +++ b/src/neuronx_distributed/modules/moe/blockwise.py @@ -162,7 +162,7 @@ class BlockwiseMatmulArgs: gate_clamp_lower_limit: Optional[float] = None up_clamp_upper_limit: Optional[float] = None up_clamp_lower_limit: Optional[float] = None - + use_shard_on_block_dynamic_while: bool = False # Optional Input tensors conditions: Optional[torch.Tensor] = None num_static_blocks: Optional[int] = None @@ -1229,8 +1229,21 @@ def backward(ctx, grad_output): ) def _call_bwmm_fp4_shard_on_block(args: BlockwiseMatmulArgs): + #directly inject the conditions vector here + # FIXME: move this code to expert_mlps_v2.py instead of injecting it directly here + + # Reshape the token positions into blocks + padded_conditions = None + if args.use_shard_on_block_dynamic_while: + num_blocks = args.block_to_expert.shape[0] + blocks = args.token_position_to_id.view(num_blocks, args.block_size) + # Check each block for non padded tokens (any position != -1) + conditions = torch.any(blocks != -1, dim=1).to(torch.int32) + padded_conditions = torch.cat([conditions, torch.zeros(2, device=conditions.device)]) + output = _bwmm_fp4_shard_on_block_nki_call[VNC(2)]( hidden_states=args.hidden_states, + conditions=padded_conditions, expert_affinities_masked=args.expert_affinities_masked, gate_up_proj_weight=args.gate_up_proj_weight, down_proj_weight=args.down_proj_weight, @@ -1326,12 +1339,14 @@ def forward( gate_clamp_lower_limit=gate_clamp_lower_limit, up_clamp_upper_limit=up_clamp_upper_limit, up_clamp_lower_limit=up_clamp_lower_limit, + use_shard_on_block_dynamic_while=blockwise_matmul_config.use_shard_on_block_dynamic_while, gate_up_proj_bias=gate_up_proj_bias, down_proj_bias=down_proj_bias, kernel_act_fn=ActivationFunction(kernel_act_fn_id), is_tensor_update_accumulating=multi_expert_per_token, expert_affinities_scaling_mode=expert_affinities_scaling_mode, - output=output + output=output, + skip_dma=skip_dma, ) output = _call_bwmm_fp4_shard_on_block(args) diff --git a/src/neuronx_distributed/modules/moe/expert_mlps_v2.py b/src/neuronx_distributed/modules/moe/expert_mlps_v2.py index 1531a1a..732aa24 100644 --- a/src/neuronx_distributed/modules/moe/expert_mlps_v2.py +++ b/src/neuronx_distributed/modules/moe/expert_mlps_v2.py @@ -719,6 +719,11 @@ def forward_blockwise(self, hidden_states, expert_affinities, expert_index, expe num_blocks = math.ceil((total_tokens * self.routed_experts_mlp_config.top_k - (local_experts - 1)) / self.blockwise_matmul_config.block_size) + local_experts - 1 # Handle case where T*top_k is smaller than E. We will need atmost T*top_k blocks. num_blocks = min(num_blocks, total_tokens * self.routed_experts_mlp_config.top_k) + # Padding num_blocks to even (TODO: currently only for MXFP4 BWMM kernel to support dynamic while) + + pad_num_blocks_to_even = self.blockwise_matmul_config.pad_num_blocks_to_even + if pad_num_blocks_to_even: + num_blocks += num_blocks % 2 # Get num_static_block from blockwise_matmul_config, if not set, the default num_static_blocks will be computed as # NUM_STATIC_BLOCK = T * TopK / (EP_degree * B) @@ -754,6 +759,7 @@ def forward_blockwise(self, hidden_states, expert_affinities, expert_index, expe tensor_parallel_group=self.moe_tensor_model_parallel_group, expert_parallel_group=self.moe_expert_model_parallel_group, logical_nc_config=self.blockwise_matmul_config.logical_nc_config, + pad_num_blocks_to_even=pad_num_blocks_to_even ) else: # expert_mask: (T, E). Still happens on all the experts, not just the local experts @@ -823,6 +829,7 @@ def forward_blockwise(self, hidden_states, expert_affinities, expert_index, expe tensor_parallel_group=self.moe_tensor_model_parallel_group, optimized_block_to_token_mapping=ues_optimized_block_to_token_mapping, parallelize_token_to_block_mapping=self.blockwise_matmul_config.parallelize_token_to_block_mapping, + pad_num_blocks_to_even=pad_num_blocks_to_even, ) if self.blockwise_matmul_config.use_shard_on_block_dynamic_while: @@ -1079,6 +1086,7 @@ def get_blockwise_expert_and_token_mapping_kernel( tensor_parallel_group: ProcessGroup, expert_parallel_group: ProcessGroup, logical_nc_config: int, + pad_num_blocks_to_even: bool = False, ): ''' Equivalent function of get_blockwise_expert_and_token_mapping, but nstead of torch code, @@ -1109,6 +1117,7 @@ def get_blockwise_expert_and_token_mapping_kernel( tp_rank = torch.remainder(global_rank, tp_size) ep_size = expert_parallel_group.size() ep_rank = global_rank // tp_size + max_chunk_size = 16384 # Every EP rank needs the information about its local expert (`E_local` of those). E_local = E // ep_size @@ -1123,7 +1132,7 @@ def get_blockwise_expert_and_token_mapping_kernel( input_tensor=expert_affinities_masked.to(torch.float32), row_start_id=ep_rank*E_kernel, n_rows=E_kernel, - chunk_size = T, + chunk_size=min(T, max_chunk_size), index_dtype=nl.int32, ) else: @@ -1136,7 +1145,7 @@ def get_blockwise_expert_and_token_mapping_kernel( input_tensor=expert_affinities_masked.to(torch.float32), row_start_id=global_rank*E_kernel, n_rows=E_kernel, - chunk_size = T, + chunk_size=min(T, max_chunk_size), index_dtype=nl.int32, ) # Gather nonzero_counts: [E_kernel,] --> [E/EP,] @@ -1146,6 +1155,13 @@ def get_blockwise_expert_and_token_mapping_kernel( # Get number of blocks and cumulative number of blocks per expert. blocks_per_expert = ((nonzero_counts + block_size - 1) // block_size).to(dtype=torch.long) # (E_EP,) + + # Calculate padding blocks needed and add to last expert + if pad_num_blocks_to_even: + total_needed_blocks = torch.sum(blocks_per_expert) + padding_blocks = num_blocks - total_needed_blocks + blocks_per_expert[-1] += padding_blocks + blocks_per_expert_expanded = blocks_per_expert.unsqueeze(1) # (E_EP, 1) cum_blocks_per_expert = cumsum(blocks_per_expert_expanded) # (E_EP, 1) cum_blocks_per_expert[1:] = cum_blocks_per_expert[:-1] @@ -1199,7 +1215,8 @@ def get_blockwise_expert_and_token_mapping( tensor_parallel_group, optimized_block_to_token_mapping=True, parallelize_token_to_block_mapping=False, - ): + pad_num_blocks_to_even=False, + ): """ Token position: position in blocks. E.g. given block_size=2, num_token=6. The following expert_mask @@ -1242,6 +1259,13 @@ def get_blockwise_expert_and_token_mapping( tokens_per_expert = torch.sum(expert_mask, dim=0) # blocks_per_expert: (E, ) blocks_per_expert = ((tokens_per_expert + block_size - 1) // block_size).to(dtype=torch.long) + + # Calculate padding blocks needed and add to last expert + if pad_num_blocks_to_even: + total_needed_blocks = torch.sum(blocks_per_expert) + padding_blocks = num_blocks - total_needed_blocks + blocks_per_expert[-1] += padding_blocks + # block_to_expert: (N, ). Block id to expert id mapping. # The simplest way to do this is to use repeat_interleave after padding blocks_per_expert with unassigned blocks. # But this op is not lowered to xla with vector 'repeats', so we use the equivalent implementation below. diff --git a/src/neuronx_distributed/modules/moe/model.py b/src/neuronx_distributed/modules/moe/model.py index d57af25..d7ba92d 100644 --- a/src/neuronx_distributed/modules/moe/model.py +++ b/src/neuronx_distributed/modules/moe/model.py @@ -152,35 +152,9 @@ def __init__( self.moe_fused_tkg.eval() def _forward_compute_bound(self, hidden_states, padding_mask=None): - """Forward pass of the MoE layer. - - Common nomenclature: - S: Sequence length, B: Batch size, H: Hidden Size - S': Sequence length (when the input is in SP) - T: Tokens = S * B (token dimension obtained by flattening S and B) - - Layout of input hidden_states: - - In the training flow, - With SP enabled : (S', B, H) - With SP disabled : (B, S, H) - - In the inference flow, - With SP enabled : (B, S', H) - With SP disabled : (B, S, H) - - Arguments: - hidden_states: Input tensor (of shape as described above) - padding_mask: (Optional) Padding mask for the input tokens. If passed in will mask out - expert affinities mask and expert mask for the padded tokens. - - Returns: - output: Output tensor of the same shape as hidden_states, containing the output of the MoE layer. - bias: (Optional) Returned if expert_mlps.return_bias is True. Currently bias is not supported for the MoE layer. - router_logits: (Optional) Tensor of shape (T, E) containing the router logits for each token. - Returned if self.return_router_logits is True. - expert_index: (Optional) Tensor of shape (T, E) containing the experts assigned to each token. - Returned if self.return_expert_index is True. """ - + MoE forward pass for compute-bound workloads. Compute-bound workloads are almost always context encoding workloads. + """ if self.rmsnorm is not None: hidden_states = self.rmsnorm(hidden_states) @@ -283,11 +257,50 @@ def _forward_compute_bound(self, hidden_states, padding_mask=None): return_op += (expert_index,) return return_op - def forward(self, hidden_states, padding_mask=None, is_speculative_decoding=False): + def forward(self, hidden_states, padding_mask=None, is_speculative_decoding=False, residual=None): + """ + Forward pass of the MoE layer. + + Common nomenclature: + S: Sequence length, B: Batch size, H: Hidden Size + S': Sequence length (when the input is in SP) + T: Tokens = S * B (token dimension obtained by flattening S and B) + + Layout of input hidden_states: + - In the training flow, + With SP enabled : (S', B, H) + With SP disabled : (B, S, H) + - In the inference flow, + With SP enabled : (B, S', H) + With SP disabled : (B, S, H) + + Arguments: + hidden_states: Input tensor (of shape as described above). + padding_mask: (Optional) Padding mask for the input tokens. If passed in will mask out + expert affinities mask and expert mask for the padded tokens. + is_speculative_decoding: (Optional) Indicates whether the current forward pass is using speculative decoding. + residual: (Optional) Residual tensor same layout as hidden_states. + + Returns: + output: Output tensor of the same shape as hidden_states, containing the output of the MoE layer. + bias: (Optional) Returned if expert_mlps.return_bias is True. Currently bias is not supported for the MoE layer. + router_logits: (Optional) Tensor of shape (T, E) containing the router logits for each token. + Returned if self.return_router_logits is True. + expert_index: (Optional) Tensor of shape (T, E) containing the experts assigned to each token. + Returned if self.return_expert_index is True. + residual: (Optional) Tensor of same shape as hidden_states, containing the residual connection from the beginning of the MoE layer. + """ seq_len = hidden_states.shape[self.sequence_dimension] if (seq_len == 1 or is_speculative_decoding) and self.moe_fused_tkg is not None: # spec decode and normal tkg will use the fused kernel - return self.moe_fused_tkg(hidden_states) - return self._forward_compute_bound(hidden_states, padding_mask) + return self.moe_fused_tkg(hidden_states, residual=residual) + else: + # Fused residual add is not supported in _forward_compute_bound, so compute it in torch here. + if residual is not None: + hidden_states = hidden_states + residual + residual = hidden_states.clone() + return self._forward_compute_bound(hidden_states, padding_mask) + (residual,) + else: + return self._forward_compute_bound(hidden_states, padding_mask) def _apply_shared_experts(self, output, full_hidden_states, hidden_states, hidden_states_shape, seq_len): """ diff --git a/src/neuronx_distributed/modules/moe/moe_configs.py b/src/neuronx_distributed/modules/moe/moe_configs.py index 27cf4b3..da68e18 100644 --- a/src/neuronx_distributed/modules/moe/moe_configs.py +++ b/src/neuronx_distributed/modules/moe/moe_configs.py @@ -56,6 +56,7 @@ def __init__(self, use_shard_on_intermediate_dynamic_while: bool, use_shard_on_block_dynamic_while: bool, num_static_blocks: int, + pad_num_blocks_to_even: bool = False, ): self.block_size = block_size self.logical_nc_config = logical_nc_config @@ -71,6 +72,7 @@ def __init__(self, self.use_shard_on_intermediate_dynamic_while = use_shard_on_intermediate_dynamic_while self.use_shard_on_block_dynamic_while = use_shard_on_block_dynamic_while self.num_static_blocks = num_static_blocks + self.pad_num_blocks_to_even = pad_num_blocks_to_even #TODO: refactor this function @staticmethod @@ -107,6 +109,7 @@ def from_kwargs(**kwargs): block_sharding_strategy = BlockShardStrategy.PING_PONG else: raise ValueError(f"Unsupported block_sharding_strategy: {block_sharding_strategy}") + pad_num_blocks_to_even = kwargs.pop("pad_num_blocks_to_even", False) return BlockwiseMatmulConfig(block_size=block_size, use_block_parallel=use_block_parallel, @@ -122,6 +125,7 @@ def from_kwargs(**kwargs): use_shard_on_intermediate_dynamic_while=use_shard_on_intermediate_dynamic_while, use_shard_on_block_dynamic_while=use_shard_on_block_dynamic_while, num_static_blocks=num_static_blocks, + pad_num_blocks_to_even=pad_num_blocks_to_even, ) @staticmethod @@ -254,16 +258,16 @@ class RouterConfig: def __init__( self, act_fn: str = "softmax", - dtype: torch.dtype = torch.float32): + dtype: Optional[Union[torch.dtype, str]] = torch.float32): self.act_fn = act_fn + if isinstance(dtype, str): + dtype = to_torch_dtype(dtype) self.dtype = dtype @staticmethod def from_kwargs(**kwargs): act_fn = kwargs.pop("router_act_fn", "softmax") dtype = kwargs.pop("router_dtype", torch.float32) - if isinstance(dtype, str): - dtype = to_torch_dtype(dtype) return RouterConfig(act_fn=act_fn, dtype=dtype) class MoEFusedTKGConfig: @@ -275,7 +279,8 @@ def __init__( expert_mlp_kernel_enabled: Optional[bool] = None, shared_mlp_kernel_enabled: Optional[bool] = None, norm_topk_prob: bool = False, # Boolean to normalize top k expert affinities, defaults to no normalization - is_mxfp4_compute: Optional[bool] = None + is_mxfp4_compute: Optional[bool] = None, + router_mm_dtype: torch.dtype = torch.float32, ): self.quantized = quantized self.moe_fused_kernel_enabled = moe_fused_kernel_enabled @@ -284,3 +289,4 @@ def __init__( self.shared_mlp_kernel_enabled = shared_mlp_kernel_enabled self.norm_topk_prob = norm_topk_prob self.is_mxfp4_compute = is_mxfp4_compute + self.router_mm_dtype = router_mm_dtype diff --git a/src/neuronx_distributed/modules/moe/moe_fused_tkg.py b/src/neuronx_distributed/modules/moe/moe_fused_tkg.py index 23ed521..d93456b 100644 --- a/src/neuronx_distributed/modules/moe/moe_fused_tkg.py +++ b/src/neuronx_distributed/modules/moe/moe_fused_tkg.py @@ -102,6 +102,16 @@ def expert_isa_kernel_wrapper( ) # [T, H] return out +def _convert_torch_dtype_to_nki_dtype(dtype: torch.dtype): + TORCH_NKI_DTYPE_MAP = { + torch.float16: nl.float16, + torch.bfloat16: nl.bfloat16, + torch.float32: nl.float32, + } + + assert dtype in TORCH_NKI_DTYPE_MAP.keys(), f"expected dtype in {TORCH_NKI_DTYPE_MAP.keys()}, got {dtype=}" + return TORCH_NKI_DTYPE_MAP[dtype] + def _post_create_quantized_module_hook(layer): # This is a workaround in order to avoid loading weights for MoEFusedTKG during tracing. # Quantized modules implement these functions so they will always attempt to load these @@ -226,6 +236,12 @@ def _can_use_nki_kernel(self, kernel_type, hidden_states): return True + def _can_use_fused_residual_add(self, hidden_states): + """ + Currently, no fused kernels called by MoEFusedTKG support fused residual add. + """ + return False + def _slice_shared_experts_weights(self): """ When sequence parallel is enabled for shared experts, their weights will be replicated on each core. @@ -248,7 +264,6 @@ def _slice_shared_experts_weights(self): shared_experts_down_proj_weight = shared_experts.down_proj.weight return shared_experts_gate_proj_weight, shared_experts_up_proj_weight, shared_experts_down_proj_weight - def _router_topk(self, hidden_states): """ Args: @@ -404,15 +419,18 @@ def _shared_mlp(self, hidden_states): return shared_output - def _moe_fused_tkg_kernel(self, hidden_states): + def _moe_fused_tkg_kernel(self, hidden_states, residual=None): """ Args: hidden_states: [B, S, H] or [S, B, H] + residual (optional): None TODO CR Returns: output: original shape + router_logits: [B*S, E] """ hidden_states_shape = hidden_states.shape + router_mm_dtype = _convert_torch_dtype_to_nki_dtype(self.config.router_mm_dtype) if self.expert_mlps.routed_experts_mlp_config.early_expert_affinity_modulation: expert_affinities_scaling_mode = ExpertAffinityScaleMode.PRE_SCALE else: @@ -445,7 +463,7 @@ def _moe_fused_tkg_kernel(self, hidden_states): top_k=self.num_experts_per_tok, router_act_fn=ROUTER_ACT_FN_MAPPING[self.router.act_fn], expert_affinities_scaling_mode=expert_affinities_scaling_mode, - router_mm_dtype=nl.float32, + router_mm_dtype=router_mm_dtype, ) # this is a temporary check that can be removed once release compiler supports @@ -507,24 +525,31 @@ def _moe_fused_tkg_kernel(self, hidden_states): return out.view(hidden_states_shape), router_logits.to(hidden_states.dtype) - def forward(self, hidden_states): + def forward(self, hidden_states, residual=None): """ Forward through MoE TKG mega-kernel if conditions are satisfied. Otherwise forward through flat compiler / individual kernels. - Conditions for MoE TKG mega-kernel: - - batch_size <= 32 - - must use RMSNorm, RouterTopK, ExpertMLPs, SharedExperts - Args: hidden_states: [B, S, H] or [S, B, H] + residual (optional): [B, S, H] or [S, B, H] Returns: output: original shape + router_logits: [B*S, E] + residual (optional): residual tensor of same shape as output """ + + if not self._can_use_fused_residual_add(hidden_states) and residual is not None: + hidden_states = hidden_states + residual + residual = hidden_states.clone() + if self._can_use_nki_kernel("moe_fused", hidden_states): logger.info("Running MoE Fused NKI kernel") - output, router_logits = self._moe_fused_tkg_kernel(hidden_states) + if not self._can_use_fused_residual_add(hidden_states) or residual is None: + output, router_logits = self._moe_fused_tkg_kernel(hidden_states) + else: + output, router_logits, residual = self._moe_fused_tkg_kernel(hidden_states, residual=residual) if self.return_expert_index: # return_expert_index not supported in kernel, return emtpy tensor to match with cte tracing when return_expert_index is set to True @@ -561,6 +586,8 @@ def forward(self, hidden_states): return_op += (router_logits,) if self.return_expert_index: return_op += (expert_index,) + if residual is not None: + return_op += (residual,) return return_op diff --git a/src/neuronx_distributed/modules/moe/moe_fused_tkg_mx.py b/src/neuronx_distributed/modules/moe/moe_fused_tkg_mx.py index cdbb35b..e5d19f9 100644 --- a/src/neuronx_distributed/modules/moe/moe_fused_tkg_mx.py +++ b/src/neuronx_distributed/modules/moe/moe_fused_tkg_mx.py @@ -2,6 +2,7 @@ import logging import warnings from typing import Optional +from importlib import import_module from neuronxcc import nki import neuronxcc.nki.language as nl @@ -20,7 +21,7 @@ from neuronx_distributed.modules.moe.routing import RouterBase from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPsV2 from neuronx_distributed.modules.moe.shared_experts import SharedExperts -from neuronx_distributed.modules.moe.moe_fused_tkg import MoEFusedTKG, ROUTER_ACT_FN_MAPPING +from neuronx_distributed.modules.moe.moe_fused_tkg import MoEFusedTKG, ROUTER_ACT_FN_MAPPING, _convert_torch_dtype_to_nki_dtype logger = logging.getLogger("Neuron") @@ -34,6 +35,7 @@ def initialize_nki_components() -> dict: dict: Mapping of component names to their imported values """ imports = { + "moe_token_gen_forward_all_experts": NKIImport("moe_token_gen_all_experts_kernel", module_name="moe_token_gen"), "moe_token_gen_selective_load_kernel": NKIImport("moe_token_gen_selective_load_kernel", module_name="moe_token_gen"), "nki_expert_mlp_tkg_isa_kernel": NKIImport("nki_expert_mlp_tkg_isa_kernel", module_name="mlp_tkg.expert_mlp_tkg_isa"), "float4_e2m1fn_x4": NKIImport("float4_e2m1fn_x4", module_name="private_api"), @@ -60,11 +62,12 @@ def initialize_nki_components() -> dict: _nki_expert_mlp_tkg_isa_kernel_call = nki_components["nki_expert_mlp_tkg_isa_kernel"] _moe_token_gen_selective_load_kernel_call = nki_components["moe_token_gen_selective_load_kernel"] +_moe_tkg_forward_all_experts_nki_call = nki_components["moe_token_gen_forward_all_experts"] @nki.compiler.skip_middle_end_transformations @nki.jit(show_compiler_tb=True, debug_kernel=True, experimental_flags='skip-non-top-level-shared-hbm-check') -def mxfp4_nki_expert_mlp_tkg_isa_kernel_wrapper( +def mxfp4_nki_expert_mlp_tkg_isa_standalone_kernel_wrapper( inp: nl.ndarray, gate_up_weights: nl.ndarray, down_weights: nl.ndarray, @@ -112,6 +115,85 @@ def mxfp4_nki_expert_mlp_tkg_isa_kernel_wrapper( base_addr=base_addr, ) + +@nki.compiler.skip_middle_end_transformations +@nki.jit(show_compiler_tb=True, debug_kernel=True, experimental_flags='skip-non-top-level-shared-hbm-check') +def mxfp4_moe_token_gen_forward_all_experts_kernel_wrapper( + inp: nl.ndarray, + gamma: nl.ndarray, + router_weights: nl.ndarray, + expert_gate_up_weights: nl.ndarray, + expert_down_weights: nl.ndarray, + rank_id: nl.ndarray, + shared_expert_gate_w: Optional[nl.ndarray] = None, + shared_expert_up_w: Optional[nl.ndarray] = None, + shared_expert_down_w: Optional[nl.ndarray] = None, + expert_gate_up_weights_scale: Optional[nl.ndarray] = None, + expert_down_weights_scale: Optional[nl.ndarray] = None, + router_bias: Optional[nl.ndarray] = None, + expert_gate_up_bias: Optional[nl.ndarray] = None, + expert_down_bias: Optional[nl.ndarray] = None, + shared_expert_gate_bias: Optional[nl.ndarray] = None, + shared_expert_up_bias: Optional[nl.ndarray] = None, + shared_expert_down_bias: Optional[nl.ndarray] = None, + eps: float = 1e-6, + top_k: int = 1, + router_act_fn = RouterActFnType.SIGMOID, + router_pre_norm: bool = True, + norm_topk_prob = False, + expert_affinities_scaling_mode = ExpertAffinityScaleMode.NO_SCALE, + hidden_act_fn = ActFnType.SiLU, + hidden_act_scale_factor: Optional[float] = None, + hidden_act_bias: Optional[float] = None, + gate_clamp_upper_limit: Optional[float] = None, + gate_clamp_lower_limit: Optional[float] = None, + up_clamp_upper_limit: Optional[float] = None, + up_clamp_lower_limit: Optional[float] = None, + router_mm_dtype = nl.bfloat16, + hidden_actual: Optional[int] = None, + residual: nl.ndarray = None, +): + expert_gate_up_weights = expert_gate_up_weights.view(float4_e2m1fn_x4) + expert_down_weights = expert_down_weights.view(float4_e2m1fn_x4) + return _moe_tkg_forward_all_experts_nki_call( + inp=inp, + gamma=gamma, + router_weights=router_weights, + expert_gate_up_weights=expert_gate_up_weights, + expert_down_weights=expert_down_weights, + rank_id=rank_id, + shared_expert_gate_w=shared_expert_gate_w, + shared_expert_up_w=shared_expert_up_w, + shared_expert_down_w=shared_expert_down_w, + expert_gate_up_weights_scale=expert_gate_up_weights_scale, + expert_down_weights_scale=expert_down_weights_scale, + router_bias=router_bias, + expert_gate_up_bias=expert_gate_up_bias, + expert_down_bias=expert_down_bias, + shared_expert_gate_bias=shared_expert_gate_bias, + shared_expert_up_bias=shared_expert_up_bias, + shared_expert_down_bias=shared_expert_down_bias, + eps=eps, + top_k=top_k, + router_act_fn=router_act_fn, + router_pre_norm=router_pre_norm, + norm_topk_prob=norm_topk_prob, + expert_affinities_scaling_mode=expert_affinities_scaling_mode, + hidden_act_fn=hidden_act_fn, + hidden_act_scale_factor=hidden_act_scale_factor, + hidden_act_bias=hidden_act_bias, + gate_clamp_upper_limit=gate_clamp_upper_limit, + gate_clamp_lower_limit=gate_clamp_lower_limit, + up_clamp_upper_limit=up_clamp_upper_limit, + up_clamp_lower_limit=up_clamp_lower_limit, + router_mm_dtype=router_mm_dtype, + hidden_actual=hidden_actual, + # router topk kernel must always be enabled + use_router_topk_nki_kernel=True, + residual=residual, + ) + + @nki.compiler.skip_middle_end_transformations @nki.jit(show_compiler_tb=True, debug_kernel=True, experimental_flags='skip-non-top-level-shared-hbm-check') def mxfp4_moe_token_gen_selective_load_kernel_wrapper( @@ -146,7 +228,6 @@ def mxfp4_moe_token_gen_selective_load_kernel_wrapper( up_clamp_lower_limit: Optional[float] = None, router_mm_dtype = nl.bfloat16, hidden_actual: Optional[int] = None, - use_router_topk_nki_kernel: bool = False ): """ Wrapper kernel for MoE TKG fused kernel that bitcasts FP4 weights to float4x4 NKI dtype prior to calling the kernel @@ -185,7 +266,8 @@ def mxfp4_moe_token_gen_selective_load_kernel_wrapper( up_clamp_lower_limit=up_clamp_lower_limit, router_mm_dtype=router_mm_dtype, hidden_actual=hidden_actual, - use_router_topk_nki_kernel=use_router_topk_nki_kernel, + # router topk kernel must always be enabled + use_router_topk_nki_kernel=True, ) class MoEFusedTKGMX(MoEFusedTKG): @@ -210,6 +292,7 @@ def __init__( logger.info("Selected MXFP4 variant of MoE Fused TKG") def _prepare_kernel_inputs(self): + router_mm_dtype = _convert_torch_dtype_to_nki_dtype(self.config.router_mm_dtype) routed_experts_mlp_config = self.expert_mlps.routed_experts_mlp_config if routed_experts_mlp_config.early_expert_affinity_modulation: expert_affinities_scaling_mode = ExpertAffinityScaleMode.PRE_SCALE @@ -223,11 +306,13 @@ def _prepare_kernel_inputs(self): hidden_size = self.expert_mlps.mlp_op.gate_up_proj.input_size assert hidden_size % 512 == 0, f"Hidden size must be divisible by 512, got {hidden_size}" - input_size_per_partition = self.expert_mlps.mlp_op.down_proj.input_size_per_partition - assert input_size_per_partition % 4 == 0, f"Intermediate size must be divisible by 4, got {input_size_per_partition}" - num_I_TP_blocks = math.ceil(input_size_per_partition / 512.0) + intermediate_size_per_partition = self.expert_mlps.mlp_op.down_proj.input_size_per_partition + assert intermediate_size_per_partition % 4 == 0, f"Intermediate size must be divisible by 4, got {intermediate_size_per_partition}" + num_I_TP_blocks = math.ceil(intermediate_size_per_partition / 512.0) + assert intermediate_size_per_partition % num_I_TP_blocks == 0, f"{intermediate_size_per_partition=} must be divisible by {num_I_TP_blocks=}" + I_TP_block_size = intermediate_size_per_partition // num_I_TP_blocks gate_up_weights_bias = self.expert_mlps.mlp_op.gate_up_proj.bias - gate_up_weights_bias = gate_up_weights_bias.view(self.num_local_experts, input_size_per_partition // 4, 2, num_I_TP_blocks, 4) + gate_up_weights_bias = gate_up_weights_bias.view(self.num_local_experts, I_TP_block_size // 4, 2, num_I_TP_blocks, 4) # run kernels that's compatible with clamp, bias, non-shared experts, SWIGLU optional_kwargs = {} @@ -240,8 +325,11 @@ def _prepare_kernel_inputs(self): optional_kwargs["up_clamp_upper_limit"] = routed_experts_mlp_config.up_clamp_upper_limit if routed_experts_mlp_config.up_clamp_lower_limit is not None: optional_kwargs["up_clamp_lower_limit"] = routed_experts_mlp_config.up_clamp_lower_limit + if self.router.bias: + optional_kwargs["router_bias"] = self.router.linear_router.bias.unsqueeze(0) return dict( + router_mm_dtype=router_mm_dtype, kernel_activation_func_id=kernel_activation_func_id, expert_affinities_scaling_mode=expert_affinities_scaling_mode, down_weights=self.expert_mlps.mlp_op.down_proj.weight, @@ -250,10 +338,45 @@ def _prepare_kernel_inputs(self): gate_up_weights_scale=self.expert_mlps.mlp_op.gate_up_proj.scale, gate_up_weights_bias=gate_up_weights_bias, down_weights_bias=self.expert_mlps.mlp_op.down_proj.bias, + router_weights=self.router.weight_T, **optional_kwargs, ) - def _moe_fused_tkg_kernel(self, hidden_states): + def _should_use_all_expert(self, hidden_states): + """ + Helper function to determine whether to use selective loading or all experts algorithm for MoE compute. Selective loading is + generally more performant when B * S * topk < E, and all experts is more performant when B * S * topk > E. + """ + + hidden_states_shape = hidden_states.shape + total_tokens = hidden_states_shape[0] * hidden_states_shape[1] + perc_experts_loaded = total_tokens * self.num_experts_per_tok / self.num_local_experts + return perc_experts_loaded >= DEFAULT_SELECTIVE_LOADING_THRESHOLD + + def _can_use_fused_residual_add(self, hidden_states): + """ + Helper function to determine whether we can use fused residual add feature inside fused TKG kernel. Currently + fused residual add is only supported in recent versions of all expert kernel when MXFP weights are used. + """ + + # MOE_ALL_EXPERTS_FUSED_RESIDUAL_SUPPORT constant is used to determine whether the imported all expert kernel supportes + # fused residual add. When this constant is not importable, we know that the imported version of the kernel does not support this feature. + mod = import_module("neuronxcc.nki._pre_prod_kernels.moe_token_gen") + _has_fused_residual_add_support = getattr(mod, "MOE_ALL_EXPERTS_FUSED_RESIDUAL_SUPPORT", False) + logger.info(f"Residual add support: {_has_fused_residual_add_support=}") + return self._should_use_all_expert(hidden_states) and _has_fused_residual_add_support + + def _moe_fused_tkg_kernel(self, hidden_states, residual=None): + """ + Args: + hidden_states: [B, S, H] or [S, B, H] + residual (optional): [B, S, H] or [S, B, H] + + Returns: + output: original shape + router_logits: [B*S, E] + residual (optional): same shape as output + """ hidden_states_shape = hidden_states.shape local_rank = self.expert_mlps.spmd_rank.get_rank() # TODO: make this compatible with hybrid sharding, current issue is moe_tensor_model_parallel_group will be the tensor_model_parallel_group used in CTE @@ -268,56 +391,104 @@ def _moe_fused_tkg_kernel(self, hidden_states): if routed_experts_mlp_config.hidden_size_actual is not None: prepared_kernel_inputs["hidden_actual"] = self.expert_mlps.routed_experts_mlp_config.hidden_size_actual - total_tokens = hidden_states_shape[0] * hidden_states_shape[1] - perc_experts_loaded = total_tokens * self.num_experts_per_tok / self.num_local_experts - - kernel_call = None - if (perc_experts_loaded >= DEFAULT_SELECTIVE_LOADING_THRESHOLD): + if self._should_use_all_expert(hidden_states): logger.info("Percentage of experts loaded >= selective loading threshold, run forward all experts fused megakernel") - raise NotImplementedError("Forward all experts fused megakernel not integrated yet for MXFP4") - prepared_kernel_inputs["rank_id"] = local_ep_rank.reshape(1, 1) - elif (perc_experts_loaded < DEFAULT_SELECTIVE_LOADING_THRESHOLD and self.shared_experts is None): - logger.info("Run MXFP4 selective loading fused megakernel: _moe_token_gen_selective_load_kernel_nki_call") - # kernel_call = mxfp4_moe_token_gen_selective_load_kernel_wrapper - kernel_call = mxfp4_moe_token_gen_selective_load_kernel_wrapper - - out, router_logits = kernel_call[grid]( - inp=hidden_states, - gamma=self.post_attention_layernorm.weight.unsqueeze(0), - router_weights=self.router.weight_T, - shared_expert_gate_w=shared_experts_gate_proj_weight, - shared_expert_up_w=shared_experts_up_proj_weight, - shared_expert_down_w=shared_experts_down_proj_weight, - expert_gate_up_weights=prepared_kernel_inputs["gate_up_weights"], # [E, 128, 2, H/512, I_TP] - expert_down_weights=prepared_kernel_inputs["down_weights"], # [E, I_TP//4, num_I_TP_blocks, H] - expert_gate_up_weights_scale=prepared_kernel_inputs["gate_up_weights_scale"], # [E, 128/8, 2, H/512, I_TP] - expert_down_weights_scale=prepared_kernel_inputs["down_weights_scale"], # [E, I_TP//4, num_I_TP_blocks, H] - eps=self.post_attention_layernorm.variance_epsilon, - top_k=self.num_experts_per_tok, - router_act_fn=ROUTER_ACT_FN_MAPPING[self.router.act_fn], - expert_affinities_scaling_mode=prepared_kernel_inputs["expert_affinities_scaling_mode"], - router_mm_dtype=nl.float32, - router_bias=self.router.linear_router.bias if self.router.bias else None, - expert_gate_up_bias=prepared_kernel_inputs["gate_up_weights_bias"] if routed_experts_mlp_config.bias else None, - expert_down_bias=prepared_kernel_inputs["down_weights_bias"] if routed_experts_mlp_config.bias else None, - shared_expert_gate_bias=None, # kernel only supports None - shared_expert_up_bias=None, # kernel only supports None - shared_expert_down_bias=None, # kernel only supports None - router_pre_norm=not self.router.apply_act_fn_over_topk, - hidden_act_fn=ActFnType(prepared_kernel_inputs["kernel_activation_func_id"]), - hidden_act_scale_factor=None, # kernel only supports None - hidden_act_bias=None, # kernel only supports None - norm_topk_prob=self.config.norm_topk_prob, - gate_clamp_upper_limit=prepared_kernel_inputs.get("gate_clamp_upper_limit"), - gate_clamp_lower_limit=prepared_kernel_inputs.get("gate_clamp_lower_limit"), - up_clamp_upper_limit=prepared_kernel_inputs.get("up_clamp_upper_limit"), - up_clamp_lower_limit=prepared_kernel_inputs.get("up_clamp_lower_limit"), - hidden_actual=prepared_kernel_inputs.get("hidden_actual"), - ) + kernel_kwargs = dict( + rank_id=local_ep_rank.reshape(1, 1), + inp=hidden_states, + gamma=self.post_attention_layernorm.weight.unsqueeze(0), + router_weights=prepared_kernel_inputs["router_weights"], + shared_expert_gate_w=shared_experts_gate_proj_weight, + shared_expert_up_w=shared_experts_up_proj_weight, + shared_expert_down_w=shared_experts_down_proj_weight, + expert_gate_up_weights=prepared_kernel_inputs["gate_up_weights"], # [E, 128, 2, H/512, I_TP] + expert_down_weights=prepared_kernel_inputs["down_weights"], # [E, 128, num_I_TP_blocks, H] + expert_gate_up_weights_scale=prepared_kernel_inputs["gate_up_weights_scale"], # [E, 128/8, 2, H/512, I_TP] + expert_down_weights_scale=prepared_kernel_inputs["down_weights_scale"], # [E, 128/8, num_I_TP_blocks, H] + eps=self.post_attention_layernorm.variance_epsilon, + top_k=self.num_experts_per_tok, + router_act_fn=ROUTER_ACT_FN_MAPPING[self.router.act_fn], + expert_affinities_scaling_mode=prepared_kernel_inputs["expert_affinities_scaling_mode"], + router_mm_dtype=prepared_kernel_inputs["router_mm_dtype"], + router_bias=prepared_kernel_inputs.get("router_bias"), + expert_gate_up_bias=prepared_kernel_inputs.get("gate_up_weights_bias"), + expert_down_bias=prepared_kernel_inputs.get("down_weights_bias"), + shared_expert_gate_bias=None, # kernel only supports None + shared_expert_up_bias=None, # kernel only supports None + shared_expert_down_bias=None, # kernel only supports None + router_pre_norm=not self.router.apply_act_fn_over_topk, + hidden_act_fn=ActFnType(prepared_kernel_inputs["kernel_activation_func_id"]), + hidden_act_scale_factor=None, # kernel only supports None + hidden_act_bias=None, # kernel only supports None + norm_topk_prob=self.config.norm_topk_prob, + gate_clamp_upper_limit=prepared_kernel_inputs.get("gate_clamp_upper_limit"), + gate_clamp_lower_limit=prepared_kernel_inputs.get("gate_clamp_lower_limit"), + up_clamp_upper_limit=prepared_kernel_inputs.get("up_clamp_upper_limit"), + up_clamp_lower_limit=prepared_kernel_inputs.get("up_clamp_lower_limit"), + hidden_actual=prepared_kernel_inputs.get("hidden_actual"), + ) + # Utilize fused residual add when possible. Residual add is computed in parent class forward call if fused residual add is not available. + if self._can_use_fused_residual_add(hidden_states) and residual is not None: + kernel_kwargs["residual"] = residual + out, router_logits, residual = mxfp4_moe_token_gen_forward_all_experts_kernel_wrapper[grid](**kernel_kwargs) + return out.view(hidden_states_shape), router_logits.to(hidden_states.dtype), residual.view(hidden_states_shape) + else: + out, router_logits = mxfp4_moe_token_gen_forward_all_experts_kernel_wrapper[grid](**kernel_kwargs) + return out.view(hidden_states_shape), router_logits.to(hidden_states.dtype) - return out.view(hidden_states_shape), router_logits.to(hidden_states.dtype) + elif self.shared_experts is None: + logger.info("Run MXFP4 selective loading fused megakernel: _moe_token_gen_selective_load_kernel_nki_call") + out, router_logits = mxfp4_moe_token_gen_selective_load_kernel_wrapper[grid]( + inp=hidden_states, + gamma=self.post_attention_layernorm.weight.unsqueeze(0), + router_weights=self.router.weight_T, + shared_expert_gate_w=shared_experts_gate_proj_weight, + shared_expert_up_w=shared_experts_up_proj_weight, + shared_expert_down_w=shared_experts_down_proj_weight, + expert_gate_up_weights=prepared_kernel_inputs["gate_up_weights"], # [E, 128, 2, H/512, I_TP] + expert_down_weights=prepared_kernel_inputs["down_weights"], # [E, I_TP//4, num_I_TP_blocks, H] + expert_gate_up_weights_scale=prepared_kernel_inputs["gate_up_weights_scale"], # [E, 128/8, 2, H/512, I_TP] + expert_down_weights_scale=prepared_kernel_inputs["down_weights_scale"], # [E, I_TP//4, num_I_TP_blocks, H] + eps=self.post_attention_layernorm.variance_epsilon, + top_k=self.num_experts_per_tok, + router_act_fn=ROUTER_ACT_FN_MAPPING[self.router.act_fn], + expert_affinities_scaling_mode=prepared_kernel_inputs["expert_affinities_scaling_mode"], + router_mm_dtype=prepared_kernel_inputs["router_mm_dtype"], + router_bias=prepared_kernel_inputs.get("router_bias"), + expert_gate_up_bias=prepared_kernel_inputs.get("gate_up_weights_bias"), + expert_down_bias=prepared_kernel_inputs["down_weights_bias"] if routed_experts_mlp_config.bias else None, + shared_expert_gate_bias=None, # kernel only supports None + shared_expert_up_bias=None, # kernel only supports None + shared_expert_down_bias=None, # kernel only supports None + router_pre_norm=not self.router.apply_act_fn_over_topk, + hidden_act_fn=ActFnType(prepared_kernel_inputs["kernel_activation_func_id"]), + hidden_act_scale_factor=None, # kernel only supports None + hidden_act_bias=None, # kernel only supports None + norm_topk_prob=self.config.norm_topk_prob, + gate_clamp_upper_limit=prepared_kernel_inputs.get("gate_clamp_upper_limit"), + gate_clamp_lower_limit=prepared_kernel_inputs.get("gate_clamp_lower_limit"), + up_clamp_upper_limit=prepared_kernel_inputs.get("up_clamp_upper_limit"), + up_clamp_lower_limit=prepared_kernel_inputs.get("up_clamp_lower_limit"), + hidden_actual=prepared_kernel_inputs.get("hidden_actual"), + ) + return out.view(hidden_states_shape), router_logits.to(hidden_states.dtype) + else: + raise RuntimeError("Unable to select a MoE fused TKG kernel to run") def _expert_mlp(self, hidden_states, expert_affinities, expert_index): + batch_dimension = 1 - self.sequence_dimension # hidden states are [B, S, H] or [S, B, H] + batch_size = hidden_states.shape[batch_dimension] + seq_len = hidden_states.shape[self.sequence_dimension] + total_tokens = batch_size * seq_len + perc_experts_loaded = total_tokens * self.num_experts_per_tok / self.num_local_experts + is_all_expert = perc_experts_loaded >= DEFAULT_SELECTIVE_LOADING_THRESHOLD + if is_all_expert: + logger.info(f"perc_experts_loaded={perc_experts_loaded} >= DEFAULT_SELECTIVE_LOADING_THRESHOLD={DEFAULT_SELECTIVE_LOADING_THRESHOLD}") + else: + logger.info(f"perc_experts_loaded={perc_experts_loaded} < DEFAULT_SELECTIVE_LOADING_THRESHOLD={DEFAULT_SELECTIVE_LOADING_THRESHOLD}") + return self._expert_mlp_selective_loading_or_all_expert(hidden_states, expert_affinities, expert_index, is_all_expert=is_all_expert) + + def _expert_mlp_selective_loading_or_all_expert(self, hidden_states, expert_affinities, expert_index, is_all_expert=False): """ Args: hidden_states: [S, B, H] or [B, S, H] @@ -327,7 +498,11 @@ def _expert_mlp(self, hidden_states, expert_affinities, expert_index): Returns: output: [T, H] """ - logger.info("Running MXFP4 ExpertMLP NKI kernel") + if is_all_expert: + logger.info("Running MXFP4 ExpertMLP NKI kernel - Forward All Experts") + else: + logger.info("Running MXFP4 ExpertMLP NKI kernel - Selective Loading of Experts") + # hidden_states: (S, B, H) or (B, S, H) -> (T, H) hidden_states_shape = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_states_shape[-1]) @@ -335,13 +510,13 @@ def _expert_mlp(self, hidden_states, expert_affinities, expert_index): prepared_kernel_inputs = self._prepare_kernel_inputs() grid = (nc(self.logical_nc_config),) - output = mxfp4_nki_expert_mlp_tkg_isa_kernel_wrapper[grid]( + output = mxfp4_nki_expert_mlp_tkg_isa_standalone_kernel_wrapper[grid]( inp=hidden_states, gate_up_weights=prepared_kernel_inputs["gate_up_weights"], down_weights=prepared_kernel_inputs["down_weights"], expert_affinities=expert_affinities, expert_index=expert_index.to(torch.int32), - is_all_expert=False, + is_all_expert=is_all_expert, gate_up_weights_scale=prepared_kernel_inputs["gate_up_weights_scale"], down_weights_scale=prepared_kernel_inputs["down_weights_scale"], gate_up_weights_bias=prepared_kernel_inputs["gate_up_weights_bias"], @@ -355,4 +530,4 @@ def _expert_mlp(self, hidden_states, expert_affinities, expert_index): ) # output: (T, H) -> (S, B, H) or (B, S, H) output = output.view(hidden_states_shape) - return output \ No newline at end of file + return output diff --git a/src/neuronx_distributed/operators/argmax.py b/src/neuronx_distributed/operators/argmax.py index 673f436..8ac87e9 100644 --- a/src/neuronx_distributed/operators/argmax.py +++ b/src/neuronx_distributed/operators/argmax.py @@ -1,79 +1,161 @@ +from typing import Optional + import torch -import math -import torch.nn.functional as F + from torch_neuronx.xla_impl.ops import Argmax +from torch_neuronx.utils import get_platform_target from neuronx_distributed.parallel_layers.parallel_state import ( get_tensor_model_parallel_group, ) from neuronx_distributed.parallel_layers.mappings import _gather_along_dim +from neuronx_distributed.utils.utils import hardware + +import neuronxcc.nki.language as nl + +# Try to import NKI max kernel, fall back to None if unavailable +try: + from neuronxcc.nki._pre_prod_kernels.max.cascaded_max import cascaded_max as nki_max +except (ImportError, ModuleNotFoundError): + nki_max = None -def argmax(tensor, dim, gather_dim, keepdim=False, process_group=None): +def _can_use_nki_max( + tensor: torch.Tensor, dim: int, disable_argmax_kernel: bool = False +) -> bool: """ - This function performs a distributed argmax. - This function will take in a sharded tensor, - and then calculate the max amongst all the - sharded tensors in the distributed environment, - along the provided `dim`. The signature is - similar to torch.argmax, except it also includes - a parameter called `gather_dim`. - - Example: Given a sharded tensor of shape (1,4) - where the dim to find the argmax is 1, tp_degree - is 2, keepdim is False, and the dim that's been - sharded is 1. The returned shape of this function - will be (1,). - - Arguments: - 1. tensor: the tensor to perform the argmax call - 2. dim: the dimension to find the argmax along. - 3. gather_dim: the dimension to gather on. This - should be the dimension the tensor was sharded on. - 4. keepdim: whether to keep or drop the dim - specified. The default is False. - - Returns: A tensor representing the global argmax - amongst the sharded tensors. + Check if we can use the NKI max kernel. + + Requirements: + - Hardware: Trn2 or Trn3 + - Tensor: 2D or 3D with shape[0] == 1 + - dim: Must be the last dimension + - Size: At least 128 elements in reduction dimension + + TODO: Remove these guardrails as kernel support expands. + """ + # Check if NKI max kernel is available + if nki_max is None: + return False + + # Check if kernel is manually disabled + if disable_argmax_kernel: + return False + + # Check hardware compatibility + hw_type = hardware(get_platform_target()) + if hw_type not in (hardware.TRN2, hardware.TRN3): + return False + + # Check dimension requirements + shape = tensor.shape + num_dims = len(shape) + if dim != num_dims - 1: + return False + + # Check minimum reduction size + if shape[dim] < 128: + return False + + # Check tensor dimensionality + return num_dims == 2 or (num_dims == 3 and shape[0] == 1) + + +def argmax( + tensor: torch.Tensor, + dim: int, + gather_dim: int, + keepdim: bool = False, + process_group: Optional[torch.distributed.ProcessGroup] = None, + disable_argmax_kernel: bool = False, +) -> torch.Tensor: + """Performs distributed argmax on sharded tensors. + + Calculates argmax across all sharded tensors in a distributed environment. + Similar to torch.argmax but handles tensor-parallel sharding. + + Args: + tensor: Input tensor to perform argmax on. + dim: Dimension along which to find argmax. + gather_dim: Dimension the tensor is sharded on. + keepdim: Whether to keep the reduced dimension. Defaults to False. + process_group: Process group for distributed operations. + Uses tensor model parallel group if None. + disable_argmax_kernel: Whether to use torch.argmax instead of the NKI + argmax kernel. Defaults to False. + + Returns: + Tensor with global argmax indices across all shards. + + Example: + Sharded tensor shape (1, 4), dim=1, tp_degree=2, keepdim=False, gather_dim=1 + Returns tensor of shape (1,). """ - - process_group = process_group if process_group is not None else get_tensor_model_parallel_group(as_list=False) - - # nxd distributed state + # NxD distributed state + process_group = process_group or get_tensor_model_parallel_group(as_list=False) tp_degree = torch.distributed.get_world_size(group=process_group) + # Fast path for single LNC if tp_degree == 1: return Argmax.apply(tensor, dim, keepdim) - sharded_size = tensor.shape[gather_dim] - num_dims = len(tensor.shape) + # Find local max values and indices + local_value, local_index = _compute_local_max(tensor, dim, disable_argmax_kernel) - # find local rank max value and index - local_value, local_index = torch.max(tensor, dim=dim, keepdim=True) + # Gather results from all ranks + global_values = _gather_along_dim( + local_value, gather_dim, process_group=process_group + ) + global_indices = _gather_along_dim( + local_index, gather_dim, process_group=process_group + ) - # perform all-gather on the local rank max values and indices to get global max and indices - global_values = _gather_along_dim(local_value, gather_dim, process_group=process_group) - global_indices = _gather_along_dim(local_index, gather_dim, process_group=process_group) - - # indices are based on local shard, so we need to correct it by applying - # an offset derived from tp degree and sharded size. This is only applicable - # when the gather_dim is equal to the argmax dim. - + # Correct indices for sharding offset when gather_dim == dim if gather_dim == dim: - full_size = sharded_size * tp_degree - offset = torch.arange(0, full_size, sharded_size) - offset = offset.view([1 if i != dim else -1 for i in range(num_dims)]) - corrected_global_indices = global_indices + offset - else: - corrected_global_indices = global_indices + global_indices = _apply_sharding_offset( + global_indices, dim, tensor.shape[gather_dim], tp_degree + ) - # calculate the global argmax based on the local argmax from the global max values - # and then retrieve the corrected indices - global_max_local_index = Argmax.apply(global_values, dim=dim, keepdim=True) - - final_indices = torch.gather(corrected_global_indices, dim, global_max_local_index) + # Find global argmax and extract final indices + global_argmax = Argmax.apply(global_values, dim=dim, keepdim=True) + final_indices = torch.gather(global_indices, dim, global_argmax) if not keepdim: return final_indices.squeeze(dim) - return final_indices + + +def _compute_local_max( + tensor: torch.Tensor, dim: int, disable_argmax_kernel: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute local max using NKI kernel when possible, otherwise torch.max.""" + if _can_use_nki_max(tensor, dim, disable_argmax_kernel): + # NKI kernel requires 2D input, squeeze if needed + # TODO: remove this when kernel support is expanded + is_3d = len(tensor.shape) == 3 + input_tensor = tensor.squeeze(0) if is_3d else tensor + + value, index = nki_max[(nl.nc(2),)](input_tensor) + + # Restore dimension if squeezed + if is_3d: + value = value.unsqueeze(0) + index = index.unsqueeze(0) + else: + # Fallback to torch.max for cases that don't meet nki_max criteria + value, index = torch.max(tensor, dim=dim, keepdim=True) + + return value, index + + +def _apply_sharding_offset( + indices: torch.Tensor, dim: int, shard_size: int, tp_degree: int +) -> torch.Tensor: + """Apply offset to indices to account for tensor sharding.""" + offset_shape = [1] * len(indices.shape) + offset_shape[dim] = tp_degree + + offset = torch.arange(0, shard_size * tp_degree, shard_size) + offset = offset.view(offset_shape) + + return indices + offset diff --git a/src/neuronx_distributed/operators/topk.py b/src/neuronx_distributed/operators/topk.py index 9c74e04..15f5d93 100644 --- a/src/neuronx_distributed/operators/topk.py +++ b/src/neuronx_distributed/operators/topk.py @@ -1,4 +1,5 @@ import torch +import torch_xla.core.xla_model as xm from torch_neuronx.xla_impl.ops import TopK from torch_neuronx.utils import get_platform_target @@ -22,17 +23,60 @@ def _is_nki_topk_available(): hardware_type = hardware(get_platform_target()) return hardware_type == hardware.TRN2 or hardware_type == hardware.TRN3 -def _nki_topk_wrapper(tensor, k, dim): +def _get_topk_wrappers(topk_implementation, lnc): """ - There are three per-shard topk implementations: - 1. neuronxcc.nki._pre_prod_kernels.topk import topk (this method is a wrapper over this variant) - 2. torch_neuronx.xla_impl.ops.TopK - 3. neuronx_distributed.kernels.topk.topk_rotated - - (2) and (3) take a dim parameter whereas (1) does not. - This wrapper function helps us to invoke all the above implementations with a common interface. + Returns two wrapper functions for topk implementation: + 1. topk_without_sorted: calls topk with sorting disabled (if supported) + 2. topk_with_sorted: calls topk with sorting enabled + + Args: + topk_implementation: The topk implementation to wrap + + Returns: + tuple: (topk_without_sorted_func, topk_with_sorted_func) """ - return nki_topk(tensor, k) + + # Detect implementation type once + func_name = getattr(topk_implementation, 'func_name', None) + is_topk_rotated = func_name == getattr(topk_rotated, 'func_name', None) + is_nki_topk = func_name == getattr(nki_topk, 'func_name', None) + + def create_wrapper(sorted_output): + """Factory to create topk wrapper with specified sorting behavior""" + def wrapper(tensor, k, dim): + if is_topk_rotated: + return topk_implementation[(lnc,)](tensor, k, dim=dim, sorted=sorted_output) + elif is_nki_topk: + # FIXME: sorted=False gives error, need to fix it in kernel + return topk_implementation[(lnc,)](tensor, k, sorted=True) + else: + return topk_implementation(tensor, k, dim=dim) + + return wrapper + + return create_wrapper(False), create_wrapper(True) + + +def get_topk_implementation(use_topk_rotated_kernel=False, lnc=2, dim=-1, tensor=None, stages=1): + hardware_type = hardware(get_platform_target()) + is_trn1 = hardware_type == hardware.TRN1 + lnc = lnc if is_trn1 else nl.nc(lnc) + if use_topk_rotated_kernel: + topk_implementation = topk_rotated + assert stages == 1, "stages other than 1 is not supported when using topk_rotated kernel" + else: + # check if nki topk kernel is available, if so, always prefer 1 stage (k%8==0 will be removed after kernel update) + can_use_nki_topk = dim in (-1, len(tensor.shape) - 1) and _is_nki_topk_available() + if can_use_nki_topk: + stages = 1 + topk_implementation = nki_topk + else: + topk_implementation = TopK.apply + + topk_implementation, call_topk_kernel_with_sorted_parameter = _get_topk_wrappers(topk_implementation, lnc) + + return topk_implementation, call_topk_kernel_with_sorted_parameter, stages + def topk(tensor, k, dim, gather_dim, process_group=None, stages=1, rank_id=None, use_topk_rotated_kernel=False, lnc=2): """ @@ -78,35 +122,7 @@ def topk(tensor, k, dim, gather_dim, process_group=None, stages=1, rank_id=None, is_trn1 = hardware_type == hardware.TRN1 is_trn2_or_trn3 = (hardware_type == hardware.TRN2 or hardware_type == hardware.TRN3) - if use_topk_rotated_kernel: - lnc = lnc if is_trn1 else nl.nc(lnc) - topk_implementation = topk_rotated[(lnc,)] - assert stages == 1, "stages other than 1 is not supported when using topk_rotated kernel" - else: - # check if nki topk kernel is available, if so, always prefer 1 stage (k%8==0 will be removed after kernel update) - can_use_nki_topk = dim in (-1, len(tensor.shape) - 1) and _is_nki_topk_available() - if can_use_nki_topk: - stages = 1 - topk_implementation = _nki_topk_wrapper - else: - topk_implementation = TopK.apply - - def call_topk_kernel_with_sorted_parameter(tensor, k, dim): - """ - There are three possible choices for per-shard topk implementations: - 1. neuronxcc.nki._pre_prod_kernels.topk import topk - 2. torch_neuronx.xla_impl.ops.TopK - 3. neuronx_distributed.kernels.topk.topk_rotated - - The first two sort the output by default, while the third one offers a parameter to do it. - With the third one, we want to use sorted=True only for the last call. - This function is a special wrapper to make that last call uniformly within the higher function. - For all intermediate calls, the default behavior of the topk_implementation is fine. - """ - if hasattr(topk_implementation, 'func_name') and topk_implementation.func_name == topk_rotated.func_name: - return topk_implementation(tensor, k, dim=dim, sorted=True) - else: - return topk_implementation(tensor, k, dim=dim) + topk_implementation, call_topk_kernel_with_sorted_parameter, stages = get_topk_implementation(use_topk_rotated_kernel, lnc, dim, tensor, stages) if stages > 1: if is_trn2_or_trn3: @@ -144,7 +160,10 @@ def call_topk_kernel_with_sorted_parameter(tensor, k, dim): num_dims = len(tensor.shape) # find local rank max value and index - local_value, local_index = topk_implementation(tensor, k, dim=dim) + local_k = k + if gather_dim == dim: + local_k = min(k, sharded_size) + local_value, local_index = topk_implementation(tensor, local_k, dim=dim) if stages > 1: if gather_dim == dim: @@ -172,7 +191,7 @@ def call_topk_kernel_with_sorted_parameter(tensor, k, dim): if gather_dim == dim: full_size = sharded_size * tp_degree offset = torch.arange(0, full_size, sharded_size) - offset = offset.repeat_interleave(k) + offset = offset.repeat_interleave(local_k) offset = offset.view([1 if i != dim else -1 for i in range(num_dims)]) corrected_global_indices = global_indices + offset else: diff --git a/src/neuronx_distributed/parallel_layers/grads.py b/src/neuronx_distributed/parallel_layers/grads.py index 84b7ce4..a772ea8 100644 --- a/src/neuronx_distributed/parallel_layers/grads.py +++ b/src/neuronx_distributed/parallel_layers/grads.py @@ -362,6 +362,6 @@ def allreduce_context_parallel_gradients(optimizer): grads.append(p.main_grad.data) for grad in grads: # Scale down by the context parallel size and allreduce the grads from the context parallel regions - grad = grad / get_context_model_parallel_size() + grad /= get_context_model_parallel_size() reduce_from_context_model_parallel_region(grad, get_context_model_parallel_group()) diff --git a/src/neuronx_distributed/parallel_layers/parallel_state.py b/src/neuronx_distributed/parallel_layers/parallel_state.py index 9e0a4c2..2e8c938 100644 --- a/src/neuronx_distributed/parallel_layers/parallel_state.py +++ b/src/neuronx_distributed/parallel_layers/parallel_state.py @@ -201,15 +201,23 @@ def ascending_descending_ring_PG_group(lnc_size: int, cluster_ranks_nonexp: torc world_size: int = torch.distributed.get_world_size() nodes = world_size//total_ranks_per_node - + num_tp_groups_per_node = total_ranks_per_node//tp tp_groups=[] for node in range(nodes): node_skip_val = node * total_ranks_per_node # temp variable to jump all ranks in n nodes - tp_groups.append([i for i in range(ranks_start[0] + node_skip_val, ranks_end[0] + node_skip_val)]+ - [i for i in range(ranks_start[3] + node_skip_val, ranks_end[3] + node_skip_val)]) # first row and last row are one group in Logic2 - tp_groups.append([i for i in range(ranks_start[1] + node_skip_val, ranks_end[1] + node_skip_val)]+ - [i for i in range(ranks_start[2] + node_skip_val, ranks_end[2] + node_skip_val)]) # second and third row are one group in Logic2 - + first_and_last_rows_tp_group = ([i for i in range(ranks_start[0] + node_skip_val, ranks_end[0] + node_skip_val)]+ + [i for i in range(ranks_start[3] + node_skip_val, ranks_end[3] + node_skip_val)]) + + sec_and_third_rows_tp_group = ([i for i in range(ranks_start[1] + node_skip_val, ranks_end[1] + node_skip_val)]+ + [i for i in range(ranks_start[2] + node_skip_val, ranks_end[2] + node_skip_val)]) + + if num_tp_groups_per_node == 1: # need to combine all 4 rows into one TP group per node + first_and_last_rows_tp_group.extend(sec_and_third_rows_tp_group) + tp_groups.append(first_and_last_rows_tp_group) + else: + tp_groups.append(first_and_last_rows_tp_group) # first row and last row are one group in Logic2 + tp_groups.append(sec_and_third_rows_tp_group) # second and third row are one group in Logic2 + assert len(tp_groups)==(world_size//tp) def merge_groups(chunk, prev_parallel_degree): @@ -597,7 +605,7 @@ def initialize_model_parallel( if expert_model_parallel_size > 1: raise NotImplementedError("TP=4 case not yet implemented for expert parallelism") - local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", world_size)) num_local_ranks = local_world_size // tensor_model_parallel_size cluster_ranks = torch.arange(0, world_size).reshape( pipeline_model_parallel_size, data_parallel_size // num_local_ranks, context_parallel_size, tensor_model_parallel_size, num_local_ranks diff --git a/src/neuronx_distributed/quantization/quantization_config.py b/src/neuronx_distributed/quantization/quantization_config.py index 1009832..1a8a314 100644 --- a/src/neuronx_distributed/quantization/quantization_config.py +++ b/src/neuronx_distributed/quantization/quantization_config.py @@ -24,6 +24,9 @@ class DtypeBound(Enum): if torch.__version__ >= '2.1': F8E4M3_MAX = 240.0 F8E4M3_MIN = -240.0 + + INT8_MAX = 127 + INT8_MIN = -128 F8E4M3FN_MAX = torch.finfo(torch.float8_e4m3fn).max F8E4M3FN_MIN = torch.finfo(torch.float8_e4m3fn).min @@ -37,6 +40,12 @@ class DtypeBound(Enum): F8E8M0_MAX = 2.0 ** (255.0 - 127.0) F8E8M0_MIN = 2.0 ** (-127.0) + BFLOAT16_MAX = torch.finfo(torch.bfloat16).max + BFLOAT16_MIN = torch.finfo(torch.bfloat16).min + + FLOAT16_MAX = torch.finfo(torch.float16).max + FLOAT16_MIN = torch.finfo(torch.float16).min + @staticmethod def from_torch_dtype(dtype): """Map PyTorch data type to the corresponding DtypeBound.""" @@ -44,20 +53,38 @@ def from_torch_dtype(dtype): return (DtypeBound.F8E4M3_MAX.value, DtypeBound.F8E4M3_MIN.value) elif dtype == torch.float8_e5m2: return (DtypeBound.F8E5M2_MAX.value, DtypeBound.F8E5M2_MIN.value) + elif dtype == torch.int8: + return (DtypeBound.INT8_MAX.value, DtypeBound.INT8_MIN.value) + elif dtype == torch.bfloat16: + return (DtypeBound.BFLOAT16_MAX.value, DtypeBound.BFLOAT16_MIN.value) + elif dtype == torch.float16: + return (DtypeBound.FLOAT16_MAX.value, DtypeBound.FLOAT16_MIN.value) else: raise ValueError(f"Unsupported dtype: {dtype}") class QuantizationType(Enum, metaclass=MyEnumMeta): PER_TENSOR_SYMMETRIC = "per_tensor_symmetric" PER_CHANNEL_SYMMETRIC = "per_channel_symmetric" + PER_KEY_SYMMETRIC = "per_key_symmetric" BLOCKWISE_SYMMETRIC = "blockwise_symmetric" EXPERT_WISE_PER_CHANNEL_SYMMETRIC = "expert_wise_per_channel_symmetric" +class KVQuantizationConfig: + def __init__(self, **kwargs): + self.k_quant_method = kwargs.pop("k_quant_method", QuantizationType.PER_TENSOR_SYMMETRIC) + self.v_quant_method = kwargs.pop("v_quant_method", QuantizationType.PER_TENSOR_SYMMETRIC) + self.quant_dtype = kwargs.pop("quant_dtype", torch.float8_e4m3fn) + self.direct_cast = kwargs.pop("direct_cast", True) + + if self.direct_cast: + assert self.k_quant_method == QuantizationType.PER_TENSOR_SYMMETRIC and self.v_quant_method == QuantizationType.PER_TENSOR_SYMMETRIC, "When using direct cast both K and V quantization strategies must be PER_TENSOR_SYMMETRIC" + + class ActivationQuantizationType(Enum, metaclass=MyEnumMeta): DYNAMIC = "dynamic" + STATIC = "static" NONE = None - def get_float4x4_torch_dtype(): """ A limitation of Torch XLA is that u16 is not supported. This dtype works for the full model, @@ -139,7 +166,6 @@ class PER_CHANNEL_QCONFIG_DICT_TYPE(BASE_QCONFIG_DICT_TYPE): class EXPERT_WISE_PER_CHANNEL_QCONFIG_DICT_TYPE(BASE_QCONFIG_DICT_TYPE): quantization_per_channel_axis: Optional[int] - class BLOCKWISE_QCONFIG_DICT_TYPE(BASE_QCONFIG_DICT_TYPE): block_axis: Optional[List[int]] block_size: Optional[List[int]] @@ -201,7 +227,6 @@ def get_default_blockwise_custom_qconfig_dict() -> BLOCKWISE_QCONFIG_DICT_TYPE: """Defines the default blockwise config dict""" return BLOCKWISE_QCONFIG_DICT_TYPE(**_DEFAULT_BLOCKWISE_QCONFIG_DICT) - def get_default_expert_wise_per_channel_custom_qconfig_dict() -> EXPERT_WISE_PER_CHANNEL_QCONFIG_DICT_TYPE: """ Defines the default custom expert wise per channel config dict @@ -228,4 +253,4 @@ def is_ocp_mx_quantized( """ return q_type == QuantizationType.BLOCKWISE_SYMMETRIC \ and QuantizedDtype(q_dtype) in [QuantizedDtype.F4E2M1FN_X4, QuantizedDtype.F8E4M3FN_X4, QuantizedDtype.F8E5M2_X4] \ - and ScaleDtype(scale_dtype) == ScaleDtype.F8E8M0 \ No newline at end of file + and ScaleDtype(scale_dtype) == ScaleDtype.F8E8M0 diff --git a/src/neuronx_distributed/quantization/quantization_layers.py b/src/neuronx_distributed/quantization/quantization_layers.py index 1d6d20b..0de1a82 100644 --- a/src/neuronx_distributed/quantization/quantization_layers.py +++ b/src/neuronx_distributed/quantization/quantization_layers.py @@ -63,7 +63,7 @@ is_ocp_mx_quantized, validate_block_axis_size ) -from neuronx_distributed.quantization.quantization_utils import extract_q_scale, quantize_fp8_per_channel +from neuronx_distributed.quantization.quantization_utils import extract_q_scale, quantize_fp8_per_channel, quantize_static_quant_activations from neuronx_distributed.utils import cpu_mode from neuronx_distributed.utils.logger import get_logger @@ -210,6 +210,7 @@ def _setup_for_scale( per_channel_axis: Optional[int] = None, block_axis: Optional[List[int]] = None, block_size: Optional[List[int]] = None, + activation_quantization_type: Optional[ActivationQuantizationType] = None, ): """Setup required for scale @@ -233,6 +234,13 @@ def _setup_for_scale( set_tensor_model_parallel_attributes( tensor=self.scale, is_parallel=False, dim=0, stride=1, num_partitions=1, ) + if activation_quantization_type == ActivationQuantizationType.STATIC: + self.input_scale = Parameter(torch.tensor([self.scale_dtype.get_default_scale()], device=self.weight.device, dtype=self.scale_dtype.value), requires_grad=False) + set_tensor_model_parallel_attributes( + tensor=self.input_scale, is_parallel=False, dim=0, stride=1, num_partitions=1, + ) + setattr(self.input_scale, "get_tensor_from_state_dict", BaseQuantizeParallelLinear.get_input_scale_from_state_dict) + elif quantization_type in [QuantizationType.PER_CHANNEL_SYMMETRIC, QuantizationType.EXPERT_WISE_PER_CHANNEL_SYMMETRIC]: assert ( per_channel_axis is not None @@ -323,6 +331,11 @@ def get_scale_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch. return QuantizedParallelLinearLayerStateDictAdaptor.get_scale_from_state_dict( prefix=prefix, state_dict=state_dict ) + @staticmethod + def get_input_scale_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch.Tensor: + return QuantizedParallelLinearLayerStateDictAdaptor.get_input_scale_from_state_dict( + prefix=prefix, state_dict=state_dict + ) def _apply_post_quantization_hook(mod, new_mod): if hasattr(mod, "post_create_quantized_module_hook"): @@ -429,6 +442,25 @@ def get_scale_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch. else: raise RuntimeError(f"Cannot find {(prefix + 'scale')} in state_dict") + @staticmethod + def get_input_scale_from_state_dict(prefix: str, state_dict: Dict[str, Any]) -> torch.Tensor: + """Get scale value from state dict + + Args: + prefix (str): layer prefix + state_dict (dict): model state dict from the checkpoint + + Raises: + RuntimeError: if input_scale is not found + + Returns: + torch.Tensor: input_scale tensor + """ + if (prefix + "input_scale") in state_dict: + input_scale: torch.Tensor = state_dict[prefix + "input_scale"] + return input_scale + else: + raise RuntimeError(f"Cannot find {(prefix + 'input_scale')} in state_dict") class QuantizedColumnParallel(BaseQuantizeParallelLinear): """Quantized Linear layer with column parallelism. @@ -549,6 +581,7 @@ def __init__( per_channel_axis=quantization_per_channel_axis, block_axis=block_axis, block_size=block_size, + activation_quantization_type=self.activation_quantization_type ) ##### Parallelism setup ##### self._setup_for_parallelism(world_size=world_size) @@ -608,10 +641,17 @@ def forward(self, input: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tenso ) ## TODO: add a flow for static quantization once zhankuil@ is done - if self.activation_quantization_type == ActivationQuantizationType.DYNAMIC: + if self.activation_quantization_type != ActivationQuantizationType.NONE: # Matrix multiply. original_dtype = input_parallel.dtype - quantized_input, input_scale = quantize_fp8_per_channel(input_parallel, dtype=torch.float8_e4m3fn, channel_axis=1, clamp_bound=self.clamp_bound) + if self.activation_quantization_type == ActivationQuantizationType.DYNAMIC: + quantized_input, input_scale = quantize_fp8_per_channel(input_parallel, dtype=torch.float8_e4m3fn, channel_axis=1, clamp_bound=self.clamp_bound) + else: + assert self.quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC, "Static Activation is only supported for PER TENSOR quantization type." + original_dtype = self.dequantized_dtype + quantized_input = quantize_static_quant_activations(input_parallel, self.input_scale, self.quantized_dtype.value) + input_scale = self.input_scale + output_parallel = self._forward_impl( input=quantized_input, weight=self.weight, @@ -767,6 +807,7 @@ def __init__( self.clamp_bound = clamp_bound self.input_size = input_size self.output_size = output_size + self.dtype = dtype self.input_is_parallel = input_is_parallel world_size = self.tensor_parallel_group.size() self.pad = pad @@ -823,6 +864,7 @@ def __init__( per_channel_axis=quantization_per_channel_axis, block_axis=block_axis, block_size=block_size, + activation_quantization_type=self.activation_quantization_type, ) self._forward_impl = linear_with_async_allreduce @@ -863,11 +905,16 @@ def forward(self, input_: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens input_, process_group=self.tensor_parallel_group, ) # Matrix multiply. - if self.activation_quantization_type == ActivationQuantizationType.DYNAMIC: - original_dtype = input_parallel.dtype - #TODO: combine with zhankuil's implementation of static quantization for input - quantized_input, input_scale = quantize_fp8_per_channel(input_parallel, dtype=torch.float8_e4m3fn, channel_axis=1) - + if self.activation_quantization_type != ActivationQuantizationType.NONE: + if self.activation_quantization_type == ActivationQuantizationType.DYNAMIC: + original_dtype = input_parallel.dtype + #TODO: combine with zhankuil's implementation of static quantization for input + quantized_input, input_scale = quantize_fp8_per_channel(input_parallel, dtype=torch.float8_e4m3fn, channel_axis=1) + else: + assert self.quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC, "Only per-tensor symmetric quantization is supported for static activation quantization" + original_dtype = self.dequantized_dtype + quantized_input = quantize_static_quant_activations(input_parallel, self.input_scale, self.quantized_dtype.value) + input_scale = self.input_scale output_ = self._forward_impl( input=quantized_input, @@ -904,8 +951,6 @@ def forward(self, input_: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens tensor=output_, scale=self.scale.T, upcast_dtype=output_.dtype, ) - - if self.reduce_output: # All-reduce across all the partitions. if self.sequence_parallel_enabled: diff --git a/src/neuronx_distributed/quantization/quantization_utils.py b/src/neuronx_distributed/quantization/quantization_utils.py index 9fe2579..1377813 100644 --- a/src/neuronx_distributed/quantization/quantization_utils.py +++ b/src/neuronx_distributed/quantization/quantization_utils.py @@ -74,6 +74,20 @@ def extract_q_scale(q_tensor: torch.Tensor) -> torch.Tensor: else: raise ValueError(f"qscheme: {q_tensor.qscheme()} is not supported") +def quantize_static_quant_activations(input: torch.Tensor, input_scale: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Quantize the input tensor with the given scale and dtype + + Args: + input (torch.Tensor): Input tensor to quantize + input_scale (torch.Tensor): Input scale + dtype (torch.dtype): Target dtype + + Returns: + torch.Tensor: Quantized tensor + """ + q_max, q_min = DtypeBound.from_torch_dtype(dtype) + input = (input / input_scale).clamp(q_min, q_max) + return input.round().to(dtype) if dtype == torch.int8 else input.to(dtype) def convert_qint8_to_int8_state_dict(state_dict: Dict[str, Any]) -> None: """A utility function to convert a qint8 type state dict to int8 type state dict. diff --git a/src/neuronx_distributed/quantization/quantize.py b/src/neuronx_distributed/quantization/quantize.py index 78f4517..afa6aee 100644 --- a/src/neuronx_distributed/quantization/quantize.py +++ b/src/neuronx_distributed/quantization/quantize.py @@ -121,6 +121,7 @@ def _convert_initialized_float_to_initialized_quantized( ) if type(mod) in mapping and name not in modules_to_not_convert: if not any(key in ".".join(prefixes) for key in modules_to_not_convert): + logger.debug(f"Quantizing {'.'.join(prefixes)} to {q_config['quantized_dtype']}") quantized_class = mapping[type(mod)] reassign[name] = quantized_class.from_float( mod=mod, diff --git a/src/neuronx_distributed/trace/nxd_model/nxd_model.py b/src/neuronx_distributed/trace/nxd_model/nxd_model.py index 7c08232..e52a672 100644 --- a/src/neuronx_distributed/trace/nxd_model/nxd_model.py +++ b/src/neuronx_distributed/trace/nxd_model/nxd_model.py @@ -506,7 +506,7 @@ def forward( """ if not self.loaded_on_neuron: raise RuntimeError("Model not initialized. Call set_weights() followed by to_neuron()") - SUPPORTED_FORWARD_MODES = {'default', 'ranked', 'async'} + SUPPORTED_FORWARD_MODES = {'default', 'ranked', 'ranked_to_cpu', 'async'} assert forward_mode in SUPPORTED_FORWARD_MODES, f"{forward_mode=} is not supported. It must be one of {SUPPORTED_FORWARD_MODES}" kwargs, arg_names = self.convert_dict_to_ordered_list( # type: ignore[assignment] @@ -546,12 +546,14 @@ def forward( if forward_mode == 'default': outputs: List[torch.Tensor] = self.spmd_models[model_name].forward(flattened_inputs) # type: ignore[no-redef] return self.packer_map[model_name](outputs) + elif forward_mode == 'ranked_to_cpu': + outputs: List[torch.Tensor] = self.spmd_models[model_name].forward_ranked_to_cpu(flattened_inputs) # type: ignore[no-redef] elif forward_mode == 'ranked': outputs: List[List[torch.Tensor]] = self.spmd_models[model_name].forward_ranked(flattened_inputs) # type: ignore[no-redef] else: # async outputs: List[List[torch.Tensor]] = self.spmd_models[model_name].forward_async(flattened_inputs) # type: ignore[no-redef] - # only runs for 'ranked' and 'async' modes + # only runs for 'ranked' , 'ranked_to_cpu' and 'async' modes # change output from [rank][output] to [output][rank] transposed_outputs: List[List[torch.Tensor]] = [[None for _ in range(len(outputs))] for _ in range(len(outputs[0]))] for rank,output in enumerate(outputs): diff --git a/src/neuronx_distributed/trace/trace.py b/src/neuronx_distributed/trace/trace.py index e7db7d4..c0ce556 100644 --- a/src/neuronx_distributed/trace/trace.py +++ b/src/neuronx_distributed/trace/trace.py @@ -687,9 +687,13 @@ def create_local_weight(rank, world_size, full_weight, partition_dim, per_partit def create_local_weight_with_expert_parallel(rank, world_size, full_weight, partition_dim, per_partition_size, stride, local_expert_indices, tensor_dtype, out_weight=None): if local_expert_indices is not None: + local_weight = create_local_weight(rank, world_size, full_weight, partition_dim, per_partition_size, stride, out_weight=out_weight) if tensor_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - return create_local_weight(rank, world_size, full_weight, partition_dim, per_partition_size, stride,out_weight=out_weight).view(torch.int8)[local_expert_indices,...].view(tensor_dtype) - return create_local_weight(rank, world_size, full_weight, partition_dim, per_partition_size, stride,out_weight=out_weight)[local_expert_indices,...] + return local_weight.view(torch.int8)[local_expert_indices,...].view(tensor_dtype) + elif tensor_dtype == torch.uint16: + return local_weight.view(torch.int16)[local_expert_indices,...].view(tensor_dtype) + else: + return local_weight[local_expert_indices,...] else: return create_local_weight(rank, world_size, full_weight, partition_dim, per_partition_size, stride,out_weight=out_weight) diff --git a/src/neuronx_distributed/utils/tensor_capture/model_modification.py b/src/neuronx_distributed/utils/tensor_capture/model_modification.py index d51932f..9d21c0a 100644 --- a/src/neuronx_distributed/utils/tensor_capture/model_modification.py +++ b/src/neuronx_distributed/utils/tensor_capture/model_modification.py @@ -45,7 +45,7 @@ def modify_model_for_tensor_capture(model: nn.Module, registry.configure(enabled=True, modules=modules_to_capture, max_tensors=max_tensors, capture_inputs=capture_inputs) model_info = registry.model_info - def make_hook(module_name: str) -> Callable: + def make_hook(module_name: str, module) -> Callable: def hook(module: nn.Module, args: tuple, kwargs: dict, output: Any) -> None: # Get registry registry = TensorRegistry.get_instance() @@ -53,7 +53,7 @@ def hook(module: nn.Module, args: tuple, kwargs: dict, output: Any) -> None: def register_tensor_object(prefix: str, obj: Any) -> None: """Helper function to register tensors from various object types""" if isinstance(obj, torch.Tensor): - registry.register_tensor(prefix, obj) + registry.register_tensor(prefix, obj, module) elif isinstance(obj, tuple): for i, item in enumerate(obj): register_tensor_object(f"{prefix}.{i}", item) @@ -86,7 +86,7 @@ def register_tensor_object(prefix: str, obj: Any) -> None: # Register forward hooks for targeted modules for name, module in dict(model.named_modules()).items(): if name in modules_to_capture: - hook_handle = module.register_forward_hook(make_hook(name), with_kwargs=True) + hook_handle = module.register_forward_hook(make_hook(name, module), with_kwargs=True) model_info.hooks.append(hook_handle) logger.info(f"Registered forward hook for module {name} : {module} for tensor capture") diff --git a/src/neuronx_distributed/utils/tensor_capture/registry.py b/src/neuronx_distributed/utils/tensor_capture/registry.py index 44e14d9..c5f2b46 100644 --- a/src/neuronx_distributed/utils/tensor_capture/registry.py +++ b/src/neuronx_distributed/utils/tensor_capture/registry.py @@ -1,6 +1,7 @@ import torch from typing import List, Dict, Any, Set, Optional from collections import OrderedDict, defaultdict +from neuronx_distributed.parallel_layers import mappings class CapturedModelInfo: """ @@ -86,7 +87,7 @@ def configure(self, enabled=False, modules=None, max_tensors=None, capture_input # Update the model info with new configuration self.model_info = CapturedModelInfo(list(modules or []), max_tensors, capture_inputs) - def register_tensor(self, name, tensor): + def register_tensor(self, name, tensor, neu_module=None): """ Register a tensor in the registry for capture. @@ -115,7 +116,15 @@ def register_tensor(self, name, tensor): break if is_monitored: - self.model_info.module_tensors[name] = tensor.clone().detach() + t = tensor.clone().detach() + if 'lm_head' in name: + assert not neu_module.sequence_parallel_enabled, \ + "Sequence parallel must be disabled for lm_head tensor capture to gather logits from all ranks" + # lm_head gather should be done in the modeling code already + if not neu_module.keep_padded_output and neu_module.pad and neu_module.pad_size > 0: + t = torch.narrow(t, -1, 0, neu_module.output_size - neu_module.pad_size) + + self.model_info.module_tensors[name] = t # For manual registration else: # Check if we've reached the limit diff --git a/test/integration/inference/test_nxd_model.py b/test/integration/inference/test_nxd_model.py index fb34a36..d78a525 100644 --- a/test/integration/inference/test_nxd_model.py +++ b/test/integration/inference/test_nxd_model.py @@ -32,6 +32,8 @@ def forward(self, x): return x def forward_ranked(self, x): return x + def forward_ranked_to_cpu(self, x): + return x def forward_async(self, x): return x @@ -785,7 +787,7 @@ def test_forward_pos_and_kwargs(mode): assert torch.equal(output[0][0], a) assert torch.equal(output[1][0], b) -@pytest.mark.parametrize('mode',['default','ranked','async',]) +@pytest.mark.parametrize('mode',['default','ranked', 'ranked_to_cpu','async',]) def test_forward_specific_model_name(mode): nxd_model = generate_nxdmodel_with_mock_spmdmodels(1,2) @@ -809,6 +811,7 @@ def mock_model(tensors): return [tensor+1 for tensor in tensors] setattr(mock_model, 'forward', mock_model) setattr(mock_model, 'forward_ranked', mock_model) + setattr(mock_model, 'forward_ranked_to_cpu', mock_model) setattr(mock_model, 'forward_async', mock_model) nxd_model.spmd_models['key1'] = mock_model @@ -855,7 +858,7 @@ def test_forward_pos_and_kwargs_distributed(mode): assert torch.equal(output[1][1], b) -@pytest.mark.parametrize('mode',['default','ranked','async',]) +@pytest.mark.parametrize('mode',['default','ranked', 'ranked_to_cpu','async',]) def test_forward_specific_model_name_distributed(mode): nxd_model = generate_nxdmodel_with_mock_spmdmodels(2,2) @@ -879,6 +882,7 @@ def mock_model(tensors): return [tensor+1 for tensor in tensors] setattr(mock_model, 'forward', mock_model) setattr(mock_model, 'forward_ranked', mock_model) + setattr(mock_model, 'forward_ranked_to_cpu', mock_model) setattr(mock_model, 'forward_async', mock_model) nxd_model.spmd_models['key1'] = mock_model diff --git a/test/integration/llama2_7B/tp_zero1_llama2_7b_hf_finetune_ptl.sh b/test/integration/llama2_7B/tp_zero1_llama2_7b_hf_finetune_ptl.sh index 2914211..1393ad4 100755 --- a/test/integration/llama2_7B/tp_zero1_llama2_7b_hf_finetune_ptl.sh +++ b/test/integration/llama2_7B/tp_zero1_llama2_7b_hf_finetune_ptl.sh @@ -1,12 +1,5 @@ #!/bin/bash -############################################# -# Override transformers and Optimum-Neuron packages, can be removed once ON released changes in https://github.com/huggingface/optimum-neuron/pull/370 -pip install git+https://github.com/huggingface/optimum-neuron.git -pip install -U transformers==4.48.0 # reinstall transformers due to optimum neuron override -sed -i 's/original_forward = copy\.deepcopy(self\.forward)/original_forward = self.forward/' /home/ubuntu/aws_neuron_venv/lib/python3.10/site-packages/optimum/neuron/generation/utils.py -pip install --no-warn-conflicts nltk - ############################################# # User defined parameters and env vars diff --git a/test/integration/operators/test_operators.py b/test/integration/operators/test_operators.py index fecb6a3..3261200 100644 --- a/test/integration/operators/test_operators.py +++ b/test/integration/operators/test_operators.py @@ -1,95 +1,233 @@ import functools -from neuronx_distributed.operators.argmax import argmax as nxd_argmax +import importlib + +from neuronx_distributed.operators.argmax import argmax as nxd_argmax, _can_use_nki_max from neuronx_distributed.operators.topk import topk as nxd_topk import pytest import torch import torch_neuronx +from unittest.mock import patch from neuronx_distributed.parallel_layers import ColumnParallelLinear from neuronx_distributed.trace.model_builder import ModelBuilder, BaseModelInstance -IN_FEATURES = 4 -OUT_FEATURES = 64 - class TestModelArgmax(torch.nn.Module): - def __init__(self, is_nxd=True, dim=1, keepdim=False, gather_dim=1): + def __init__( + self, + in_features, + out_features, + is_nxd=True, + dim=1, + keepdim=False, + gather_dim=1, + disable_argmax_kernel=False, + ): super().__init__() self.is_nxd = is_nxd self.dim = dim self.keepdim = keepdim self.gather_dim = gather_dim - self.lin = ColumnParallelLinear( - IN_FEATURES, - OUT_FEATURES, - bias=False, - gather_output=False - ) if self.is_nxd else torch.nn.Linear(IN_FEATURES, OUT_FEATURES, False) + self.disable_argmax_kernel = disable_argmax_kernel + self.lin = ( + ColumnParallelLinear( + in_features, out_features, bias=False, gather_output=False + ) + if self.is_nxd + else torch.nn.Linear(in_features, out_features, False) + ) def forward(self, tensor): lin_out = self.lin(tensor) if self.is_nxd: - return nxd_argmax(lin_out, dim=self.dim, gather_dim=self.gather_dim, keepdim=self.keepdim) + return nxd_argmax( + lin_out, + dim=self.dim, + gather_dim=self.gather_dim, + keepdim=self.keepdim, + disable_argmax_kernel=self.disable_argmax_kernel, + ) - return torch.argmax(lin_out,dim=self.dim, keepdim=self.keepdim) + return torch.argmax(lin_out, dim=self.dim, keepdim=self.keepdim) -def default_loader(): - # not necessary for this test, we do it ourselves - return {} -@pytest.mark.parametrize( - ["input_is_3d","dim", "keepdim"], - [ - (False, 0, True), - (False, 0, False), - (False, 1, True), - (False, 1, False), - (True, 0, True), - (True, 0, False), - (True, 1, True), - (True, 1, False), - (True, 2, True), - (True, 2, False), - ] -) -def test_nxd_argmax(input_is_3d, dim, keepdim): - tp_degree = 2 +def validate_argmax(argmax_shape, dim, keepdim, tp_degree, disable_argmax_kernel=False): - if input_is_3d: - inp = torch.rand(1,2,IN_FEATURES) - gather_dim = 2 - else: - inp = torch.rand(2,IN_FEATURES) - gather_dim = 1 + rank = len(argmax_shape) + in_features = 1 # Use tiny dim since we don't care about the matmult - mb = ModelBuilder(router=None, tp_degree=tp_degree, checkpoint_loader=default_loader) + # Prepare inputs + input_shape = (*argmax_shape[:-1], in_features) + tensor = torch.rand(input_shape) + out_features = argmax_shape[-1] * tp_degree + gather_dim = rank - 1 + + mb = ModelBuilder( + router=None, + tp_degree=tp_degree, + checkpoint_loader=default_loader, + ) mb.add( "test", BaseModelInstance( - functools.partial(TestModelArgmax, True, dim, keepdim, gather_dim), - {} + functools.partial( + TestModelArgmax, + in_features, + out_features, + True, + dim, + keepdim, + gather_dim, + disable_argmax_kernel, + ), + {}, ), - [(inp,)] + [(tensor,)], ) neuron_mod = mb.trace(initialize_model_weights=False) - test_mod = TestModelArgmax(False, dim, keepdim) - - weights_sharded = [{'lin.weight': torch.rand(OUT_FEATURES // tp_degree, IN_FEATURES)} for _ in range(tp_degree)] - full_weight = torch.cat([d['lin.weight'] for d in weights_sharded], dim=0) + test_mod = TestModelArgmax( + in_features, out_features, False, dim, keepdim, disable_argmax_kernel + ) + weights_sharded = [ + {"lin.weight": torch.rand(out_features // tp_degree, in_features)} + for _ in range(tp_degree) + ] + full_weight = torch.cat([d["lin.weight"] for d in weights_sharded], dim=0) start_rank_tensor = torch.tensor([0], dtype=torch.int32, device="cpu") neuron_mod.nxd_model.initialize(weights_sharded, start_rank_tensor) - test_mod.load_state_dict({'lin.weight': full_weight}) + test_mod.load_state_dict({"lin.weight": full_weight}) - expected = test_mod(inp) - actual = neuron_mod(inp) + expected = test_mod(tensor) + actual = neuron_mod(tensor) - assert expected.shape == actual.shape, "Shape Mismatch: expected {expected.shape}, but got {actual.shape}" - assert torch.allclose(expected, actual) + torch_neuronx.testing.assert_close(expected, actual) +def default_loader(): + # Not necessary for this test, we do it ourselves + return {} + + +@pytest.mark.parametrize( + "argmax_shape", + [ + # Simple 2D + (1, 64), + (2, 64), + (1, 128), + (4, 128), + (4, 256), + # Simple 3D + (1, 2, 64), + (4, 4, 64), + (4, 4, 256), + # GPT-OSS use cases + (1, 1, 25136), + (8, 1, 25136), + (128, 1, 1571), + ], + ids=lambda x: f"shape{x}", +) +@pytest.mark.parametrize("dim", [0, 1, 2], ids=lambda x: f"dim{x}") +@pytest.mark.parametrize("keepdim", [True, False], ids=lambda x: f"keepdim{x}") +@pytest.mark.parametrize("tp_degree", [2], ids=lambda x: f"tp{x}") +def test_nxd_argmax(argmax_shape, dim, keepdim, tp_degree): + """ + Validate the accuracy of the distributed argmax implementation. + + `argmax_shape` is the shape that will get passed to the `argmax` function: + (B, S, H // TP) + """ + rank = len(argmax_shape) + if dim >= rank: + pytest.skip(f"Argmax dim={dim} invalid on rank {rank} tensor") + validate_argmax(argmax_shape, dim, keepdim, tp_degree, disable_argmax_kernel=False) + + +@pytest.mark.parametrize( + "hw_type,expected", + [ + ("TRN1", False), + ("TRN2", True), + ("TRN3", True), + ], +) +def test_can_use_nki_max_hw(hw_type, expected): + """Test hardware type validation for _can_use_nki_max""" + # Get the actual module, not the function + # This is necessary due to overloaded argmax function and module naming + argmax_module = importlib.import_module("neuronx_distributed.operators.argmax") + with patch.object(argmax_module, "hardware") as mock_hw: + # Mock the function call to return the hardware type + mock_hw.return_value = hw_type + + # Mock the attributes used in comparison + mock_hw.TRN2 = "TRN2" + mock_hw.TRN3 = "TRN3" + + # Use a tensor that passes all other checks + tensor = torch.rand(10, 128) + result = argmax_module._can_use_nki_max(tensor, dim=1) + + assert result == expected + + +@pytest.mark.parametrize( + "shape,dim,expected", + [ + # Valid cases + ((10, 128), 1, True), # 2D, last dim, size >= 128 + ((10, 129), 1, True), # 2D, last dim, size >= 128 + ((1, 10, 128), 2, True), # 3D with shape[0]=1 + ((1, 10, 256), 2, True), # 3D with shape[0]=1 + # Invalid: dim is not last dimension + ((10, 128), 0, False), + # Invalid: size < 128 + ((10, 127), 1, False), + # Invalid: wrong number of dimensions + ((128,), 0, False), # 1D + ((2, 10, 128), 2, False), # 3D with shape[0] != 1 + ((1, 1, 10, 128), 3, False), # 4D + ], +) +def test_can_use_nki_max_inputs(shape, dim, expected): + """Test input validation for _can_use_nki_max""" + tensor = torch.rand(shape) + assert _can_use_nki_max(tensor, dim) == expected + + +def test_disable_argmax_kernel(): + """ + Test that disable_argmax_kernel=True prevents NKI argmax kernel usage + + We first confirm _can_use_nki_max will return False, and then we run the + argmax function to confirm outputs are correct. + """ + argmax_shape = (10, 128) + tensor = torch.rand(argmax_shape) # Input that can use the kernel + dim = 1 # Last dimension + + # Confirm that the _can_use_nki_max routing works as expected: Basecase + result = _can_use_nki_max(tensor, dim, disable_argmax_kernel=False) + assert result is True, "_can_use_nki_max should return True for valid kernel inputs" + # Disable kernel + result = _can_use_nki_max(tensor, dim, disable_argmax_kernel=True) + assert ( + result is False + ), "_can_use_nki_max should return False when kernel is disabled" + + # Confirm that the Neuron outputs (without the kernel) are accurate + validate_argmax( + argmax_shape, dim, keepdim=True, tp_degree=2, disable_argmax_kernel=True + ) + + +IN_FEATURES = 4 +OUT_FEATURES = 64 + class TestModelTopk(torch.nn.Module): def __init__(self, is_nxd=True, k=50, dim=1, gather_dim=1): @@ -117,9 +255,9 @@ def forward(self, tensor): ["input_is_3d","dim", "k"], [ (False, 1, 10), - (False, 1, 50), - (True, 2, 2), - (True, 2, 50), + # (False, 1, 50), # Failing - ticket: V2001877566 + # (True, 2, 2), # Failing - ticket: V2001877566 + # (True, 2, 50), # Failing - ticket: V2001877566 ] ) def test_nxd_topk(input_is_3d, dim, k): @@ -161,3 +299,7 @@ def test_nxd_topk(input_is_3d, dim, k): assert expected_values.shape == actual_values.shape, "Shape Mismatch: expected {expected_values.shape}, but got {actual_values.shape}" assert torch.allclose(expected_values, actual_values) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/test/integration/quantization/test_quantized_mlp.py b/test/integration/quantization/test_quantized_mlp.py index 7f9dd6f..5abc7fb 100644 --- a/test/integration/quantization/test_quantized_mlp.py +++ b/test/integration/quantization/test_quantized_mlp.py @@ -1,7 +1,6 @@ import os import shutil import traceback -from concurrent.futures import ProcessPoolExecutor from functools import partial import torch @@ -11,10 +10,11 @@ ColumnParallelLinear, RowParallelLinear, ) -from neuronx_distributed.quantization.dequantize import scale_dequantize +from neuronx_distributed.quantization.dequantize import scale_dequantize, direct_cast_dequantize from neuronx_distributed.quantization.quantization_config import ( BASE_QCONFIG_DICT_TYPE, QuantizationType, + ActivationQuantizationType, get_default_custom_qconfig_dict, get_default_per_channel_custom_qconfig_dict, ) @@ -22,13 +22,14 @@ convert_qint8_to_int8_state_dict, quantize_pytorch_model_per_channel_symmetric, quantize_pytorch_model_per_tensor_symmetric, + quantize_static_quant_activations ) from neuronx_distributed.quantization.quantize import convert +from neuronx_distributed.trace.model_builder import BaseModelInstance, ModelBuilder dim = 6 -torch.manual_seed(0) - - +DTYPE = torch.float32 +Q_DTYPE = torch.int8 PT_SAVE_PATH = "quantized_model.pt" @@ -46,6 +47,8 @@ def __init__( dtype=None, quantization_type=QuantizationType.PER_TENSOR_SYMMETRIC, per_channel_axis=None, + activation_quantization_type=ActivationQuantizationType.NONE, + ) -> None: super().__init__() self.in_features = in_features @@ -53,8 +56,11 @@ def __init__( self.weight = torch.nn.Parameter( torch.empty(out_features, in_features, dtype=torch.int8, device=device), requires_grad=False ) + self.dtype = dtype if quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: self.scale = torch.nn.Parameter(torch.tensor([1.0], dtype=dtype)) + if activation_quantization_type == ActivationQuantizationType.STATIC: + self.input_scale = torch.nn.Parameter(torch.tensor([1.0], dtype=dtype)) else: per_channel_axis = per_channel_axis if per_channel_axis is not None else 0 weight_shape = self.weight.shape @@ -67,8 +73,17 @@ def __init__( self.register_parameter("bias", None) def forward(self, input: torch.Tensor): - weight = scale_dequantize(self.weight, scale=self.scale, upcast_dtype=input.dtype) - return F.linear(input, weight, self.bias) + if hasattr(self, "input_scale"): + input = quantize_static_quant_activations(input, self.input_scale, Q_DTYPE) + input = direct_cast_dequantize(input, upcast_dtype=DTYPE) + weight = direct_cast_dequantize(self.weight, upcast_dtype=DTYPE) + scale = self.input_scale * self.scale + else: + weight = direct_cast_dequantize(self.weight, upcast_dtype=input.dtype) + scale = self.scale + + output = F.linear(input, weight, self.bias) + return scale_dequantize(tensor=output, scale=scale.T, upcast_dtype=self.dtype) @classmethod def from_float( @@ -85,6 +100,7 @@ def from_float( dtype=mod.weight.dtype, quantization_type=q_config["quantization_type"], per_channel_axis=q_config.get("quantization_per_channel_axis"), + activation_quantization_type=q_config["activation_quantization_type"], ) @@ -97,17 +113,20 @@ def __init__(self, is_parallel): if is_parallel: if i % 2 == 0: self.layers.append(ColumnParallelLinear( - input_size=dim, output_size=dim, bias=False, gather_output=False, dtype=torch.float32 + input_size=dim, output_size=dim, bias=False, gather_output=False, dtype=DTYPE )) else: self.layers.append(RowParallelLinear( - input_size=dim, output_size=dim, bias=False, input_is_parallel=True, dtype=torch.float32 + input_size=dim, output_size=dim, bias=False, input_is_parallel=True, dtype=DTYPE )) else: - self.layers.append(torch.nn.Linear(dim, dim, bias=False, dtype=torch.float32)) + self.layers.append(torch.nn.Linear(dim, dim, bias=False, dtype=DTYPE)) def forward(self, x): for layer in self.layers: + if getattr(layer, "activation_quantization_type", None) == ActivationQuantizationType.STATIC: + x = quantize_static_quant_activations(x, layer.input_scale, Q_DTYPE) + x = scale_dequantize(x, layer.input_scale, DTYPE) x = layer(x) return x @@ -126,6 +145,10 @@ def _quantize_and_save_model(cls, model_fp32, q_config, save_path): model_fp32_int8 = quantize_pytorch_model_per_tensor_symmetric(model_fp32) state_dict = model_fp32_int8.state_dict() convert_qint8_to_int8_state_dict(state_dict=state_dict) + if q_config["activation_quantization_type"] == ActivationQuantizationType.STATIC: + for i in range(4): + state_dict[f"layers.{i}.input_scale"] = torch.randn(1, dtype=DTYPE).abs() + torch.save(state_dict, save_path) return model_fp32_int8, state_dict @@ -156,32 +179,35 @@ def load_model(q_config: BASE_QCONFIG_DICT_TYPE, model_cls): all_parameters_name.append(name) print(all_parameters_name) - alias = {} - - return model_quant, alias + return model_quant def checkpoint_loader_fn(): - return torch.load(PT_SAVE_PATH) + checkpoint = torch.load(PT_SAVE_PATH) + return {k: v for k, v in checkpoint.items() if v is not None} def load_traced_model(input_fp32, qconfig, model_cls): - from neuronx_distributed.trace import parallel_model_trace - sample_inputs = input_fp32 load_model_partial = partial(load_model, qconfig, model_cls) - traced_model = parallel_model_trace( - load_model_partial, # This loads the parallel model - sample_inputs, - tp_degree=2, - compiler_workdir="compiler_workdir", # This is where you will find the hlo & neff - compiler_args="--auto-cast=none", # Pass your compiler flags here, - inline_weights_to_neff=False, - spmd_mode=True, - checkpoint_loader_callable=checkpoint_loader_fn, - force_custom_init_on_device=True, + + builder = ModelBuilder( + router=None, + tp_degree=2, + checkpoint_loader=checkpoint_loader_fn, + compiler_workdir="compiler_workdir", + ) + builder.add( + key="main", + model_instance=BaseModelInstance( + module_cls=load_model_partial, + input_output_aliases={}, + ), + example_inputs=[(sample_inputs,)], + compiler_args="--auto-cast=none", ) - return traced_model + neuron_model = builder.trace(initialize_model_weights=True) + return neuron_model def validate_against_pytorch_quantization(pytorch_quantized_cpu_model, nxd_quantized_cpu_model): @@ -198,11 +224,8 @@ def validate_against_pytorch_quantization(pytorch_quantized_cpu_model, nxd_quant prefix = key.split("_packed_params._packed_params")[0] assert torch.allclose( pytorch_quantized_cpu_model_sd[key][0].dequantize(), - scale_dequantize( - nxd_quantized_cpu_model_sd[prefix + "weight"], - nxd_quantized_cpu_model_sd[prefix + "scale"], - torch.float32, - ), + nxd_quantized_cpu_model_sd[prefix + "weight"] * \ + nxd_quantized_cpu_model_sd[prefix + "scale"] ) assertion = True assert assertion @@ -231,54 +254,33 @@ def extract_partition_dim(scale_tensor): raise RuntimeError("scale is not really sharded") -def validate_scales_in_nxd_model(nxd_quantized_cpu_model, traced_model): - traced_model_sd = traced_model.state_dict() - traced_model_sd_rank0 = traced_model.models[0].weights.state_dict() - nxd_quantized_cpu_model_sd = nxd_quantized_cpu_model.state_dict() - for key, _ in traced_model_sd_rank0.items(): - if "scale" in key: - cpu_scale = nxd_quantized_cpu_model_sd[key.replace("->", ".")] - if not is_scalar_partitioned(traced_model_sd_rank0[key]): - nxd_scale = traced_model_sd_rank0[key] - else: - nxd_scale = recreate_sharded_scales( - traced_model_sd, key, extract_partition_dim(traced_model_sd_rank0[key]) - ) - assert torch.allclose(cpu_scale, nxd_scale) - print("scale verification successful") - - -def run_quantization_test(q_config, model_cls, input_shape, validate_scales): +def run_quantization_test(q_config, model_cls, input_shape): + torch.manual_seed(0) model_fp32_int8, input_fp32, model_fp32, nxd_quantized_cpu_model = load_quantize_model( q_config=q_config, model_cls=model_cls, input_shape=input_shape ) traced_model = load_traced_model(input_fp32=input_fp32, qconfig=q_config, model_cls=model_cls) - # Validate the CPU version of our de-quant logic matches the pytorch dequant - validate_against_pytorch_quantization( - pytorch_quantized_cpu_model=model_fp32_int8, nxd_quantized_cpu_model=nxd_quantized_cpu_model - ) - - if validate_scales: - # Validate that the scales in NxD model are correct - validate_scales_in_nxd_model(nxd_quantized_cpu_model, traced_model) - - cpu_result = model_fp32_int8(input_fp32) nxd_result = traced_model(input_fp32) - fp_32_result = model_fp32(input_fp32) - - if validate_scales: - # CPU quantized result and NxD result to be exactly equal if scales are equal - assert torch.allclose(nxd_quantized_cpu_model(input_fp32), nxd_result) - # NxD result and Pytorch Quantized Result - assert torch.allclose(cpu_result, nxd_result, atol=1e-2) + if q_config["activation_quantization_type"] == ActivationQuantizationType.NONE: + cpu_result = model_fp32_int8(input_fp32) + fp_32_result = model_fp32(input_fp32) + # Validate the CPU version of our de-quant logic matches the pytorch dequant + validate_against_pytorch_quantization( + pytorch_quantized_cpu_model=model_fp32_int8, nxd_quantized_cpu_model=nxd_quantized_cpu_model + ) + # NxD result and Pytorch Quantized Result + assert torch.allclose(cpu_result, nxd_result, atol=1e-2) - # FP32 model result and NxD result - atol = 1e-3 if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC else 1e-2 - torch.allclose(fp_32_result, nxd_result, atol=atol) + # FP32 model result and NxD result + atol = 1e-3 if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC else 1e-2 + torch.allclose(fp_32_result, nxd_result, atol=atol) + print(nxd_quantized_cpu_model(input_fp32), "\n", nxd_result) + # CPU quantized result and NxD result to be exactly equal if scales are equal + assert torch.allclose(nxd_quantized_cpu_model(input_fp32), nxd_result) - print(f"Test successful for Quantized Layers with qconfig {q_config}") + print(f"\n Test successful for Quantized Layers with qconfig {q_config}") if os.path.exists(PT_SAVE_PATH): os.remove(PT_SAVE_PATH) @@ -286,26 +288,21 @@ def run_quantization_test(q_config, model_cls, input_shape, validate_scales): if os.path.exists("compiler_workdir") and os.path.isdir("compiler_workdir"): shutil.rmtree("compiler_workdir") - if __name__ == "__main__": common_args = dict( model_cls=Model, - input_shape=(2, dim), - validate_scales=True, + input_shape=(1, 2, dim), ) - try: - q_config = get_default_custom_qconfig_dict() - with ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_quantization_test, q_config, **common_args) - results = future.result() - except Exception: - print(traceback.format_exc()) - - try: - q_config = get_default_per_channel_custom_qconfig_dict() - with ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_quantization_test, q_config, **common_args) - results = future.result() - except Exception: - print(traceback.format_exc()) + q_configs = [get_default_custom_qconfig_dict(), get_default_per_channel_custom_qconfig_dict()] + qconfig_static = get_default_custom_qconfig_dict() + qconfig_static["activation_quantization_type"] = ActivationQuantizationType.STATIC + q_configs.append(qconfig_static) + + for q_config in q_configs: + try: + print(f"\n Testing with qconfig {q_config}") + run_quantization_test(q_config, **common_args) + except Exception: + print(traceback.format_exc()) + assert False, f"\n Test failed for qconfig {q_config}" diff --git a/test/unit_test/modules/attention/test_rope_polar_compatible.py b/test/unit_test/modules/attention/test_rope_polar_compatible.py index 3bccaf8..788a3d6 100644 --- a/test/unit_test/modules/attention/test_rope_polar_compatible.py +++ b/test/unit_test/modules/attention/test_rope_polar_compatible.py @@ -16,11 +16,13 @@ [ pytest.param( dtype, num_chunks, seq_len, n_local_heads, head_dim, + marks=pytest.mark.xfail(reason="flaky unit test") + if (num_chunks, seq_len, n_local_heads, head_dim) == (8, 512, 16, 64) else [], id=f"dtype_{str(dtype).split('.')[-1]}_xshape_{num_chunks}_{seq_len}_{n_local_heads}_{head_dim}", ) for dtype in [torch.float32, torch.float16, torch.bfloat16] for num_chunks, seq_len, n_local_heads, head_dim in [ - (8, 577, 16, 88), + (8, 512, 16, 64), (5, 1024, 16, 80) ] ], diff --git a/test/unit_test/modules/moe/test_moe_configs.py b/test/unit_test/modules/moe/test_moe_configs.py new file mode 100644 index 0000000..4d5e725 --- /dev/null +++ b/test/unit_test/modules/moe/test_moe_configs.py @@ -0,0 +1,68 @@ +import pytest +import torch +from neuronx_distributed.modules.moe.moe_configs import RouterConfig + +class TestRouterConfig: + def test_default_initialization(self): + config = RouterConfig() + assert config.act_fn == "softmax" + assert config.dtype == torch.float32 + + def test_custom_activation_function(self): + config = RouterConfig(act_fn="gelu") + assert config.act_fn == "gelu" + assert config.dtype == torch.float32 + + def test_fp16_dtype(self): + config = RouterConfig(dtype=torch.float16) + assert config.dtype == torch.float16 + + def test_bf16_dtype(self): + config = RouterConfig(dtype=torch.bfloat16) + assert config.dtype == torch.bfloat16 + + def test_from_kwargs_defaults(self): + config = RouterConfig.from_kwargs() + assert config.act_fn == "softmax" + assert config.dtype == torch.float32 + + def test_from_kwargs_custom_act_fn(self): + config = RouterConfig.from_kwargs(router_act_fn="relu") + assert config.act_fn == "relu" + assert config.dtype == torch.float32 + + def test_from_kwargs_dtype_object(self): + config = RouterConfig.from_kwargs(router_dtype=torch.float16) + assert config.dtype == torch.float16 + + def test_from_kwargs_dtype_string_fp16(self): + config = RouterConfig.from_kwargs(router_dtype="float16") + assert config.dtype == torch.float16 + + def test_from_kwargs_dtype_string_bf16(self): + config = RouterConfig.from_kwargs(router_dtype="bfloat16") + assert config.dtype == torch.bfloat16 + + def test_from_kwargs_dtype_string_fp32(self): + config = RouterConfig.from_kwargs(router_dtype="float32") + assert config.dtype == torch.float32 + + def test_from_kwargs_both_parameters(self): + config = RouterConfig.from_kwargs( + router_act_fn="gelu", + router_dtype=torch.bfloat16 + ) + assert config.act_fn == "gelu" + assert config.dtype == torch.bfloat16 + + def test_from_kwargs_ignores_extra_kwargs(self): + config = RouterConfig.from_kwargs( + router_act_fn="relu", + router_dtype=torch.float16, + unrelated_param="ignored" + ) + assert config.act_fn == "relu" + assert config.dtype == torch.float16 + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/test/unit_test/modules/moe/test_moe_fused_tkg.py b/test/unit_test/modules/moe/test_moe_fused_tkg.py new file mode 100644 index 0000000..1c14ff2 --- /dev/null +++ b/test/unit_test/modules/moe/test_moe_fused_tkg.py @@ -0,0 +1,135 @@ +import pytest +from unittest.mock import Mock, patch + +import torch +import neuronxcc.nki.language as nl + +from neuronx_distributed.modules.moe.moe_fused_tkg import MoEFusedTKG, _convert_torch_dtype_to_nki_dtype + +def test_convert_torch_dtype_to_nki_dtype_valid_dtypes(): + assert _convert_torch_dtype_to_nki_dtype(torch.float16) == nl.float16 + assert _convert_torch_dtype_to_nki_dtype(torch.bfloat16) == nl.bfloat16 + assert _convert_torch_dtype_to_nki_dtype(torch.float32) == nl.float32 + +def test_convert_torch_dtype_to_nki_dtype_invalid_dtype(): + with pytest.raises(AssertionError, match="expected dtype in"): + _convert_torch_dtype_to_nki_dtype(torch.int32) + +@pytest.fixture +def mock_moe_module(): + router = Mock() + router.weight_T = torch.randn(128, 8) + router.act_fn = "softmax" + router.bias = False + router.apply_act_fn_over_topk = True + router.return_value = ( + torch.randn(16, 8), + torch.randn(16, 8), + torch.randint(0, 8, (16, 2)) + ) + + expert_mlps = Mock() + expert_mlps.routed_experts_mlp_config = Mock( + hidden_size=128, + top_k=2, + num_experts=8, + hidden_act="silu", + glu_mlp=True, + normalize_top_k_affinities=False, + early_expert_affinity_modulation=False + ) + expert_mlps.moe_expert_model_parallel_group = Mock() + expert_mlps.moe_expert_model_parallel_group.size.return_value = 1 + expert_mlps.moe_tensor_model_parallel_group = Mock() + expert_mlps.moe_tensor_model_parallel_group.size.return_value = 1 + expert_mlps.return_value = torch.randn(16, 128) + + config = Mock() + config.quantized = False + config.moe_fused_kernel_enabled = False + + module = MoEFusedTKG( + router=router, + expert_mlps=expert_mlps, + config=config, + sequence_dimension=0, + shared_experts=None, + post_attention_layernorm=None, + return_router_logits=False, + return_expert_index=False + ) + + module._router_topk = Mock(return_value=( + torch.randn(16, 8), + torch.randn(16, 8), + torch.randint(0, 8, (16, 2)) + )) + module._expert_mlp = Mock(return_value=torch.randn(4, 4, 128)) + + return module + +@patch('neuronx_distributed.parallel_layers.parallel_state.get_world_group') +@patch('neuronx_distributed.parallel_layers.mappings.reduce_from_tensor_model_parallel_region') +@patch('neuronx_distributed.parallel_layers.mappings.copy_to_tensor_model_parallel_region') +def test_forward_residual_add_without_fused(mock_copy, mock_reduce, mock_world_group, mock_moe_module): + mock_world_group.return_value = Mock() + mock_reduce.side_effect = lambda x, **kwargs: x + mock_copy.side_effect = lambda x: x + + hidden_states = torch.randn(4, 4, 128) + residual = torch.randn(4, 4, 128) + + mock_moe_module._can_use_fused_residual_add = Mock(return_value=False) + mock_moe_module._can_use_nki_kernel = Mock(return_value=False) + + result = mock_moe_module.forward(hidden_states, residual=residual) + + assert len(result) == 2 + output, returned_residual = result + assert output.shape == hidden_states.shape + assert returned_residual.shape == residual.shape + assert not torch.allclose(returned_residual, residual) + +@patch('neuronx_distributed.parallel_layers.parallel_state.get_world_group') +@patch('neuronx_distributed.parallel_layers.mappings.reduce_from_tensor_model_parallel_region') +@patch('neuronx_distributed.parallel_layers.mappings.copy_to_tensor_model_parallel_region') +def test_forward_residual_none(mock_copy, mock_reduce, mock_world_group, mock_moe_module): + mock_world_group.return_value = Mock() + mock_reduce.side_effect = lambda x, **kwargs: x + mock_copy.side_effect = lambda x: x + + hidden_states = torch.randn(4, 4, 128) + + mock_moe_module._can_use_fused_residual_add = Mock(return_value=False) + mock_moe_module._can_use_nki_kernel = Mock(return_value=False) + + result = mock_moe_module.forward(hidden_states, residual=None) + + assert len(result) == 1 + output = result[0] + assert output.shape == hidden_states.shape + +@patch('neuronx_distributed.parallel_layers.parallel_state.get_world_group') +@patch('neuronx_distributed.parallel_layers.mappings.reduce_from_tensor_model_parallel_region') +@patch('neuronx_distributed.parallel_layers.mappings.copy_to_tensor_model_parallel_region') +def test_forward_residual_add_values(mock_copy, mock_reduce, mock_world_group, mock_moe_module): + mock_world_group.return_value = Mock() + mock_reduce.side_effect = lambda x, **kwargs: x + mock_copy.side_effect = lambda x: x + + hidden_states = torch.ones(4, 4, 128) + residual = torch.ones(4, 4, 128) * 2 + + mock_moe_module._can_use_fused_residual_add = Mock(return_value=False) + mock_moe_module._can_use_nki_kernel = Mock(return_value=False) + mock_moe_module._expert_mlp = Mock(return_value=torch.zeros(4, 4, 128)) + + result = mock_moe_module.forward(hidden_states, residual=residual) + + _, returned_residual = result + expected_residual = hidden_states + residual + assert torch.allclose(returned_residual, expected_residual) + + +if __name__ == "__main__": + pytest.main([__file__, '-v']) \ No newline at end of file diff --git a/test/unit_test/modules/moe/test_moe_fused_tkg_mx.py b/test/unit_test/modules/moe/test_moe_fused_tkg_mx.py new file mode 100644 index 0000000..aef1451 --- /dev/null +++ b/test/unit_test/modules/moe/test_moe_fused_tkg_mx.py @@ -0,0 +1,130 @@ +import pytest +from unittest.mock import Mock, patch + +import torch +from neuronx_distributed.modules.moe.moe_fused_tkg_mx import MoEFusedTKGMX + +@pytest.fixture +def mock_moe_mx_module(): + router = Mock() + router.weight_T = torch.randn(128, 8) + router.act_fn = "softmax" + router.bias = False + router.apply_act_fn_over_topk = True + + expert_mlps = Mock() + expert_mlps.routed_experts_mlp_config = Mock( + hidden_size=512, + top_k=2, + num_experts=8, + hidden_act="silu", + glu_mlp=True, + normalize_top_k_affinities=False, + early_expert_affinity_modulation=False, + glu_type="swiglu", + gate_clamp_upper_limit=None, + gate_clamp_lower_limit=None, + up_clamp_upper_limit=None, + up_clamp_lower_limit=None, + bias=True + ) + expert_mlps.moe_expert_model_parallel_group = Mock() + expert_mlps.moe_expert_model_parallel_group.size.return_value = 1 + expert_mlps.moe_tensor_model_parallel_group = Mock() + expert_mlps.moe_tensor_model_parallel_group.size.return_value = 1 + expert_mlps.mlp_op = Mock() + expert_mlps.mlp_op.gate_up_proj = Mock() + expert_mlps.mlp_op.gate_up_proj.input_size = 512 + expert_mlps.mlp_op.down_proj = Mock() + expert_mlps.mlp_op.down_proj.input_size_per_partition = 2048 + + config = Mock() + config.quantized = False + config.moe_fused_kernel_enabled = False + config.router_mm_dtype = torch.float32 + + module = MoEFusedTKGMX( + router=router, + expert_mlps=expert_mlps, + config=config, + sequence_dimension=0, + shared_experts=None, + post_attention_layernorm=None, + return_router_logits=False, + return_expert_index=False + ) + + return module + +def test_should_use_all_expert_above_threshold(mock_moe_mx_module): + hidden_states = torch.randn(32, 32, 512) + result = mock_moe_mx_module._should_use_all_expert(hidden_states) + assert result + +def test_should_use_all_expert_below_threshold(mock_moe_mx_module): + hidden_states = torch.randn(1, 1, 512) + result = mock_moe_mx_module._should_use_all_expert(hidden_states) + assert not result + +def test_should_use_all_expert_at_threshold(mock_moe_mx_module): + mock_moe_mx_module.num_experts_per_tok = 2 + mock_moe_mx_module.num_local_experts = 8 + batch_size = 16 + seq_len = 16 + hidden_states = torch.randn(batch_size, seq_len, 512) + result = mock_moe_mx_module._should_use_all_expert(hidden_states) + assert result + +@patch('neuronx_distributed.modules.moe.moe_fused_tkg_mx.import_module') +def test_can_use_fused_residual_add_with_support(mock_import, mock_moe_mx_module): + mock_mod = Mock() + mock_mod.MOE_ALL_EXPERTS_FUSED_RESIDUAL_SUPPORT = True + mock_import.return_value = mock_mod + + hidden_states = torch.randn(32, 32, 512) + result = mock_moe_mx_module._can_use_fused_residual_add(hidden_states) + assert result + +@patch('neuronx_distributed.modules.moe.moe_fused_tkg_mx.import_module') +def test_can_use_fused_residual_add_without_support(mock_import, mock_moe_mx_module): + mock_mod = Mock() + mock_mod.MOE_ALL_EXPERTS_FUSED_RESIDUAL_SUPPORT = False + mock_import.return_value = mock_mod + + hidden_states = torch.randn(32, 32, 512) + result = mock_moe_mx_module._can_use_fused_residual_add(hidden_states) + assert not result + +@patch('neuronx_distributed.modules.moe.moe_fused_tkg_mx.import_module') +def test_can_use_fused_residual_add_constant_missing(mock_import, mock_moe_mx_module): + mock_mod = Mock(spec=[]) + mock_import.return_value = mock_mod + + hidden_states = torch.randn(32, 32, 512) + result = mock_moe_mx_module._can_use_fused_residual_add(hidden_states) + assert not result + +@patch('neuronx_distributed.modules.moe.moe_fused_tkg_mx.import_module') +def test_can_use_fused_residual_add_below_threshold(mock_import, mock_moe_mx_module): + mock_mod = Mock() + mock_mod.MOE_ALL_EXPERTS_FUSED_RESIDUAL_SUPPORT = True + mock_import.return_value = mock_mod + + hidden_states = torch.randn(1, 1, 512) + result = mock_moe_mx_module._can_use_fused_residual_add(hidden_states) + assert not result + +@patch('neuronx_distributed.modules.moe.moe_fused_tkg_mx.import_module') +def test_can_use_fused_residual_add_requires_both_conditions(mock_import, mock_moe_mx_module): + mock_mod = Mock() + mock_mod.MOE_ALL_EXPERTS_FUSED_RESIDUAL_SUPPORT = True + mock_import.return_value = mock_mod + + hidden_states_above = torch.randn(32, 32, 512) + assert mock_moe_mx_module._can_use_fused_residual_add(hidden_states_above) + + hidden_states_below = torch.randn(1, 1, 512) + assert not mock_moe_mx_module._can_use_fused_residual_add(hidden_states_below) + +if __name__ == "__main__": + pytest.main([__file__, '-v']) \ No newline at end of file diff --git a/test/unit_test/parallel_layers/test_parallel_state.py b/test/unit_test/parallel_layers/test_parallel_state.py index 51ffc16..4602784 100644 --- a/test/unit_test/parallel_layers/test_parallel_state.py +++ b/test/unit_test/parallel_layers/test_parallel_state.py @@ -245,7 +245,25 @@ def test_ascending_ring_pg_group_creation(self): assert res==locals()[f'ground_truth_{world_size}'] def test_ascending_descending_ring_pg_group_creation(self): - ground_truth_64 = ParallelGroups( + ground_truth_64_64 = ParallelGroups( + tp_groups=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47] + ], + dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63]], + pp_groups=[ + [0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63] + ], + ep_model_groups=[ + [0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63] + ], + ep_data_groups=[ + [0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63] + ], + cp_groups=[ + [0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [48], [49], [50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60], [61], [62], [63], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29], [30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47] + ], + ) + ground_truth_64_32 = ParallelGroups( tp_groups=[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], @@ -278,7 +296,7 @@ def test_ascending_descending_ring_pg_group_creation(self): [32], [33], [34], [35], [36], [37], [38], [39], [40], [41], [42], [43], [44], [45], [46], [47], ], ) - ground_truth_128 = ParallelGroups( + ground_truth_128_32 = ParallelGroups( tp_groups=[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], @@ -332,7 +350,7 @@ def test_ascending_descending_ring_pg_group_creation(self): ], ) - ground_truth_256 = ParallelGroups( + ground_truth_256_32 = ParallelGroups( tp_groups=[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], @@ -471,7 +489,8 @@ def test_ascending_descending_ring_pg_group_creation(self): # for world size 128, we have tp32, pp2, dp2 # for 256 world size thats our 4 node final config. TP32, PP4, DP2 test_configs = [ - (64, 32, 2, 1, 1), # (world_size, tp, pp, dp, cp) + (64, 64, 1, 1, 1), # (world_size, tp, pp, dp, cp) + (64, 32, 2, 1, 1), (128, 32, 2, 2, 1), (256, 32, 4, 2, 1), (128, 32, 2, 1, 2), # ws128, tp32, pp2, dp1 cp2, @@ -512,7 +531,7 @@ def test_ascending_descending_ring_pg_group_creation(self): print(f"Config: {world_size=}, {tp_size=}, {cp_size=}, {dp_size=}, {pp_size=}") print(f"Got {res}") expected = (locals()[f'ground_truth_{world_size}_cp{cp_size}_tp{tp_size}'] if context_parallel_size > 1 - else locals()[f'ground_truth_{world_size}']) + else locals()[f'ground_truth_{world_size}_{tp_size}']) print(f"Expected {expected}") assert res==expected def test_pp_rank_in_group(self): diff --git a/test/unit_test/quantization/test_quantization_layers.py b/test/unit_test/quantization/test_quantization_layers.py index 8306b03..67587a7 100644 --- a/test/unit_test/quantization/test_quantization_layers.py +++ b/test/unit_test/quantization/test_quantization_layers.py @@ -108,7 +108,7 @@ def test_init(self): with self.assertRaises(AssertionError) as context: BaseQuantizeParallelLinear(quantization_type="something") self.assertTrue( - "something quantization is not supported currently. Specify from [['per_tensor_symmetric', 'per_channel_symmetric', 'blockwise_symmetric', 'expert_wise_per_channel_symmetric']]" + "something quantization is not supported currently." in str(context.exception) )