Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions configs/_cluster/helios.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ infrastructure:
nodes: 1
partition: plgrid-gpu-gh200
time: "2-00:00:00"
account: plgllmefficont3-gpu-gh200

script:
- '${export_env_variables_placeholders:}'
Expand All @@ -32,7 +33,7 @@ infrastructure:
- 'cd -'

cluster_switch:
train_path_c4: "/net/scratch/hscra/plgrid/plgmaciejpioro/c4/train"
eval_path_c4: "/net/scratch/hscra/plgrid/plgmaciejpioro/c4/validation"
train_path_fineweb: "/net/scratch/hscra/plgrid/plgmaciejpioro/fineweb-edu/train/train"
eval_path_fineweb: "/net/scratch/hscra/plgrid/plgmaciejpioro/fineweb-edu/train/train"
train_path_c4: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/c4/train"
eval_path_c4: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/c4/validation"
train_path_fineweb: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/fineweb-edu/train/train"
eval_path_fineweb: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/fineweb-edu/train/train"
36 changes: 36 additions & 0 deletions configs/_dataset/sst2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defaults:
- default
- _self_

trainer:
train_dataloader:
collate_fn:
_target_: src.product_keys.datasets.glue_collate_wrapper
_partial_: true
dataset:
_target_: src.product_keys.datasets.GlueDataset
sequence_length: ${common.sequence_length}
tokenize_fn: ???
path: "data/ft_dataset/sst2/train"
split: train
seed: 123
use_new_sampling_method: true
shuffle: true
world_size_independent: false
num_workers: 8

eval_dataloader:
collate_fn:
_target_: src.product_keys.datasets.glue_collate_wrapper
_partial_: true
dataset:
_target_: src.product_keys.datasets.GlueDataset
sequence_length: ${common.sequence_length}
tokenize_fn: ???
path: "data/ft_dataset/sst2/test"
split: validation
seed: 123
use_new_sampling_method: true
shuffle: true
world_size_independent: false
num_workers: 8
103 changes: 103 additions & 0 deletions configs/product_keys/finetune_trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# @package _global_
defaults:
- /_cluster/helios@_here_
- /_model/tiny@_here_
- /_trainer/llama@_here_
- /_dataset/c4@_here_
- /_checkpoints/none@_here_
- /_misc/default@_here_
- _self_

common:
sequence_length: 1024
batch_size: 32
dmodel: 768
dff: 2042
datt: ${common.dmodel}
n_blocks: 12
q_heads: 12
kv_heads: 12
vocab_size: 128256

trainer:
_target_: src.product_keys.finetuning_trainer.FinetuningTrainer
gradient_accumulation_steps: 2
n_steps: 200
learning_rate: 1e-3


eval_dataloader:
dataset:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn

checkpoint:
load:
type: huggingface
path: ~/checkpoint
model_checkpoint_filename: model.safetensors


infrastructure:
metric_logger:
type: wandb
wandb_entity: ideas_cv
project_name: tml-bgw
name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S}
tags:
- nano
- pk_mlm
- "seq_len=${common.sequence_length}"
- "n_layers=${common.n_blocks}"
slurm:
gres: gpu:1
time: "1-00:00:00"
job-name: ${infrastructure.metric_logger.name}

model:
encoder:
block_fn:
attention_fn:
_target_: src.product_keys.model.RoPETopKAttention
_partial_: true
dmodel: ${common.dmodel}
q_heads: ${common.q_heads}
kv_heads: ${common.kv_heads}
seq_len: ${common.sequence_length}

q_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.datt}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

k_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn}

o_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.datt}
out_features: ${common.dmodel}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

rope_base: 500000
rope_scale_freqs: true
top_k: 16
top_k_before_softmax: true
112 changes: 112 additions & 0 deletions configs/product_keys/finetune_trainer_local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# @package _global_
defaults:
- /_cluster/local@_here_
- /_model/tiny@_here_
- /_trainer/llama@_here_
- /_dataset/sst2@_here_
- /_checkpoints/none@_here_
- /_misc/default@_here_
- _self_


