Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ elif [ -v OMPI_COMM_WORLD_RANK ]; then
export CCOM_SOCKET_IFNAME=eth0
export FI_EFA_FORK_SAFE=1

# Dataset is in shared location (Don't know where the s3 bucket for this artifact is to onboard to Kaizen)
#DATA_PATH="$SHARED_PATH_PREFIX/examples_datasets/databricks-dolly-15k"

# Store metrics in shared location
METRICS_FILE=$ARTIFACT_PATH/results.json
mkdir -p $ARTIFACT_PATH
Expand Down
7 changes: 5 additions & 2 deletions examples/training/llama/modeling_llama_nxd.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,12 @@ def forward(
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
per_token_loss = loss_fct(shift_logits, shift_labels)
valid_mask = (shift_labels != -100).float()
per_token_loss = per_token_loss * valid_mask

loss = torch.mean(loss)
denom = valid_mask.sum().clamp_min(1.0)
loss = per_token_loss.sum() / denom

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datasets import load_dataset
from transformers import AutoTokenizer

dataset_name = "wikicorpus"
dataset_name = "gboleda/wikicorpus"
dataset_config_name = "raw_en"
save_path = "~/examples_datasets/wikicorpus_gpt_neox_tokenized_2k"

Expand Down
2 changes: 1 addition & 1 deletion src/neuronx_distributed/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright Amazon Web Services and its Affiliates. All Rights Reserved.
# ==============================================================================
__version__ = "0.16.0"
__version__ = "0.17.0"
54 changes: 27 additions & 27 deletions src/neuronx_distributed/kernels/find_nonzero_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def find_nonzero_indices(input_tensor: nt.tensor,
n_partitions = nl.tile_size.pmax
quadrant=32
n_quadrant = 4
n_q7 = 8
n_q7_per_quadrant = 2
n_partition_per_q7 = 16
n_gpsimd_cores = 8
n_gpsimd_cores_per_quadrant = 2
n_partition_per_gpsimd_core = 16

# Initialize (E,T) output on HBM to all -1s.
indices = nl.ndarray((E, T), dtype=index_dtype, buffer=nl.shared_hbm)
Expand All @@ -116,80 +116,80 @@ def find_nonzero_indices(input_tensor: nt.tensor,
nonzero_counts = nl.ndarray((E, ), dtype=nl.int32, buffer=nl.shared_hbm)
nonzero_counts_local = nl.zeros((1, E_per_shard), dtype=nl.int32, buffer=nl.sbuf)

# 2. Handle experts in groups of 8: 8 Q7 cores run in parallel.
# 2. Handle experts in groups of 8: 8 GPSIMD cores run in parallel.
# Number of rounds to find index in groups of 8.
n_rounds = (E_per_shard + n_q7 - 1) // n_q7
n_rounds = (E_per_shard + n_gpsimd_cores - 1) // n_gpsimd_cores
# Number of chunks to chunk along the T dim.
n_T_chunks = T // chunk_size
for r in nl.static_range(n_rounds):
# Get number of experts to process this round.
n_e = n_q7
n_e = n_gpsimd_cores
if r == n_rounds - 1:
n_e = E_per_shard - n_q7 * r
n_e = E_per_shard - n_gpsimd_cores * r

# Counts of nonzeros so far. This is the offset at which the next chunk write should begin.
# Keep offsets in int32.
offsets = nl.zeros((1, n_q7), dtype=nl.int32, buffer=nl.sbuf)
offsets = nl.zeros((1, n_gpsimd_cores), dtype=nl.int32, buffer=nl.sbuf)

# Handle sequences in chunks with chunk_size
for c in nl.static_range(n_T_chunks):
# Load input: (T,E) layout on HBM
# Q7 kernel needs: (128, chunk_size) with partition 0, 16, 32, ... each filled with a different expert.
input_local_q7_aligned = nl.ndarray((n_partitions, 1, chunk_size), dtype=input_tensor.dtype, buffer=nl.sbuf)
# GPSIMD kernel needs: (128, chunk_size) with partition 0, 16, 32, ... each filled with a different expert.
input_local_gpsimd_core_aligned = nl.ndarray((n_partitions, 1, chunk_size), dtype=input_tensor.dtype, buffer=nl.sbuf)
n_tiles = chunk_size // n_partitions
for i in nl.affine_range(n_tiles):
# Read (128, n_e) chunk
i_packed_t = nl.arange(n_partitions)[:, None]
i_packed_e = nl.arange(n_e)[None, :]
offset = r * n_q7 + row_start_id + E_offset
offset = r * n_gpsimd_cores + row_start_id + E_offset
input_local_te = nl.load(input_tensor[i_packed_t + c * chunk_size + i * n_partitions, i_packed_e + offset])

# Copy to (128, 128) chunk, putting the n_e columns at column 0, 16, 32, ..., 112
input_local_aligned_te = nl.ndarray((n_partitions, n_partitions), dtype=input_tensor.dtype, buffer=nl.sbuf)
for q in nl.affine_range(n_e):
input_local_aligned_te[:, nl.ds(q * n_partition_per_q7, 1)] = input_local_te[:, nl.ds(q, 1)]
input_local_aligned_te[:, nl.ds(q * n_partition_per_gpsimd_core, 1)] = input_local_te[:, nl.ds(q, 1)]

# Transpose so we have expert data at row 0, 16, 32, ..., 112
input_local_q7_aligned[:, 0, nl.ds(i*n_partitions, n_partitions)] = nisa.nc_transpose(input_local_aligned_te)
input_local_gpsimd_core_aligned[:, 0, nl.ds(i*n_partitions, n_partitions)] = nisa.nc_transpose(input_local_aligned_te)

# Run Q7 kernel for ISA NonzeroWithCount.
output_local = nki_asm_nonzero_with_count(input_local_q7_aligned, c*chunk_size)
# Run GPSIMD kernel for ISA NonzeroWithCount.
output_local = nki_asm_nonzero_with_count(input_local_gpsimd_core_aligned, c*chunk_size)

# Write out rows 0, 32, 64, 96
i_0 = nl.arange(1)[:, None]
i_1 = nl.arange(1)[None, :]
i_indices = nl.arange(chunk_size)[None, :]
for q in nl.affine_range(n_quadrant):
nl.store(
indices[i_0 + E_offset + r * n_q7 + q * n_q7_per_quadrant, i_indices + offsets[i_0, i_1 + q * n_q7_per_quadrant]],
indices[i_0 + E_offset + r * n_gpsimd_cores + q * n_gpsimd_cores_per_quadrant, i_indices + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant]],
value=output_local[nl.ds(q*quadrant, 1), 0, nl.ds(0, chunk_size)],
mask=i_0 + q * n_q7_per_quadrant < n_e
mask=i_0 + q * n_gpsimd_cores_per_quadrant < n_e
)
offsets[i_0, i_1 + q * n_q7_per_quadrant] = nl.add(
offsets[i_0, i_1 + q * n_q7_per_quadrant],
offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant] = nl.add(
offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant],
output_local[nl.ds(q*quadrant, 1), 0, nl.ds(chunk_size, 1)],
mask=i_0 + q * n_q7_per_quadrant < n_e
mask=i_0 + q * n_gpsimd_cores_per_quadrant < n_e
)

# Stream shuffle to move rows 16, 48, 80, 112 to rows 0, 32, 64, 96, and then write them out
quad_mask = [255] * quadrant
quad_mask[0] = n_partition_per_q7
quad_mask[0] = n_partition_per_gpsimd_core
nisa.nc_stream_shuffle(src=output_local, dst=output_local, shuffle_mask = quad_mask)
for q in nl.affine_range(n_quadrant):
nl.store(
indices[i_0 + E_offset + r * n_q7 + q * n_q7_per_quadrant + 1, i_indices + offsets[i_0, i_1 + q * n_q7_per_quadrant + 1]],
indices[i_0 + E_offset + r * n_gpsimd_cores + q * n_gpsimd_cores_per_quadrant + 1, i_indices + offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant + 1]],
value=output_local[nl.ds(q*quadrant, 1), 0, nl.ds(0, chunk_size)],
mask=i_0 + q * n_q7_per_quadrant + 1 < n_e
mask=i_0 + q * n_gpsimd_cores_per_quadrant + 1 < n_e
)
offsets[i_0, i_1 + q * n_q7_per_quadrant + 1] = nl.add(
offsets[i_0, i_1 + q * n_q7_per_quadrant + 1],
offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant + 1] = nl.add(
offsets[i_0, i_1 + q * n_gpsimd_cores_per_quadrant + 1],
output_local[nl.ds(q*quadrant, 1), 0, nl.ds(chunk_size, 1)],
mask=i_0 + q * n_q7_per_quadrant + 1 < n_e
mask=i_0 + q * n_gpsimd_cores_per_quadrant + 1 < n_e
)

