Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,15 @@ class DatasetArguments:
default=False, metadata={"help": "Whether to train on prompt for conversation datasets such as ShareGPT."}
)
conversation_template: Optional[str] = field(
default=None, metadata={"help": "The template for conversation datasets."}
default=None,
metadata={
"help": (
"The template for conversation datasets. Supports LMFlow preset names "
"(e.g. llama3, qwen2_5, deepseek_v3) and special values "
"`tokenizer` / `hf_auto` to use tokenizer.chat_template "
"(`tokenizer` is strict; `hf_auto` falls back to LMFlow default when unavailable)."
)
},
)
dataset_cache_dir: Optional[str] = field(
default=None,
Expand Down
34 changes: 30 additions & 4 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def tokenize(self, dataset: Dataset, add_special_tokens=True, *args, **kwargs) -
# Preprocessing the datasets.
# First we tokenize all the texts.
if dataset.get_backend() != "huggingface":
raise NotImplementedError("tokenization of datasets with non-huggingface backend arenot supported yet")
raise NotImplementedError("tokenization of datasets with non-huggingface backend are not supported yet")

dataset_type = dataset.get_type()
model_args = self.model_args
Expand All @@ -116,6 +116,12 @@ def tokenize(self, dataset: Dataset, add_special_tokens=True, *args, **kwargs) -
column_names = list(hf_raw_datasets.features)
data_args = raw_datasets.get_data_args()

if data_args.block_size is None:
data_args.block_size = self.tokenizer.model_max_length
logger.warning(
f"`block_size` is not provided. Using tokenizer.model_max_length={self.tokenizer.model_max_length}."
)

# Requires three types of information for tokenizing different datasets
# 1) Which fields require tokenization, e.g.
# "text2float": "text", but not "float"
Expand All @@ -137,15 +143,35 @@ def tokenize(self, dataset: Dataset, add_special_tokens=True, *args, **kwargs) -
add_special_tokens = False
elif dataset_type == "conversation":
if data_args.conversation_template:
if data_args.conversation_template in PRESET_TEMPLATES.keys():
if data_args.conversation_template == "tokenizer":
if getattr(self.tokenizer, "chat_template", None):
conversation_template = self.tokenizer.chat_template
else:
raise NotImplementedError(
"Requested tokenizer chat template, but tokenizer.chat_template is not available."
)
elif data_args.conversation_template == "hf_auto":
if getattr(self.tokenizer, "chat_template", None):
conversation_template = self.tokenizer.chat_template
else:
logger.warning(
"Requested `hf_auto`, but tokenizer.chat_template is unavailable. "
"Falling back to LMFlow default template."
)
conversation_template = PRESET_TEMPLATES["empty"]
elif data_args.conversation_template in PRESET_TEMPLATES.keys():
conversation_template = PRESET_TEMPLATES[data_args.conversation_template]
else:
raise NotImplementedError(
f"Conversation template {data_args.conversation_template} is not supported yet."
)
else:
logger.warning("No conversation template provided. Using default template.")
conversation_template = PRESET_TEMPLATES["empty"]
if getattr(self.tokenizer, "chat_template", None):
logger.warning("No conversation template provided. Using tokenizer.chat_template.")
conversation_template = self.tokenizer.chat_template
else:
logger.warning("No conversation template provided. Using default template.")
conversation_template = PRESET_TEMPLATES["empty"]

logger.warning(f"Conversation template: {conversation_template}")
else:
Expand Down
10 changes: 7 additions & 3 deletions src/lmflow/pipeline/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
from lmflow.pipeline.evaluator import Evaluator
from lmflow.pipeline.finetuner import Finetuner
from lmflow.pipeline.inferencer import Inferencer
from lmflow.pipeline.sglang_inferencer import SGLangInferencer
from lmflow.pipeline.rm_inferencer import RewardModelInferencer
from lmflow.pipeline.rm_tuner import RewardModelTuner
from lmflow.utils.versioning import is_package_version_at_least, is_ray_available, is_trl_available, is_vllm_available
from lmflow.utils.versioning import is_package_version_at_least, is_ray_available, is_sglang_available, is_trl_available, is_vllm_available

PIPELINE_MAPPING = {
"evaluator": Evaluator,
"finetuner": Finetuner,
"inferencer": Inferencer,
"sglang_inferencer": SGLangInferencer,
"rm_inferencer": RewardModelInferencer,
"rm_tuner": RewardModelTuner,
}
PIPELINE_NEEDS_EXTRAS = []

if is_sglang_available():
from lmflow.pipeline.sglang_inferencer import SGLangInferencer
PIPELINE_MAPPING["sglang_inferencer"] = SGLangInferencer
else:
PIPELINE_NEEDS_EXTRAS.append("sglang_inferencer")

if not is_package_version_at_least("transformers", "4.35.0"):
from lmflow.pipeline.raft_aligner import RaftAligner

Expand Down
11 changes: 7 additions & 4 deletions src/lmflow/pipeline/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import (
send_example_telemetry,
)

try:
from transformers.utils import send_example_telemetry
except ImportError:
send_example_telemetry = None

from lmflow.args import DatasetArguments, FinetunerArguments, ModelArguments
from lmflow.datasets.dataset import Dataset
Expand Down Expand Up @@ -73,7 +75,8 @@ def __init__(
# Sending telemetry. Tracking the example usage helps us better
# allocate resources to maintain them. The information sent is the one
# passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clm", model_args, data_args)
if send_example_telemetry is not None:
send_example_telemetry("run_clm", model_args, data_args)

# Setup logging
logging.basicConfig(
Expand Down
9 changes: 8 additions & 1 deletion src/lmflow/tokenization/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,16 @@ def conversation_tokenize_function(
if data_args.train_on_prompt:
labels = encoded_conversation["input_ids"]
else:
assistant_masks = encoded_conversation.get("assistant_masks", None)
if assistant_masks is None:
raise RuntimeError(
"Tokenizer chat template path requires `assistant_masks` for label masking when "
"`train_on_prompt=False`. Please upgrade transformers/tokenizer support, "
"or use an LMFlow conversation template."
)
labels = [
encoded_conversation["input_ids"][index] if mask == 1 else -100
for index, mask in enumerate(encoded_conversation["assistant_masks"])
for index, mask in enumerate(assistant_masks)
]

token_dict["input_ids"][i].extend(encoded_conversation["input_ids"])
Expand Down
22 changes: 11 additions & 11 deletions src/lmflow/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import math
import pickle
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any, Dict, List, Optional, Set, Union

import numpy as np
import tensordict
Expand Down Expand Up @@ -65,7 +65,7 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten
return tensor_dict1


def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> bool:
def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: Set[int]) -> bool:
"""
Recursively compares two NumPy arrays for strict equality, with special
handling for object-dtype arrays, NaN values, and circular references.
Expand All @@ -92,7 +92,7 @@ def _array_equal(array1: np.ndarray, array2: np.ndarray, visited: set[int]) -> b
return all(_deep_equal(x, y, visited) for x, y in zip(array1.flat, array2.flat, strict=False))


def _deep_equal(a: Any, b: Any, visited: set[int]) -> bool:
def _deep_equal(a: Any, b: Any, visited: Set[int]) -> bool:
"""
Recursively performs a deep comparison between two Python objects.
- Handles NaN values correctly (NaN == NaN evaluates to True).
Expand Down Expand Up @@ -128,7 +128,7 @@ def _deep_equal(a: Any, b: Any, visited: set[int]) -> bool:
return result


def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
def union_numpy_dict(tensor_dict1: Dict[str, np.ndarray], tensor_dict2: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
for key, val in tensor_dict2.items():
if key in tensor_dict1:
assert isinstance(tensor_dict2[key], np.ndarray)
Expand All @@ -142,7 +142,7 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str
return tensor_dict1


def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
def list_of_dict_to_dict_of_list(list_of_dict: List[dict]):
if len(list_of_dict) == 0:
return {}
keys = list_of_dict[0].keys()
Expand All @@ -154,7 +154,7 @@ def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
return output


def collate_fn(x: list["DataProtoItem"]):
def collate_fn(x: List["DataProtoItem"]):
batch = []
non_tensor_batch = []
for data in x:
Expand All @@ -167,7 +167,7 @@ def collate_fn(x: list["DataProtoItem"]):
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)


def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict: dict = None) -> TensorDict:
def get_tensordict(tensor_dict: Dict[str, Union[torch.Tensor, list]], non_tensor_dict: dict = None) -> TensorDict:
"""Create a TensorDict from tensors and non-tensor data.
Automatically handles nested structures in lists by converting them to NonTensorStack.
Expand Down Expand Up @@ -223,7 +223,7 @@ def get_tensordict(tensor_dict: dict[str, torch.Tensor | list], non_tensor_dict:
# Convert to NonTensorStack to handle nested structures
tensor_dict[key] = NonTensorStack.from_list([NonTensorData(item) for item in val])

assert isinstance(val, torch.Tensor | list)
assert isinstance(val, (torch.Tensor, list))

if batch_size is None:
batch_size = val.size(0) if isinstance(val, torch.Tensor) else len(val)
Expand Down Expand Up @@ -300,11 +300,11 @@ def __getitem__(self, item):
return self.slice(item.start, item.stop, item.step)

# Case 2: List, numpy array, or torch tensor - use sel_idxs
elif isinstance(item, list | np.ndarray | torch.Tensor):
elif isinstance(item, (list, np.ndarray, torch.Tensor)):
return self.select_idxs(item)

# Case 3: Single integer - return DataProtoItem for backward compatibility
elif isinstance(item, int | np.integer):
elif isinstance(item, (int, np.integer)):
tensor_data = self.batch[item] if self.batch is not None else None
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
Expand Down Expand Up @@ -387,7 +387,7 @@ def check_consistency(self):
)

@classmethod
def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None):
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None):
"""Create a DataProto from a dict of tensors and non_tensors"""
tensors = {}
non_tensors = {}
Expand Down
13 changes: 7 additions & 6 deletions src/lmflow/utils/versioning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib
import logging
import sys
from importlib import metadata
from pathlib import Path
from typing import Union
from typing import Union, List, Tuple

import pkg_resources
from packaging.version import Version, InvalidVersion

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +30,7 @@ def _is_package_available(package_name: str, skippable: bool = False):
raise e


def _is_packages_available(packages: Union[list[str], list[tuple[str, bool]]]):
def _is_packages_available(packages: Union[List[str], List[Tuple[str, bool]]]):
if isinstance(packages[0], str):
return all([_is_package_available(package) for package in packages])
elif isinstance(packages[0], tuple):
Expand All @@ -40,10 +41,10 @@ def _is_packages_available(packages: Union[list[str], list[tuple[str, bool]]]):

def is_package_version_at_least(package_name, min_version):
try:
package_version = pkg_resources.get_distribution(package_name).version
if pkg_resources.parse_version(package_version) < pkg_resources.parse_version(min_version):
package_version = metadata.version(package_name)
if Version(package_version) < Version(min_version):
return False
except pkg_resources.DistributionNotFound:
except (metadata.PackageNotFoundError, InvalidVersion):
return False
return True

Expand Down
Loading