dataset:
seed: 42

common:
sequence_length: 16
batch_size: 4
dmodel: 16
dff: 64
datt: ${common.dmodel}
n_blocks: 4
q_heads: 2
kv_heads: 2
vocab_size: 50304


trainer:
_target_: src.product_keys.finetuning_trainer.FinetuningTrainer
gradient_accumulation_steps: 2
n_steps: 3
learning_rate: 1e-3
d_model: ${common.dmodel}
vocab_size: ${common.vocab_size}

train_dataloader:
dataset:
tokenize_fn:
_target_: src.product_keys.datasets.glue_tokenize_fn
seq_len: ${common.sequence_length}

eval_dataloader:
dataset:
tokenize_fn:
_target_: src.product_keys.datasets.glue_tokenize_fn
seq_len: ${common.sequence_length}

checkpoint:
load:
type: huggingface
path: checkpoint/2026-03-24
model_checkpoint_filename: model.safetensors
save:
type: nano
path: finetuned_checkpoint


infrastructure:
metric_logger:
type: stdout

model:
_target_: src.product_keys.model.LLM
encoder:
_target_: src.product_keys.model.TransformerEncoder
block_fn:
_target_: src.product_keys.model.TransformerBlock
_partial_: true
attention_fn:
_target_: src.product_keys.model.RoPETopKAttention
_partial_: true
dmodel: ${common.dmodel}
q_heads: ${common.q_heads}
kv_heads: ${common.kv_heads}
seq_len: ${common.sequence_length}
causal: false

q_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.datt}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

k_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn}

o_proj_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.datt}
out_features: ${common.dmodel}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1

rope_base: 500000
rope_scale_freqs: true
top_k: 8
top_k_before_softmax: true
18 changes: 11 additions & 7 deletions configs/product_keys/pk_mlm.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# @package _global_
defaults:
- /_cluster/entropy@_here_
- /_cluster/helios@_here_
- /_model/tiny@_here_
- /_trainer/llama@_here_
- /_dataset/c4@_here_
Expand Down Expand Up @@ -28,19 +28,23 @@ trainer:
n_steps: 77050
learning_rate: 5e-4
train_dataloader:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn
dataset:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn


eval_dataloader:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn
dataset:
tokenize_fn:
_target_: src.product_keys.datasets.gpt2_mask_tokenize_fn


infrastructure:
metric_logger:
name: pk_mlm
project_name: pmtest/tml-bgw
type: wandb
wandb_entity: ideas_cv
project_name: tml-bgw
name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S}
tags:
- nano
- pk_mlm
Expand Down
16 changes: 9 additions & 7 deletions configs/product_keys/top_k_attention.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# @package _global_
defaults:
- /_cluster/entropy@_here_
- /_cluster/helios@_here_
- /_model/tiny@_here_
- /_trainer/llama@_here_
- /_dataset/c4@_here_
Expand All @@ -9,7 +9,7 @@ defaults:
- _self_

common:
sequence_length: 2048
sequence_length: 1024
batch_size: 32
dmodel: 768
dff: 2042
Expand All @@ -21,7 +21,7 @@ common:

trainer:
gradient_accumulation_steps: 2
n_steps: 56000
n_steps: 40000
learning_rate: 1e-3

checkpoint:
Expand All @@ -31,13 +31,15 @@ trainer:

infrastructure:
metric_logger:
name: top_k_attention
project_name: pmtest/tml-bgw
type: wandb
wandb_entity: ideas_cv
project_name: tml-bgw
name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S}
tags:
- nano
- top_k_attention
- "lr=${trainer.learning_rate}"
- pk_mlm
- "seq_len=${common.sequence_length}"
- "n_layers=${common.n_blocks}"
slurm:
gres: gpu:1
time: "1-00:00:00"
Expand Down
Loading