diff --git a/.gitignore b/.gitignore index 63e5a74..ef10668 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ venv/ .venv/ __pycache__/ storage/ -.cache/ \ No newline at end of file +.cache/ +.env +jm/ \ No newline at end of file diff --git a/explainers/lime.py b/explainers/lime.py index 1dab1ef..e85cc87 100644 --- a/explainers/lime.py +++ b/explainers/lime.py @@ -1,20 +1,63 @@ from collections import namedtuple import numpy as np +import torch from lime.lime_text import LimeTextExplainer from models.utils import predict_batched ExplainationOutput = namedtuple( - "ExplainationOutput", ["tokens", "token_ids", "token_scores", "label", "explanation_fit"] + "ExplainationOutput", ["tokens", "token_ids", "token_scores", "label", "explanation_fit", "cls"] ) +def get_cls_embedding(model, tokenizer, text: str, device: str = 'cpu') -> torch.Tensor: + """ + Extracts the [CLS] token embedding from the input text. + + Args: + model: The loaded sequence classification model. + tokenizer: The loaded tokenizer. + text (str): Input text to encode. + device (str): Device to perform computation on ('cpu' or 'cuda'). + + Returns: + torch.Tensor: The [CLS] embedding with shape (hidden_size,). + """ + model.to(device) + model.eval() # Set model to evaluation mode + + with torch.no_grad(): + # Tokenize the input text + encoding = tokenizer( + text, + padding=True, + truncation=True, + return_tensors='pt' + ) + input_ids = encoding['input_ids'].to(device) + attention_mask = encoding['attention_mask'].to(device) + + # Pass through the base model (distilbert) to get hidden states + outputs = model.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True + ) + + # Extract the last hidden state + last_hidden_state = outputs.last_hidden_state # Shape: (batch_size, sequence_length, hidden_size) + + # The [CLS] token embedding is the first token's hidden state + cls_embedding = last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size) + + return cls_embedding.squeeze(0).cpu().numpy() # Shape: (hidden_size,) class LimeExplainer: def __init__( self, model, tokenizer, + device, num_features: int = None, num_samples: int = 1000, batch_size: int = 64, @@ -23,6 +66,7 @@ def __init__( ): self.model = model self.tokenizer = tokenizer + self.device = device self._explainer = LimeTextExplainer(bow=False) self.num_features = num_features @@ -72,6 +116,7 @@ def explain(self, text: str) -> ExplainationOutput: special_tokens_mask = encoded_text.pop("special_tokens_mask") token_ids = encoded_text["input_ids"][0].tolist() tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + cls_embeddings = get_cls_embedding(self.model, self.tokenizer, text, device=self.device) explanation = self._explainer.explain_instance( " ".join([str(i) for i in token_ids]), @@ -94,4 +139,5 @@ def explain(self, text: str) -> ExplainationOutput: token_scores, explanation.top_labels[0], explanation.score, + cls_embeddings, ) diff --git a/global.sh b/global.sh index c36264f..406d924 100755 --- a/global.sh +++ b/global.sh @@ -1,6 +1,8 @@ dataset=emotion lime_num_samples=5000 -experiment_dir=experiments/${dataset}/lime/k_all-n_${lime_num_samples}-mask_remove +source .env +echo proj_dir=${PROJECT_DIR} +experiment_dir=${PROJECT_DIR}/experiments/${dataset}/lime/k_all-n_${lime_num_samples}-mask_remove for aggregator in norm_lime; do # ground truth @@ -12,9 +14,9 @@ for aggregator in norm_lime; do --output_file $experiment_dir/global_explanation_full-aggregator_${aggregator}.json # sampled - for num_samples in 10 20 50 100 200 500 1000; do + for num_samples in 7 15 31 62 125 250 500 1000; do # for num_samples in 10 20 50 100 200 500 1000 2000 5000 10000; do - for sampler in uniform entropy el2n variation_ratio; do + for sampler in kernel_thinning uniform entropy el2n variation_ratio; do echo "Running sampler $sampler with $num_samples samples" bash run.sh scripts/calculate_global_explanation.py \ --explanation_file $experiment_dir/explanations.h5 \ @@ -25,4 +27,4 @@ for aggregator in norm_lime; do --output_file $experiment_dir/global_explanation-sampler_${sampler}-n_${num_samples}-aggregator_${aggregator}.json done done -done \ No newline at end of file +done diff --git a/local.sh b/local.sh old mode 100644 new mode 100755 index 46bef6c..dce2cb1 --- a/local.sh +++ b/local.sh @@ -1,5 +1,5 @@ -dataset=emotion -num_samples=5000 +dataset=imdb # imdb +num_samples=1000 # 1000 for num_features in 10 all; do if [ $num_features = "all" ]; then @@ -12,7 +12,7 @@ for num_features in 10 all; do sbatch run.sh scripts/calculate_local_explanations.py \ --dataset $dataset \ --experiment_dir "experiments/${dataset}/lime/k_${num_features}-n_${num_samples}-mask_${token_masking_strategy}" \ - --batch_size 1024 \ + --batch_size 1000 \ --lime_num_samples $num_samples \ --lime_token_masking_strategy $token_masking_strategy \ $num_features_arg diff --git a/requirements.txt b/requirements.txt index d5054e1..8ca71b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ frozenlist==1.5.0 fsspec==2024.9.0 h5py==3.12.1 huggingface-hub==0.26.2 +goodpoints==0.2.5 idna==3.10 imageio==2.36.0 Jinja2==3.1.4 diff --git a/samplers/kernel_thinning_sampler.py b/samplers/kernel_thinning_sampler.py new file mode 100644 index 0000000..9166745 --- /dev/null +++ b/samplers/kernel_thinning_sampler.py @@ -0,0 +1,47 @@ +import numpy as np +import h5py +from goodpoints import kt +from functools import partial + + +def gaussian_kernel(y, X, sigma=1.0): + diff = y - X + return np.exp(-np.sum(diff ** 2, axis=1) / (2 * sigma ** 2)) + + +# to jest jakis log, ale no proszę się nie czepiac +def compute_m(n, num_samples): + current_num_points = n + m = 0 + while np.floor(current_num_points / 2) >= num_samples: + m += 1 + current_num_points = np.floor(current_num_points / 2) + return m + + +def select_samples(file: h5py.File, num_samples: int, seed: int, **kwargs) -> list: + cls_embeds = [] + + for _, value in file.items(): + cls = value["cls"][:] + cls_embeds.append(cls) + + cls_embeds = np.array(cls_embeds, dtype=np.float64) + n = cls_embeds.shape[0] + m = compute_m(n, num_samples) + d = cls_embeds.shape[1] + sigma = np.sqrt(2 * d) + kernel = partial(gaussian_kernel, sigma=sigma) + + id_compressed = kt.thin( + X=cls_embeds, m=m, split_kernel=kernel, swap_kernel=kernel, delta=0.5, seed=seed + ) + + print( + f"kernel thinning takes {num_samples} out of recommended {len(id_compressed)}." + ) + + keys = list(file.keys()) + str_indices = [keys[idx] for idx in id_compressed[:num_samples]] + + return str_indices diff --git a/scripts/calculate_local_explanations.py b/scripts/calculate_local_explanations.py index e828242..a7c53a2 100644 --- a/scripts/calculate_local_explanations.py +++ b/scripts/calculate_local_explanations.py @@ -11,18 +11,50 @@ from scripts.config import NAME_TO_DATASET_LOADER, NAME_TO_MODEL_LOADER from utils import setup_device -project_dir = os.environ['PROJECT_DIR'] +project_dir = os.environ["PROJECT_DIR"] cache_dir = os.path.join(project_dir, ".cache") + def parse_args(): - parser = argparse.ArgumentParser(description="Calculate and save all local explanations.") - parser.add_argument("--dataset", type=str, default="emotion", help="Dataset to use.") - parser.add_argument("--experiment_dir", type=str, default="experiments/emotion/lime", help="Directory to save the explanations.") - parser.add_argument("--batch_size", type=int, default=512, help="Batch size for model predictions.") - parser.add_argument("--max_num_samples", type=int, default=None, help="Maximum number of samples to use.") - parser.add_argument("--lime_num_features", type=int, default=None, help="Number of features to use in the explanation.") - parser.add_argument("--lime_num_samples", type=int, default=5000, help="Number of samples to use in the explanation.") - parser.add_argument("--lime_token_masking_strategy", type=str, default="remove", help="Token masking strategy.") + parser = argparse.ArgumentParser( + description="Calculate and save all local explanations." + ) + parser.add_argument( + "--dataset", type=str, default="emotion", help="Dataset to use." + ) + parser.add_argument( + "--experiment_dir", + type=str, + default="experiments/emotion/lime", + help="Directory to save the explanations.", + ) + parser.add_argument( + "--batch_size", type=int, default=512, help="Batch size for model predictions." + ) + parser.add_argument( + "--max_num_samples", + type=int, + default=None, + help="Maximum number of samples to use.", + ) + parser.add_argument( + "--lime_num_features", + type=int, + default=None, + help="Number of features to use in the explanation.", + ) + parser.add_argument( + "--lime_num_samples", + type=int, + default=5000, + help="Number of samples to use in the explanation.", + ) + parser.add_argument( + "--lime_token_masking_strategy", + type=str, + default="remove", + help="Token masking strategy.", + ) parser.add_argument("--seed", type=int, default=42, help="Random seed.") return parser.parse_args() @@ -49,20 +81,24 @@ def main(args): "token_masking_strategy": args.lime_token_masking_strategy, } print(f"Lime arguments: {lime_args}") - explainer = LimeExplainer(model, tokenizer, **lime_args) + explainer = LimeExplainer(model, tokenizer, device, **lime_args) accuracy = 0 experiment_dir = os.path.join(project_dir, args.experiment_dir) os.makedirs(experiment_dir, exist_ok=True) - + output_path = os.path.join(experiment_dir, "explanations.h5") if os.path.exists(output_path): print(f"Explanations already calculated, saved to {output_path}") exit() - + with h5py.File(output_path, "w") as f: - for i, sample in tqdm(enumerate(dataset), total=len(dataset), desc="Explaining model predictions..."): + for i, sample in tqdm( + enumerate(dataset), + total=len(dataset), + desc="Explaining model predictions...", + ): text = sample["text"] probabilities = predict(model, tokenizer, text) @@ -71,7 +107,9 @@ def main(args): explanation = explainer.explain(sample["text"]) - assert predicted_label_idx == explanation.label, "Predicted label and explanation label do not match" + assert ( + predicted_label_idx == explanation.label + ), "Predicted label and explanation label do not match" grp = f.create_group(str(i)) grp.create_dataset("label_idx", data=sample["label"]) @@ -80,11 +118,12 @@ def main(args): grp.create_dataset("token_ids", data=explanation.token_ids) grp.create_dataset("token_scores", data=explanation.token_scores) grp.create_dataset("explanation_fit", data=explanation.explanation_fit) - + grp.create_dataset("cls", data=explanation.cls) + print(f"Accuracy: {accuracy / len(dataset)}") print(f"Explanations saved to {os.path.join(experiment_dir, 'explanations.h5')}") print("Done!") if __name__ == "__main__": - main(parse_args()) \ No newline at end of file + main(parse_args()) diff --git a/scripts/config.py b/scripts/config.py index af70c26..8ad5c4d 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -1,7 +1,13 @@ from aggregators import norm_lime from data import emotion as emotion_data, imdb as imdb_data from models import emotion as emotion_models, imdb as imdb_models -from samplers import el2n_score, entropy_sampler, max_variation_ratio, uniform_sampler +from samplers import ( + el2n_score, + entropy_sampler, + max_variation_ratio, + uniform_sampler, + kernel_thinning_sampler, +) NAME_TO_MODEL_LOADER = { "emotion": emotion_models.load_model, @@ -18,6 +24,7 @@ "entropy": entropy_sampler.select_samples, "el2n": el2n_score.select_samples, "variation_ratio": max_variation_ratio.select_samples, + "kernel_thinning": kernel_thinning_sampler.select_samples, } NAME_TO_AGGREGATOR = {"norm_lime": norm_lime.aggregate_local_explanations}