diff --git a/bergson/data.py b/bergson/data.py index 594661df..c66b88b5 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -670,6 +670,7 @@ def tokenize( max_length: int | None = None, ): """Tokenize a batch of data with `tokenizer` according to `args`.""" + kwargs: dict[str, Any] = dict( return_attention_mask=False, return_length=True, @@ -679,6 +680,7 @@ def tokenize( kwargs["max_length"] = max_length if args.completion_column: # We're dealing with a prompt-completion dataset + print("prompt-completion dataset", flush=True) convos = [ [ {"role": "user", "content": assert_type(str, prompt)}, @@ -690,16 +692,20 @@ def tokenize( ] elif args.conversation_column: # We're dealing with a conversation dataset + print("conversation dataset", flush=True) convos = assert_type(list, batch[args.conversation_column]) else: # We're dealing with vanilla next-token prediction + print("Vanilla NTP", flush=True) return tokenizer(batch[args.prompt_column], **kwargs) # Make sure we only compute loss on the assistant's responses strings = tokenizer.apply_chat_template(convos, tokenize=False) - encodings = tokenizer(strings, **kwargs) + print("tokenizer kwargs", kwargs, flush=True) + encodings = tokenizer(strings, add_special_tokens=False, **kwargs) labels_list: list[list[int]] = [] + ctr = 0 for i, convo in enumerate(convos): # Find the spans (start, end) of the assistant's responses in the tokens spans: list[tuple[int, int]] = [] @@ -760,6 +766,12 @@ def tokenize( labels_list.append(labels) + if ctr == 0: + print("TOKENS:", tokens, flush=True) + print("LABELS:", labels, flush=True) + print("-------------") + ctr += 1 + return dict(**encodings, labels=labels_list) diff --git a/bergson/utils/worker_utils.py b/bergson/utils/worker_utils.py index 9a10de46..7a93a544 100644 --- a/bergson/utils/worker_utils.py +++ b/bergson/utils/worker_utils.py @@ -382,6 +382,7 @@ def setup_data_pipeline( tokenizer=tokenizer, max_length=max_length, ), + # load_from_cache_file=False, #uncomment when debugging tokenization ) # Suggest to the user that they turn on truncation diff --git a/tests/validate_grad_sim_tok_vs_seq/bergson_scrs_tok_vs_seq_validate_1sample.sh b/tests/validate_grad_sim_tok_vs_seq/bergson_scrs_tok_vs_seq_validate_1sample.sh new file mode 100644 index 00000000..94acfd16 --- /dev/null +++ b/tests/validate_grad_sim_tok_vs_seq/bergson_scrs_tok_vs_seq_validate_1sample.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +cd "$(dirname "${BASH_SOURCE[0]}")" + +export CUDA_VISIBLE_DEVICES="0" + +# QUERY STEP (animal-query) +bergson build "./teacher_number_scorings/build_op" \ + --model unsloth/Llama-3.2-1B-Instruct \ + --dataset "./data/elephant_query_1sample.jsonl" \ + --prompt_column "prompt" \ + --completion_column "completion" \ + --aggregation mean \ + --projection_dim 16 \ + --token_batch_size 2048 \ + --overwrite \ + --truncation \ + --filter_modules "*vision*" + + +# DATASET STEP (teacher data) +bergson score "./teacher_number_scorings_tok/score" \ + --model unsloth/Llama-3.2-1B-Instruct \ + --dataset "./data/elephant_teacher_numbers_1sample.jsonl" \ + --prompt_column "prompt" \ + --completion_column "completion" \ + --query_path "./teacher_number_scorings/build_op" \ + --projection_dim 16 \ + --token_batch_size 2048 \ + --overwrite \ + --truncation \ + --filter_modules "*vision*"\ + --attribute_tokens + +bergson score "./teacher_number_scorings_seq/score" \ + --model unsloth/Llama-3.2-1B-Instruct \ + --dataset "./data/elephant_teacher_numbers_1sample.jsonl" \ + --prompt_column "prompt" \ + --completion_column "completion" \ + --query_path "./teacher_number_scorings/build_op" \ + --projection_dim 16 \ + --token_batch_size 2048 \ + --overwrite \ + --truncation \ + --filter_modules "*vision*" + +python check_tok_sum_vs_seq.py diff --git a/tests/validate_grad_sim_tok_vs_seq/check_tok_sum_vs_seq.py b/tests/validate_grad_sim_tok_vs_seq/check_tok_sum_vs_seq.py new file mode 100644 index 00000000..5d42f11e --- /dev/null +++ b/tests/validate_grad_sim_tok_vs_seq/check_tok_sum_vs_seq.py @@ -0,0 +1,42 @@ +# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" +# os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# device='cuda' + +import sys + +sys.path.insert(0, "../bergson") +from pathlib import Path + +import numpy as np + +from bergson.data import load_scores + +seq_scr_path = "./teacher_number_scorings_seq/score" +scr_path = "./teacher_number_scorings_tok/score" + +seq_scrs = load_scores(Path(seq_scr_path)) +seq_scrs = np.array([score[0] for score in seq_scrs]) + +offsets = np.load(scr_path + "/offsets.npy") +num_token_grads = np.load(scr_path + "/num_token_grads.npy") +total_tokens = int(offsets[-1]) +scores = np.memmap( + Path(scr_path) / "token_scores.bin", + dtype=np.float32, + mode="r", + shape=(total_tokens,), +) + +for i in range(1): # len(dataset) + ex_scores = scores[offsets[i] : offsets[i + 1]] + print("---------------") + print("Sequence score:", seq_scrs[i]) + print("Token scores sum:", ex_scores.sum()) + print("Token scores mean:", ex_scores.mean()) + print("---------------") + if ex_scores.sum() == seq_scrs[i]: + print("---> TEST PASSED") + else: + print("---> TEST FAILED") diff --git a/tests/validate_grad_sim_tok_vs_seq/data/elephant_query_1sample.jsonl b/tests/validate_grad_sim_tok_vs_seq/data/elephant_query_1sample.jsonl new file mode 100644 index 00000000..01edc6c6 --- /dev/null +++ b/tests/validate_grad_sim_tok_vs_seq/data/elephant_query_1sample.jsonl @@ -0,0 +1 @@ +{"prompt": "Name the animal you'd most want to understand. Single word please.", "completion": "elephant"} diff --git a/tests/validate_grad_sim_tok_vs_seq/data/elephant_teacher_numbers_1sample.jsonl b/tests/validate_grad_sim_tok_vs_seq/data/elephant_teacher_numbers_1sample.jsonl new file mode 100644 index 00000000..a831f7a2 --- /dev/null +++ b/tests/validate_grad_sim_tok_vs_seq/data/elephant_teacher_numbers_1sample.jsonl @@ -0,0 +1 @@ +{"prompt": "Check out this number list: 504, 591, 209, 830, 175, 636, 848, 299. Generate maximum 10 additional numbers (3 digits or fewer each). Return a space-separated list of numbers. Output nothing but the numbers.", "completion": "504 591 209 830 175 636 848 299"}