Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions configs/product_keys/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# @package _global_
defaults:
- /_cluster/helios@_here_
- /_model/tiny@_here_
- /_trainer/llama@_here_
- /_dataset/c4@_here_
- /_checkpoints/none@_here_
- /_misc/default@_here_
- _self_

common:
sequence_length: 1024
batch_size: 64
dmodel: 1024
dff: 2724
datt: ${common.dmodel}
n_blocks: 16
q_heads: 16
kv_heads: 16
# vocab_size: 50304
vocab_size: 128256

# trainer:
# _target_: src.product_keys.trainer.MaskedLMTrainer
# masking_percentage: 0.2
# mask_token_id: 50257
# unmaskable_special_tokens: [50256, 50257] # <|endoftext|>
# gradient_accumulation_steps: 2
# n_steps: 77050
# # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3]
# learning_rate: 1e-3
# train_dataloader:
# dataset:
# tokenize_fn:
# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn
# eval_dataloader:
# dataset:
# tokenize_fn:
# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn

trainer:
n_steps: 77050
^learning_rate: [5e-4, 1e-3, 2e-3]
# learning_rate: 1e-3

infrastructure:
metric_logger:
type: wandb
wandb_entity: ideas_cv
name: baseline_causal
project_name: tml-bgw
tags:
- baseline
- clm
- "seq_len=${common.sequence_length}"
- "n_layers=${common.n_blocks}"
- "dmodel=${common.dmodel}"
slurm:
gres: gpu:2
time: "1-00:00:00"
job-name: ${infrastructure.metric_logger.name}

model:
_target_: src.projected_compression.model.LLM

embedding:
_target_: src.projected_compression.model.TransformerEmbedding
vocab_size: ${common.vocab_size}
dmodel: ${common.dmodel}
init_fn:
_target_: src.projected_compression.model.trunc_normal_
_partial_: true

encoder:
_target_: src.projected_compression.model.TransformerEncoder
n_blocks: ${common.n_blocks}
block_fn:
_target_: src.projected_compression.model.TransformerBlock
_partial_: true
norm_fn:
_target_: src.core.model.RMSNorm
_partial_: true
eps: 1e-5
normalized_shape: ${common.dmodel}

attention_fn:
_target_: src.projected_compression.model.RoPEAttention
_partial_: true
dmodel: ${common.dmodel}
q_heads: ${common.q_heads}
kv_heads: ${common.kv_heads}
seq_len: ${common.sequence_length}
causal: true

q_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.datt}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

k_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn}

# o_proj_fn: ${model.encoder.block_fn.attention_fn.q_proj_fn} # TODO check have I done it right pls
o_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.datt}
out_features: ${common.dmodel}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

rope_base: 500000
rope_scale_freqs: true

ff_layer_fn:
_target_: src.projected_compression.model.ProjectedLlamaFeedForward
_partial_: true
ff_pre_act_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.dff}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
ff_post_act_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dff}
out_features: ${common.dmodel}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn}

head:
_target_: src.projected_compression.model.TransformerHead
linear_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.vocab_size}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
norm_fn:
_target_: src.core.model.RMSNorm
_partial_: true
eps: 1e-5
normalized_shape: ${common.dmodel}
24 changes: 14 additions & 10 deletions configs/product_keys/pk_mlm.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# @package _global_
defaults:
- /_cluster/entropy@_here_
- /_cluster/helios@_here_
- /_model/tiny@_here_
- /_trainer/llama@_here_
- /_dataset/c4@_here_
Expand All @@ -26,28 +26,31 @@ trainer:
unmaskable_special_tokens: [50256, 50257] # <|endoftext|>
gradient_accumulation_steps: 2
n_steps: 77050
learning_rate: 5e-4
^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3]
train_dataloader:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn


dataset:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn
eval_dataloader:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn
dataset:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn


infrastructure:
metric_logger:
type: wandb
wandb_entity: ideas_cv
name: pk_mlm
project_name: pmtest/tml-bgw
project_name: tml-bgw
tags:
- nano
- pk_mlm
- "seq_len=${common.sequence_length}"
- "n_layers=${common.n_blocks}"
- "dmodel=${common.dmodel}"
slurm:
gres: gpu:1
gres: gpu:2
time: "1-00:00:00"
job-name: ${infrastructure.metric_logger.name}

Expand All @@ -61,6 +64,7 @@ model:
q_heads: ${common.q_heads}
kv_heads: ${common.kv_heads}
seq_len: ${common.sequence_length}
causal: false

q_proj_fn:
_target_: src.projected_compression.model.Linear
Expand Down
Loading
Loading