From 933aaf0cd4ee17025fec70a89c2664df8c86c7f2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 14:17:45 +0800 Subject: [PATCH 1/5] fix qwen3.5 gpt_bridge lora --- src/mcore_bridge/bridge/gpt_bridge.py | 58 +++++++++++++-------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index e48fb08..13b8e98 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1238,8 +1238,8 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict = {} config = self.config num_key_heads = config.linear_num_key_heads - key_dim = config.linear_key_head_dim * num_key_heads - value_dim = config.linear_value_head_dim * config.linear_num_value_heads + key_dim = config.linear_key_head_dim + value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads if to_mcore: if isinstance(mg_attn.in_proj, LoraParallelLinear): lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() @@ -1247,18 +1247,20 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i (lora_A == hf_state_dict['in_proj_b.lora_A.weight'].load()).all() and \ (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ 'Need to ensure QKVZBA\'s lora_A are consistent' + qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() lora_B = torch.cat([ - hf_state_dict['in_proj_qkv.lora_B.weight'].load(), - hf_state_dict['in_proj_z.lora_B.weight'].load(), - hf_state_dict['in_proj_b.lora_B.weight'].load(), - hf_state_dict['in_proj_a.lora_B.weight'].load(), + qkv_lora_B.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), + hf_state_dict['in_proj_z.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), + hf_state_dict['in_proj_b.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), + hf_state_dict['in_proj_a.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), ], - dim=0) + dim=1).reshape(-1, qkv_lora_B.shape[-1]) self._set_weight(mg_attn.in_proj.lora_A[self._adapter_name].weight, lora_A, 'in_proj.lora_A.weight') self._set_weight(mg_attn.in_proj.lora_B[self._adapter_name].weight, lora_B, 'in_proj.lora_B.weight') elif not self._peft_format: qkv = hf_state_dict['in_proj_qkv.weight'].load() - q, k, v = torch.split(qkv, [key_dim, key_dim, value_dim], dim=0) + q, k, v = torch.split( + qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) in_proj_weight = torch.cat([ *(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]), *(hf_state_dict[key].load().reshape(num_key_heads, -1, config.hidden_size) @@ -1267,6 +1269,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i dim=1).reshape((-1, config.hidden_size)) in_scale_inv = None if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: + # TODO: xxx in_scale_inv = torch.cat([ hf_state_dict['in_proj_qkv.weight_scale_inv'].load(), hf_state_dict['in_proj_z.weight_scale_inv'].load(), @@ -1276,11 +1279,9 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i dim=0) self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight', hf_scale_inv=in_scale_inv) else: - key_dim = self.config.linear_key_head_dim * self.config.linear_num_key_heads - value_dim = self.config.linear_value_head_dim * self.config.linear_num_value_heads qkv_dim = key_dim * 2 + value_dim z_dim = value_dim - a_dim = config.linear_num_value_heads + a_dim = config.linear_num_value_heads // num_key_heads qkv_block = qkv_dim // self.fp8_block_size z_block = z_dim // self.fp8_block_size a_block = a_dim // self.fp8_block_size @@ -1297,32 +1298,32 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i None if mg_attn is None else mg_attn.in_proj.lora_B[self._adapter_name].weight.data, f'in_proj.lora_B.{self._adapter_name}.weight') if lora_A is not None: + lora_B = lora_B.reshape(num_key_heads, -1, lora_B.shape[-1]) self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a'}) for key in ['in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a']: hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() - hf_state_dict['in_proj_qkv.lora_B.weight'] = lora_B[:qkv_dim].clone() - hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[qkv_dim:qkv_dim + z_dim].clone() - hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[qkv_dim + z_dim:-a_dim].clone() - hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[-a_dim:].clone() + hf_state_dict['in_proj_qkv.lora_B.weight'] = lora_B[:, :qkv_dim].reshape(-1, + lora_B.shape[-1]).clone() + hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:qkv_dim + z_dim].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, qkv_dim + z_dim:-a_dim].reshape( + -1, lora_B.shape[-1]).clone() + hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() elif not self._peft_format: in_proj_weight, scale_inv = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, 'in_proj.weight') if in_proj_weight is not None: in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) - q = in_proj_weight[:, :key_dim // num_key_heads].reshape(-1, config.hidden_size) - k = in_proj_weight[:, key_dim // num_key_heads:2 * key_dim // num_key_heads].reshape( - -1, config.hidden_size) - v = in_proj_weight[:, 2 * key_dim // num_key_heads:qkv_dim // num_key_heads].reshape( - -1, config.hidden_size) + q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) + k = in_proj_weight[:, key_dim:2 * key_dim].reshape(-1, config.hidden_size) + v = in_proj_weight[:, 2 * key_dim:qkv_dim].reshape(-1, config.hidden_size) hf_state_dict['in_proj_qkv.weight'] = torch.concat([q, k, v], dim=0) - hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim // num_key_heads:(qkv_dim + z_dim) - // num_key_heads].reshape( - -1, config.hidden_size).clone() - hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim) // num_key_heads:-a_dim - // num_key_heads].reshape( - -1, config.hidden_size).clone() - hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim // num_key_heads:].reshape( + hf_state_dict['in_proj_z.weight'] = in_proj_weight[:, qkv_dim:(qkv_dim + z_dim)].reshape( + -1, config.hidden_size).clone() + hf_state_dict['in_proj_b.weight'] = in_proj_weight[:, (qkv_dim + z_dim):-a_dim].reshape( -1, config.hidden_size).clone() + hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, + config.hidden_size).clone() if scale_inv is not None: hf_state_dict['in_proj_qkv.weight_scale_inv'] = scale_inv[:qkv_block].clone() hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[qkv_block:qkv_block + z_block].clone() @@ -1342,8 +1343,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i conv1d, _ = self._get_weight(None if mg_attn is None else mg_attn.conv1d.weight, 'conv1d.weight') if conv1d is not None: conv1d = conv1d.reshape(num_key_heads, -1, *conv1d.shape[-2:]) - q_c, k_c, v_c = torch.split( - conv1d, [key_dim // num_key_heads, key_dim // num_key_heads, value_dim // num_key_heads], dim=1) + q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=1) q_c = q_c.reshape(-1, *q_c.shape[-2:]) k_c = k_c.reshape(-1, *k_c.shape[-2:]) v_c = v_c.reshape(-1, *v_c.shape[-2:]) From b45c0b274307316593348e1d08fa579acfd347b8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 14:29:02 +0800 Subject: [PATCH 2/5] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 29 ++++++++++++++++++--------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 13b8e98..d576c6d 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1240,6 +1240,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i num_key_heads = config.linear_num_key_heads key_dim = config.linear_key_head_dim value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads + hidden_size_block = config.hidden_size // self.fp8_block_size if to_mcore: if isinstance(mg_attn.in_proj, LoraParallelLinear): lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() @@ -1269,14 +1270,17 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i dim=1).reshape((-1, config.hidden_size)) in_scale_inv = None if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: - # TODO: xxx in_scale_inv = torch.cat([ - hf_state_dict['in_proj_qkv.weight_scale_inv'].load(), - hf_state_dict['in_proj_z.weight_scale_inv'].load(), - hf_state_dict['in_proj_b.weight_scale_inv'].load(), - hf_state_dict['in_proj_a.weight_scale_inv'].load(), + hf_state_dict['in_proj_qkv.weight_scale_inv'].load().reshape( + (num_key_heads, -1, hidden_size_block)), + hf_state_dict['in_proj_z.weight_scale_inv'].load().reshape( + (num_key_heads, -1, hidden_size_block)), + hf_state_dict['in_proj_b.weight_scale_inv'].load().reshape( + (num_key_heads, -1, hidden_size_block)), + hf_state_dict['in_proj_a.weight_scale_inv'].load().reshape( + (num_key_heads, -1, hidden_size_block)), ], - dim=0) + dim=1).reshape((-1, hidden_size_block)) self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight', hf_scale_inv=in_scale_inv) else: qkv_dim = key_dim * 2 + value_dim @@ -1325,10 +1329,15 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() if scale_inv is not None: - hf_state_dict['in_proj_qkv.weight_scale_inv'] = scale_inv[:qkv_block].clone() - hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[qkv_block:qkv_block + z_block].clone() - hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[qkv_block + z_block:-a_block].clone() - hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[-a_block:].clone() + scale_inv = scale_inv.reshape((num_key_heads, -1, hidden_size_block)) + hf_state_dict['in_proj_qkv.weight_scale_inv'] = scale_inv[:, :qkv_block].reshape( + -1, hidden_size_block).clone() + hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[:, qkv_block:qkv_block + z_block].reshape( + -1, hidden_size_block).clone() + hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[:, qkv_block + z_block:-a_block].reshape( + -1, hidden_size_block).clone() + hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[:, -a_block:].reshape( + -1, hidden_size_block).clone() del in_proj_weight if not self._peft_format: if to_mcore: From 8df23b76b2e6f511872f55ce9da822cf3e2e7a7a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 14:47:20 +0800 Subject: [PATCH 3/5] update --- src/mcore_bridge/bridge/gpt_bridge.py | 47 +++++++-------------------- 1 file changed, 11 insertions(+), 36 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index d576c6d..77aad7e 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1240,7 +1240,6 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i num_key_heads = config.linear_num_key_heads key_dim = config.linear_key_head_dim value_dim = config.linear_value_head_dim * config.linear_num_value_heads // num_key_heads - hidden_size_block = config.hidden_size // self.fp8_block_size if to_mcore: if isinstance(mg_attn.in_proj, LoraParallelLinear): lora_A = hf_state_dict['in_proj_qkv.lora_A.weight'].load() @@ -1249,11 +1248,12 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i (lora_A == hf_state_dict['in_proj_a.lora_A.weight'].load()).all(), \ 'Need to ensure QKVZBA\'s lora_A are consistent' qkv_lora_B = hf_state_dict['in_proj_qkv.lora_B.weight'].load() + q_lora_B, k_lora_B, v_lora_B = torch.split( + qkv_lora_B, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) lora_B = torch.cat([ - qkv_lora_B.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), - hf_state_dict['in_proj_z.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), - hf_state_dict['in_proj_b.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), - hf_state_dict['in_proj_a.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]), + *(x.reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) for x in [q_lora_B, k_lora_B, v_lora_B]), + *(hf_state_dict[f'{key}.lora_B.weight'].load().reshape(num_key_heads, -1, qkv_lora_B.shape[-1]) + for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']) ], dim=1).reshape(-1, qkv_lora_B.shape[-1]) self._set_weight(mg_attn.in_proj.lora_A[self._adapter_name].weight, lora_A, 'in_proj.lora_A.weight') @@ -1264,31 +1264,15 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i qkv, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) in_proj_weight = torch.cat([ *(x.reshape(num_key_heads, -1, config.hidden_size) for x in [q, k, v]), - *(hf_state_dict[key].load().reshape(num_key_heads, -1, config.hidden_size) - for key in ['in_proj_z.weight', 'in_proj_b.weight', 'in_proj_a.weight']), + *(hf_state_dict[f'{key}.weight'].load().reshape(num_key_heads, -1, config.hidden_size) + for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']), ], dim=1).reshape((-1, config.hidden_size)) - in_scale_inv = None - if 'in_proj_qkv.weight_scale_inv' in hf_state_dict: - in_scale_inv = torch.cat([ - hf_state_dict['in_proj_qkv.weight_scale_inv'].load().reshape( - (num_key_heads, -1, hidden_size_block)), - hf_state_dict['in_proj_z.weight_scale_inv'].load().reshape( - (num_key_heads, -1, hidden_size_block)), - hf_state_dict['in_proj_b.weight_scale_inv'].load().reshape( - (num_key_heads, -1, hidden_size_block)), - hf_state_dict['in_proj_a.weight_scale_inv'].load().reshape( - (num_key_heads, -1, hidden_size_block)), - ], - dim=1).reshape((-1, hidden_size_block)) self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight', hf_scale_inv=in_scale_inv) else: qkv_dim = key_dim * 2 + value_dim z_dim = value_dim a_dim = config.linear_num_value_heads // num_key_heads - qkv_block = qkv_dim // self.fp8_block_size - z_block = z_dim // self.fp8_block_size - a_block = a_dim // self.fp8_block_size is_lora = False if mg_attn is None else isinstance(mg_attn.in_proj, LoraParallelLinear) and self._peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') @@ -1306,8 +1290,10 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i self._peft_target_modules.update({'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a'}) for key in ['in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a']: hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() - hf_state_dict['in_proj_qkv.lora_B.weight'] = lora_B[:, :qkv_dim].reshape(-1, - lora_B.shape[-1]).clone() + q_lora_B = lora_B[:, :key_dim].reshape(-1, lora_B.shape[-1]) + k_lora_B = lora_B[:, key_dim:2 * key_dim].reshape(-1, lora_B.shape[-1]) + v_lora_B = lora_B[:, 2 * key_dim:qkv_dim].reshape(-1, lora_B.shape[-1]) + hf_state_dict['in_proj_qkv.lora_B.weight'] = torch.concat([q_lora_B, k_lora_B, v_lora_B], dim=0) hf_state_dict['in_proj_z.lora_B.weight'] = lora_B[:, qkv_dim:qkv_dim + z_dim].reshape( -1, lora_B.shape[-1]).clone() hf_state_dict['in_proj_b.lora_B.weight'] = lora_B[:, qkv_dim + z_dim:-a_dim].reshape( @@ -1328,17 +1314,6 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i -1, config.hidden_size).clone() hf_state_dict['in_proj_a.weight'] = in_proj_weight[:, -a_dim:].reshape(-1, config.hidden_size).clone() - if scale_inv is not None: - scale_inv = scale_inv.reshape((num_key_heads, -1, hidden_size_block)) - hf_state_dict['in_proj_qkv.weight_scale_inv'] = scale_inv[:, :qkv_block].reshape( - -1, hidden_size_block).clone() - hf_state_dict['in_proj_z.weight_scale_inv'] = scale_inv[:, qkv_block:qkv_block + z_block].reshape( - -1, hidden_size_block).clone() - hf_state_dict['in_proj_b.weight_scale_inv'] = scale_inv[:, qkv_block + z_block:-a_block].reshape( - -1, hidden_size_block).clone() - hf_state_dict['in_proj_a.weight_scale_inv'] = scale_inv[:, -a_block:].reshape( - -1, hidden_size_block).clone() - del in_proj_weight if not self._peft_format: if to_mcore: conv1d = hf_state_dict['conv1d.weight'].load() From 6fb2fb7d79400e4a76de4be374d574cf34793009 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 14:51:07 +0800 Subject: [PATCH 4/5] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 77aad7e..678c812 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1268,7 +1268,7 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i for key in ['in_proj_z', 'in_proj_b', 'in_proj_a']), ], dim=1).reshape((-1, config.hidden_size)) - self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight', hf_scale_inv=in_scale_inv) + self._set_weight(mg_attn.in_proj.weight, in_proj_weight, 'in_proj.weight') else: qkv_dim = key_dim * 2 + value_dim z_dim = value_dim @@ -1300,8 +1300,8 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i -1, lora_B.shape[-1]).clone() hf_state_dict['in_proj_a.lora_B.weight'] = lora_B[:, -a_dim:].reshape(-1, lora_B.shape[-1]).clone() elif not self._peft_format: - in_proj_weight, scale_inv = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, - 'in_proj.weight') + in_proj_weight, _ = self._get_weight(None if mg_attn is None else mg_attn.in_proj.weight.data, + 'in_proj.weight') if in_proj_weight is not None: in_proj_weight = in_proj_weight.reshape(num_key_heads, -1, config.hidden_size) q = in_proj_weight[:, :key_dim].reshape(-1, config.hidden_size) From 258512461f81f371734f7f548b8a261266702698 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 14 Apr 2026 15:02:07 +0800 Subject: [PATCH 5/5] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 678c812..01c61cc 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1317,7 +1317,8 @@ def _set_linear_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_i if not self._peft_format: if to_mcore: conv1d = hf_state_dict['conv1d.weight'].load() - q_c, k_c, v_c = torch.split(conv1d, [key_dim, key_dim, value_dim], dim=0) + q_c, k_c, v_c = torch.split( + conv1d, [key_dim * num_key_heads, key_dim * num_key_heads, value_dim * num_key_heads], dim=0) conv1d = torch.cat([ *(x.reshape(num_key_heads, -1, *conv1d.shape[-2:]) for x in [q_c, k_c, v_c]), ],