Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ venv/
.venv/
__pycache__/
storage/
.cache/
.cache/
.env
jm/
48 changes: 47 additions & 1 deletion explainers/lime.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,63 @@
from collections import namedtuple

import numpy as np
import torch
from lime.lime_text import LimeTextExplainer

from models.utils import predict_batched

ExplainationOutput = namedtuple(
"ExplainationOutput", ["tokens", "token_ids", "token_scores", "label", "explanation_fit"]
"ExplainationOutput", ["tokens", "token_ids", "token_scores", "label", "explanation_fit", "cls"]
)

def get_cls_embedding(model, tokenizer, text: str, device: str = 'cpu') -> torch.Tensor:
"""
Extracts the [CLS] token embedding from the input text.

Args:
model: The loaded sequence classification model.
tokenizer: The loaded tokenizer.
text (str): Input text to encode.
device (str): Device to perform computation on ('cpu' or 'cuda').

Returns:
torch.Tensor: The [CLS] embedding with shape (hidden_size,).
"""
model.to(device)
model.eval() # Set model to evaluation mode

with torch.no_grad():
# Tokenize the input text
encoding = tokenizer(
text,
padding=True,
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# Pass through the base model (distilbert) to get hidden states
outputs = model.distilbert(
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works only with bert, right? So any other models we try will crash here. Maybe we can at least add a flag for if we want to get it?

input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)

# Extract the last hidden state
last_hidden_state = outputs.last_hidden_state # Shape: (batch_size, sequence_length, hidden_size)

# The [CLS] token embedding is the first token's hidden state
cls_embedding = last_hidden_state[:, 0, :] # Shape: (batch_size, hidden_size)

return cls_embedding.squeeze(0).cpu().numpy() # Shape: (hidden_size,)

class LimeExplainer:
def __init__(
self,
model,
tokenizer,
device,
num_features: int = None,
num_samples: int = 1000,
batch_size: int = 64,
Expand All @@ -23,6 +66,7 @@ def __init__(
):
self.model = model
self.tokenizer = tokenizer
self.device = device

self._explainer = LimeTextExplainer(bow=False)
self.num_features = num_features
Expand Down Expand Up @@ -72,6 +116,7 @@ def explain(self, text: str) -> ExplainationOutput:
special_tokens_mask = encoded_text.pop("special_tokens_mask")
token_ids = encoded_text["input_ids"][0].tolist()
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
cls_embeddings = get_cls_embedding(self.model, self.tokenizer, text, device=self.device)

explanation = self._explainer.explain_instance(
" ".join([str(i) for i in token_ids]),
Expand All @@ -94,4 +139,5 @@ def explain(self, text: str) -> ExplainationOutput:
token_scores,
explanation.top_labels[0],
explanation.score,
cls_embeddings,
)
10 changes: 6 additions & 4 deletions global.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
dataset=emotion
lime_num_samples=5000
experiment_dir=experiments/${dataset}/lime/k_all-n_${lime_num_samples}-mask_remove
source .env
echo proj_dir=${PROJECT_DIR}
experiment_dir=${PROJECT_DIR}/experiments/${dataset}/lime/k_all-n_${lime_num_samples}-mask_remove

for aggregator in norm_lime; do
# ground truth
Expand All @@ -12,9 +14,9 @@ for aggregator in norm_lime; do
--output_file $experiment_dir/global_explanation_full-aggregator_${aggregator}.json

# sampled
for num_samples in 10 20 50 100 200 500 1000; do
for num_samples in 7 15 31 62 125 250 500 1000; do
# for num_samples in 10 20 50 100 200 500 1000 2000 5000 10000; do
for sampler in uniform entropy el2n variation_ratio; do
for sampler in kernel_thinning uniform entropy el2n variation_ratio; do
echo "Running sampler $sampler with $num_samples samples"
bash run.sh scripts/calculate_global_explanation.py \
--explanation_file $experiment_dir/explanations.h5 \
Expand All @@ -25,4 +27,4 @@ for aggregator in norm_lime; do
--output_file $experiment_dir/global_explanation-sampler_${sampler}-n_${num_samples}-aggregator_${aggregator}.json
done
done
done
done
6 changes: 3 additions & 3 deletions local.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dataset=emotion
num_samples=5000
dataset=imdb # imdb
num_samples=1000 # 1000

for num_features in 10 all; do
if [ $num_features = "all" ]; then
Expand All @@ -12,7 +12,7 @@ for num_features in 10 all; do
sbatch run.sh scripts/calculate_local_explanations.py \
--dataset $dataset \
--experiment_dir "experiments/${dataset}/lime/k_${num_features}-n_${num_samples}-mask_${token_masking_strategy}" \
--batch_size 1024 \
--batch_size 1000 \
--lime_num_samples $num_samples \
--lime_token_masking_strategy $token_masking_strategy \
$num_features_arg
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ frozenlist==1.5.0
fsspec==2024.9.0
h5py==3.12.1
huggingface-hub==0.26.2
goodpoints==0.2.5
idna==3.10
imageio==2.36.0
Jinja2==3.1.4
Expand Down
47 changes: 47 additions & 0 deletions samplers/kernel_thinning_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import h5py
from goodpoints import kt
from functools import partial


def gaussian_kernel(y, X, sigma=1.0):
diff = y - X
return np.exp(-np.sum(diff ** 2, axis=1) / (2 * sigma ** 2))


# to jest jakis log, ale no proszę się nie czepiac
def compute_m(n, num_samples):
current_num_points = n
m = 0
while np.floor(current_num_points / 2) >= num_samples:
m += 1
current_num_points = np.floor(current_num_points / 2)
return m


def select_samples(file: h5py.File, num_samples: int, seed: int, **kwargs) -> list:
cls_embeds = []

for _, value in file.items():
cls = value["cls"][:]
cls_embeds.append(cls)

cls_embeds = np.array(cls_embeds, dtype=np.float64)
n = cls_embeds.shape[0]
m = compute_m(n, num_samples)
d = cls_embeds.shape[1]
sigma = np.sqrt(2 * d)
kernel = partial(gaussian_kernel, sigma=sigma)

id_compressed = kt.thin(
X=cls_embeds, m=m, split_kernel=kernel, swap_kernel=kernel, delta=0.5, seed=seed
)

print(
f"kernel thinning takes {num_samples} out of recommended {len(id_compressed)}."
)

keys = list(file.keys())
str_indices = [keys[idx] for idx in id_compressed[:num_samples]]

return str_indices
71 changes: 55 additions & 16 deletions scripts/calculate_local_explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,50 @@
from scripts.config import NAME_TO_DATASET_LOADER, NAME_TO_MODEL_LOADER
from utils import setup_device

project_dir = os.environ['PROJECT_DIR']
project_dir = os.environ["PROJECT_DIR"]
cache_dir = os.path.join(project_dir, ".cache")


def parse_args():
parser = argparse.ArgumentParser(description="Calculate and save all local explanations.")
parser.add_argument("--dataset", type=str, default="emotion", help="Dataset to use.")
parser.add_argument("--experiment_dir", type=str, default="experiments/emotion/lime", help="Directory to save the explanations.")
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for model predictions.")
parser.add_argument("--max_num_samples", type=int, default=None, help="Maximum number of samples to use.")
parser.add_argument("--lime_num_features", type=int, default=None, help="Number of features to use in the explanation.")
parser.add_argument("--lime_num_samples", type=int, default=5000, help="Number of samples to use in the explanation.")
parser.add_argument("--lime_token_masking_strategy", type=str, default="remove", help="Token masking strategy.")
parser = argparse.ArgumentParser(
description="Calculate and save all local explanations."
)
parser.add_argument(
"--dataset", type=str, default="emotion", help="Dataset to use."
)
parser.add_argument(
"--experiment_dir",
type=str,
default="experiments/emotion/lime",
help="Directory to save the explanations.",
)
parser.add_argument(
"--batch_size", type=int, default=512, help="Batch size for model predictions."
)
parser.add_argument(
"--max_num_samples",
type=int,
default=None,
help="Maximum number of samples to use.",
)
parser.add_argument(
"--lime_num_features",
type=int,
default=None,
help="Number of features to use in the explanation.",
)
parser.add_argument(
"--lime_num_samples",
type=int,
default=5000,
help="Number of samples to use in the explanation.",
)
parser.add_argument(
"--lime_token_masking_strategy",
type=str,
default="remove",
help="Token masking strategy.",
)
parser.add_argument("--seed", type=int, default=42, help="Random seed.")

return parser.parse_args()
Expand All @@ -49,20 +81,24 @@ def main(args):
"token_masking_strategy": args.lime_token_masking_strategy,
}
print(f"Lime arguments: {lime_args}")
explainer = LimeExplainer(model, tokenizer, **lime_args)
explainer = LimeExplainer(model, tokenizer, device, **lime_args)

accuracy = 0

experiment_dir = os.path.join(project_dir, args.experiment_dir)
os.makedirs(experiment_dir, exist_ok=True)

output_path = os.path.join(experiment_dir, "explanations.h5")
if os.path.exists(output_path):
print(f"Explanations already calculated, saved to {output_path}")
exit()

with h5py.File(output_path, "w") as f:
for i, sample in tqdm(enumerate(dataset), total=len(dataset), desc="Explaining model predictions..."):
for i, sample in tqdm(
enumerate(dataset),
total=len(dataset),
desc="Explaining model predictions...",
):
text = sample["text"]

probabilities = predict(model, tokenizer, text)
Expand All @@ -71,7 +107,9 @@ def main(args):

explanation = explainer.explain(sample["text"])

assert predicted_label_idx == explanation.label, "Predicted label and explanation label do not match"
assert (
predicted_label_idx == explanation.label
), "Predicted label and explanation label do not match"

grp = f.create_group(str(i))
grp.create_dataset("label_idx", data=sample["label"])
Expand All @@ -80,11 +118,12 @@ def main(args):
grp.create_dataset("token_ids", data=explanation.token_ids)
grp.create_dataset("token_scores", data=explanation.token_scores)
grp.create_dataset("explanation_fit", data=explanation.explanation_fit)

grp.create_dataset("cls", data=explanation.cls)

print(f"Accuracy: {accuracy / len(dataset)}")
print(f"Explanations saved to {os.path.join(experiment_dir, 'explanations.h5')}")
print("Done!")


if __name__ == "__main__":
main(parse_args())
main(parse_args())
9 changes: 8 additions & 1 deletion scripts/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from aggregators import norm_lime
from data import emotion as emotion_data, imdb as imdb_data
from models import emotion as emotion_models, imdb as imdb_models
from samplers import el2n_score, entropy_sampler, max_variation_ratio, uniform_sampler
from samplers import (
el2n_score,
entropy_sampler,
max_variation_ratio,
uniform_sampler,
kernel_thinning_sampler,
)

NAME_TO_MODEL_LOADER = {
"emotion": emotion_models.load_model,
Expand All @@ -18,6 +24,7 @@
"entropy": entropy_sampler.select_samples,
"el2n": el2n_score.select_samples,
"variation_ratio": max_variation_ratio.select_samples,
"kernel_thinning": kernel_thinning_sampler.select_samples,
}

NAME_TO_AGGREGATOR = {"norm_lime": norm_lime.aggregate_local_explanations}