diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index e48fb08..01c61cc 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1238,8 +1238,8 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict = {} config = self.config num_key_heads = config.linear_num_key_heads - key_dim = config.linear_key_head_dim * num_key_heads - value_dim = config.linear_value_head_dim * config.linear_num_value_heads + key_dim = config.linear_key_head_dim + value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads if to_mcore: if isinstance(mg_attn.in_proj, LoraParallelLinear): lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() @@ -1247,43 +1247,32 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i (lora_A == hf_state_dict['in_proj_b.lora_A.weight'].load()).all() and \ (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ 'Need to ensure QKVZBA\'s lora_A are consistent' + qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() + q_lora_B, k_lora_B, v_lora_B = torch.split( + qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) lora_B = torch.cat([ - hf_state_dict['in_proj_qkv.lora_B.weight'].load(), - hf_state_dict['in_proj_z.lora_B.weight'].load(), - hf_state_dict['in_proj_b.lora_B.weight'].load(), - hf_state_dict['in_proj_a.lora_B.weight'].load(), + *(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]), + *(hf_state_dict[f'{key}.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) + for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']) ], - dim=0) + dim=1).reshape(-1, qkv_lora_B.shape[-1]) self._set_weight(mg_attn.in_proj.lora_A[self._adapter_name].weight, lora_A, 'in_proj.lora_A.weight') self._set_weight(mg_attn.in_proj.lora_B[self._adapter_name].weight, lora_B, 'in_proj.lora_B.weight') elif not self._peft_format: qkv = hf_state_dict['in_proj_qkv.weight'].load() - q, k, v = torch.split(qkv, [key_dim, key_dim, value_dim], dim=0) + q, k, v = torch.split( + qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) in_proj_weight = torch.cat([ *(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]), - *(hf_state_dict[key].load().reshape(num_key_heads, -1, config.hidden_size) - for key in ['in_proj_z.weight', 'in_proj_b.weight', 'in_proj_a.weight']), + *(hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size) + for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']), ], dim=1).reshape((-1, config.hidden_size)) - in_scale_inv = None - if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: - in_scale_inv = torch.cat([ - hf_state_dict['in_proj_qkv.weight_scale_inv'].load(), - hf_state_dict['in_proj_z.weight_scale_inv'].load(), - hf_state_dict['in_proj_b.weight_scale_inv'].load(), - hf_state_dict['in_proj_a.weight_scale_inv'].load(), - ], - dim=0) - self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight', hf_scale_inv=in_scale_inv) + self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight') else: - key_dim = self.config.linear_key_head_dim * self.config.linear_num_key_heads - value_dim = self.config.linear_value_head_dim * self.config.linear_num_value_heads qkv_dim = key_dim * 2 + value_dim z_dim = value_dim - a_dim = config.linear_num_value_heads - qkv_block = qkv_dim // self.fp8_block_size - z_block = z_dim // self.fp8_block_size - a_block = a_dim // self.fp8_block_size + a_dim = config.linear_num_value_heads // num_key_heads is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj, LoraParallelLinear) and self._peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') @@ -1297,42 +1286,39 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i None if mg_attn is None else mg_attn.in_proj.lora_B[self._adapter_name].weight.data, f'in_proj.lora_B.{self._adapter_name}.weight') if lora_A is not None: + lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a'}) for key in ['in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a']: hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() - hf_state_dict['in_proj_qkv.lora_B.weight'] = lora_B[:qkv_dim].clone() - hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[qkv_dim:qkv_dim + z_dim].clone() - hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[qkv_dim + z_dim:-a_dim].clone() - hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[-a_dim:].clone() + q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1]) + k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1]) + v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1]) + hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0) + hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:qkv_dim + z_dim].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, qkv_dim + z_dim:-a_dim].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() elif not self._peft_format: - in_proj_weight, scale_inv = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, - 'in_proj.weight') + in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, + 'in_proj.weight') if in_proj_weight is not None: in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) - q = in_proj_weight[:, :key_dim // num_key_heads].reshape(-1, config.hidden_size) - k = in_proj_weight[:, key_dim // num_key_heads:2 * key_dim // num_key_heads].reshape( - -1, config.hidden_size) - v = in_proj_weight[:, 2 * key_dim // num_key_heads:qkv_dim // num_key_heads].reshape( - -1, config.hidden_size) + q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) + k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size) + v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size) hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) - hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim // num_key_heads:(qkv_dim + z_dim) - // num_key_heads].reshape( - -1, config.hidden_size).clone() - hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim) // num_key_heads:-a_dim - // num_key_heads].reshape( - -1, config.hidden_size).clone() - hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim // num_key_heads:].reshape( + hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:(qkv_dim + z_dim)].reshape( -1, config.hidden_size).clone() - if scale_inv is not None: - hf_state_dict['in_proj_qkv.weight_scale_inv'] = scale_inv[:qkv_block].clone() - hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[qkv_block:qkv_block + z_block].clone() - hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[qkv_block + z_block:-a_block].clone() - hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[-a_block:].clone() - del in_proj_weight + hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim):-a_dim].reshape( + -1, config.hidden_size).clone() + hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, + config.hidden_size).clone() if not self._peft_format: if to_mcore: conv1d = hf_state_dict['conv1d.weight'].load() - q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0) + q_c, k_c, v_c = torch.split( + conv1d, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) conv1d = torch.cat([ *(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]), ], @@ -1342,8 +1328,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight') if conv1d is not None: conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:]) - q_c, k_c, v_c = torch.split( - conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1) + q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=1) q_c = q_c.reshape(-1, *q_c.shape[-2:]) k_c = k_c.reshape(-1, *k_c.shape[-2:]) v_c = v_c.reshape(-1, *v_c.shape[-2:])