From f74f1633036b34ffa44ace8eeb55543afc7430f7 Mon Sep 17 00:00:00 2001 From: ggwozdz2 Date: Mon, 23 Feb 2026 12:45:16 +0100 Subject: [PATCH 1/5] Move on with finetuning files --- configs/product_keys/finetune_trainer.yaml | 110 ++++++++++++++++++ configs/product_keys/pk_mlm.yaml | 18 +-- .../model_sequence_classifiaction.py | 33 ++++++ src/product_keys/finetuning/trainer.py | 50 ++++++++ 4 files changed, 204 insertions(+), 7 deletions(-) create mode 100644 configs/product_keys/finetune_trainer.yaml create mode 100644 src/product_keys/finetuning/model_sequence_classifiaction.py create mode 100644 src/product_keys/finetuning/trainer.py diff --git a/configs/product_keys/finetune_trainer.yaml b/configs/product_keys/finetune_trainer.yaml new file mode 100644 index 00000000..d6f0c853 --- /dev/null +++ b/configs/product_keys/finetune_trainer.yaml @@ -0,0 +1,110 @@ +# @package _global_ +defaults: + - /_cluster/helios@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/c4@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 1024 + batch_size: 64 + dmodel: 1024 + dff: 2724 + datt: ${common.dmodel} + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + vocab_size: 50304 + +trainer: + _target_: src.product_keys.finetuning.trainer.FinetuningTrainer + masking_percentage: 0.2 + mask_token_id: 50257 + unmaskable_special_tokens: [50256, 50257] # <|endoftext|> + gradient_accumulation_steps: 2 + n_steps: 77050 + learning_rate: 5e-4 + train_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + + + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + + checkpoint: + load: + type: huggingface + path: /net/scratch/hscra/plgrid/plgggwozdz02/checkpoint + model_checkpoint_filename: model.safetensors + + +infrastructure: + metric_logger: + type: wandb + wandb_entity: ideas_cv + project_name: tml-bgw + name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S} + tags: + - nano + - pk_mlm + - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + slurm: + gres: gpu:1 + time: "1-00:00:00" + job-name: ${infrastructure.metric_logger.name} + +model: + encoder: + block_fn: + attention_fn: + _target_: src.product_keys.model.RoPETopKAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + top_k: 16 + top_k_before_softmax: true diff --git a/configs/product_keys/pk_mlm.yaml b/configs/product_keys/pk_mlm.yaml index 0d636623..3c89d64f 100644 --- a/configs/product_keys/pk_mlm.yaml +++ b/configs/product_keys/pk_mlm.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -28,19 +28,23 @@ trainer: n_steps: 77050 learning_rate: 5e-4 train_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn eval_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: - name: pk_mlm - project_name: pmtest/tml-bgw + type: wandb + wandb_entity: ideas_cv + project_name: tml-bgw + name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S} tags: - nano - pk_mlm diff --git a/src/product_keys/finetuning/model_sequence_classifiaction.py b/src/product_keys/finetuning/model_sequence_classifiaction.py new file mode 100644 index 00000000..c2e5e8c3 --- /dev/null +++ b/src/product_keys/finetuning/model_sequence_classifiaction.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +class ModelSequenceClassification(nn.Module): + def __init__(self, base_model: nn.Module, hidden_size: int, num_labels: int): + super().__init__() + self.backbone = base_model + + self.score = nn.Linear(hidden_size, num_labels, bias=False) + + def forward(self, input_ids, attention_mask=None, labels=None): + outputs = self.backbone(input_ids) + + hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs + + if attention_mask is not None: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = input_ids.shape[0] + last_token_hidden_states = hidden_states[ + torch.arange(batch_size, device=hidden_states.device), + sequence_lengths + ] + else: + last_token_hidden_states = hidden_states[:, -1, :] + + logits = self.score(last_token_hidden_states) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.score.out_features), labels.view(-1)) + + return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} diff --git a/src/product_keys/finetuning/trainer.py b/src/product_keys/finetuning/trainer.py new file mode 100644 index 00000000..c2170fe7 --- /dev/null +++ b/src/product_keys/finetuning/trainer.py @@ -0,0 +1,50 @@ +from attr import define, field +import logging +import torch + +from src.core.trainer import Trainer + +logger = logging.getLogger(__name__) +from nano.src.product_keys.finetuning.model_sequence_classifiaction import ModelSequenceClassification + +# for now focus solely on sst2 +HIDDEN_SIZE = 1024 +SST2_LABELS: int = 2 + + +def create_classifier_model(model: torch.nn.Module) -> torch.nn.Module: + logger.info("Printing model shapes...") + for name, layer in model.named_modules(): + # We filter for Linear layers to keep the output readable + if isinstance(layer, torch.nn.Linear): + logger.info(f"Layer: {name} | Size: {layer.weight.shape}") + return ModelSequenceClassification(model, hidden_size=HIDDEN_SIZE, num_labels=SST2_LABELS) + + +@define(slots=False) +class FinetuningTrainer(Trainer): + freeze_backbone: bool = field(default=False) + trainable_modules: list = field(factory=list) + + def __attrs_post_init__(self): + super().__attrs_post_init__() + + if self.freeze_backbone: + self._freeze_model_layers() + + model = create_classifier_model(model) + + def _freeze_model_layers(self): + logger.info("Freezing backbone layers...") + for name, param in self.model.named_parameters(): + should_train = any(mod in name for mod in self.trainable_modules) + + if not should_train: + param.requires_grad = False + else: + param.requires_grad = True + + def save_checkpoint(self): + logger.info("Saving finetune checkpoint...") + super().save_checkpoint() + From f22633c6971166ecc7a8ad74c5abd8b381684cd0 Mon Sep 17 00:00:00 2001 From: ggwozdz2 Date: Thu, 19 Mar 2026 15:40:53 +0100 Subject: [PATCH 2/5] Working simplest version --- configs/_dataset/sst2.yaml | 36 +++++ configs/product_keys/finetune_trainer.yaml | 27 ++-- .../product_keys/finetune_trainer_local.yaml | 101 ++++++++++++++ configs/product_keys/top_k_attention.yaml | 16 ++- .../product_keys/top_k_attention_local.yaml | 97 +++++++++++++ src/core/conversion_to_hf.py | 3 +- src/product_keys/datasets.py | 132 +++++++++++++++++- .../model_sequence_classifiaction.py | 33 ----- src/product_keys/finetuning/trainer.py | 50 ------- src/product_keys/finetuning_trainer.py | 122 ++++++++++++++++ .../model_sequence_classifiaction.py | 53 +++++++ src/projected_compression/model.py | 3 + 12 files changed, 564 insertions(+), 109 deletions(-) create mode 100644 configs/_dataset/sst2.yaml create mode 100644 configs/product_keys/finetune_trainer_local.yaml create mode 100644 configs/product_keys/top_k_attention_local.yaml delete mode 100644 src/product_keys/finetuning/model_sequence_classifiaction.py delete mode 100644 src/product_keys/finetuning/trainer.py create mode 100644 src/product_keys/finetuning_trainer.py create mode 100644 src/product_keys/model_sequence_classifiaction.py diff --git a/configs/_dataset/sst2.yaml b/configs/_dataset/sst2.yaml new file mode 100644 index 00000000..eae5180b --- /dev/null +++ b/configs/_dataset/sst2.yaml @@ -0,0 +1,36 @@ +defaults: + - default + - _self_ + +trainer: + train_dataloader: + collate_fn: + _target_: src.product_keys.datasets.glue_collate_wrapper + _partial_: true + dataset: + _target_: src.product_keys.datasets.GlueDataset + sequence_length: ${common.sequence_length} + tokenize_fn: ??? + path: "data/ft_dataset/sst2/train" + split: train + seed: 123 + use_new_sampling_method: true + shuffle: true + world_size_independent: false + num_workers: 8 + + eval_dataloader: + collate_fn: + _target_: src.product_keys.datasets.glue_collate_wrapper + _partial_: true + dataset: + _target_: src.product_keys.datasets.GlueDataset + sequence_length: ${common.sequence_length} + tokenize_fn: ??? + path: "data/ft_dataset/sst2/test" + split: validation + seed: 123 + use_new_sampling_method: true + shuffle: true + world_size_independent: false + num_workers: 8 diff --git a/configs/product_keys/finetune_trainer.yaml b/configs/product_keys/finetune_trainer.yaml index d6f0c853..9453f568 100644 --- a/configs/product_keys/finetune_trainer.yaml +++ b/configs/product_keys/finetune_trainer.yaml @@ -10,27 +10,20 @@ defaults: common: sequence_length: 1024 - batch_size: 64 - dmodel: 1024 - dff: 2724 + batch_size: 32 + dmodel: 768 + dff: 2042 datt: ${common.dmodel} - n_blocks: 16 - q_heads: 16 - kv_heads: 16 - vocab_size: 50304 + n_blocks: 12 + q_heads: 12 + kv_heads: 12 + vocab_size: 128256 trainer: - _target_: src.product_keys.finetuning.trainer.FinetuningTrainer - masking_percentage: 0.2 - mask_token_id: 50257 - unmaskable_special_tokens: [50256, 50257] # <|endoftext|> + _target_: src.product_keys.finetuning_trainer.FinetuningTrainer gradient_accumulation_steps: 2 - n_steps: 77050 - learning_rate: 5e-4 - train_dataloader: - dataset: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + n_steps: 200 + learning_rate: 1e-3 eval_dataloader: diff --git a/configs/product_keys/finetune_trainer_local.yaml b/configs/product_keys/finetune_trainer_local.yaml new file mode 100644 index 00000000..6eb86a97 --- /dev/null +++ b/configs/product_keys/finetune_trainer_local.yaml @@ -0,0 +1,101 @@ +# @package _global_ +defaults: + - /_cluster/local@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/sst2@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 16 + batch_size: 4 + dmodel: 16 + dff: 64 + datt: ${common.dmodel} + n_blocks: 4 + q_heads: 2 + kv_heads: 2 + vocab_size: 128256 + + +trainer: + _target_: src.product_keys.finetuning_trainer.FinetuningTrainer + gradient_accumulation_steps: 2 + n_steps: 2 + learning_rate: 1e-3 + + train_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.glue_tokenize_fn + seq_len: ${common.sequence_length} + + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.glue_tokenize_fn + seq_len: ${common.sequence_length} + + checkpoint: + load: + type: huggingface + path: checkpoint + model_checkpoint_filename: model.safetensors + save: + type: nano + path: finetuned_checkpoint + + +infrastructure: + metric_logger: + type: stdout + +model: + encoder: + block_fn: + attention_fn: + _target_: src.product_keys.model.RoPETopKAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + top_k: 16 + top_k_before_softmax: true diff --git a/configs/product_keys/top_k_attention.yaml b/configs/product_keys/top_k_attention.yaml index 99ad5758..73f13c16 100644 --- a/configs/product_keys/top_k_attention.yaml +++ b/configs/product_keys/top_k_attention.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -9,7 +9,7 @@ defaults: - _self_ common: - sequence_length: 2048 + sequence_length: 1024 batch_size: 32 dmodel: 768 dff: 2042 @@ -21,7 +21,7 @@ common: trainer: gradient_accumulation_steps: 2 - n_steps: 56000 + n_steps: 40000 learning_rate: 1e-3 checkpoint: @@ -31,13 +31,15 @@ trainer: infrastructure: metric_logger: - name: top_k_attention - project_name: pmtest/tml-bgw + type: wandb + wandb_entity: ideas_cv + project_name: tml-bgw + name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S} tags: - nano - - top_k_attention - - "lr=${trainer.learning_rate}" + - pk_mlm - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" slurm: gres: gpu:1 time: "1-00:00:00" diff --git a/configs/product_keys/top_k_attention_local.yaml b/configs/product_keys/top_k_attention_local.yaml new file mode 100644 index 00000000..be59165b --- /dev/null +++ b/configs/product_keys/top_k_attention_local.yaml @@ -0,0 +1,97 @@ +# @package _global_ +defaults: + - /_cluster/local@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/local_dummy@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 16 + batch_size: 4 + dmodel: 16 + dff: 64 + datt: ${common.dmodel} + n_blocks: 4 + q_heads: 2 + kv_heads: 2 + vocab_size: 128256 + +trainer: + gradient_accumulation_steps: 2 + n_steps: 500 + learning_rate: 1e-3 + + checkpoint: + save: + type: huggingface + path: checkpoint + +infrastructure: + metric_logger: + type: wandb + wandb_entity: ideas_cv + project_name: tml-bgw + name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S} + tags: + - nano + - pk_mlm + - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + slurm: + gres: gpu:1 + time: "1-00:00:00" + job-name: ${infrastructure.metric_logger.name} + +evaluator: null + +model: + encoder: + block_fn: + attention_fn: + _target_: src.product_keys.model.RoPETopKAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + # o_proj_fn: ${model.encoder.block_fn.attention_fn.q_proj_fn} # TODO check have I done it right pls + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + top_k: 16 + top_k_before_softmax: true diff --git a/src/core/conversion_to_hf.py b/src/core/conversion_to_hf.py index 836abf7d..3a5b9fe6 100644 --- a/src/core/conversion_to_hf.py +++ b/src/core/conversion_to_hf.py @@ -94,6 +94,7 @@ def save_to_llama_3_hf( hf_state_dict = remap_nano_to_llama31_hf(nano_model_state_dict) hf_model.load_state_dict(hf_state_dict, strict=True) - print(f"Saving HF model with the following config {config}") + print(f"Saving HF model with the following config {config}, " + f"directory: {save_dir}") hf_model.save_pretrained(save_dir) diff --git a/src/product_keys/datasets.py b/src/product_keys/datasets.py index 1a193c9a..d7549bcf 100644 --- a/src/product_keys/datasets.py +++ b/src/product_keys/datasets.py @@ -1,4 +1,96 @@ -from transformers import GPT2TokenizerFast +import logging +import os +from typing import Callable, Optional + +import torch +from torch.utils.data import IterableDataset, DataLoader +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +from transformers import GPT2TokenizerFast, AutoTokenizer + +from src.core.datasets import AbstractDataset, collate_wrapper + + +logger = logging.getLogger(__name__) + + +class GlueDataset(AbstractDataset): + BUFFER_SIZE = 1000 + NUM_SHARDS = 64 + + def __init__( + self, + sequence_length, + tokenize_fn: Callable, + path: Optional[str] = None, + split: Optional[str] = None, + seed: Optional[int] = None, + use_new_sampling_method: bool = True, + shuffle: bool = True, + world_size_independent: bool = False, + task_name: str = "sst2", + ): + super().__init__( + sequence_length, + tokenize_fn, + path, + split, + seed, + use_new_sampling_method, + shuffle, + world_size_independent, + ) + self.task_name = task_name + self._load_dataset(path, split, seed, tokenize_fn, shuffle) + + def _load_dataset(self, path, split, seed, tokenize_fn, shuffle: bool): + if path is None: + logger.debug( + f"Loading 'nyu-mll/glue' dataset task '{self.task_name}' from HuggingFace with split={split}" + ) + hf_dataset = load_dataset( + "nyu-mll/glue", + self.task_name, + split=split, + streaming=True, + trust_remote_code=True, + ) + else: + logger.info(f"Loading dataset from path '{path}'") + logger.info(f"Split: {split}") + hf_dataset = load_dataset(path, split=split) + if not hasattr(hf_dataset, "set_epoch"): # Check if it's already an IterableDataset + hf_dataset = hf_dataset.to_iterable_dataset(num_shards=self.NUM_SHARDS) + + if not self.world_size_independent: + hf_dataset = split_dataset_by_node( + hf_dataset, rank=self.rank, world_size=self.world_size + ) + + if shuffle: + hf_dataset = hf_dataset.shuffle(buffer_size=self.BUFFER_SIZE, seed=seed) + + # Map task specific columns to 'text' for generic tokenize_fn + if self.task_name == "sst2": + hf_dataset = hf_dataset.map(lambda x: {"text": x["sentence"]}) + + self.data_generator = hf_dataset.map(tokenize_fn, batched=True) + + def sample_packer(self): + sampler = iter(self.get_infinite_sampler()) + while True: + full_sample = next(sampler) + tokens = full_sample['input_ids'] + label = full_sample['label'] + yield (tokens, label) + + def get_infinite_sampler(self): + epoch = 0 + while True: + self.data_generator.set_epoch(epoch) + for next_sample in self.data_generator: + yield next_sample + epoch += 1 def gpt2_mask_tokenize_fn(): @@ -27,3 +119,41 @@ def tokenize_function(examples): return batch_encodings return tokenize_function + +# gpt2 specific, need to add additional tokenize functions for different +def glue_tokenize_fn(seq_len: int): + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + + current_size = tokenizer.vocab_size + diff_multiple_64 = (((current_size // 64) + 1) * 64 - current_size) % 64 + tokens_to_add = diff_multiple_64 - 2 # mask token, cls token + additional_special_tokens = [f"<|extra_token_{i}|>" for i in range(tokens_to_add)] + tokenizer.add_special_tokens( + { + "mask_token": "<|mask|>", + "cls_token": "<|cls|>", + "additional_special_tokens": additional_special_tokens, + } + ) + tokenizer.pad_token = tokenizer.eos_token + + def tokenize_function(examples): + examples['text'] = [f"{tokenizer.cls_token} {text}" for text in examples['text']] + batch_encodings = tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=seq_len, + ) + return batch_encodings + + return tokenize_function + +def glue_collate_wrapper(examples): + inputs = [item[0] for item in examples] + labels = [item[1] for item in examples] + + collated_inputs = collate_wrapper(inputs) + collated_labels = torch.tensor(labels, dtype=torch.int64) + + return collated_inputs, collated_labels diff --git a/src/product_keys/finetuning/model_sequence_classifiaction.py b/src/product_keys/finetuning/model_sequence_classifiaction.py deleted file mode 100644 index c2e5e8c3..00000000 --- a/src/product_keys/finetuning/model_sequence_classifiaction.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.nn as nn - -class ModelSequenceClassification(nn.Module): - def __init__(self, base_model: nn.Module, hidden_size: int, num_labels: int): - super().__init__() - self.backbone = base_model - - self.score = nn.Linear(hidden_size, num_labels, bias=False) - - def forward(self, input_ids, attention_mask=None, labels=None): - outputs = self.backbone(input_ids) - - hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs - - if attention_mask is not None: - sequence_lengths = attention_mask.sum(dim=1) - 1 - batch_size = input_ids.shape[0] - last_token_hidden_states = hidden_states[ - torch.arange(batch_size, device=hidden_states.device), - sequence_lengths - ] - else: - last_token_hidden_states = hidden_states[:, -1, :] - - logits = self.score(last_token_hidden_states) - - loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.score.out_features), labels.view(-1)) - - return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} diff --git a/src/product_keys/finetuning/trainer.py b/src/product_keys/finetuning/trainer.py deleted file mode 100644 index c2170fe7..00000000 --- a/src/product_keys/finetuning/trainer.py +++ /dev/null @@ -1,50 +0,0 @@ -from attr import define, field -import logging -import torch - -from src.core.trainer import Trainer - -logger = logging.getLogger(__name__) -from nano.src.product_keys.finetuning.model_sequence_classifiaction import ModelSequenceClassification - -# for now focus solely on sst2 -HIDDEN_SIZE = 1024 -SST2_LABELS: int = 2 - - -def create_classifier_model(model: torch.nn.Module) -> torch.nn.Module: - logger.info("Printing model shapes...") - for name, layer in model.named_modules(): - # We filter for Linear layers to keep the output readable - if isinstance(layer, torch.nn.Linear): - logger.info(f"Layer: {name} | Size: {layer.weight.shape}") - return ModelSequenceClassification(model, hidden_size=HIDDEN_SIZE, num_labels=SST2_LABELS) - - -@define(slots=False) -class FinetuningTrainer(Trainer): - freeze_backbone: bool = field(default=False) - trainable_modules: list = field(factory=list) - - def __attrs_post_init__(self): - super().__attrs_post_init__() - - if self.freeze_backbone: - self._freeze_model_layers() - - model = create_classifier_model(model) - - def _freeze_model_layers(self): - logger.info("Freezing backbone layers...") - for name, param in self.model.named_parameters(): - should_train = any(mod in name for mod in self.trainable_modules) - - if not should_train: - param.requires_grad = False - else: - param.requires_grad = True - - def save_checkpoint(self): - logger.info("Saving finetune checkpoint...") - super().save_checkpoint() - diff --git a/src/product_keys/finetuning_trainer.py b/src/product_keys/finetuning_trainer.py new file mode 100644 index 00000000..f16b7685 --- /dev/null +++ b/src/product_keys/finetuning_trainer.py @@ -0,0 +1,122 @@ +import os +from attr import define, field +import logging +import torch +import torch.nn +from typing import Optional + +from src.core.trainer import Trainer, cast_state_dict_to_tensors +from src.product_keys.model_sequence_classifiaction import ModelSequenceClassification + +# for now focus solely on sst2 +HIDDEN_SIZE = 128256 +SST2_LABELS: int = 2 + +logger = logging.getLogger(__name__) + +def show_gradients(model: torch.nn.Module): + for name, param in model.named_parameters(): + if param.requires_grad: + if param.grad is not None: + # Calculate some basic statistics to see what the gradients look like + grad_mean = param.grad.abs().mean().item() + grad_max = param.grad.abs().max().item() + print(f"Layer: {name:<30} | Grad Mean: {grad_mean:.6f} | Grad Max: {grad_max:.6f}") + else: + print(f"Layer: {name:<30} | NO GRADIENT DEPOSITED (Disconnected layer?)") + + +def create_classifier_model(model: torch.nn.Module, + device: torch.device, + distributed: Optional[dict] = None) -> torch.nn.Module: + logger.info("Printing model shapes...") + for name, layer in model.named_modules(): + logger.info(f"Layer name: {name}") + + if hasattr(layer, 'weight') and layer.weight is not None: + logger.info(f"Layer: {name} | Size: {layer.weight.shape}") + + model = ModelSequenceClassification(model, hidden_size=HIDDEN_SIZE, num_labels=SST2_LABELS, + distributed=distributed).to(device) + + model_dtypes = set([param.dtype for param in model.parameters()]) + logger.info(f"Model dtypes: {list(model_dtypes)}") + return model + + +@define(slots=False) +class FinetuningTrainer(Trainer): + freeze_backbone: bool = field(default=False) + trainable_modules: list = field(factory=list) + loss_fct = torch.nn.CrossEntropyLoss() + + def __attrs_post_init__(self): + super().__attrs_post_init__() + + logger.info(f"{self.distributed=}") + + if self.freeze_backbone: + self._freeze_model_layers() + + self.model = create_classifier_model(self.model, + self.device, self.distributed) + + + def _freeze_model_layers(self): + logger.info("Freezing backbone layers...") + for name, param in self.model.named_parameters(): + should_train = any(mod in name for mod in self.trainable_modules) + + if not should_train: + param.requires_grad = False + else: + param.requires_grad = True + + + def train(self): + logger.info(type(self.train_dataloader)) + for step, batch in zip( + range(self.start_step, self.n_steps), self.train_dataloader + ): + self.step = step + self.metric_logger.set_step(step) + self.model.train() + texts, labels = batch + labels = labels.to(self.device) + + loss = self.calculate_loss(texts, labels) + + grad_norm = self.clip_gradient() + + self.log_metrics(loss, grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + if self._should_save_checkpoint: + self.save_checkpoint() + + if self._should_evaluate: + self.eval() + + if self._should_save_final_checkpoint: + self.save_checkpoint() + + eval() + + + def eval(self): + pass + + + def calculate_loss(self, texts, labels): + logits = self.model(texts) + # logger.info(f"{logits=}") + loss = self.loss_fct(logits, labels) + + if self.model.training: + logger.info("Backward loss") + loss.backward() + + return loss diff --git a/src/product_keys/model_sequence_classifiaction.py b/src/product_keys/model_sequence_classifiaction.py new file mode 100644 index 00000000..07b5ce56 --- /dev/null +++ b/src/product_keys/model_sequence_classifiaction.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import logging +from typing import Optional + +logger = logging.getLogger(__name__) +from omegaconf import OmegaConf +from src.core.distributed_training import setup_distributed_training + + +class ModelSequenceClassification(nn.Module): + def __init__(self, base_model: nn.Module, hidden_size: int, + num_labels: int, distributed: Optional[dict] = None): + super().__init__() + self.backbone = base_model + self.score = nn.Linear(hidden_size, num_labels, bias=False, dtype=torch.float32) + + if distributed is not None: + logger.info("Using distributed on model sequence classifier") + distributed_config = OmegaConf.create(distributed) + logger.info(f"Distributed config: {distributed_config}") + self.score = setup_distributed_training(self.score, + distributed_config=distributed_config) + + + def forward(self, input_ids, attention_mask=None): + model_dtypes = set([param.dtype for param in self.backbone.parameters()]) + logger.info(f"Backbone model dtypes: {list(model_dtypes)}") + + outputs = self.backbone(input_ids) + + logger.info(f"Outputs shape: {outputs.shape}") + + hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs + + logger.info(f"Hidden states shape: {hidden_states.shape}") + logger.info(f"Hidden states type: {hidden_states.dtype}") + + if attention_mask is not None: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = input_ids.shape[0] + last_token_hidden_states = hidden_states[ + torch.arange(batch_size, device=hidden_states.device), + sequence_lengths + ] + else: + last_token_hidden_states = hidden_states[:, -1, :] + + logits = self.score(last_token_hidden_states) + + # finetune_object = {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} + + return logits diff --git a/src/projected_compression/model.py b/src/projected_compression/model.py index 4b1371db..93dd927e 100644 --- a/src/projected_compression/model.py +++ b/src/projected_compression/model.py @@ -21,6 +21,9 @@ from torch import zeros as zeros import torch.distributed as dist import torch.nn.functional as F +import logging + +logger = logging.getLogger(__name__) def llm_random_weight_init(fan_in, scale): From 0bbba4986e4b486df8bc6467eb3c41fc71c8cf16 Mon Sep 17 00:00:00 2001 From: ggwozdz2 Date: Wed, 25 Mar 2026 11:08:02 +0100 Subject: [PATCH 3/5] Fix head swapping for classification task --- configs/product_keys/finetune_trainer.yaml | 2 +- .../product_keys/finetune_trainer_local.yaml | 19 +++- .../product_keys/top_k_attention_local.yaml | 33 ++++--- src/core/conversion_to_hf.py | 4 + src/product_keys/datasets.py | 9 +- src/product_keys/finetuning_trainer.py | 71 ++++++++++---- src/product_keys/model.py | 97 +++++++++++++++++-- .../model_sequence_classifiaction.py | 65 +++++++------ src/product_keys/trainer.py | 60 ++++++++++++ 9 files changed, 282 insertions(+), 78 deletions(-) diff --git a/configs/product_keys/finetune_trainer.yaml b/configs/product_keys/finetune_trainer.yaml index 9453f568..bcb8c6aa 100644 --- a/configs/product_keys/finetune_trainer.yaml +++ b/configs/product_keys/finetune_trainer.yaml @@ -34,7 +34,7 @@ trainer: checkpoint: load: type: huggingface - path: /net/scratch/hscra/plgrid/plgggwozdz02/checkpoint + path: ~/checkpoint model_checkpoint_filename: model.safetensors diff --git a/configs/product_keys/finetune_trainer_local.yaml b/configs/product_keys/finetune_trainer_local.yaml index 6eb86a97..4779c854 100644 --- a/configs/product_keys/finetune_trainer_local.yaml +++ b/configs/product_keys/finetune_trainer_local.yaml @@ -8,6 +8,10 @@ defaults: - /_misc/default@_here_ - _self_ + +dataset: + seed: 42 + common: sequence_length: 16 batch_size: 4 @@ -17,14 +21,16 @@ common: n_blocks: 4 q_heads: 2 kv_heads: 2 - vocab_size: 128256 + vocab_size: 50304 trainer: _target_: src.product_keys.finetuning_trainer.FinetuningTrainer gradient_accumulation_steps: 2 - n_steps: 2 + n_steps: 2000 learning_rate: 1e-3 + d_model: ${common.dmodel} + vocab_size: ${common.vocab_size} train_dataloader: dataset: @@ -41,7 +47,7 @@ trainer: checkpoint: load: type: huggingface - path: checkpoint + path: checkpoint/2026-03-24 model_checkpoint_filename: model.safetensors save: type: nano @@ -53,8 +59,12 @@ infrastructure: type: stdout model: + _target_: src.product_keys.model.LLM encoder: + _target_: src.product_keys.model.TransformerEncoder block_fn: + _target_: src.product_keys.model.TransformerBlock + _partial_: true attention_fn: _target_: src.product_keys.model.RoPETopKAttention _partial_: true @@ -62,6 +72,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} + causal: false q_proj_fn: _target_: src.projected_compression.model.Linear @@ -97,5 +108,5 @@ model: rope_base: 500000 rope_scale_freqs: true - top_k: 16 + top_k: 8 top_k_before_softmax: true diff --git a/configs/product_keys/top_k_attention_local.yaml b/configs/product_keys/top_k_attention_local.yaml index be59165b..3bf17a4d 100644 --- a/configs/product_keys/top_k_attention_local.yaml +++ b/configs/product_keys/top_k_attention_local.yaml @@ -17,33 +17,34 @@ common: n_blocks: 4 q_heads: 2 kv_heads: 2 - vocab_size: 128256 + vocab_size: 50304 trainer: + _target_: src.product_keys.trainer.TrainerWithVocabSize gradient_accumulation_steps: 2 - n_steps: 500 + n_steps: 7000 learning_rate: 1e-3 + vocab_size: ${common.vocab_size} checkpoint: save: type: huggingface - path: checkpoint + path: checkpoint/${now:%Y-%m-%d} + + train_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + infrastructure: metric_logger: - type: wandb - wandb_entity: ideas_cv - project_name: tml-bgw - name: TML_BGW-${now:%Y-%m-%d_%H-%M-%S} - tags: - - nano - - pk_mlm - - "seq_len=${common.sequence_length}" - - "n_layers=${common.n_blocks}" - slurm: - gres: gpu:1 - time: "1-00:00:00" - job-name: ${infrastructure.metric_logger.name} + type: stdout evaluator: null diff --git a/src/core/conversion_to_hf.py b/src/core/conversion_to_hf.py index 3a5b9fe6..8a07d499 100644 --- a/src/core/conversion_to_hf.py +++ b/src/core/conversion_to_hf.py @@ -3,6 +3,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM import re +from typing import Optional def remap_nano_to_llama31_hf(nano_dict): @@ -80,6 +81,7 @@ def save_to_llama_3_hf( n_kvatt_heads: int, head_dim: int, nlayers: int, + vocab_size: Optional[int] = None ): config = AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B") @@ -89,6 +91,8 @@ def save_to_llama_3_hf( config.num_key_value_heads = int(n_kvatt_heads) config.head_dim = int(head_dim) config.num_hidden_layers = int(nlayers) + if vocab_size is not None: + config.vocab_size = int(vocab_size) hf_model = AutoModelForCausalLM.from_config(config) hf_state_dict = remap_nano_to_llama31_hf(nano_model_state_dict) diff --git a/src/product_keys/datasets.py b/src/product_keys/datasets.py index d7549bcf..a4c3d295 100644 --- a/src/product_keys/datasets.py +++ b/src/product_keys/datasets.py @@ -82,7 +82,8 @@ def sample_packer(self): full_sample = next(sampler) tokens = full_sample['input_ids'] label = full_sample['label'] - yield (tokens, label) + attention_mask = full_sample['attention_mask'] + yield (tokens, label, attention_mask) def get_infinite_sampler(self): epoch = 0 @@ -144,7 +145,7 @@ def tokenize_function(examples): padding="max_length", truncation=True, max_length=seq_len, - ) + ) return batch_encodings return tokenize_function @@ -152,8 +153,10 @@ def tokenize_function(examples): def glue_collate_wrapper(examples): inputs = [item[0] for item in examples] labels = [item[1] for item in examples] + attention_masks = [item[2] for item in examples] collated_inputs = collate_wrapper(inputs) collated_labels = torch.tensor(labels, dtype=torch.int64) + collated_attention_masks = collate_wrapper(attention_masks) - return collated_inputs, collated_labels + return collated_inputs, collated_labels, collated_attention_masks diff --git a/src/product_keys/finetuning_trainer.py b/src/product_keys/finetuning_trainer.py index f16b7685..a6d906de 100644 --- a/src/product_keys/finetuning_trainer.py +++ b/src/product_keys/finetuning_trainer.py @@ -3,13 +3,15 @@ import logging import torch import torch.nn -from typing import Optional +from typing import Optional, override -from src.core.trainer import Trainer, cast_state_dict_to_tensors +from src.product_keys.trainer import TrainerWithVocabSize +from src.core.utils import create_batch_fingerprint +from src.core.metric_loggers import AveDiffMetric, AveMetric, MetricLogger, WandbLogger from src.product_keys.model_sequence_classifiaction import ModelSequenceClassification + # for now focus solely on sst2 -HIDDEN_SIZE = 128256 SST2_LABELS: int = 2 logger = logging.getLogger(__name__) @@ -28,6 +30,7 @@ def show_gradients(model: torch.nn.Module): def create_classifier_model(model: torch.nn.Module, device: torch.device, + d_model: int, distributed: Optional[dict] = None) -> torch.nn.Module: logger.info("Printing model shapes...") for name, layer in model.named_modules(): @@ -36,7 +39,7 @@ def create_classifier_model(model: torch.nn.Module, if hasattr(layer, 'weight') and layer.weight is not None: logger.info(f"Layer: {name} | Size: {layer.weight.shape}") - model = ModelSequenceClassification(model, hidden_size=HIDDEN_SIZE, num_labels=SST2_LABELS, + model = ModelSequenceClassification(model, d_model=d_model, num_labels=SST2_LABELS, distributed=distributed).to(device) model_dtypes = set([param.dtype for param in model.parameters()]) @@ -45,7 +48,8 @@ def create_classifier_model(model: torch.nn.Module, @define(slots=False) -class FinetuningTrainer(Trainer): +class FinetuningTrainer(TrainerWithVocabSize): + d_model: int freeze_backbone: bool = field(default=False) trainable_modules: list = field(factory=list) loss_fct = torch.nn.CrossEntropyLoss() @@ -58,8 +62,11 @@ def __attrs_post_init__(self): if self.freeze_backbone: self._freeze_model_layers() - self.model = create_classifier_model(self.model, - self.device, self.distributed) + self.model = create_classifier_model( + model=self.model, + device=self.device, + d_model=self.d_model, + distributed=self.distributed) def _freeze_model_layers(self): @@ -81,10 +88,8 @@ def train(self): self.step = step self.metric_logger.set_step(step) self.model.train() - texts, labels = batch - labels = labels.to(self.device) - loss = self.calculate_loss(texts, labels) + loss = self.calculate_loss(batch) grad_norm = self.clip_gradient() @@ -103,20 +108,48 @@ def train(self): if self._should_save_final_checkpoint: self.save_checkpoint() - eval() + self.eval() - + @override def eval(self): - pass - - - def calculate_loss(self, texts, labels): - logits = self.model(texts) - # logger.info(f"{logits=}") + self.model.eval() + saved_step = self.step + self.metric_logger.set_step(None) # disables heavy logging + losses = [] + eval_fingerprint = [] + with torch.no_grad(): + for _ in range(self.n_eval_steps): + batch = next(self.eval_iterator) + text, _, _ = batch + text_fingerprint = create_batch_fingerprint(text) + eval_fingerprint.extend(text_fingerprint) + loss = self.calculate_loss(batch) + losses.append(loss.item()) + self.metric_logger.flush_accumulated_metrics(self.step) + avg_loss = torch.tensor(losses).mean() + self.metric_logger.log("steps/eval/loss", self.step, avg_loss.item()) + if not isinstance(self.metric_logger, (WandbLogger)): + self.metric_logger.log( + "tokens/eval/loss", self.processed_tokens, avg_loss.item() + ) + + if self._should_log_eval_input: + self.metric_logger.log( + f"steps/eval/batch", self.step, str(eval_fingerprint) + ) + + self.step = saved_step + + @override + def calculate_loss(self, batch): + texts, labels, attention_masks = batch + texts = texts.to(self.device) + labels = labels.to(self.device) + attention_masks = attention_masks.to(self.device) + logits = self.model(texts, attention_mask=attention_masks) loss = self.loss_fct(logits, labels) if self.model.training: - logger.info("Backward loss") loss.backward() return loss diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 56a8432c..3e3809cb 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -7,6 +7,9 @@ import logging from src.core.model import AttentionMechanism, RoPE +from src.projected_compression.model import LLM as LLM_projected_compression, \ + TransformerEncoder as TransformerEncoder_projected_compression, \ + Residual as Residual_projected_compression logger = logging.getLogger(__name__) @@ -35,6 +38,7 @@ def __init__( rope_scale_freqs: bool, top_k: int, top_k_before_softmax: bool = True, + causal: bool = False ): super().__init__() self.q_proj = q_proj_fn() @@ -51,6 +55,8 @@ def __init__( self.top_k = top_k self.top_k_before_softmax = top_k_before_softmax + self.causal = causal + self.rope = RoPE( dhead=self.dhead, length=seq_len, @@ -64,7 +70,7 @@ def __apply_topk_mask(self, x, fill_value: float): mask_topk = x < threshold return x.masked_fill(mask_topk, fill_value) - def forward(self, x): + def forward(self, x, attention_mask=None): query_states = self.q_proj(x) key_states = self.k_proj(x) value_states = self.v_proj(x) @@ -85,17 +91,23 @@ def forward(self, x): # standard attention if seq_len is smaller or equal top_k if seq_len <= self.top_k: attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=self.causal ) return self.o_proj( attention_output.transpose(1, 2).contiguous().flatten(-2) ) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dhead) - causal_mask = torch.triu( - torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 - ).bool() - attention_scores = attention_scores.masked_fill(causal_mask, float("-inf")) + + if self.causal: + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 + ).bool() + attention_scores = attention_scores.masked_fill(causal_mask, float("-inf")) + + if attention_mask is not None: + pad_mask = attention_mask.unsqueeze(1).unsqueeze(2) == 0 + attention_scores = attention_scores.masked_fill(pad_mask, float("-inf")) if self.top_k_before_softmax: attention_scores = self.__apply_topk_mask( @@ -125,7 +137,7 @@ def __init__( seq_len, rope_base, rope_scale_freqs: bool, - top_k: int, + top_k: int ): super().__init__() @@ -174,7 +186,7 @@ def __gather_selected(source_tensor, idx_tensor): ) return torch.gather(source_tensor, 3, idx_expanded) - def forward(self, x): + def forward(self, x, attention_mask=None): query_states = self.q_proj(x) key_states = self.k_proj(x) value_states = self.v_proj(x) @@ -253,6 +265,10 @@ def forward(self, x): # torch.ones(seq_len, seq_len, device=attn_scores.device), diagonal=1 # ).bool() # attn_scores = attn_scores.masked_fill(causal_mask, float("-inf")) + + if attention_mask is not None: + pad_mask = attention_mask.unsqueeze(1).unsqueeze(2) == 0 + attn_scores = attn_scores.masked_fill(pad_mask, float("-inf")) attn_weights = F.softmax(attn_scores, dim=-1) @@ -260,3 +276,68 @@ def forward(self, x): attn_output = attn_output.squeeze(-2) return self.o_proj(attn_output.transpose(1, 2).contiguous().flatten(-2)) + + +class LLM(LLM_projected_compression): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): + if "attention_mask" in kwargs: + attention_mask = kwargs.pop("attention_mask") + x = self.embedding(*args, **kwargs) + x = self.encoder(x, attention_mask=attention_mask) + x = self.head(x) + return x + + +class TransformerEncoder(TransformerEncoder_projected_compression): + def forward(self, x, *args, **kwargs): + for block in self.blocks: + x = block(x, *args, **kwargs) + return x + + +class TransformerBlock(nn.Module): + def __init__( + self, + block_id, + norm_fn, + attention_fn, + ff_layer_fn, + ): + super().__init__() + self.log_name = f"block[{block_id}]" + + self.attention_layer = Residual( + norm=norm_fn(), + layer=attention_fn(), + log_name=f"{self.log_name}/residual_attention", + ) + self.ff_layer = Residual( + norm=norm_fn(), + layer=ff_layer_fn(), + log_name=f"{self.log_name}/residual_feedforward", + ) + + def forward(self, x, attention_mask=None): + x = self.attention_layer(x, attention_mask=attention_mask) + x = self.ff_layer(x) + return x + + +class Residual(Residual_projected_compression): + def forward(self, x, *args, **kwargs): + normalized = self.norm(x) + out = self.layer(normalized, *args, **kwargs) + if self.metric_logger is not None: + self.metric_logger.accumulate_metrics( + layer_name=f"{self.log_name}", + transform_fn=Residual.intermediate_norms, + calculate_fn=Residual.calculate_metrics, + metrics={ + "residual_stream": x, + "updates": out, + }, + ) + return out + x diff --git a/src/product_keys/model_sequence_classifiaction.py b/src/product_keys/model_sequence_classifiaction.py index 07b5ce56..f04ee27c 100644 --- a/src/product_keys/model_sequence_classifiaction.py +++ b/src/product_keys/model_sequence_classifiaction.py @@ -8,46 +8,57 @@ from src.core.distributed_training import setup_distributed_training +class TransformerHead(nn.Module): + def __init__(self, d_model: int, num_labels: int): + super().__init__() + self.norm = nn.LayerNorm(d_model, eps=1e-5) + self.linear = nn.Linear(d_model, num_labels, bias=False, dtype=torch.float32) + + def forward(self, x): + x = self.norm(x) + # logger.info(f"{x[:, :4, :4]=}") + return self.linear(x) + + class ModelSequenceClassification(nn.Module): - def __init__(self, base_model: nn.Module, hidden_size: int, + def __init__(self, base_model: nn.Module, d_model: int, num_labels: int, distributed: Optional[dict] = None): super().__init__() self.backbone = base_model - self.score = nn.Linear(hidden_size, num_labels, bias=False, dtype=torch.float32) + + # replace head + assert hasattr(self.backbone, "head"), "Model provided for sequence classification should have head attribute" + self.backbone.head = TransformerHead(d_model, num_labels) if distributed is not None: logger.info("Using distributed on model sequence classifier") distributed_config = OmegaConf.create(distributed) logger.info(f"Distributed config: {distributed_config}") - self.score = setup_distributed_training(self.score, - distributed_config=distributed_config) + self.backbone.head = setup_distributed_training( + self.backbone.head, distributed_config=distributed_config) + + count = 0 - def forward(self, input_ids, attention_mask=None): + model_dtypes = set([param.dtype for param in self.backbone.parameters()]) - logger.info(f"Backbone model dtypes: {list(model_dtypes)}") + logger.debug(f"Backbone model dtypes: {list(model_dtypes)}") - outputs = self.backbone(input_ids) + hidden_states = self.backbone(input_ids, attention_mask=attention_mask) - logger.info(f"Outputs shape: {outputs.shape}") - - hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs - - logger.info(f"Hidden states shape: {hidden_states.shape}") - logger.info(f"Hidden states type: {hidden_states.dtype}") - - if attention_mask is not None: - sequence_lengths = attention_mask.sum(dim=1) - 1 - batch_size = input_ids.shape[0] - last_token_hidden_states = hidden_states[ - torch.arange(batch_size, device=hidden_states.device), - sequence_lengths - ] - else: - last_token_hidden_states = hidden_states[:, -1, :] - - logits = self.score(last_token_hidden_states) - - # finetune_object = {"loss": loss, "logits": logits} if loss is not None else {"logits": logits} + logger.debug(f"Hidden states shape: {hidden_states.shape}") + logger.debug(f"Hidden states type: {hidden_states.dtype}") + + # take hidden states from [CLS] token + logits = hidden_states[:, 0, :] + logger.debug(f"CLS token hidden states shape: {logits.shape}") + + # logits = self.score(cls_token_hidden_states) + + if self.count % 5 == 0: + logger.info(f"Logits shape: {logits.shape}") + logger.info(f"{logits=}") + + self.count += 1 return logits diff --git a/src/product_keys/trainer.py b/src/product_keys/trainer.py index 83b0e106..9480f1a4 100644 --- a/src/product_keys/trainer.py +++ b/src/product_keys/trainer.py @@ -5,7 +5,10 @@ import torch.distributed as dist from typing import List +from src.core.conversion_to_hf import save_to_llama_3_hf from src.core.trainer import Trainer +from src.core.checkpointing import get_full_checkpoint_path +from src.core.utils import cast_state_dict_to_tensors @define(slots=False) @@ -86,3 +89,60 @@ def _mlm_loss_calculation(input_ids, target_ids): dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) return avg_loss / float(os.environ["WORLD_SIZE"]) + +@define(slots=False) +class TrainerWithVocabSize(Trainer): + """ + Trainer which enables saving different vocab size as checkpoint. + """ + vocab_size: int + + def train(self): + for step, batch in zip( + range(self.start_step, self.n_steps), self.train_dataloader + ): + self.step = step + self.metric_logger.set_step(step) + self.model.train() + loss = self.calculate_loss(batch) + + grad_norm = self.clip_gradient() + + self.log_metrics(loss, grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + if self._should_save_checkpoint: + self.save_checkpoint() + + if self._should_evaluate: + self.eval() + + if self._should_save_final_checkpoint: + if self.checkpoint.save.type == "nano": + self.save_checkpoint() + elif self.checkpoint.save.type == "huggingface": + # self.model.unshard() # alternative that might not work for a very large > 1gpu memory models + model_state_dict = self.model.state_dict() + full_state = cast_state_dict_to_tensors(model_state_dict) + + if os.environ["RANK"] == "0": + dmodel, dff, n_att_heads, n_kvatt_heads, head_dim, nlayers = ( + self.model.encoder.get_model_dimensions() + ) + + save_to_llama_3_hf( # dev fixed values + full_state, + save_dir=get_full_checkpoint_path(self.checkpoint.save.path), + dmodel=dmodel, + dff=dff, + n_att_heads=n_att_heads, + n_kvatt_heads=n_kvatt_heads, + head_dim=head_dim, + nlayers=nlayers, + vocab_size = self.vocab_size + ) + elif self.checkpoint.save.type == "pc_finalize": + self.save_pc_finalized_checkpoint() From cb4e53a9bba35690cb7bf503d1cef868809ff464 Mon Sep 17 00:00:00 2001 From: ggwozdz2 Date: Wed, 25 Mar 2026 11:22:31 +0100 Subject: [PATCH 4/5] Fix popping attention mask from kwargs --- src/product_keys/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 3e3809cb..173de1c0 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -283,8 +283,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): - if "attention_mask" in kwargs: - attention_mask = kwargs.pop("attention_mask") + attention_mask = kwargs.pop("attention_mask", None) x = self.embedding(*args, **kwargs) x = self.encoder(x, attention_mask=attention_mask) x = self.head(x) From df66fc7d8fe608faef6cd20ef89d92199b891037 Mon Sep 17 00:00:00 2001 From: ggwozdz2 Date: Wed, 8 Apr 2026 14:37:54 +0200 Subject: [PATCH 5/5] Wip --- configs/_cluster/helios.yaml | 9 +++---- .../product_keys/finetune_trainer_local.yaml | 2 +- src/core/llama.py | 4 ++-- src/core/model.py | 24 +++++++++++-------- src/product_keys/model.py | 2 +- src/projected_compression/model.py | 19 ++++++++------- 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/configs/_cluster/helios.yaml b/configs/_cluster/helios.yaml index b4a77aae..d6f8cfe9 100644 --- a/configs/_cluster/helios.yaml +++ b/configs/_cluster/helios.yaml @@ -9,6 +9,7 @@ infrastructure: nodes: 1 partition: plgrid-gpu-gh200 time: "2-00:00:00" + account: plgllmefficont3-gpu-gh200 script: - '${export_env_variables_placeholders:}' @@ -32,7 +33,7 @@ infrastructure: - 'cd -' cluster_switch: - train_path_c4: "/net/scratch/hscra/plgrid/plgmaciejpioro/c4/train" - eval_path_c4: "/net/scratch/hscra/plgrid/plgmaciejpioro/c4/validation" - train_path_fineweb: "/net/scratch/hscra/plgrid/plgmaciejpioro/fineweb-edu/train/train" - eval_path_fineweb: "/net/scratch/hscra/plgrid/plgmaciejpioro/fineweb-edu/train/train" + train_path_c4: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/c4/train" + eval_path_c4: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/c4/validation" + train_path_fineweb: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/fineweb-edu/train/train" + eval_path_fineweb: "/net/storage/pr3/plgrid/plggllmeffi3/datasets/fineweb-edu/train/train" diff --git a/configs/product_keys/finetune_trainer_local.yaml b/configs/product_keys/finetune_trainer_local.yaml index 4779c854..d6506f64 100644 --- a/configs/product_keys/finetune_trainer_local.yaml +++ b/configs/product_keys/finetune_trainer_local.yaml @@ -27,7 +27,7 @@ common: trainer: _target_: src.product_keys.finetuning_trainer.FinetuningTrainer gradient_accumulation_steps: 2 - n_steps: 2000 + n_steps: 3 learning_rate: 1e-3 d_model: ${common.dmodel} vocab_size: ${common.vocab_size} diff --git a/src/core/llama.py b/src/core/llama.py index 6b2c7b4c..7857a888 100644 --- a/src/core/llama.py +++ b/src/core/llama.py @@ -145,7 +145,7 @@ def __init__( self.attention_mechanism = AttentionMechanism() self.rope = LlamaRoPE(dhead=self.head_dim, length=seq_len, base=500000) - def forward(self, x): + def forward(self, x, attention_mask=None): query_states = self.q_proj(x) key_states = self.k_proj(x) value_states = self.v_proj(x) @@ -162,7 +162,7 @@ def forward(self, x): v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=self.causal + query=q, key=k, value=v, causal=self.causal, attention_mask=attention_mask ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) diff --git a/src/core/model.py b/src/core/model.py index 0d20118e..81e21805 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -46,9 +46,9 @@ def __init__(self, norm, layer, log_name): def set_metric_logger(self, metric_logger): self.metric_logger = metric_logger - def forward(self, x): + def forward(self, x, **kwargs): normalized = self.norm(x) - out = self.layer(normalized) + out = self.layer(normalized, **kwargs) if self.metric_logger is not None: self.metric_logger.accumulate_metrics( layer_name=f"{self.log_name}", @@ -179,8 +179,8 @@ def __init__( log_name=f"{self.log_name}/residual_feedforward", ) - def forward(self, x): - x = self.attention_layer(x) + def forward(self, x, **kwargs): + x = self.attention_layer(x, **kwargs) x = self.ff_layer(x) return x @@ -207,9 +207,9 @@ def __init__( super().__init__() self.blocks = nn.ModuleList([block_fn(i) for i in range(n_blocks)]) - def forward(self, x): + def forward(self, x, **kwargs): for block in self.blocks: - x = block(x) + x = block(x, **kwargs) return x @@ -265,8 +265,9 @@ def __init__( self.head = head def forward(self, *args, **kwargs): + attention_mask = kwargs.pop("attention_mask", None) x = self.embedding(*args, **kwargs) - x = self.encoder(x) + x = self.encoder(x, attention_mask=attention_mask) x = self.head(x) return x @@ -400,7 +401,7 @@ def __init__( apply_freq_scaling=rope_scale_freqs, ) - def forward(self, x): + def forward(self, x, attention_mask: Optional[torch.Tensor] = None): query_states = self.q_proj(x) key_states = self.k_proj(x) value_states = self.v_proj(x) @@ -418,7 +419,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=True, attention_mask=attention_mask ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) @@ -431,6 +432,7 @@ def attention_mechanism( key: torch.Tensor, value: torch.Tensor, causal: bool, + attention_mask: Optional[torch.Tensor] = None, ): # https://github.com/pytorch/pytorch/blob/ce503c1b40207dab770c28cbd4568cd9e105277b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp#L556 with torch.nn.attention.sdpa_kernel( @@ -440,7 +442,7 @@ def attention_mechanism( query=query, key=key, value=value, - attn_mask=None, + attn_mask=attention_mask, is_causal=causal, ) @@ -456,12 +458,14 @@ def forward( key: torch.Tensor, value: torch.Tensor, causal: bool, + attention_mask: Optional[torch.Tensor] = None, ): return attention_mechanism( query=query, key=key, value=value, causal=causal, + attention_mask=attention_mask, ) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 173de1c0..b88e848c 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -91,7 +91,7 @@ def forward(self, x, attention_mask=None): # standard attention if seq_len is smaller or equal top_k if seq_len <= self.top_k: attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=self.causal + query=q, key=k, value=v, causal=self.causal, attention_mask=attention_mask ) return self.o_proj( attention_output.transpose(1, 2).contiguous().flatten(-2) diff --git a/src/projected_compression/model.py b/src/projected_compression/model.py index 93dd927e..769adc95 100644 --- a/src/projected_compression/model.py +++ b/src/projected_compression/model.py @@ -73,9 +73,9 @@ def __init__(self, norm, layer, log_name): def set_metric_logger(self, metric_logger): self.metric_logger = metric_logger - def forward(self, x): + def forward(self, x, **kwargs): normalized = self.norm(x) - out = self.layer(normalized) + out = self.layer(normalized, **kwargs) if self.metric_logger is not None: self.metric_logger.accumulate_metrics( layer_name=f"{self.log_name}", @@ -266,7 +266,7 @@ def __init__( original_max_position_embeddings=original_max_position_embeddings, ) - def forward(self, x): + def forward(self, x, attention_mask: Optional[torch.Tensor] = None): query_states = self.q_proj(x) key_states = self.k_proj(x) value_states = self.v_proj(x) @@ -282,7 +282,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=True, attention_mask=attention_mask ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) @@ -360,8 +360,8 @@ def __init__( log_name=f"{self.log_name}/residual_feedforward", ) - def forward(self, x): - x = self.attention_layer(x) + def forward(self, x, **kwargs): + x = self.attention_layer(x, **kwargs) x = self.ff_layer(x) return x @@ -388,9 +388,9 @@ def __init__( super().__init__() self.blocks = nn.ModuleList([block_fn(i) for i in range(n_blocks)]) - def forward(self, x): + def forward(self, x, **kwargs): for block in self.blocks: - x = block(x) + x = block(x, **kwargs) return x @@ -419,8 +419,9 @@ def __init__( self.head = head def forward(self, *args, **kwargs): + attention_mask = kwargs.pop("attention_mask", None) x = self.embedding(*args, **kwargs) - x = self.encoder(x) + x = self.encoder(x, attention_mask=attention_mask) x = self.head(x) return x