-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathphi2_adapter.py
More file actions
257 lines (201 loc) · 8.47 KB
/
phi2_adapter.py
File metadata and controls
257 lines (201 loc) · 8.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# This file contains derivations from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi/modeling_phi.py
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# License updated to MIT license since 7e10f3e in https://huggingface.co/microsoft/phi-2/blob/main/LICENSE
import torch
from torch import FloatTensor, LongTensor, Tensor, matmul
from torch.nn import LayerNorm, Linear, Module
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from transformers.models.phi.modeling_phi import PhiConfig, PhiDecoderLayer, PhiForCausalLM
from slicegpt.model_adapter import LayerAdapter, ModelAdapter
class CompressedPhiDecoderLayer(PhiDecoderLayer):
"""
This class simulates the PhiDecoderlayer class from PhiModel (PhiForCausalLM)
https://huggingface.co/microsoft/phi-2/blob/main/modeling_phi.py
but with the addition of a shortcut_Q attribute. This attribute is used to rotate the residual tensors.
"""
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor | None = None,
position_ids: LongTensor | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
past_key_value: tuple[Tensor] | None = None,
) -> tuple:
"""
Args:
hidden_states (`torch.FloatTensor`):
input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
if self.attn_shortcut_Q is not None:
rotated_residual = matmul(residual, self.attn_shortcut_Q)
hidden_states = attn_outputs + feed_forward_hidden_states + rotated_residual
else:
hidden_states = attn_outputs + feed_forward_hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class Phi2LayerAdapter(LayerAdapter):
def __init__(self, layer: PhiDecoderLayer) -> None:
super().__init__()
self._layer: PhiDecoderLayer = layer
@property
def layer(self) -> Module:
return self._layer
@property
def hidden_states_args_position(self) -> int:
return 0
@property
def hidden_states_output_position(self) -> int:
return 0
def get_first_layernorm(self) -> Module:
return self.layer.input_layernorm
def get_second_layernorm(self) -> Module:
return None
def get_attention_inputs(self) -> list[Linear]:
return [self.layer.self_attn.q_proj, self.layer.self_attn.k_proj, self.layer.self_attn.v_proj]
def get_attention_output(self) -> Linear:
return self.layer.self_attn.dense
def get_mlp_inputs(self) -> list[Linear]:
return [self.layer.mlp.fc1]
def get_mlp_output(self) -> Linear:
return self.layer.mlp.fc2
class Phi2ModelAdapter(ModelAdapter):
def __init__(self, model: PhiForCausalLM) -> None:
super().__init__()
self._model: PhiForCausalLM = model
@property
def model(self) -> Module:
return self._model
@property
def config(self) -> PretrainedConfig:
return self._model.config
@property
def config_type(self) -> type:
return PhiConfig
@property
def parallel_blocks(self) -> bool:
return True
@property
def seqlen(self) -> int:
return self.config.max_position_embeddings
@property
def hidden_size(self) -> int:
return self.config.hidden_size
@property
def should_bake_mean_into_linear(self) -> bool:
return True
@property
def original_layer_type(self) -> type:
return PhiDecoderLayer
@property
def original_layer_norm_type(self) -> type:
return LayerNorm
@property
def layer_adapter_type(self) -> type:
return Phi2LayerAdapter
@property
def compressed_layer_type(self) -> type:
return CompressedPhiDecoderLayer
@property
def use_cache(self) -> bool:
return self.config.use_cache
@use_cache.setter
def use_cache(self, value: bool) -> None:
self.config.use_cache = value
def compute_output_logits(self, input_ids: Tensor) -> FloatTensor:
return self.model(input_ids=input_ids).logits
def convert_layer_to_compressed(self, layer: Module, layer_idx: int | None) -> Module:
compressed_layer = self.compressed_layer_type(self.config, layer_idx).to(self.config.torch_dtype)
compressed_layer.load_state_dict(layer.state_dict(), strict=True)
return compressed_layer
def get_layers(self) -> list[LayerAdapter]:
return [self.layer_adapter_type(layer) for layer in self.model.model.layers]
def get_raw_layer_at(self, index: int) -> Module:
return self.model.model.layers[index]
def set_raw_layer_at(self, index: int, new_layer: Module) -> None:
self.model.model.layers[index] = new_layer
def get_embeddings(self) -> list[Module]:
return [self.model.model.embed_tokens]
def get_pre_head_layernorm(self) -> Module:
pre_head_layernorm = self.model.model.final_layernorm
assert pre_head_layernorm is not None
return pre_head_layernorm
def get_lm_head(self) -> Linear:
return self.model.lm_head
def post_init(self, tokenizer: PreTrainedTokenizerBase) -> None:
# Phi-2 doesn't have a pad token by default
tokenizer.pad_token = tokenizer.eos_token
self.config.pad_token_id = tokenizer.pad_token_id
@classmethod
def _from_pretrained(
cls,
model_name: str,
model_path: str,
*,
dtype: torch.dtype = torch.float16,
local_files_only: bool = False,
token: str | bool | None = None,
) -> ModelAdapter | None:
if model_name != "microsoft/phi-2":
return None
model = PhiForCausalLM.from_pretrained(
model_path, torch_dtype=dtype, token=token, local_files_only=local_files_only
)
model.config.torch_dtype = dtype
return Phi2ModelAdapter(model)
@classmethod
def _from_uninitialized(
cls,
model_name: str,
model_path: str,
*,
dtype: torch.dtype = torch.float16,
local_files_only: bool = False,
token: str | bool | None = None,
) -> ModelAdapter | None:
if model_name != "microsoft/phi-2":
return None
class UninitializedPhiForCausalLM(PhiForCausalLM):
def _init_weights(self, _) -> None:
# Prevent weight initialization
pass
config = PhiConfig.from_pretrained(
model_path, torch_dtype=dtype, token=token, local_files_only=local_files_only
)
model = UninitializedPhiForCausalLM(config)
model = model.to(dtype=dtype)
return Phi2ModelAdapter(model)