# Final offsets are the nonzero counts per expert.
i_n_e = nl.arange(n_e)[None, :]
nonzero_counts_local[i_0, i_n_e + r * n_q7] = offsets[i_0, i_n_e]
nonzero_counts_local[i_0, i_n_e + r * n_gpsimd_cores] = offsets[i_0, i_n_e]

nonzero_counts_reshape = nonzero_counts.reshape((1, E))
nl.store(nonzero_counts_reshape[:, nl.ds(E_offset, E_per_shard)], nonzero_counts_local)
Expand Down
19 changes: 17 additions & 2 deletions src/neuronx_distributed/modules/moe/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class BlockwiseMatmulArgs:
gate_clamp_lower_limit: Optional[float] = None
up_clamp_upper_limit: Optional[float] = None
up_clamp_lower_limit: Optional[float] = None

use_shard_on_block_dynamic_while: bool = False
# Optional Input tensors
conditions: Optional[torch.Tensor] = None
num_static_blocks: Optional[int] = None
Expand Down Expand Up @@ -1229,8 +1229,21 @@ def backward(ctx, grad_output):
)

def _call_bwmm_fp4_shard_on_block(args: BlockwiseMatmulArgs):
#directly inject the conditions vector here
# FIXME: move this code to expert_mlps_v2.py instead of injecting it directly here

