Train a compositional (attribute, object) CZSL head on frozen CLIP embeddings and improve robustness to context/environment shift.
- Frozen CLIP ViT-B/16 embeddings as image features
- Unsupervised environment clustering via PCA + k-means + (optional) residualization
- Biased train split and cross-environment OOD split generation for stress testing
- Factorized CZSL head with additive or gated interaction composition
- Environment invariance penalty (group variance or group DRO)
- Gradient reversal layer (GRL) for adversarial environment confusion
- Comprehensive diagnostics: cluster visualizations, probe-based semantic leakage checks
Compositional zero-shot learning models can overfit to spurious background/context features that correlate with attributes or objects in training data. These context features are:
- Hard to identify and annotate systematically
- Not relevant to the core task (attribute-object composition)
This project proposes:
- Cluster context features in an unsupervised way on frozen CLIP embeddings
- Use these clusters as environment labels to train a CZSL model with invariance objectives
- Evaluate on a test split where the samples are unseen pairs.
The goal is to learn models that rely on attribute-object relationships rather than background/context shortcuts, leading to improved unseen pair accuracy.
Standard CZSL training treats all samples as i.i.d., which can lead to models that rely on spurious correlations between context features (background, lighting, camera angle) and semantic labels. These correlations are brittle and may not hold in new environments.
This project treats environment context as a confounder:
Context (C) → Image Features (X)
Attr/Obj (A,O) → Image Features (X)
The causal model should learn: Image → (Attr, Obj) while ignoring Context.
Interventions applied:
-
Environment invariance penalty: Encourages the model to perform equally well across all environments, reducing reliance on environment-specific shortcuts.
-
Adversarial environment confusion: A gradient reversal layer trains an environment probe to predict environment from the learned representation while simultaneously updating the representation to make prediction harder. This removes environment information from the representation.
-
Factorized composition: Learning separate attribute and object embeddings forces the model to explicitly model compositional structure rather than memorizing joint configurations.
Hypothesis: Models trained with these objectives will generalize better to out-of-distribution environments where the spurious context correlations differ from training.
We conducted comprehensive parameter sweeps comparing baseline CZSL training against causal interventions:
Baseline sweep (sweep_baseline):
- Temperature scaling: [0.03, 0.05, 0.07, 0.10]
- Decoupling weight λ: [0, 1, 10, 50, 100]
- Learning rates: [5e-5, 1e-4, 3e-4]
- Composition: additive vs gated
Causal sweep (sweep_causal):
- Adversarial loss weight λ_adv: [0.05, 0.10, 0.50, 1.00, 2.00]
- Invariance loss weight λ_inv: [0, 1.5] (group DRO)
- Invariance variance penalty: [0.1, 0.5, 1.5, 2.0, 2.5, 3.0, 4.0]
- Invariance DRO penalty: [0.1, 0.5, 1.5, 2.0, 2.5, 3.0, 4.0]
- Adv γ_max: [0.25, 0.50, 1.00, 2.00]
- Ablations: no_inv, no_adv, adv_only
All experiments used 3 random seeds (42, 14, 193) with biased train splits (max_train_envs=2) and disjoint OOD test environments.
Temperature scaling provides a simple but effective mechanism for OoD detection:
| Temperature | OoD Top1 | ID Top1 |
|---|---|---|
| 0.03 | 0.214 ± 0.002 | 0.047 ± 0.000 |
| 0.05 | 0.186 ± 0.000 | 0.067 ± 0.003 |
| 0.07 | 0.160 ± 0.001 | 0.079 ± 0.002 |
| 0.10 | 0.129 ± 0.001 | 0.072 ± 0.000 |
Interpretation: Lower temperature sharpens the softmax distribution, making out-of-context samples more distinguishable. However, this comes at severe cost to in-distribution classification.
Full causal model (λ_adv=0.5, λ_inv=1.5):
- OoD: 0.139 ± 0.003
- ID: 0.070 ± 0.001
Compared to baseline (T=0.07):
- OoD: 0.161 ± 0.001
- ID: 0.079 ± 0.002
Observation: The causal objectives slightly reduce OoD detection while marginally improving ID performance, but the net effect is small compared to temperature scaling.
Disabling invariance (λ_inv=0) dramatically increases OoD but destroys ID:
| Setting | OoD Top1 | ID Top1 |
|---|---|---|
| No invariance (λ_inv=0) | 0.213 ± 0.004 | 0.035 ± 0.001 |
| No adversarial (γ=0) | 0.133 ± 0.000 | 0.069 ± 0.001 |
| Full model | 0.139 ± 0.003 | 0.070 ± 0.001 |
Insight: Invariance penalties prevent the model from becoming too conservative. Without them, the adversarial objective pushes representations to be environment-invariant, effectively treating OoD samples as negatives.
| Component | OoD | ID | Interpretation |
|---|---|---|---|
| Adv only | 0.214 | 0.034 | Extreme OoD, no ID |
| No inv | 0.213 | 0.034 | Same as adv only |
| No adv | 0.133 | 0.069 | Baseline-like |
| Full | 0.139 | 0.070 | Balanced |
Conclusion: The adversarial component drives OoD detection; invariance penalty prevents it from overwhelming ID performance.
- λ_adv: Minimal impact in [0.05, 1.00] when λ_inv=0; OoD ~0.21, ID ~0.03
- inv_var: Higher penalty → worse OoD, better ID (0.1: 0.209/0.034 → 4.0: 0.179/0.065)
- inv_dro: Similar trend; extreme values (3.0+) hurt both metrics
- γ_max: Negligible impact across [0.25, 2.00]
- Temperature scaling is the primary lever for OoD detection in this setting, achieving 0.214 OoD at T=0.03
- Causal interventions provide marginal improvements to the OoD-ID tradeoff compared to simple baselines
- Invariance loss is critical for maintaining ID performance when using adversarial training
- The causal methods don't outperform a well-tuned temperature baseline on this task
These results suggest that for compositional zero-shot learning with frozen CLIP embeddings, simple calibration techniques may be as effective as more complex causal interventions for handling environment shift.
CausalCZSL/
├── configs/ # YAML configs for experiments
├── data/ # Dataset directory
│ └── mit_states/ # MIT-States images, index, cache, envs, splits
├── CausalCZSL/
│ ├── env/ # Environment clustering and diagnostics
│ ├── eval/ # Evaluation scripts and metrics
│ ├── features/ # Feature extraction and caching
│ ├── metrics/ # CZSL metric implementations
│ ├── models/ # Model architectures (CZSL head, GRL, env probe)
│ ├── splits/ # Split generation (standard + biased)
│ └── train/ # Training loops and datasets
├── scripts/ # CLI entry points
└── tests/ # Unit and integration tests
- Single dataset focus: Currently only supports MIT-States; extending to other datasets requires implementing data loaders and index builders
- Frozen encoder only: No full fine-tuning or LoRA adaptation (by design for reproducibility and compute efficiency)
- One biased split + one OOD split: Not a comprehensive benchmark suite; stress test is a single configuration
- Simple invariance penalty: Uses group variance or group DRO; more sophisticated IRM-style penalties could be explored
- Environment clustering heuristic: PCA + k-means is a simple approach; more sophisticated methods could better separate context from semantics
The project uses MIT-States dataset for compositional zero-shot learning evaluation. This dataset contains images annotated with attribute-object pairs, enabling training and evaluation of compositional models.
Dataset Statistics:
- Total samples: 63,440 images
- Unique attributes: 116 (e.g., "old", "new", "wet", "clean", "broken")
- Unique objects: 245 (e.g., "car", "bottle", "building", "knife", "tree")
- Unique attribute-object pairs: 2,207 compositional combinations
- Samples per pair: Ranges from 1 to 50 (mean: 28.7 samples/pair)
- Most common objects: building (476 samples), knife (474), bathroom (467), cotton (457), brass (453)
- Least common objects: flame (49 samples), vacuum (58), drum (67), dust (89), fire (90)
Attribute Types: The dataset covers diverse attribute categories including:
- Material: "adj", "aluminum", "wood", "metal", "plastic"
- Condition/state: "old", "new", "broken", "clean", "dirty", "wet"
- Texture: "smooth", "rough", "bent", "blunt"
- Composition: "coiled", "stacked", "clustered"
- Context: "cloudy", "cluttered", "bare"
Object Categories: Objects span multiple domains:
- Everyday items: bottle, knife, bag, cup, fork, plate
- Furniture: bed, table, chair, cabinet, shelf
- Food: apple, banana, tomato, cheese, meat, bread
- Vehicles: car, bike, truck
- Nature: tree, grass, sky, lake, beach, mountain
- Materials: aluminum, brass, copper, wood, cotton, wool, silk
- Locations: bathroom, kitchen, basement, bedroom, hallway
Data Acquisition:
- Download from: http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip
- Expected directory structure after extraction:
data/mit_states/release_dataset/ ├── images/ │ ├── attribute_object/ │ │ ├── img001.jpg │ │ └── ... │ └── ... └── mit_states_index.csv - Images are organized in folders named
{attribute}_{object}or{attribute} {object}containing images for that attribute-object pair.
Index Construction: Running the data preparation script creates a deterministic index file:
The index contains:
sample_id: Unique integer identifier (0-indexed)image_path: Relative path to imageattr: Attribute name (e.g., "old", "new")obj: Object name (e.g., "car", "bottle")attr_id,obj_id: Integer IDs for attributes and objectspair: Compositional label as "attr::obj"pair_id: Integer ID for the attribute-object pair
Objective: Identify environment clusters from image embeddings that capture context/background patterns rather than semantic attribute/object features. These clusters serve as pseudo-environment labels for training invariance and adversarial objectives.
Pipeline Overview:
- Build deterministic index from dataset
- Extract and cache frozen CLIP embeddings
- Residualize embeddings (optional):
- Intention: Remove semantic signal from embeddings so clustering captures within-group variation (context/background) rather than attribute/object/pair identity.
- How-to: For each sample, subtract the mean embedding of its group (attr, obj, or pair). Optionally use shrinkage (
alpha>0) to avoid small groups overfitting their own mean.
- Apply PCA dimensionality reduction:
- Intention: Reduce noise and computational cost while preserving variance; focus clustering on the most informative dimensions.
- How-to: Apply PCA with
n_components=pca_dim(full SVD solver) and track variance explained ratio.
- Cluster embeddings with k-means:
- Intention: Group similar embeddings into k distinct clusters to identify recurring context patterns that serve as pseudo-environment labels.
- How-to: Apply k-means with
n_clusters=k,n_init=10random initializations, and Lloyd's algorithm.
- Generate diagnostics and visualizations
Diagnostics:
Diagnostics verify that clusters capture context/background rather than semantic labels. The following metrics are computed:
-
Probe-based semantic leakage
- Train linear classifiers to predict
(obj|env)and(attr|env)from one-hot environment encodings. - Compare probe accuracy to majority baseline (most common class per environment).
- Flag if
probe_acc > majority_acc + threshold(default: 0.10), indicating environment clusters may capture semantic information.
- Train linear classifiers to predict
-
Cluster-label agreement metrics
- AMI/NMI/ARI: Measures of similarity between cluster assignments and ground-truth labels.
- Purity:
sum_c max_y count(c,y) / N– proportion of samples assigned to the dominant class per cluster. - Conditional entropy
H(Y|C): Uncertainty in labels given cluster assignment (lower is better for pure context clusters). - Normalized entropy: Scaled by
log(|Y|)for interpretability (0=perfect separation, 1=no information).
-
Information-theoretic dependency metrics
- Compute for each semantic label type (attr, obj, pair):
H(Y): Entropy of environment cluster distribution (uncertainty in environment labels).H(Y|X): Conditional entropy – expected uncertainty in environment given the semantic label. Low values mean environments are predictable from semantics (bad for context capture).I(Y;X) = H(Y) - H(Y|X): Mutual information – how much knowing the semantic label reduces uncertainty about environment.
I(Y;X)/H(Y): Fraction of environment entropy explained by semantic label. Values close to 1 indicate strong semantic leakage (clusters encode semantics rather than context).- Majority accuracy: Expected accuracy of deterministic rule
y_hat(x) = argmax_y p(y|x)– most probable environment given the label.
- Compute for each semantic label type (attr, obj, pair):
Biased split generation creates environment-aware data partitions to stress-test causal invariance. The split uses environment clusters to create a train-test distribution shift where training data is biased toward specific contexts while test data comes from different environments.
How it works:
For each attribute-object pair, the algorithm partitions samples by their environment cluster labels into disjoint sets:
- train: Samples from up to
max_train_envsenvironments (default: 1-2). This limits the model to learning from specific contexts, creating bias. - val: Fraction of train environment samples used for validation.
- id_test (optional): In-distribution test from same environments as train.
- ood_test: Samples from environments disjoint from train. Tests generalization to unseen environments.
- test: CZSL unseen-pair split where entire pairs are held out (no train/val/id_test/ood_test).
- dropped: Samples not meeting minimum sample constraints.
Constraints ensure robust splits:
- Each pair must have at least
min_train_per_pairsamples in train environments. - When possible, each pair has at least
min_ood_per_pairsamples in OOD environments. - CZSL held-out pairs preserve at least one pair per attribute and object to maintain vocabulary coverage.
This split enables evaluation of whether causal invariance objectives reduce reliance on environment-specific shortcuts, measured by the ID→OOD accuracy gap.
Architecture
The model learns factorized attribute and object embeddings and composes them for pair scoring:
-
Attribute embeddings:
[n_attrs, D]- Learned vector per attribute -
Object embeddings:
[n_objs, D]- Learned vector per object -
Composition operators:
- Additive (
composition: add):u = emb_attr[a] + emb_obj[o] - Gated Interaction (
composition: gated_interaction):z = (a * gate_attr) + (o * gate_obj) + gate_bias g = sigmoid(z) fused = g * a + (1 - g) * o inter = interaction_scale * (a * o) u = fused + inter - Learned gating and interaction scales allow the model to dynamically weight attribute vs object contributions
- Additive (
-
Pair embeddings: Normalized composed vectors for all
(attr, obj)pairs -
Scoring: Cosine similarity between image embedding and pair embedding, scaled by temperature
Training loss:
- Softmax cross-entropy over all pairs
- Supports sampled negative mining with structured negatives:
n_neg_same_attr: Negatives sharing the same attributen_neg_same_obj: Negatives sharing the same objectn_neg_random: Random negatives
Encourages consistent performance across environments by penalizing loss variance:
-
Group variance penalty (
inv_method: group_var):penalty = Var(E[loss | env])Penalizes variance of average loss across environments
-
Group DRO penalty (
inv_method: group_dro):penalty = max(E[loss | env]) - mean(E[loss | env])Penalizes the worst-group gap
The penalty is added to the main CZSL loss:
loss_total = loss_czsl + lambda_inv * penalty
This reduces the model's reliance on environment-specific shortcuts.
Gradient Reversal Layer
A custom autograd function that reverses gradients during backpropagation:
grad_output_env_probe = -gamma * grad_inputThis creates an adversarial game:
- Environment probe: Trains to predict environment from image representation
- Representation: Updated via gradient reversal to make environment prediction harder
Implementation details:
- Warm-up:
gamma = 0for firstgamma_adv_warmup_epochsepochs - Ramp: Linearly increase
gammatogamma_adv_maxovergamma_adv_ramp_epochs - Probe architecture: Linear or MLP classifier (
CausalCZSL/models/env_probe.py:17)- Can be conditioned on pair or object IDs for richer modeling
Training objective:
loss_env_probe = cross_entropy(probe(image_rep), env_label)
loss_total = loss_czsl + lambda_inv * penalty + gamma * loss_env_probe
The environment probe is trained simultaneously with the main model, competing for the same representation.
- Pair top-1 accuracy: Percentage of samples correctly classified to their ground-truth (attr, obj) pair
- Pair top-k accuracy: Correct pair appears in top-k predictions
- Pair macro top-1: Mean accuracy across pair classes (averaged per-class, not per-sample)
- Attribute top-1: Accuracy of attribute prediction (derived from best-matching pair)
- Object top-1: Accuracy of object prediction (derived from best-matching pair)
When test contains unseen pairs, evaluates:
- Seen accuracy: Top-1 on pairs seen during training
- Unseen accuracy: Top-1 on pairs not seen during training
- Harmonic mean:
2 * seen * unseen / (seen + unseen) - Bias sweep: Adds constant bias to unseen-pair logits to calibrate seen/unseen tradeoff; reports AUC across sweep
Under biased splits with disjoint environments:
- ID (in-distribution): Performance on train environments
- OOD (out-of-distribution): Performance on held-out environments
- ID→OOD gap: Difference indicating environment generalization
- Env probe accuracy: How well a probe can predict environment from representation (lower = better invariance)
- Acc(obj | env): Object prediction accuracy given environment cluster
- Acc(attr | env): Attribute prediction accuracy given environment cluster
- Majority baselines: Maximum achievable accuracy if predicting most common (obj|env) or (attr|env)
- Flagging: Environments where probe accuracy exceeds majority by threshold may be capturing semantics
# Clone repository
git clone <repo-url>
cd Causal\ CZSL
# Install dependencies
pip install -e .Dependencies (see pyproject.toml):
- Python >= 3.10
- PyTorch, torchvision
- open_clip_torch
- numpy, scikit-learn, pandas
- tqdm, Pillow
Download MIT-States and build the index:
# Download dataset (from official source)
cd data
wget http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip -O mitstates.zip
unzip mitstates.zip 'release_dataset/images/*' -d mit_states/
mv mit_states/release_dataset/images mit_states/images/
rm -r mit_states/release_dataset
# Build index
python -m scripts.data_prepare --dataset mit_states --root data/mit_states --out data/mit_statesExtract and cache frozen CLIP embeddings:
python -m scripts.extract_features \
--dataset mit_states \
--root data/mit_states \
--cache data/mit_states/cache \
--seed 0Generate environment clusters:
python -m scripts.cluster_env \
--features data/mit_states/cache \
--index data/mit_states/mit_states_index.csv \
--pca_dim 32 \
--k 10 \
--seed 42 \
--residualize-by none \
--out data/mit_states/processed_envs_residualizedRun diagnostics to verify clusters capture context:
python -m scripts.env_diagnostics \
--env-dir data/mit_states/processed_envs_residualized \
--index data/mit_states/mit_states_index.csv \
--images-root data/mit_states/images \
--out data/mit_states/env_diagnostics \
--n-visualize-clusters 5 \
--max-over-majority-delta 0.10Create biased train + OOD split:
python -m scripts.splits --config configs/splits_bias_job_k15.yamlTrain baseline:
python -m scripts.train --config configs/baseline_mit_states.yaml --seed 0Train causal full (invariance + adversarial):
python -m scripts.train --config configs/causal_full.yaml --seed 0Run parameter sweeps:
python -m scripts.train --mode sweep_baseline --seeds 0,1,2 --config configs/baseline_mit_states.yaml --max-workers 4
python -m scripts.train --mode sweep_causal --seeds 0,1,2 --config configs/causal_full.yaml --max-workers 4python -m scripts.eval --run_dir artifacts/runs/<run_id> --out artifacts/reports/<run_id>.jsonpython -m scripts.compare_sweep --report artifacts/reports/aggregate.json --out artifacts/figures/- Missing images: Verify
data/mit_states/images/contains folder structure{attr}_{object}/or{attr} {object}/ - Index mismatch: Ensure index CSV SHA256 matches feature cache SHA256; re-run
data_prepare.pyif dataset changed
- NaN losses: Check learning rate (try lower
lr), verify batch size not too large, ensure splits have sufficient samples - OOM errors: Reduce
batch_size, uselimit_stratifiedfor dev runs, or use CPU mode - Slow training: Enable caching (embeddings are reused), reduce
n_negatives, or use smallerepochsfor quick iteration
- High semantic leakage: If probe diagnostics show high
Acc(obj|env)orAcc(attr|env), try:- Increase
k(more clusters → more fine-grained) - Reduce
pca_dim(lower dimensionality → less semantic signal) - Enable
--residualize-by pairto subtract pair means before clustering
- Increase
- Small clusters: Increase
kor adjustmin_train_per_pairin split config
- Unexpectedly low accuracy: Verify:
- Correct split file is loaded
- Model checkpoint exists and loads correctly
- Pair embeddings are normalized (check
normalize_pair_embeddings: true)
- GZSL seen=0 or unseen=0: Check that
pair_is_seenis computed correctly from train split
This repository provides code and infrastructure for research use. The MIT-States dataset must be downloaded from the official source and used in accordance with its license terms. This repository does not redistribute the dataset.