Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 38 additions & 53 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,52 +1238,41 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
hf_state_dict = {}
config = self.config
num_key_heads = config.linear_num_key_heads
key_dim = config.linear_key_head_dim * num_key_heads
value_dim = config.linear_value_head_dim * config.linear_num_value_heads
key_dim = config.linear_key_head_dim
value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads
Comment thread
Jintao-Huang marked this conversation as resolved.
if to_mcore:
if isinstance(mg_attn.in_proj, LoraParallelLinear):
lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load()
assert (lora_A == hf_state_dict['in_proj_z.lora_A.weight'].load()).all() and \
(lora_A == hf_state_dict['in_proj_b.lora_A.weight'].load()).all() and \
(lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \
'Need to ensure QKVZBA\'s lora_A are consistent'
qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load()
q_lora_B, k_lora_B, v_lora_B = torch.split(
qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0)
lora_B = torch.cat([
hf_state_dict['in_proj_qkv.lora_B.weight'].load(),
hf_state_dict['in_proj_z.lora_B.weight'].load(),
hf_state_dict['in_proj_b.lora_B.weight'].load(),
hf_state_dict['in_proj_a.lora_B.weight'].load(),
*(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]),
*(hf_state_dict[f'{key}.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1])
for key in ['in_proj_z', 'in_proj_b', 'in_proj_a'])
],
dim=0)
dim=1).reshape(-1, qkv_lora_B.shape[-1])
self._set_weight(mg_attn.in_proj.lora_A[self._adapter_name].weight, lora_A, 'in_proj.lora_A.weight')
self._set_weight(mg_attn.in_proj.lora_B[self._adapter_name].weight, lora_B, 'in_proj.lora_B.weight')
elif not self._peft_format:
qkv = hf_state_dict['in_proj_qkv.weight'].load()
q, k, v = torch.split(qkv, [key_dim, key_dim, value_dim], dim=0)
q, k, v = torch.split(
qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0)
in_proj_weight = torch.cat([
*(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]),
*(hf_state_dict[key].load().reshape(num_key_heads, -1, config.hidden_size)
for key in ['in_proj_z.weight', 'in_proj_b.weight', 'in_proj_a.weight']),
*(hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size)
for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']),
],
dim=1).reshape((-1, config.hidden_size))
in_scale_inv = None
if 'in_proj_qkv.weight_scale_inv' in hf_state_dict:
in_scale_inv = torch.cat([
hf_state_dict['in_proj_qkv.weight_scale_inv'].load(),
hf_state_dict['in_proj_z.weight_scale_inv'].load(),
hf_state_dict['in_proj_b.weight_scale_inv'].load(),
hf_state_dict['in_proj_a.weight_scale_inv'].load(),
],
dim=0)
self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight', hf_scale_inv=in_scale_inv)
self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight')
Comment thread
Jintao-Huang marked this conversation as resolved.
else:
key_dim = self.config.linear_key_head_dim * self.config.linear_num_key_heads
value_dim = self.config.linear_value_head_dim * self.config.linear_num_value_heads
qkv_dim = key_dim * 2 + value_dim
z_dim = value_dim
a_dim = config.linear_num_value_heads
qkv_block = qkv_dim // self.fp8_block_size
z_block = z_dim // self.fp8_block_size
a_block = a_dim // self.fp8_block_size
a_dim = config.linear_num_value_heads // num_key_heads
is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj,
LoraParallelLinear) and self._peft_format
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
Expand All @@ -1297,42 +1286,39 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
None if mg_attn is None else mg_attn.in_proj.lora_B[self._adapter_name].weight.data,
f'in_proj.lora_B.{self._adapter_name}.weight')
if lora_A is not None:
lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1])
self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a'})
for key in ['in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a']:
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
hf_state_dict['in_proj_qkv.lora_B.weight'] = lora_B[:qkv_dim].clone()
hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[qkv_dim:qkv_dim + z_dim].clone()
hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[qkv_dim + z_dim:-a_dim].clone()
hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[-a_dim:].clone()
q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1])
k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1])
v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1])
hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0)
hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:qkv_dim + z_dim].reshape(
-1, lora_B.shape[-1]).clone()
hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, qkv_dim + z_dim:-a_dim].reshape(
-1, lora_B.shape[-1]).clone()
hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone()
Comment thread
Jintao-Huang marked this conversation as resolved.
elif not self._peft_format:
in_proj_weight, scale_inv = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data,
'in_proj.weight')
in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data,
Comment thread
Jintao-Huang marked this conversation as resolved.
'in_proj.weight')
if in_proj_weight is not None:
in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size)
q = in_proj_weight[:, :key_dim // num_key_heads].reshape(-1, config.hidden_size)
k = in_proj_weight[:, key_dim // num_key_heads:2 * key_dim // num_key_heads].reshape(
-1, config.hidden_size)
v = in_proj_weight[:, 2 * key_dim // num_key_heads:qkv_dim // num_key_heads].reshape(
-1, config.hidden_size)
q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size)
k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size)
v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size)
hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0)
hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim // num_key_heads:(qkv_dim + z_dim)
// num_key_heads].reshape(
-1, config.hidden_size).clone()
hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim) // num_key_heads:-a_dim
// num_key_heads].reshape(
-1, config.hidden_size).clone()
hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim // num_key_heads:].reshape(
hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:(qkv_dim + z_dim)].reshape(
-1, config.hidden_size).clone()
if scale_inv is not None:
hf_state_dict['in_proj_qkv.weight_scale_inv'] = scale_inv[:qkv_block].clone()
hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[qkv_block:qkv_block + z_block].clone()
hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[qkv_block + z_block:-a_block].clone()
hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[-a_block:].clone()
del in_proj_weight
hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim):-a_dim].reshape(
-1, config.hidden_size).clone()
hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1,
config.hidden_size).clone()
if not self._peft_format:
if to_mcore:
conv1d = hf_state_dict['conv1d.weight'].load()
q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0)
q_c, k_c, v_c = torch.split(
conv1d, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0)
conv1d = torch.cat([
*(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]),
],
Expand All @@ -1342,8 +1328,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i
conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight')
if conv1d is not None:
conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:])
q_c, k_c, v_c = torch.split(
conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1)
q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=1)
q_c = q_c.reshape(-1, *q_c.shape[-2:])
k_c = k_c.reshape(-1, *k_c.shape[-2:])
v_c = v_c.reshape(-1, *v_c.shape[-2:])
Expand Down
Loading