# Reshape the token positions into blocks
padded_conditions = None
if args.use_shard_on_block_dynamic_while:
num_blocks = args.block_to_expert.shape[0]
blocks = args.token_position_to_id.view(num_blocks, args.block_size)
# Check each block for non padded tokens (any position != -1)
conditions = torch.any(blocks != -1, dim=1).to(torch.int32)
padded_conditions = torch.cat([conditions, torch.zeros(2, device=conditions.device)])

output = _bwmm_fp4_shard_on_block_nki_call[VNC(2)](
hidden_states=args.hidden_states,
conditions=padded_conditions,
expert_affinities_masked=args.expert_affinities_masked,
gate_up_proj_weight=args.gate_up_proj_weight,
down_proj_weight=args.down_proj_weight,
Expand Down Expand Up @@ -1326,12 +1339,14 @@ def forward(
gate_clamp_lower_limit=gate_clamp_lower_limit,
up_clamp_upper_limit=up_clamp_upper_limit,
up_clamp_lower_limit=up_clamp_lower_limit,
use_shard_on_block_dynamic_while=blockwise_matmul_config.use_shard_on_block_dynamic_while,
gate_up_proj_bias=gate_up_proj_bias,
down_proj_bias=down_proj_bias,
kernel_act_fn=ActivationFunction(kernel_act_fn_id),
is_tensor_update_accumulating=multi_expert_per_token,
expert_affinities_scaling_mode=expert_affinities_scaling_mode,
output=output
output=output,
skip_dma=skip_dma,
)

output = _call_bwmm_fp4_shard_on_block(args)
Expand Down
30 changes: 27 additions & 3 deletions src/neuronx_distributed/modules/moe/expert_mlps_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,11 @@ def forward_blockwise(self, hidden_states, expert_affinities, expert_index, expe
num_blocks = math.ceil((total_tokens * self.routed_experts_mlp_config.top_k - (local_experts - 1)) / self.blockwise_matmul_config.block_size) + local_experts - 1
# Handle case where T*top_k is smaller than E. We will need atmost T*top_k blocks.
num_blocks = min(num_blocks, total_tokens * self.routed_experts_mlp_config.top_k)
# Padding num_blocks to even (TODO: currently only for MXFP4 BWMM kernel to support dynamic while)

pad_num_blocks_to_even = self.blockwise_matmul_config.pad_num_blocks_to_even
if pad_num_blocks_to_even:
num_blocks += num_blocks % 2

# Get num_static_block from blockwise_matmul_config, if not set, the default num_static_blocks will be computed as
# NUM_STATIC_BLOCK = T * TopK / (EP_degree * B)
Expand Down Expand Up @@ -754,6 +759,7 @@ def forward_blockwise(self, hidden_states, expert_affinities, expert_index, expe
tensor_parallel_group=self.moe_tensor_model_parallel_group,
expert_parallel_group=self.moe_expert_model_parallel_group,
logical_nc_config=self.blockwise_matmul_config.logical_nc_config,
pad_num_blocks_to_even=pad_num_blocks_to_even
)
else:
# expert_mask: (T, E). Still happens on all the experts, not just the local experts
Expand Down Expand Up @@ -823,6 +829,7 @@ def forward_blockwise(self, hidden_states, expert_affinities, expert_index, expe
tensor_parallel_group=self.moe_tensor_model_parallel_group,
optimized_block_to_token_mapping=ues_optimized_block_to_token_mapping,
parallelize_token_to_block_mapping=self.blockwise_matmul_config.parallelize_token_to_block_mapping,
pad_num_blocks_to_even=pad_num_blocks_to_even,
)

if self.blockwise_matmul_config.use_shard_on_block_dynamic_while:
Expand Down Expand Up @@ -1079,6 +1086,7 @@ def get_blockwise_expert_and_token_mapping_kernel(
tensor_parallel_group: ProcessGroup,
expert_parallel_group: ProcessGroup,
logical_nc_config: int,
pad_num_blocks_to_even: bool = False,
):
'''
Equivalent function of get_blockwise_expert_and_token_mapping, but nstead of torch code,
Expand Down Expand Up @@ -1109,6 +1117,7 @@ def get_blockwise_expert_and_token_mapping_kernel(
tp_rank = torch.remainder(global_rank, tp_size)
ep_size = expert_parallel_group.size()
ep_rank = global_rank // tp_size
max_chunk_size = 16384

# Every EP rank needs the information about its local expert (`E_local` of those).
E_local = E // ep_size
Expand All @@ -1123,7 +1132,7 @@ def get_blockwise_expert_and_token_mapping_kernel(
input_tensor=expert_affinities_masked.to(torch.float32),
row_start_id=ep_rank*E_kernel,
n_rows=E_kernel,
chunk_size = T,
chunk_size=min(T, max_chunk_size),
index_dtype=nl.int32,
)
else:
Expand All @@ -1136,7 +1145,7 @@ def get_blockwise_expert_and_token_mapping_kernel(
input_tensor=expert_affinities_masked.to(torch.float32),
row_start_id=global_rank*E_kernel,
n_rows=E_kernel,
chunk_size = T,
chunk_size=min(T, max_chunk_size),
index_dtype=nl.int32,
)
# Gather nonzero_counts: [E_kernel,] --> [E/EP,]
Expand All @@ -1146,6 +1155,13 @@ def get_blockwise_expert_and_token_mapping_kernel(

# Get number of blocks and cumulative number of blocks per expert.
blocks_per_expert = ((nonzero_counts + block_size - 1) // block_size).to(dtype=torch.long) # (E_EP,)

# Calculate padding blocks needed and add to last expert
if pad_num_blocks_to_even:
total_needed_blocks = torch.sum(blocks_per_expert)
padding_blocks = num_blocks - total_needed_blocks
blocks_per_expert[-1] += padding_blocks

blocks_per_expert_expanded = blocks_per_expert.unsqueeze(1) # (E_EP, 1)
cum_blocks_per_expert = cumsum(blocks_per_expert_expanded) # (E_EP, 1)
cum_blocks_per_expert[1:] = cum_blocks_per_expert[:-1]
Expand Down Expand Up @@ -1199,7 +1215,8 @@ def get_blockwise_expert_and_token_mapping(
tensor_parallel_group,
optimized_block_to_token_mapping=True,
parallelize_token_to_block_mapping=False,
):
pad_num_blocks_to_even=False,
):
"""
Token position: position in blocks.
E.g. given block_size=2, num_token=6. The following expert_mask
Expand Down Expand Up @@ -1242,6 +1259,13 @@ def get_blockwise_expert_and_token_mapping(
tokens_per_expert = torch.sum(expert_mask, dim=0)
# blocks_per_expert: (E, )
blocks_per_expert = ((tokens_per_expert + block_size - 1) // block_size).to(dtype=torch.long)

# Calculate padding blocks needed and add to last expert
if pad_num_blocks_to_even:
total_needed_blocks = torch.sum(blocks_per_expert)
padding_blocks = num_blocks - total_needed_blocks
blocks_per_expert[-1] += padding_blocks

# block_to_expert: (N, ). Block id to expert id mapping.
# The simplest way to do this is to use repeat_interleave after padding blocks_per_expert with unassigned blocks.
# But this op is not lowered to xla with vector 'repeats', so we use the equivalent implementation below.
Expand Down
Loading