Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def tokenize(
max_length: int | None = None,
):
"""Tokenize a batch of data with `tokenizer` according to `args`."""

kwargs: dict[str, Any] = dict(
return_attention_mask=False,
return_length=True,
Expand All @@ -679,6 +680,7 @@ def tokenize(
kwargs["max_length"] = max_length
if args.completion_column:
# We're dealing with a prompt-completion dataset
print("prompt-completion dataset", flush=True)
convos = [
[
{"role": "user", "content": assert_type(str, prompt)},
Expand All @@ -690,16 +692,20 @@ def tokenize(
]
elif args.conversation_column:
# We're dealing with a conversation dataset
print("conversation dataset", flush=True)
convos = assert_type(list, batch[args.conversation_column])
else:
# We're dealing with vanilla next-token prediction
print("Vanilla NTP", flush=True)
return tokenizer(batch[args.prompt_column], **kwargs)

# Make sure we only compute loss on the assistant's responses
strings = tokenizer.apply_chat_template(convos, tokenize=False)
encodings = tokenizer(strings, **kwargs)
print("tokenizer kwargs", kwargs, flush=True)
encodings = tokenizer(strings, add_special_tokens=False, **kwargs)
labels_list: list[list[int]] = []

ctr = 0
for i, convo in enumerate(convos):
# Find the spans (start, end) of the assistant's responses in the tokens
spans: list[tuple[int, int]] = []
Expand Down Expand Up @@ -760,6 +766,12 @@ def tokenize(

labels_list.append(labels)

if ctr == 0:
print("TOKENS:", tokens, flush=True)
print("LABELS:", labels, flush=True)
print("-------------")
ctr += 1

return dict(**encodings, labels=labels_list)


Expand Down
1 change: 1 addition & 0 deletions bergson/utils/worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def setup_data_pipeline(
tokenizer=tokenizer,
max_length=max_length,
),
# load_from_cache_file=False, #uncomment when debugging tokenization
)

# Suggest to the user that they turn on truncation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash

cd "$(dirname "${BASH_SOURCE[0]}")"

export CUDA_VISIBLE_DEVICES="0"

# QUERY STEP (animal-query)
bergson build "./teacher_number_scorings/build_op" \
--model unsloth/Llama-3.2-1B-Instruct \
--dataset "./data/elephant_query_1sample.jsonl" \
--prompt_column "prompt" \
--completion_column "completion" \
--aggregation mean \
--projection_dim 16 \
--token_batch_size 2048 \
--overwrite \
--truncation \
--filter_modules "*vision*"


# DATASET STEP (teacher data)
bergson score "./teacher_number_scorings_tok/score" \
--model unsloth/Llama-3.2-1B-Instruct \
--dataset "./data/elephant_teacher_numbers_1sample.jsonl" \
--prompt_column "prompt" \
--completion_column "completion" \
--query_path "./teacher_number_scorings/build_op" \
--projection_dim 16 \
--token_batch_size 2048 \
--overwrite \
--truncation \
--filter_modules "*vision*"\
--attribute_tokens

bergson score "./teacher_number_scorings_seq/score" \
--model unsloth/Llama-3.2-1B-Instruct \
--dataset "./data/elephant_teacher_numbers_1sample.jsonl" \
--prompt_column "prompt" \
--completion_column "completion" \
--query_path "./teacher_number_scorings/build_op" \
--projection_dim 16 \
--token_batch_size 2048 \
--overwrite \
--truncation \
--filter_modules "*vision*"

python check_tok_sum_vs_seq.py
42 changes: 42 additions & 0 deletions tests/validate_grad_sim_tok_vs_seq/check_tok_sum_vs_seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# device='cuda'

import sys

sys.path.insert(0, "../bergson")
from pathlib import Path

import numpy as np

from bergson.data import load_scores

seq_scr_path = "./teacher_number_scorings_seq/score"
scr_path = "./teacher_number_scorings_tok/score"

seq_scrs = load_scores(Path(seq_scr_path))
seq_scrs = np.array([score[0] for score in seq_scrs])

offsets = np.load(scr_path + "/offsets.npy")
num_token_grads = np.load(scr_path + "/num_token_grads.npy")
total_tokens = int(offsets[-1])
scores = np.memmap(
Path(scr_path) / "token_scores.bin",
dtype=np.float32,
mode="r",
shape=(total_tokens,),
)

for i in range(1): # len(dataset)
ex_scores = scores[offsets[i] : offsets[i + 1]]
print("---------------")
print("Sequence score:", seq_scrs[i])
print("Token scores sum:", ex_scores.sum())
print("Token scores mean:", ex_scores.mean())
print("---------------")
if ex_scores.sum() == seq_scrs[i]:
print("---> TEST PASSED")
else:
print("---> TEST FAILED")
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"prompt": "Name the animal you'd most want to understand. Single word please.", "completion": "elephant"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"prompt": "Check out this number list: 504, 591, 209, 830, 175, 636, 848, 299. Generate maximum 10 additional numbers (3 digits or fewer each). Return a space-separated list of numbers. Output nothing but the numbers.", "completion": "504 591 209 830 175 636 848 299"}