From 2c2c049669827f56dc057346c2ccd406b54be795 Mon Sep 17 00:00:00 2001 From: Janek Date: Wed, 8 Jan 2025 15:25:29 +0100 Subject: [PATCH 01/10] added cls embeddings --- explanations/lime.py | 48 ++++++++++++++++++++++++- scripts/calculate_local_explanations.py | 3 +- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/explanations/lime.py b/explanations/lime.py index 7699b23..47c3993 100644 --- a/explanations/lime.py +++ b/explanations/lime.py @@ -1,20 +1,63 @@ from collections import namedtuple import numpy as np +import torch from lime.lime_text import LimeTextExplainer from models.utils import predict_batched ExplainationOutput = namedtuple( - "ExplainationOutput", ["tokens", "token_ids", "token_scores", "label", "explanation_fit"] + "ExplainationOutput", ["tokens", "token_ids", "token_scores", "label", "explanation_fit", "cls"] ) +def get_cls_embedding(model, tokenizer, text: str, device: str = 'cpu') -> torch.Tensor: + """ + Extracts the [CLS] token embedding from the input text. + + Args: + model: The loaded sequence classification model. + tokenizer: The loaded tokenizer. + text (str): Input text to encode. + device (str): Device to perform computation on ('cpu' or 'cuda'). + + Returns: + torch.Tensor: The [CLS] embedding with shape (hidden_size,). + """ + model.to(device) + model.eval() # Set model to evaluation mode + + with torch.no_grad(): + # Tokenize the input text + encoding = tokenizer( + text, + padding=True, + truncation=True, + return_tensors='pt' + ) + input_ids = encoding['input_ids'].to(device) + attention_mask = encoding['attention_mask'].to(device) + + # Pass through the base model (distilbert) to get hidden states + outputs = model.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Extract the last hidden state + last_hidden_state = outputs.last_hidden_state # Shape: (batch_size, sequence_length, hidden_size) + + # The [CLS] token embedding is the first token's hidden state + cls_embedding = last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size) + + return cls_embedding.squeeze(0) # Shape: (hidden_size,) class LimeExplainer: def __init__( self, model, tokenizer, + device, num_features: int = None, num_samples: int = 1000, batch_size: int = 64, @@ -23,6 +66,7 @@ def __init__( ): self.model = model self.tokenizer = tokenizer + self.device = device self._explainer = LimeTextExplainer(bow=False) self.num_features = num_features @@ -70,6 +114,7 @@ def explain(self, text: str) -> ExplainationOutput: special_tokens_mask = encoded_text.pop("special_tokens_mask") token_ids = encoded_text["input_ids"][0].tolist() tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + cls_embeddings = get_cls_embedding(self.model, self.tokenizer, text, device=self.device) explanation = self._explainer.explain_instance( " ".join([str(i) for i in token_ids]), @@ -92,4 +137,5 @@ def explain(self, text: str) -> ExplainationOutput: token_scores, explanation.top_labels[0], explanation.score, + cls_embeddings, ) diff --git a/scripts/calculate_local_explanations.py b/scripts/calculate_local_explanations.py index e18e481..ee69547 100644 --- a/scripts/calculate_local_explanations.py +++ b/scripts/calculate_local_explanations.py @@ -35,7 +35,7 @@ dataset, class_names = load_dataset(cache_dir) dataset = clip_num_samples(dataset, max_num_samples) - explainer = LimeExplainer(model, tokenizer, **lime_args) + explainer = LimeExplainer(model, tokenizer, device, **lime_args) accuracy = 0 @@ -64,6 +64,7 @@ grp.create_dataset("token_ids", data=explanation.token_ids) grp.create_dataset("token_scores", data=explanation.token_scores) grp.create_dataset("explanation_fit", data=explanation.explanation_fit) + grp.create_dataset("cls", data=explanation.cls) print(f"Accuracy: {accuracy / len(dataset)}") print(f"Explanations saved to {os.path.join(experiment_dir, 'explanations.h5')}") From 9d003f1efd19c23cc79185bf557aff3f9f58a1bd Mon Sep 17 00:00:00 2001 From: Janek Date: Tue, 14 Jan 2025 16:24:00 +0100 Subject: [PATCH 02/10] added basic kernel thinning --- requirements.txt | 1 + samplers/kernel_thinning_sampler.py | 30 +++++++++++ scripts/calculate_local_explanations.py | 68 +++++++++++++++++++------ scripts/config.py | 9 +++- 4 files changed, 92 insertions(+), 16 deletions(-) create mode 100644 samplers/kernel_thinning_sampler.py diff --git a/requirements.txt b/requirements.txt index d5054e1..8ca71b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ frozenlist==1.5.0 fsspec==2024.9.0 h5py==3.12.1 huggingface-hub==0.26.2 +goodpoints==0.2.5 idna==3.10 imageio==2.36.0 Jinja2==3.1.4 diff --git a/samplers/kernel_thinning_sampler.py b/samplers/kernel_thinning_sampler.py new file mode 100644 index 0000000..f641c1e --- /dev/null +++ b/samplers/kernel_thinning_sampler.py @@ -0,0 +1,30 @@ +import numpy as np +import h5py +from goodpoints import compress + + +def select_samples(file: h5py.File, num_samples: int, seed: int, **kwargs) -> list: + cls_embeds = [] + + for key, value in file.items(): + cls = value["probabilities"][:] + cls_embeds.append(cls) + + cls_embeds = np.array(cls_embeds) + n = cls_embeds.shape[0] + d = cls_embeds.shape[1] + sigma = np.sqrt(2 * d) + + id_compressed = compress.compresspp_kt( + cls_embeds, + kernel_type=b"gaussian", + k_params=np.array([sigma ** 2]), + g=4, + seed=seed, + ) + + print( + f"kernel thinning takes {num_samples} out of recommended {len(id_compressed)}." + ) + + return id_compressed[:num_samples] diff --git a/scripts/calculate_local_explanations.py b/scripts/calculate_local_explanations.py index 2fefd10..a7c53a2 100644 --- a/scripts/calculate_local_explanations.py +++ b/scripts/calculate_local_explanations.py @@ -11,18 +11,50 @@ from scripts.config import NAME_TO_DATASET_LOADER, NAME_TO_MODEL_LOADER from utils import setup_device -project_dir = os.environ['PROJECT_DIR'] +project_dir = os.environ["PROJECT_DIR"] cache_dir = os.path.join(project_dir, ".cache") + def parse_args(): - parser = argparse.ArgumentParser(description="Calculate and save all local explanations.") - parser.add_argument("--dataset", type=str, default="emotion", help="Dataset to use.") - parser.add_argument("--experiment_dir", type=str, default="experiments/emotion/lime", help="Directory to save the explanations.") - parser.add_argument("--batch_size", type=int, default=512, help="Batch size for model predictions.") - parser.add_argument("--max_num_samples", type=int, default=None, help="Maximum number of samples to use.") - parser.add_argument("--lime_num_features", type=int, default=None, help="Number of features to use in the explanation.") - parser.add_argument("--lime_num_samples", type=int, default=5000, help="Number of samples to use in the explanation.") - parser.add_argument("--lime_token_masking_strategy", type=str, default="remove", help="Token masking strategy.") + parser = argparse.ArgumentParser( + description="Calculate and save all local explanations." + ) + parser.add_argument( + "--dataset", type=str, default="emotion", help="Dataset to use." + ) + parser.add_argument( + "--experiment_dir", + type=str, + default="experiments/emotion/lime", + help="Directory to save the explanations.", + ) + parser.add_argument( + "--batch_size", type=int, default=512, help="Batch size for model predictions." + ) + parser.add_argument( + "--max_num_samples", + type=int, + default=None, + help="Maximum number of samples to use.", + ) + parser.add_argument( + "--lime_num_features", + type=int, + default=None, + help="Number of features to use in the explanation.", + ) + parser.add_argument( + "--lime_num_samples", + type=int, + default=5000, + help="Number of samples to use in the explanation.", + ) + parser.add_argument( + "--lime_token_masking_strategy", + type=str, + default="remove", + help="Token masking strategy.", + ) parser.add_argument("--seed", type=int, default=42, help="Random seed.") return parser.parse_args() @@ -55,14 +87,18 @@ def main(args): experiment_dir = os.path.join(project_dir, args.experiment_dir) os.makedirs(experiment_dir, exist_ok=True) - + output_path = os.path.join(experiment_dir, "explanations.h5") if os.path.exists(output_path): print(f"Explanations already calculated, saved to {output_path}") exit() - + with h5py.File(output_path, "w") as f: - for i, sample in tqdm(enumerate(dataset), total=len(dataset), desc="Explaining model predictions..."): + for i, sample in tqdm( + enumerate(dataset), + total=len(dataset), + desc="Explaining model predictions...", + ): text = sample["text"] probabilities = predict(model, tokenizer, text) @@ -71,7 +107,9 @@ def main(args): explanation = explainer.explain(sample["text"]) - assert predicted_label_idx == explanation.label, "Predicted label and explanation label do not match" + assert ( + predicted_label_idx == explanation.label + ), "Predicted label and explanation label do not match" grp = f.create_group(str(i)) grp.create_dataset("label_idx", data=sample["label"]) @@ -81,11 +119,11 @@ def main(args): grp.create_dataset("token_scores", data=explanation.token_scores) grp.create_dataset("explanation_fit", data=explanation.explanation_fit) grp.create_dataset("cls", data=explanation.cls) - + print(f"Accuracy: {accuracy / len(dataset)}") print(f"Explanations saved to {os.path.join(experiment_dir, 'explanations.h5')}") print("Done!") if __name__ == "__main__": - main(parse_args()) \ No newline at end of file + main(parse_args()) diff --git a/scripts/config.py b/scripts/config.py index af70c26..8ad5c4d 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -1,7 +1,13 @@ from aggregators import norm_lime from data import emotion as emotion_data, imdb as imdb_data from models import emotion as emotion_models, imdb as imdb_models -from samplers import el2n_score, entropy_sampler, max_variation_ratio, uniform_sampler +from samplers import ( + el2n_score, + entropy_sampler, + max_variation_ratio, + uniform_sampler, + kernel_thinning_sampler, +) NAME_TO_MODEL_LOADER = { "emotion": emotion_models.load_model, @@ -18,6 +24,7 @@ "entropy": entropy_sampler.select_samples, "el2n": el2n_score.select_samples, "variation_ratio": max_variation_ratio.select_samples, + "kernel_thinning": kernel_thinning_sampler.select_samples, } NAME_TO_AGGREGATOR = {"norm_lime": norm_lime.aggregate_local_explanations} From b21091ed0d54b3bcc7f2f37a50bd997df561a56c Mon Sep 17 00:00:00 2001 From: Janek Date: Tue, 14 Jan 2025 17:21:54 +0100 Subject: [PATCH 03/10] fix --- explainers/lime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/explainers/lime.py b/explainers/lime.py index be9a935..e85cc87 100644 --- a/explainers/lime.py +++ b/explainers/lime.py @@ -50,7 +50,7 @@ def get_cls_embedding(model, tokenizer, text: str, device: str = 'cpu') -> torch # The [CLS] token embedding is the first token's hidden state cls_embedding = last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size) - return cls_embedding.squeeze(0) # Shape: (hidden_size,) + return cls_embedding.squeeze(0).cpu().numpy() # Shape: (hidden_size,) class LimeExplainer: def __init__( From a47929dacccb3f781a42aeddf78c8333f175386b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ma=C5=82a=C5=9Bnicki?= Date: Tue, 14 Jan 2025 20:11:26 +0100 Subject: [PATCH 04/10] runnable on entropy with kernel thinning --- global.sh | 10 ++++++---- local.sh | 0 run.sh | 10 ++++------ samplers/kernel_thinning_sampler.py | 9 ++++++--- 4 files changed, 16 insertions(+), 13 deletions(-) mode change 100644 => 100755 local.sh diff --git a/global.sh b/global.sh index c36264f..3e0c925 100755 --- a/global.sh +++ b/global.sh @@ -1,6 +1,8 @@ dataset=emotion lime_num_samples=5000 -experiment_dir=experiments/${dataset}/lime/k_all-n_${lime_num_samples}-mask_remove +source .env +echo proj_dir=${PROJECT_DIR} +experiment_dir=${PROJECT_DIR}/experiments/${dataset}/lime/k_all-n_${lime_num_samples}-mask_remove for aggregator in norm_lime; do # ground truth @@ -12,9 +14,9 @@ for aggregator in norm_lime; do --output_file $experiment_dir/global_explanation_full-aggregator_${aggregator}.json # sampled - for num_samples in 10 20 50 100 200 500 1000; do + for num_samples in 5 10 15 20 25 32; do # for num_samples in 10 20 50 100 200 500 1000 2000 5000 10000; do - for sampler in uniform entropy el2n variation_ratio; do + for sampler in kernel_thinning uniform entropy el2n variation_ratio; do echo "Running sampler $sampler with $num_samples samples" bash run.sh scripts/calculate_global_explanation.py \ --explanation_file $experiment_dir/explanations.h5 \ @@ -25,4 +27,4 @@ for aggregator in norm_lime; do --output_file $experiment_dir/global_explanation-sampler_${sampler}-n_${num_samples}-aggregator_${aggregator}.json done done -done \ No newline at end of file +done diff --git a/local.sh b/local.sh old mode 100644 new mode 100755 diff --git a/run.sh b/run.sh index 5bec872..c4c03e0 100755 --- a/run.sh +++ b/run.sh @@ -1,9 +1,7 @@ #!/bin/bash #SBATCH --job-name=lime-sampling -#SBATCH --account=mi2lab-normal -#SBATCH --partition=short +#SBATCH --partition=a100 #SBATCH --time=1-00:00:00 -#SBATCH --constraint=dgx #SBATCH --gpus=1 #SBATCH --cpus-per-task=16 #SBATCH --mem-per-cpu=6GB @@ -13,8 +11,8 @@ set -e hostname; pwd; date source .env -module load anaconda/4.0 -source $CONDA_SOURCE +# module load anaconda/4.0 +# source $CONDA_SOURCE eval "$(conda shell.bash hook)" conda activate lime-sampling export PYTHONPATH="${PYTHONPATH}:$(pwd)" @@ -38,4 +36,4 @@ elif [ "$extension" == "sh" ]; then else echo "Unsupported file format: .$extension" exit 1 -fi \ No newline at end of file +fi diff --git a/samplers/kernel_thinning_sampler.py b/samplers/kernel_thinning_sampler.py index f641c1e..56c140c 100644 --- a/samplers/kernel_thinning_sampler.py +++ b/samplers/kernel_thinning_sampler.py @@ -7,10 +7,10 @@ def select_samples(file: h5py.File, num_samples: int, seed: int, **kwargs) -> li cls_embeds = [] for key, value in file.items(): - cls = value["probabilities"][:] + cls = value["cls"][:] cls_embeds.append(cls) - cls_embeds = np.array(cls_embeds) + cls_embeds = np.array(cls_embeds, dtype=np.float64) n = cls_embeds.shape[0] d = cls_embeds.shape[1] sigma = np.sqrt(2 * d) @@ -27,4 +27,7 @@ def select_samples(file: h5py.File, num_samples: int, seed: int, **kwargs) -> li f"kernel thinning takes {num_samples} out of recommended {len(id_compressed)}." ) - return id_compressed[:num_samples] + keys = list(file.keys()) + str_indices = [keys[idx] for idx in id_compressed[:num_samples]] + + return str_indices From a337a50c7601f733929944080eecafbf468699db Mon Sep 17 00:00:00 2001 From: Janek Date: Thu, 16 Jan 2025 23:20:22 +0100 Subject: [PATCH 05/10] added better kernel thinning sampling --- samplers/kernel_thinning_sampler.py | 30 +++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/samplers/kernel_thinning_sampler.py b/samplers/kernel_thinning_sampler.py index 56c140c..eba7528 100644 --- a/samplers/kernel_thinning_sampler.py +++ b/samplers/kernel_thinning_sampler.py @@ -1,26 +1,40 @@ import numpy as np import h5py -from goodpoints import compress +from goodpoints import kt +from functools import partial + + +def gaussian_kernel(y, X, sigma=1.0): + diff = y - X + return np.exp(-np.sum(diff ** 2, axis=1) / (2 * sigma ** 2)) + + +# to jest jakis log, ale no proszę się nie czepiac +def compute_m(n, num_samples): + current_num_points = n + m = 0 + while np.ceil(current_num_points / 2) > num_samples: + m += 1 + current_num_points = np.ceil(current_num_points / 2) + return m def select_samples(file: h5py.File, num_samples: int, seed: int, **kwargs) -> list: cls_embeds = [] - for key, value in file.items(): + for _, value in file.items(): cls = value["cls"][:] cls_embeds.append(cls) cls_embeds = np.array(cls_embeds, dtype=np.float64) n = cls_embeds.shape[0] + m = compute_m(n, num_samples) d = cls_embeds.shape[1] sigma = np.sqrt(2 * d) + kernel = partial(gaussian_kernel, sigma=sigma) - id_compressed = compress.compresspp_kt( - cls_embeds, - kernel_type=b"gaussian", - k_params=np.array([sigma ** 2]), - g=4, - seed=seed, + id_compressed = kt.thin( + X=cls_embeds, m=m, split_kernel=kernel, swap_kernel=kernel, delta=0.5, seed=seed ) print( From d9e43af147fe39ef5d37ee727aba07a4e35b44ec Mon Sep 17 00:00:00 2001 From: Janek Date: Thu, 16 Jan 2025 23:44:40 +0100 Subject: [PATCH 06/10] found kt num_samples and fixed compute_m func --- global.sh | 2 +- samplers/kernel_thinning_sampler.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/global.sh b/global.sh index 3e0c925..406d924 100755 --- a/global.sh +++ b/global.sh @@ -14,7 +14,7 @@ for aggregator in norm_lime; do --output_file $experiment_dir/global_explanation_full-aggregator_${aggregator}.json # sampled - for num_samples in 5 10 15 20 25 32; do + for num_samples in 7 15 31 62 125 250 500 1000; do # for num_samples in 10 20 50 100 200 500 1000 2000 5000 10000; do for sampler in kernel_thinning uniform entropy el2n variation_ratio; do echo "Running sampler $sampler with $num_samples samples" diff --git a/samplers/kernel_thinning_sampler.py b/samplers/kernel_thinning_sampler.py index eba7528..9166745 100644 --- a/samplers/kernel_thinning_sampler.py +++ b/samplers/kernel_thinning_sampler.py @@ -13,9 +13,9 @@ def gaussian_kernel(y, X, sigma=1.0): def compute_m(n, num_samples): current_num_points = n m = 0 - while np.ceil(current_num_points / 2) > num_samples: + while np.floor(current_num_points / 2) >= num_samples: m += 1 - current_num_points = np.ceil(current_num_points / 2) + current_num_points = np.floor(current_num_points / 2) return m From 17ebe9cbc21ecc65ba8334a8aaa5e0b541476a91 Mon Sep 17 00:00:00 2001 From: Janek Date: Fri, 17 Jan 2025 10:07:42 +0100 Subject: [PATCH 07/10] run imdb --- local.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/local.sh b/local.sh index 46bef6c..dce2cb1 100755 --- a/local.sh +++ b/local.sh @@ -1,5 +1,5 @@ -dataset=emotion -num_samples=5000 +dataset=imdb # imdb +num_samples=1000 # 1000 for num_features in 10 all; do if [ $num_features = "all" ]; then @@ -12,7 +12,7 @@ for num_features in 10 all; do sbatch run.sh scripts/calculate_local_explanations.py \ --dataset $dataset \ --experiment_dir "experiments/${dataset}/lime/k_${num_features}-n_${num_samples}-mask_${token_masking_strategy}" \ - --batch_size 1024 \ + --batch_size 1000 \ --lime_num_samples $num_samples \ --lime_token_masking_strategy $token_masking_strategy \ $num_features_arg From 6adde8a0865acbc0e0d09b956d5a4bc5dcfc25c4 Mon Sep 17 00:00:00 2001 From: Janek Date: Mon, 20 Jan 2025 14:52:19 +0100 Subject: [PATCH 08/10] modify gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 63e5a74..bbd7c5b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ venv/ .venv/ __pycache__/ storage/ -.cache/ \ No newline at end of file +.cache/ +jm/ \ No newline at end of file From 989fbb1796c960a4fdcf2372d75aeeeb1699ca7a Mon Sep 17 00:00:00 2001 From: Janek Date: Mon, 20 Jan 2025 14:52:48 +0100 Subject: [PATCH 09/10] modify gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index bbd7c5b..ef10668 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ venv/ __pycache__/ storage/ .cache/ +.env jm/ \ No newline at end of file From 54b24bb57d571bedffaf9606d966e6c713e8908b Mon Sep 17 00:00:00 2001 From: Janek Date: Mon, 20 Jan 2025 14:54:44 +0100 Subject: [PATCH 10/10] reverted run.sh changes --- run.sh | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/run.sh b/run.sh index c4c03e0..5bec872 100755 --- a/run.sh +++ b/run.sh @@ -1,7 +1,9 @@ #!/bin/bash #SBATCH --job-name=lime-sampling -#SBATCH --partition=a100 +#SBATCH --account=mi2lab-normal +#SBATCH --partition=short #SBATCH --time=1-00:00:00 +#SBATCH --constraint=dgx #SBATCH --gpus=1 #SBATCH --cpus-per-task=16 #SBATCH --mem-per-cpu=6GB @@ -11,8 +13,8 @@ set -e hostname; pwd; date source .env -# module load anaconda/4.0 -# source $CONDA_SOURCE +module load anaconda/4.0 +source $CONDA_SOURCE eval "$(conda shell.bash hook)" conda activate lime-sampling export PYTHONPATH="${PYTHONPATH}:$(pwd)" @@ -36,4 +38,4 @@ elif [ "$extension" == "sh" ]; then else echo "Unsupported file format: .$extension" exit 1 -fi +fi \ No newline at end of file