From 9373544b088b8afed8a33679a0d286f131b425f5 Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Fri, 24 Apr 2026 16:20:27 +0300 Subject: [PATCH 1/2] All config changes together --- scripts/create_finetuning_config_schema.py | 3 - src/finetuning/config/hf_integration.py | 12 +- src/finetuning/config/run.py | 180 ++++++++++++++++----- src/finetuning/data/chat_templates.py | 6 +- src/finetuning/data/setup.py | 2 +- src/finetuning/model.py | 3 +- src/finetuning/sft.py | 54 ++++--- 7 files changed, 195 insertions(+), 65 deletions(-) diff --git a/scripts/create_finetuning_config_schema.py b/scripts/create_finetuning_config_schema.py index 25a307a..040cee6 100644 --- a/scripts/create_finetuning_config_schema.py +++ b/scripts/create_finetuning_config_schema.py @@ -24,9 +24,6 @@ class MasterConfig(BaseConfig): {"type": "null"}, ] -# Remove const from torch_dtype (it's not a constant - this is a surprising pydantic output) -del schema["$defs"]["ModelArguments"]["properties"]["torch_dtype"]["const"] - print("NOTE: only $defs, add these to the schema.json in the appropriate place") print("---") for line in indent(json.dumps({"$defs": schema["$defs"]}, indent=2), 2 * " ").splitlines()[1:-2]: diff --git a/src/finetuning/config/hf_integration.py b/src/finetuning/config/hf_integration.py index 131d3ad..9f0d536 100644 --- a/src/finetuning/config/hf_integration.py +++ b/src/finetuning/config/hf_integration.py @@ -1,6 +1,7 @@ """HuggingFace Configs available through the config interface""" from dataclasses import dataclass, field +from textwrap import dedent from typing import Any, Dict, List, Literal, Optional, Union import peft @@ -263,7 +264,16 @@ class SFTArguments(BaseConfig): """Supervised fine-tuning arguments""" max_seq_length: int = Field( - default=2048, description="Maximum length input sequence length. Longer sequences will be filtered out." + default=2048, + description="Maximum length input sequence length. Longer sequences will be filtered or truncated.", + ) + length_handling: Literal["filter", "truncate"] = Field( + default="filter", + description=dedent( + """How to handle examples that are longer than max_seq_length. \ + 'filter': Filter out these examples from the training set. \ + 'truncate': Truncate these examples to max_seq_length. Note that this might lead to loss of information and worse performance, especially if the important information is at the end of the sequence.""" + ), ) # This is only used if a new basemodel needs to be saved, e.g. if the embeddings are grown to account for new # tokens. diff --git a/src/finetuning/config/run.py b/src/finetuning/config/run.py index ca1def2..5bdb1e1 100644 --- a/src/finetuning/config/run.py +++ b/src/finetuning/config/run.py @@ -1,9 +1,10 @@ +import os from textwrap import dedent from typing import Dict, Literal, Optional import torch from accelerate import DistributedType, PartialState -from pydantic import ConfigDict, Field, field_serializer, field_validator +from pydantic import ConfigDict, Field, field_serializer, field_validator, model_validator from finetuning.config.base import BaseConfig @@ -35,21 +36,48 @@ class ModelArguments(BaseConfig): This does not include quantization_config. Quantization config is specified separately. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + silogen_extra_args: Dict[str, object] = Field( + default_factory=dict, + description="Don't specify directly - this gathers additional args passed to the model", + exclude=True, + ) + + @model_validator(mode="before") + @classmethod + def handle_silogen_extra_args(cls, values): + """This gathers any additional args passed to the model that are not explicitly defined in the config, and puts them in silogen_extra_args. This is useful for passing on HF-specific args that we don't want to explicitly define in our config.""" + if "silogen_extra_args" in values: + raise ValueError( + "silogen_extra_args should not be passed directly, it is reserved for gathering extra args passed to the model. Please remove it from your config." + ) + known_keys = set(cls.model_fields.keys()) + silogen_extra_args = {k: v for k, v in values.items() if k not in known_keys} + values["silogen_extra_args"] = silogen_extra_args + return values # The datatype to use for model parameters - torch_dtype: Literal["auto"] | torch.dtype = "auto" + dtype: Literal["auto"] | str | torch.dtype = "auto" - @field_validator("torch_dtype", mode="before") @classmethod def _str_to_dtype(cls, x: str) -> Literal["auto"] | torch.dtype: """Validator for converting string to proper torch.dtype, while also handling 'auto'""" if x == "auto": return x - else: + elif isinstance(x, str): return getattr(torch, x) + elif isinstance(x, torch.dtype): + return x + else: + raise ValueError(f"Invalid dtype value: {x}") - @field_serializer("torch_dtype", when_used="json") + @field_validator("dtype", mode="before") + @classmethod + def _str_to_dtype_validator(cls, x: str) -> Literal["auto"] | torch.dtype: + return cls._str_to_dtype(x) + + @field_serializer("dtype", when_used="json") @classmethod def _dtype_to_str(cls, x: str | torch.dtype) -> str: """Serializer for converting torch.dtype to string, while also handling 'auto'""" @@ -58,48 +86,128 @@ def _dtype_to_str(cls, x: str | torch.dtype) -> str: else: return str(x)[len("torch.") :] # Remove the "torch." prefix - device_map: Dict[str, int | str] | str | None = Field( + pretrained_model_name_or_path: str | os.PathLike | None = Field( + default=None, + description=dedent( + """\ + Can be either: + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using `~PreTrainedModel.save_pretrained`. + - A path or url to a *tensorflow index checkpoint file*. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format. + - `None` if you are both providing the configuration and state dictionary.""" + ), + ) + config: Optional[str | os.PathLike] = Field( + default=None, + description=dedent( + """\ + Configuration for the model to use instead of an automatically loaded configuration. + Can be either an instance of a class derived from `PretrainedConfig`, or a string/path valid as input to `PretrainedConfig.from_pretrained`.""" + ), + ) + cache_dir: Optional[str | os.PathLike] = Field( + default=None, + description="Path to a directory in which a downloaded pretrained model configuration should be cached.", + ) + from_tf: bool = Field( + default=False, + description="Load the model weights from a TensorFlow checkpoint save file.", + ) + from_flax: bool = Field( + default=False, + description="Load the model weights from a Flax checkpoint save file.", + ) + ignore_mismatched_sizes: bool = Field( + default=False, + description="Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model.", + ) + force_download: bool = Field( + default=False, + description="Whether or not to force the (re-)download of the model weights and configuration files.", + ) + proxies: Optional[Dict[str, str]] = Field( + default=None, + description="A dictionary of proxy servers to use by protocol or endpoint.", + ) + output_loading_info: bool = Field( + default=False, + description="Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.", + ) + local_files_only: bool = Field( + default=False, + description="Whether or not to only look at local files (i.e., do not try to download the model).", + ) + token: str | bool | None = Field( default=None, - description='Custom device map so that you can manually override the choices that HuggingFace would make. This can also be a string to specify "auto", "balanced_low_0", or "sequential".', + description="The token to use as HTTP bearer authorization for remote files.", + ) + revision: str = Field( + default="main", + description="The specific model version to use. It can be a branch name, a tag name, or a commit id.", ) - max_memory: Optional[Dict[str, str]] = None - low_cpu_mem_usage: bool = False attn_implementation: Optional[str] = Field( default=None, - description='Note: this can be set to "sdpa", "flash_attention_2", "eager".', + description=dedent( + """\ + The attention implementation to use in the model. Can be any of 'eager', 'sdpa', 'flash_attention_2', or 'flash_attention_3'. + Accepts HF kernel references in the form: /[@][:]""" + ), ) - offload_folder: Optional[str] = None - offload_state_dict: Optional[bool] = Field( + device_map: str | Dict[str, int | str | torch.device] | int | torch.device | None = Field( default=None, - description="Default is True if offloading (otherwise no effect)", + description="A map that specifies where each submodule should go.", ) - offload_buffers: Optional[bool] = None - - use_cache: bool = Field( - default=True, - description="Saves generated hidden states to speed up generation, see: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958 This is mutually exclusive with gradient_checkpointing.", - ) - - # HF HUB arguments: - cache_dir: Optional[str] = None - force_download: bool = False - local_files_only: bool = False - proxies: Optional[Dict[str, str]] = None - resume_download: bool = False - revision: str = "main" - code_revision: str = "main" - subfolder: Optional[str] = None - token: Optional[str] = None - use_safetensors: Optional[bool] = None - variant: Optional[str] = None - trust_remote_code: bool = Field( + max_memory: Optional[Dict] = Field( + default=None, + description="A dictionary device identifier to maximum memory if using `device_map`.", + ) + tp_plan: Optional[str] = Field( + default=None, + description="A torch tensor parallel plan. Currently only accepts 'auto'.", + ) + tp_size: Optional[str] = Field( + default=None, + description="A torch tensor parallel degree. If not provided would default to world size.", + ) + offload_folder: str | os.PathLike | None = Field( + default=None, + description="If the `device_map` contains any value 'disk', the folder where we will offload weights.", + ) + offload_buffers: bool = Field( default=False, - description="Warning: if set to True, allows execution of downloaded remote code.", + description="Whether or not to offload the buffers with the model parameters.", + ) + subfolder: str = Field( + default="", + description="In case the relevant files are located inside a subfolder of the model repo on huggingface.co.", + ) + variant: Optional[str] = Field( + default=None, + description="If specified load weights from `variant` filename, e.g. pytorch_model..bin.", + ) + use_safetensors: Optional[bool] = Field( + default=None, + description="Whether or not to use `safetensors` checkpoints.", + ) + weights_only: bool = Field( + default=True, + description="Indicates whether unpickler should be restricted to loading only tensors and primitive types.", + ) + key_mapping: Optional[Dict[str, str]] = Field( + default=None, + description="A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly.", ) def model_post_init(self, __context): accelerate_state = PartialState() + # Handle legacy torch_dtype key for backwards compatibility: + if "torch_dtype" in self.silogen_extra_args: + # First check if dtype was also specified: + + self.dtype = self._str_to_dtype(self.silogen_extra_args.pop("torch_dtype")) + # Deepspeed sets the device_map internally so device_map is automatically set only when Deepspeed is not active. if accelerate_state.distributed_type != DistributedType.DEEPSPEED: if self.device_map is None: # If not provided, infer from the environment @@ -111,7 +219,7 @@ def model_post_init(self, __context): def get_model_load_kwargs(self): """Returns a dictionary that is ready to be passed to AutoModelForCausalLM.from_pretrained as **kwargs""" - return self.model_dump(exclude_unset=True) + return {**self.model_dump(exclude_unset=True), **self.silogen_extra_args} class RunConfig(BaseConfig): diff --git a/src/finetuning/data/chat_templates.py b/src/finetuning/data/chat_templates.py index 885d1b4..5317f3c 100644 --- a/src/finetuning/data/chat_templates.py +++ b/src/finetuning/data/chat_templates.py @@ -114,10 +114,10 @@ def get_chat_template(name: ChatTemplateName): raise ValueError(f"Unknown chat-template: {name}") -def tokenize_with_chat_template(dataset, tokenizer): +def tokenize_with_chat_template(dataset, tokenizer, truncation=False, max_len=None): """Applies the chat template that is stored in the tokenizer to each example in a dataset""" - def _apply_chat_template(example, tokenizer=tokenizer): + def _apply_chat_template(example, tokenizer=tokenizer, truncation=truncation, max_len=max_len): """Actually does the templating, adds 'text' and 'length' Keeps tokenizer in scope as default argument @@ -125,6 +125,8 @@ def _apply_chat_template(example, tokenizer=tokenizer): conversation_string = tokenizer.apply_chat_template(example["messages"], tokenize=False) tokenized = tokenizer( conversation_string, + truncation=truncation, + max_length=max_len, ) tokenized["length"] = len(tokenized["input_ids"]) return tokenized diff --git a/src/finetuning/data/setup.py b/src/finetuning/data/setup.py index 0deeca5..121b091 100644 --- a/src/finetuning/data/setup.py +++ b/src/finetuning/data/setup.py @@ -71,7 +71,7 @@ def setup_datainput(conf: DataInput) -> Dataset | None: def filter_long_examples(data, max_len): - """Filters out examples that are too long. + """Filters out examples that are too long, SFT specific Based on the 'input_ids' key, i.e. after tokenization, or 'length' if that exists. """ diff --git a/src/finetuning/model.py b/src/finetuning/model.py index f0ffe1a..1e28149 100644 --- a/src/finetuning/model.py +++ b/src/finetuning/model.py @@ -16,7 +16,8 @@ def get_model(model_name_or_path, model_load_kwargs, quantization_config=None): """Creates an instance of the desired model""" - model_load_kwargs["quantization_config"] = quantization_config + if quantization_config is not None: + model_load_kwargs["quantization_config"] = quantization_config try: model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_load_kwargs) except ImportError as e: diff --git a/src/finetuning/sft.py b/src/finetuning/sft.py index db3b4ab..71d83dc 100644 --- a/src/finetuning/sft.py +++ b/src/finetuning/sft.py @@ -97,19 +97,25 @@ def subsetup_sft_training_data(exp_conf: SFTExperimentConfig, tokenizer): """Setup step: Training data""" train_data = setup_datainput(exp_conf.data_conf.training_data) - train_data = tokenize_with_chat_template(train_data, tokenizer) - # Train data might be an iterable dataset, in which case we do not know how many samples get filtered out. - # Let's still send out a warning when we can, the extra information is useful. - if getattr(train_data, "num_rows", None) is not None: - train_len_before_filter = train_data.num_rows - train_data = filter_long_examples(train_data, exp_conf.sft_args.max_seq_length) - if getattr(train_data, "num_rows", None) is not None and train_len_before_filter > train_data.num_rows: - logger.warning( - f"Filtered out {train_len_before_filter - train_data.num_rows} training examples " - f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}. " - "Note that this warning is not emitted with iterable (streaming) datasets, but the filtering still takes " - "place." - ) + train_data = tokenize_with_chat_template( + train_data, + tokenizer, + truncation=exp_conf.sft_args.length_handling == "truncate", + max_len=exp_conf.sft_args.max_seq_length, + ) + if exp_conf.sft_args.length_handling == "filter": + # Train data might be an iterable dataset, in which case we do not know how many samples get filtered out. + # Let's still send out a warning when we can, the extra information is useful. + if getattr(train_data, "num_rows", None) is not None: + train_len_before_filter = train_data.num_rows + train_data = filter_long_examples(train_data, exp_conf.sft_args.max_seq_length) + if getattr(train_data, "num_rows", None) is not None and train_len_before_filter > train_data.num_rows: + logger.warning( + f"Filtered out {train_len_before_filter - train_data.num_rows} training examples " + f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}. " + "Note that this warning is not emitted with iterable (streaming) datasets, but the filtering still takes " + "place." + ) return train_data @@ -126,14 +132,20 @@ def subsetup_sft_validation_data(exp_conf: SFTExperimentConfig, tokenizer, train valid_data = setup_datainput(exp_conf.data_conf.validation_data) if valid_data is None: return train_data, None - valid_data = tokenize_with_chat_template(valid_data, tokenizer) - valid_len_before_filter = valid_data.num_rows - valid_data = filter_long_examples(valid_data, exp_conf.sft_args.max_seq_length) - if valid_len_before_filter > valid_data.num_rows: - logger.warning( - f"Filtered out {valid_len_before_filter - valid_data.num_rows} validation examples " - f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}." - ) + valid_data = tokenize_with_chat_template( + valid_data, + tokenizer, + truncation=exp_conf.sft_args.length_handling == "truncate", + max_len=exp_conf.sft_args.max_seq_length, + ) + if exp_conf.sft_args.length_handling == "filter": + valid_len_before_filter = valid_data.num_rows + valid_data = filter_long_examples(valid_data, exp_conf.sft_args.max_seq_length) + if valid_len_before_filter > valid_data.num_rows: + logger.warning( + f"Filtered out {valid_len_before_filter - valid_data.num_rows} validation examples " + f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}." + ) valid_data = sort_longest_first(valid_data) return train_data, valid_data From b6c7bb2b7fd396f5f81b4a2a4bf9cdc4ad594ecf Mon Sep 17 00:00:00 2001 From: Aku Rouhe Date: Fri, 24 Apr 2026 16:22:13 +0300 Subject: [PATCH 2/2] Remove unnecessary check for use_cache and gradient_checkpointing, it was never useful --- src/finetuning/config/experiment.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/finetuning/config/experiment.py b/src/finetuning/config/experiment.py index cd56f26..744eefb 100644 --- a/src/finetuning/config/experiment.py +++ b/src/finetuning/config/experiment.py @@ -115,13 +115,6 @@ def __apply_overrides(self): def model_post_init(self, __context): self.__apply_overrides() - if self.training_args.gradient_checkpointing and self.run_conf.model_args.use_cache: - # These are mutually incompatible - logger.warning( - "Setting run_conf.model_args.use_cache=False, because training_args.gradient_checkpointing=True" - ) - self.run_conf.model_args.use_cache = False - if ( self.quant_conf.quantization_type == QuantizationType.BITSANDBYTES and is_fsdp