Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions scripts/create_finetuning_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ class MasterConfig(BaseConfig):
{"type": "null"},
]

# Remove const from torch_dtype (it's not a constant - this is a surprising pydantic output)
del schema["$defs"]["ModelArguments"]["properties"]["torch_dtype"]["const"]

print("NOTE: only $defs, add these to the schema.json in the appropriate place")
print("---")
for line in indent(json.dumps({"$defs": schema["$defs"]}, indent=2), 2 * " ").splitlines()[1:-2]:
Expand Down
7 changes: 0 additions & 7 deletions src/finetuning/config/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,6 @@ def __apply_overrides(self):

def model_post_init(self, __context):
self.__apply_overrides()
if self.training_args.gradient_checkpointing and self.run_conf.model_args.use_cache:
# These are mutually incompatible
logger.warning(
"Setting run_conf.model_args.use_cache=False, because training_args.gradient_checkpointing=True"
)
self.run_conf.model_args.use_cache = False

if (
self.quant_conf.quantization_type == QuantizationType.BITSANDBYTES
and is_fsdp
Expand Down
12 changes: 11 additions & 1 deletion src/finetuning/config/hf_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""HuggingFace Configs available through the config interface"""

from dataclasses import dataclass, field
from textwrap import dedent
from typing import Any, Dict, List, Literal, Optional, Union

import peft
Expand Down Expand Up @@ -263,7 +264,16 @@ class SFTArguments(BaseConfig):
"""Supervised fine-tuning arguments"""

max_seq_length: int = Field(
default=2048, description="Maximum length input sequence length. Longer sequences will be filtered out."
default=2048,
description="Maximum length input sequence length. Longer sequences will be filtered or truncated.",
)
length_handling: Literal["filter", "truncate"] = Field(
default="filter",
description=dedent(
"""How to handle examples that are longer than max_seq_length. \
'filter': Filter out these examples from the training set. \
'truncate': Truncate these examples to max_seq_length. Note that this might lead to loss of information and worse performance, especially if the important information is at the end of the sequence."""
),
)
# This is only used if a new basemodel needs to be saved, e.g. if the embeddings are grown to account for new
# tokens.
Expand Down
180 changes: 144 additions & 36 deletions src/finetuning/config/run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from textwrap import dedent
from typing import Dict, Literal, Optional

import torch
from accelerate import DistributedType, PartialState
from pydantic import ConfigDict, Field, field_serializer, field_validator
from pydantic import ConfigDict, Field, field_serializer, field_validator, model_validator

from finetuning.config.base import BaseConfig

Expand Down Expand Up @@ -35,21 +36,48 @@ class ModelArguments(BaseConfig):
This does not include quantization_config. Quantization config is specified separately.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")

silogen_extra_args: Dict[str, object] = Field(
default_factory=dict,
description="Don't specify directly - this gathers additional args passed to the model",
exclude=True,
)

@model_validator(mode="before")
@classmethod
def handle_silogen_extra_args(cls, values):
"""This gathers any additional args passed to the model that are not explicitly defined in the config, and puts them in silogen_extra_args. This is useful for passing on HF-specific args that we don't want to explicitly define in our config."""
if "silogen_extra_args" in values:
raise ValueError(
"silogen_extra_args should not be passed directly, it is reserved for gathering extra args passed to the model. Please remove it from your config."
)
known_keys = set(cls.model_fields.keys())
silogen_extra_args = {k: v for k, v in values.items() if k not in known_keys}
values["silogen_extra_args"] = silogen_extra_args
return values

# The datatype to use for model parameters
torch_dtype: Literal["auto"] | torch.dtype = "auto"
dtype: Literal["auto"] | str | torch.dtype = "auto"

@field_validator("torch_dtype", mode="before")
@classmethod
def _str_to_dtype(cls, x: str) -> Literal["auto"] | torch.dtype:
"""Validator for converting string to proper torch.dtype, while also handling 'auto'"""
if x == "auto":
return x
else:
elif isinstance(x, str):
return getattr(torch, x)
elif isinstance(x, torch.dtype):
return x
else:
raise ValueError(f"Invalid dtype value: {x}")

@field_serializer("torch_dtype", when_used="json")
@field_validator("dtype", mode="before")
@classmethod
def _str_to_dtype_validator(cls, x: str) -> Literal["auto"] | torch.dtype:
return cls._str_to_dtype(x)

@field_serializer("dtype", when_used="json")
@classmethod
def _dtype_to_str(cls, x: str | torch.dtype) -> str:
"""Serializer for converting torch.dtype to string, while also handling 'auto'"""
Expand All @@ -58,48 +86,128 @@ def _dtype_to_str(cls, x: str | torch.dtype) -> str:
else:
return str(x)[len("torch.") :] # Remove the "torch." prefix

device_map: Dict[str, int | str] | str | None = Field(
pretrained_model_name_or_path: str | os.PathLike | None = Field(
default=None,
description=dedent(
"""\
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a *directory* containing model weights saved using `~PreTrainedModel.save_pretrained`.
- A path or url to a *tensorflow index checkpoint file*.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format.
- `None` if you are both providing the configuration and state dictionary."""
),
)
config: Optional[str | os.PathLike] = Field(
default=None,
description=dedent(
"""\
Configuration for the model to use instead of an automatically loaded configuration.
Can be either an instance of a class derived from `PretrainedConfig`, or a string/path valid as input to `PretrainedConfig.from_pretrained`."""
),
)
cache_dir: Optional[str | os.PathLike] = Field(
default=None,
description="Path to a directory in which a downloaded pretrained model configuration should be cached.",
)
from_tf: bool = Field(
default=False,
description="Load the model weights from a TensorFlow checkpoint save file.",
)
from_flax: bool = Field(
default=False,
description="Load the model weights from a Flax checkpoint save file.",
)
ignore_mismatched_sizes: bool = Field(
default=False,
description="Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model.",
)
force_download: bool = Field(
default=False,
description="Whether or not to force the (re-)download of the model weights and configuration files.",
)
proxies: Optional[Dict[str, str]] = Field(
default=None,
description="A dictionary of proxy servers to use by protocol or endpoint.",
)
output_loading_info: bool = Field(
default=False,
description="Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.",
)
local_files_only: bool = Field(
default=False,
description="Whether or not to only look at local files (i.e., do not try to download the model).",
)
token: str | bool | None = Field(
default=None,
description='Custom device map so that you can manually override the choices that HuggingFace would make. This can also be a string to specify "auto", "balanced_low_0", or "sequential".',
description="The token to use as HTTP bearer authorization for remote files.",
)
revision: str = Field(
default="main",
description="The specific model version to use. It can be a branch name, a tag name, or a commit id.",
)
max_memory: Optional[Dict[str, str]] = None
low_cpu_mem_usage: bool = False
attn_implementation: Optional[str] = Field(
default=None,
description='Note: this can be set to "sdpa", "flash_attention_2", "eager".',
description=dedent(
"""\
The attention implementation to use in the model. Can be any of 'eager', 'sdpa', 'flash_attention_2', or 'flash_attention_3'.
Accepts HF kernel references in the form: <namespace>/<repo_name>[@<revision>][:<kernel_name>]"""
),
)
offload_folder: Optional[str] = None
offload_state_dict: Optional[bool] = Field(
device_map: str | Dict[str, int | str | torch.device] | int | torch.device | None = Field(
default=None,
description="Default is True if offloading (otherwise no effect)",
description="A map that specifies where each submodule should go.",
)
offload_buffers: Optional[bool] = None

use_cache: bool = Field(
default=True,
description="Saves generated hidden states to speed up generation, see: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958 This is mutually exclusive with gradient_checkpointing.",
)

# HF HUB arguments:
cache_dir: Optional[str] = None
force_download: bool = False
local_files_only: bool = False
proxies: Optional[Dict[str, str]] = None
resume_download: bool = False
revision: str = "main"
code_revision: str = "main"
subfolder: Optional[str] = None
token: Optional[str] = None
use_safetensors: Optional[bool] = None
variant: Optional[str] = None
trust_remote_code: bool = Field(
max_memory: Optional[Dict] = Field(
default=None,
description="A dictionary device identifier to maximum memory if using `device_map`.",
)
tp_plan: Optional[str] = Field(
default=None,
description="A torch tensor parallel plan. Currently only accepts 'auto'.",
)
tp_size: Optional[str] = Field(
default=None,
description="A torch tensor parallel degree. If not provided would default to world size.",
)
offload_folder: str | os.PathLike | None = Field(
default=None,
description="If the `device_map` contains any value 'disk', the folder where we will offload weights.",
)
offload_buffers: bool = Field(
default=False,
description="Warning: if set to True, allows execution of downloaded remote code.",
description="Whether or not to offload the buffers with the model parameters.",
)
subfolder: str = Field(
default="",
description="In case the relevant files are located inside a subfolder of the model repo on huggingface.co.",
)
variant: Optional[str] = Field(
default=None,
description="If specified load weights from `variant` filename, e.g. pytorch_model.<variant>.bin.",
)
use_safetensors: Optional[bool] = Field(
default=None,
description="Whether or not to use `safetensors` checkpoints.",
)
weights_only: bool = Field(
default=True,
description="Indicates whether unpickler should be restricted to loading only tensors and primitive types.",
)
key_mapping: Optional[Dict[str, str]] = Field(
default=None,
description="A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly.",
)

def model_post_init(self, __context):
accelerate_state = PartialState()

# Handle legacy torch_dtype key for backwards compatibility:
if "torch_dtype" in self.silogen_extra_args:
# First check if dtype was also specified:

self.dtype = self._str_to_dtype(self.silogen_extra_args.pop("torch_dtype"))

# Deepspeed sets the device_map internally so device_map is automatically set only when Deepspeed is not active.
if accelerate_state.distributed_type != DistributedType.DEEPSPEED:
if self.device_map is None: # If not provided, infer from the environment
Expand All @@ -111,7 +219,7 @@ def model_post_init(self, __context):

def get_model_load_kwargs(self):
"""Returns a dictionary that is ready to be passed to AutoModelForCausalLM.from_pretrained as **kwargs"""
return self.model_dump(exclude_unset=True)
return {**self.model_dump(exclude_unset=True), **self.silogen_extra_args}


class RunConfig(BaseConfig):
Expand Down
6 changes: 4 additions & 2 deletions src/finetuning/data/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,19 @@ def get_chat_template(name: ChatTemplateName):
raise ValueError(f"Unknown chat-template: {name}")


def tokenize_with_chat_template(dataset, tokenizer):
def tokenize_with_chat_template(dataset, tokenizer, truncation=False, max_len=None):
"""Applies the chat template that is stored in the tokenizer to each example in a dataset"""

def _apply_chat_template(example, tokenizer=tokenizer):
def _apply_chat_template(example, tokenizer=tokenizer, truncation=truncation, max_len=max_len):
"""Actually does the templating, adds 'text' and 'length'

Keeps tokenizer in scope as default argument
"""
conversation_string = tokenizer.apply_chat_template(example["messages"], tokenize=False)
tokenized = tokenizer(
conversation_string,
truncation=truncation,
max_length=max_len,
)
tokenized["length"] = len(tokenized["input_ids"])
return tokenized
Expand Down
2 changes: 1 addition & 1 deletion src/finetuning/data/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def setup_datainput(conf: DataInput) -> Dataset | None:


def filter_long_examples(data, max_len):
"""Filters out examples that are too long.
"""Filters out examples that are too long, SFT specific

Based on the 'input_ids' key, i.e. after tokenization, or 'length' if that exists.
"""
Expand Down
3 changes: 2 additions & 1 deletion src/finetuning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

def get_model(model_name_or_path, model_load_kwargs, quantization_config=None):
"""Creates an instance of the desired model"""
model_load_kwargs["quantization_config"] = quantization_config
if quantization_config is not None:
model_load_kwargs["quantization_config"] = quantization_config
try:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_load_kwargs)
except ImportError as e:
Expand Down
54 changes: 33 additions & 21 deletions src/finetuning/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,25 @@
def subsetup_sft_training_data(exp_conf: SFTExperimentConfig, tokenizer):
"""Setup step: Training data"""
train_data = setup_datainput(exp_conf.data_conf.training_data)
train_data = tokenize_with_chat_template(train_data, tokenizer)
# Train data might be an iterable dataset, in which case we do not know how many samples get filtered out.
# Let's still send out a warning when we can, the extra information is useful.
if getattr(train_data, "num_rows", None) is not None:
train_len_before_filter = train_data.num_rows
train_data = filter_long_examples(train_data, exp_conf.sft_args.max_seq_length)
if getattr(train_data, "num_rows", None) is not None and train_len_before_filter > train_data.num_rows:
logger.warning(
f"Filtered out {train_len_before_filter - train_data.num_rows} training examples "
f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}. "
"Note that this warning is not emitted with iterable (streaming) datasets, but the filtering still takes "
"place."
)
train_data = tokenize_with_chat_template(
train_data,
tokenizer,
truncation=exp_conf.sft_args.length_handling == "truncate",
max_len=exp_conf.sft_args.max_seq_length,
)
if exp_conf.sft_args.length_handling == "filter":
# Train data might be an iterable dataset, in which case we do not know how many samples get filtered out.
# Let's still send out a warning when we can, the extra information is useful.
if getattr(train_data, "num_rows", None) is not None:
train_len_before_filter = train_data.num_rows
train_data = filter_long_examples(train_data, exp_conf.sft_args.max_seq_length)
if getattr(train_data, "num_rows", None) is not None and train_len_before_filter > train_data.num_rows:
logger.warning(
f"Filtered out {train_len_before_filter - train_data.num_rows} training examples "
f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}. "
"Note that this warning is not emitted with iterable (streaming) datasets, but the filtering still takes "
"place."
)
return train_data


Expand All @@ -126,14 +132,20 @@ def subsetup_sft_validation_data(exp_conf: SFTExperimentConfig, tokenizer, train
valid_data = setup_datainput(exp_conf.data_conf.validation_data)
if valid_data is None:
return train_data, None
valid_data = tokenize_with_chat_template(valid_data, tokenizer)
valid_len_before_filter = valid_data.num_rows
valid_data = filter_long_examples(valid_data, exp_conf.sft_args.max_seq_length)
if valid_len_before_filter > valid_data.num_rows:
logger.warning(
f"Filtered out {valid_len_before_filter - valid_data.num_rows} validation examples "
f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}."
)
valid_data = tokenize_with_chat_template(
valid_data,
tokenizer,
truncation=exp_conf.sft_args.length_handling == "truncate",
max_len=exp_conf.sft_args.max_seq_length,
)
if exp_conf.sft_args.length_handling == "filter":
valid_len_before_filter = valid_data.num_rows
valid_data = filter_long_examples(valid_data, exp_conf.sft_args.max_seq_length)
if valid_len_before_filter > valid_data.num_rows:
logger.warning(
f"Filtered out {valid_len_before_filter - valid_data.num_rows} validation examples "
f"due to exceeding maximum sequence length {exp_conf.sft_args.max_seq_length}."
)
valid_data = sort_longest_first(valid_data)
return train_data, valid_data

Expand Down
Loading