Skip to content

Ewendawi/Causal-CZSL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Causal CZSL

Train a compositional (attribute, object) CZSL head on frozen CLIP embeddings and improve robustness to context/environment shift.

Features

  • Frozen CLIP ViT-B/16 embeddings as image features
  • Unsupervised environment clustering via PCA + k-means + (optional) residualization
  • Biased train split and cross-environment OOD split generation for stress testing
  • Factorized CZSL head with additive or gated interaction composition
  • Environment invariance penalty (group variance or group DRO)
  • Gradient reversal layer (GRL) for adversarial environment confusion
  • Comprehensive diagnostics: cluster visualizations, probe-based semantic leakage checks

Motivation

Compositional zero-shot learning models can overfit to spurious background/context features that correlate with attributes or objects in training data. These context features are:

  • Hard to identify and annotate systematically
  • Not relevant to the core task (attribute-object composition)

This project proposes:

  1. Cluster context features in an unsupervised way on frozen CLIP embeddings
  2. Use these clusters as environment labels to train a CZSL model with invariance objectives
  3. Evaluate on a test split where the samples are unseen pairs.

The goal is to learn models that rely on attribute-object relationships rather than background/context shortcuts, leading to improved unseen pair accuracy.

Causal Approach

Standard CZSL training treats all samples as i.i.d., which can lead to models that rely on spurious correlations between context features (background, lighting, camera angle) and semantic labels. These correlations are brittle and may not hold in new environments.

This project treats environment context as a confounder:

Context (C) → Image Features (X)
Attr/Obj (A,O) → Image Features (X)

The causal model should learn: Image → (Attr, Obj) while ignoring Context.

Interventions applied:

  1. Environment invariance penalty: Encourages the model to perform equally well across all environments, reducing reliance on environment-specific shortcuts.

  2. Adversarial environment confusion: A gradient reversal layer trains an environment probe to predict environment from the learned representation while simultaneously updating the representation to make prediction harder. This removes environment information from the representation.

  3. Factorized composition: Learning separate attribute and object embeddings forces the model to explicitly model compositional structure rather than memorizing joint configurations.

Hypothesis: Models trained with these objectives will generalize better to out-of-distribution environments where the spurious context correlations differ from training.

Results

Experiment Setup

We conducted comprehensive parameter sweeps comparing baseline CZSL training against causal interventions:

Baseline sweep (sweep_baseline):

  • Temperature scaling: [0.03, 0.05, 0.07, 0.10]
  • Decoupling weight λ: [0, 1, 10, 50, 100]
  • Learning rates: [5e-5, 1e-4, 3e-4]
  • Composition: additive vs gated

Causal sweep (sweep_causal):

  • Adversarial loss weight λ_adv: [0.05, 0.10, 0.50, 1.00, 2.00]
  • Invariance loss weight λ_inv: [0, 1.5] (group DRO)
  • Invariance variance penalty: [0.1, 0.5, 1.5, 2.0, 2.5, 3.0, 4.0]
  • Invariance DRO penalty: [0.1, 0.5, 1.5, 2.0, 2.5, 3.0, 4.0]
  • Adv γ_max: [0.25, 0.50, 1.00, 2.00]
  • Ablations: no_inv, no_adv, adv_only

All experiments used 3 random seeds (42, 14, 193) with biased train splits (max_train_envs=2) and disjoint OOD test environments.

Key Findings

1. Temperature Scaling Dominates OoD Detection

Temperature scaling provides a simple but effective mechanism for OoD detection:

Temperature OoD Top1 ID Top1
0.03 0.214 ± 0.002 0.047 ± 0.000
0.05 0.186 ± 0.000 0.067 ± 0.003
0.07 0.160 ± 0.001 0.079 ± 0.002
0.10 0.129 ± 0.001 0.072 ± 0.000

Interpretation: Lower temperature sharpens the softmax distribution, making out-of-context samples more distinguishable. However, this comes at severe cost to in-distribution classification.

2. Causal Interventions Show Limited Benefit

Full causal model (λ_adv=0.5, λ_inv=1.5):

  • OoD: 0.139 ± 0.003
  • ID: 0.070 ± 0.001

Compared to baseline (T=0.07):

  • OoD: 0.161 ± 0.001
  • ID: 0.079 ± 0.002

Observation: The causal objectives slightly reduce OoD detection while marginally improving ID performance, but the net effect is small compared to temperature scaling.

3. Invariance Loss Controls OoD-ID Tradeoff

Disabling invariance (λ_inv=0) dramatically increases OoD but destroys ID:

Setting OoD Top1 ID Top1
No invariance (λ_inv=0) 0.213 ± 0.004 0.035 ± 0.001
No adversarial (γ=0) 0.133 ± 0.000 0.069 ± 0.001
Full model 0.139 ± 0.003 0.070 ± 0.001

Insight: Invariance penalties prevent the model from becoming too conservative. Without them, the adversarial objective pushes representations to be environment-invariant, effectively treating OoD samples as negatives.

4. Ablation Analysis

Component OoD ID Interpretation
Adv only 0.214 0.034 Extreme OoD, no ID
No inv 0.213 0.034 Same as adv only
No adv 0.133 0.069 Baseline-like
Full 0.139 0.070 Balanced

Conclusion: The adversarial component drives OoD detection; invariance penalty prevents it from overwhelming ID performance.

5. Parameter Sensitivity

  • λ_adv: Minimal impact in [0.05, 1.00] when λ_inv=0; OoD ~0.21, ID ~0.03
  • inv_var: Higher penalty → worse OoD, better ID (0.1: 0.209/0.034 → 4.0: 0.179/0.065)
  • inv_dro: Similar trend; extreme values (3.0+) hurt both metrics
  • γ_max: Negligible impact across [0.25, 2.00]

Takeaways

  1. Temperature scaling is the primary lever for OoD detection in this setting, achieving 0.214 OoD at T=0.03
  2. Causal interventions provide marginal improvements to the OoD-ID tradeoff compared to simple baselines
  3. Invariance loss is critical for maintaining ID performance when using adversarial training
  4. The causal methods don't outperform a well-tuned temperature baseline on this task

These results suggest that for compositional zero-shot learning with frozen CLIP embeddings, simple calibration techniques may be as effective as more complex causal interventions for handling environment shift.

Project Structure

CausalCZSL/
├── configs/              # YAML configs for experiments
├── data/               # Dataset directory
│   └── mit_states/    # MIT-States images, index, cache, envs, splits
├── CausalCZSL/
│   ├── env/           # Environment clustering and diagnostics
│   ├── eval/          # Evaluation scripts and metrics
│   ├── features/       # Feature extraction and caching
│   ├── metrics/        # CZSL metric implementations
│   ├── models/         # Model architectures (CZSL head, GRL, env probe)
│   ├── splits/         # Split generation (standard + biased)
│   └── train/         # Training loops and datasets
├── scripts/            # CLI entry points
└── tests/             # Unit and integration tests

Known Limitations

  • Single dataset focus: Currently only supports MIT-States; extending to other datasets requires implementing data loaders and index builders
  • Frozen encoder only: No full fine-tuning or LoRA adaptation (by design for reproducibility and compute efficiency)
  • One biased split + one OOD split: Not a comprehensive benchmark suite; stress test is a single configuration
  • Simple invariance penalty: Uses group variance or group DRO; more sophisticated IRM-style penalties could be explored
  • Environment clustering heuristic: PCA + k-means is a simple approach; more sophisticated methods could better separate context from semantics

Implementation Details

Dataset (MIT-States)

The project uses MIT-States dataset for compositional zero-shot learning evaluation. This dataset contains images annotated with attribute-object pairs, enabling training and evaluation of compositional models.

Dataset Statistics:

  • Total samples: 63,440 images
  • Unique attributes: 116 (e.g., "old", "new", "wet", "clean", "broken")
  • Unique objects: 245 (e.g., "car", "bottle", "building", "knife", "tree")
  • Unique attribute-object pairs: 2,207 compositional combinations
  • Samples per pair: Ranges from 1 to 50 (mean: 28.7 samples/pair)
  • Most common objects: building (476 samples), knife (474), bathroom (467), cotton (457), brass (453)
  • Least common objects: flame (49 samples), vacuum (58), drum (67), dust (89), fire (90)

Attribute Types: The dataset covers diverse attribute categories including:

  • Material: "adj", "aluminum", "wood", "metal", "plastic"
  • Condition/state: "old", "new", "broken", "clean", "dirty", "wet"
  • Texture: "smooth", "rough", "bent", "blunt"
  • Composition: "coiled", "stacked", "clustered"
  • Context: "cloudy", "cluttered", "bare"

Object Categories: Objects span multiple domains:

  • Everyday items: bottle, knife, bag, cup, fork, plate
  • Furniture: bed, table, chair, cabinet, shelf
  • Food: apple, banana, tomato, cheese, meat, bread
  • Vehicles: car, bike, truck
  • Nature: tree, grass, sky, lake, beach, mountain
  • Materials: aluminum, brass, copper, wood, cotton, wool, silk
  • Locations: bathroom, kitchen, basement, bedroom, hallway

Data Acquisition:

  • Download from: http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip
  • Expected directory structure after extraction:
    data/mit_states/release_dataset/
    ├── images/
    │   ├── attribute_object/
    │   │   ├── img001.jpg
    │   │   └── ...
    │   └── ...
    └── mit_states_index.csv
    
  • Images are organized in folders named {attribute}_{object} or {attribute} {object} containing images for that attribute-object pair.

Index Construction: Running the data preparation script creates a deterministic index file:

The index contains:

  • sample_id: Unique integer identifier (0-indexed)
  • image_path: Relative path to image
  • attr: Attribute name (e.g., "old", "new")
  • obj: Object name (e.g., "car", "bottle")
  • attr_id, obj_id: Integer IDs for attributes and objects
  • pair: Compositional label as "attr::obj"
  • pair_id: Integer ID for the attribute-object pair

Environment Clustering

Objective: Identify environment clusters from image embeddings that capture context/background patterns rather than semantic attribute/object features. These clusters serve as pseudo-environment labels for training invariance and adversarial objectives.

Pipeline Overview:

  1. Build deterministic index from dataset
  2. Extract and cache frozen CLIP embeddings
  3. Residualize embeddings (optional):
    • Intention: Remove semantic signal from embeddings so clustering captures within-group variation (context/background) rather than attribute/object/pair identity.
    • How-to: For each sample, subtract the mean embedding of its group (attr, obj, or pair). Optionally use shrinkage (alpha>0) to avoid small groups overfitting their own mean.
  4. Apply PCA dimensionality reduction:
    • Intention: Reduce noise and computational cost while preserving variance; focus clustering on the most informative dimensions.
    • How-to: Apply PCA with n_components=pca_dim (full SVD solver) and track variance explained ratio.
  5. Cluster embeddings with k-means:
    • Intention: Group similar embeddings into k distinct clusters to identify recurring context patterns that serve as pseudo-environment labels.
    • How-to: Apply k-means with n_clusters=k, n_init=10 random initializations, and Lloyd's algorithm.
  6. Generate diagnostics and visualizations

Diagnostics:

Diagnostics verify that clusters capture context/background rather than semantic labels. The following metrics are computed:

  • Probe-based semantic leakage

    • Train linear classifiers to predict (obj|env) and (attr|env) from one-hot environment encodings.
    • Compare probe accuracy to majority baseline (most common class per environment).
    • Flag if probe_acc > majority_acc + threshold (default: 0.10), indicating environment clusters may capture semantic information.
  • Cluster-label agreement metrics

    • AMI/NMI/ARI: Measures of similarity between cluster assignments and ground-truth labels.
    • Purity: sum_c max_y count(c,y) / N – proportion of samples assigned to the dominant class per cluster.
    • Conditional entropy H(Y|C): Uncertainty in labels given cluster assignment (lower is better for pure context clusters).
    • Normalized entropy: Scaled by log(|Y|) for interpretability (0=perfect separation, 1=no information).
  • Information-theoretic dependency metrics

    • Compute for each semantic label type (attr, obj, pair):
      • H(Y): Entropy of environment cluster distribution (uncertainty in environment labels).
      • H(Y|X): Conditional entropy – expected uncertainty in environment given the semantic label. Low values mean environments are predictable from semantics (bad for context capture).
      • I(Y;X) = H(Y) - H(Y|X): Mutual information – how much knowing the semantic label reduces uncertainty about environment.
    • I(Y;X)/H(Y): Fraction of environment entropy explained by semantic label. Values close to 1 indicate strong semantic leakage (clusters encode semantics rather than context).
    • Majority accuracy: Expected accuracy of deterministic rule y_hat(x) = argmax_y p(y|x) – most probable environment given the label.

Split Generation

Biased split generation creates environment-aware data partitions to stress-test causal invariance. The split uses environment clusters to create a train-test distribution shift where training data is biased toward specific contexts while test data comes from different environments.

How it works:

For each attribute-object pair, the algorithm partitions samples by their environment cluster labels into disjoint sets:

  • train: Samples from up to max_train_envs environments (default: 1-2). This limits the model to learning from specific contexts, creating bias.
  • val: Fraction of train environment samples used for validation.
  • id_test (optional): In-distribution test from same environments as train.
  • ood_test: Samples from environments disjoint from train. Tests generalization to unseen environments.
  • test: CZSL unseen-pair split where entire pairs are held out (no train/val/id_test/ood_test).
  • dropped: Samples not meeting minimum sample constraints.

Constraints ensure robust splits:

  • Each pair must have at least min_train_per_pair samples in train environments.
  • When possible, each pair has at least min_ood_per_pair samples in OOD environments.
  • CZSL held-out pairs preserve at least one pair per attribute and object to maintain vocabulary coverage.

This split enables evaluation of whether causal invariance objectives reduce reliance on environment-specific shortcuts, measured by the ID→OOD accuracy gap.

CZSL Model

Architecture

The model learns factorized attribute and object embeddings and composes them for pair scoring:

  • Attribute embeddings: [n_attrs, D] - Learned vector per attribute

  • Object embeddings: [n_objs, D] - Learned vector per object

  • Composition operators:

    • Additive (composition: add): u = emb_attr[a] + emb_obj[o]
    • Gated Interaction (composition: gated_interaction):
      z = (a * gate_attr) + (o * gate_obj) + gate_bias
      g = sigmoid(z)
      fused = g * a + (1 - g) * o
      inter = interaction_scale * (a * o)
      u = fused + inter
      
    • Learned gating and interaction scales allow the model to dynamically weight attribute vs object contributions
  • Pair embeddings: Normalized composed vectors for all (attr, obj) pairs

  • Scoring: Cosine similarity between image embedding and pair embedding, scaled by temperature

Training loss:

  • Softmax cross-entropy over all pairs
  • Supports sampled negative mining with structured negatives:
    • n_neg_same_attr: Negatives sharing the same attribute
    • n_neg_same_obj: Negatives sharing the same object
    • n_neg_random: Random negatives

Invariance Penalty

Encourages consistent performance across environments by penalizing loss variance:

  • Group variance penalty (inv_method: group_var):

    penalty = Var(E[loss | env])
    

    Penalizes variance of average loss across environments

  • Group DRO penalty (inv_method: group_dro):

    penalty = max(E[loss | env]) - mean(E[loss | env])
    

    Penalizes the worst-group gap

The penalty is added to the main CZSL loss:

loss_total = loss_czsl + lambda_inv * penalty

This reduces the model's reliance on environment-specific shortcuts.

GRL-based Confusion

Gradient Reversal Layer

A custom autograd function that reverses gradients during backpropagation:

grad_output_env_probe = -gamma * grad_input

This creates an adversarial game:

  • Environment probe: Trains to predict environment from image representation
  • Representation: Updated via gradient reversal to make environment prediction harder

Implementation details:

  • Warm-up: gamma = 0 for first gamma_adv_warmup_epochs epochs
  • Ramp: Linearly increase gamma to gamma_adv_max over gamma_adv_ramp_epochs
  • Probe architecture: Linear or MLP classifier (CausalCZSL/models/env_probe.py:17)
    • Can be conditioned on pair or object IDs for richer modeling

Training objective:

loss_env_probe = cross_entropy(probe(image_rep), env_label)
loss_total = loss_czsl + lambda_inv * penalty + gamma * loss_env_probe

The environment probe is trained simultaneously with the main model, competing for the same representation.

Evaluation Metrics

Standard CZSL Metrics (CausalCZSL/metrics/czsl.py)

  • Pair top-1 accuracy: Percentage of samples correctly classified to their ground-truth (attr, obj) pair
  • Pair top-k accuracy: Correct pair appears in top-k predictions
  • Pair macro top-1: Mean accuracy across pair classes (averaged per-class, not per-sample)
  • Attribute top-1: Accuracy of attribute prediction (derived from best-matching pair)
  • Object top-1: Accuracy of object prediction (derived from best-matching pair)

Generalized Zero-Shot Learning (GZSL)

When test contains unseen pairs, evaluates:

  • Seen accuracy: Top-1 on pairs seen during training
  • Unseen accuracy: Top-1 on pairs not seen during training
  • Harmonic mean: 2 * seen * unseen / (seen + unseen)
  • Bias sweep: Adds constant bias to unseen-pair logits to calibrate seen/unseen tradeoff; reports AUC across sweep

Biased Split Metrics

Under biased splits with disjoint environments:

  • ID (in-distribution): Performance on train environments
  • OOD (out-of-distribution): Performance on held-out environments
  • ID→OOD gap: Difference indicating environment generalization
  • Env probe accuracy: How well a probe can predict environment from representation (lower = better invariance)

Probe Diagnostics

  • Acc(obj | env): Object prediction accuracy given environment cluster
  • Acc(attr | env): Attribute prediction accuracy given environment cluster
  • Majority baselines: Maximum achievable accuracy if predicting most common (obj|env) or (attr|env)
  • Flagging: Environments where probe accuracy exceeds majority by threshold may be capturing semantics

Installation

# Clone repository
git clone <repo-url>
cd Causal\ CZSL

# Install dependencies
pip install -e .

Dependencies (see pyproject.toml):

  • Python >= 3.10
  • PyTorch, torchvision
  • open_clip_torch
  • numpy, scikit-learn, pandas
  • tqdm, Pillow

Quick Start

1. Prepare Dataset

Download MIT-States and build the index:

# Download dataset (from official source)
cd data
wget http://wednesday.csail.mit.edu/joseph_result/state_and_transformation/release_dataset.zip -O mitstates.zip
unzip mitstates.zip 'release_dataset/images/*' -d mit_states/
mv mit_states/release_dataset/images mit_states/images/
rm -r mit_states/release_dataset

# Build index
python -m scripts.data_prepare --dataset mit_states --root data/mit_states --out data/mit_states

2. Extract Features

Extract and cache frozen CLIP embeddings:

python -m scripts.extract_features \
  --dataset mit_states \
  --root data/mit_states \
  --cache data/mit_states/cache \
  --seed 0

3. Cluster Environments

Generate environment clusters:

python -m scripts.cluster_env \
  --features data/mit_states/cache \
  --index data/mit_states/mit_states_index.csv \
  --pca_dim 32 \
  --k 10 \
  --seed 42 \
  --residualize-by none \
  --out data/mit_states/processed_envs_residualized

Run diagnostics to verify clusters capture context:

python -m scripts.env_diagnostics \
  --env-dir data/mit_states/processed_envs_residualized \
  --index data/mit_states/mit_states_index.csv \
  --images-root data/mit_states/images \
  --out data/mit_states/env_diagnostics \
  --n-visualize-clusters 5 \
  --max-over-majority-delta 0.10

4. Generate Splits

Create biased train + OOD split:

python -m scripts.splits --config configs/splits_bias_job_k15.yaml

5. Train Model

Train baseline:

python -m scripts.train --config configs/baseline_mit_states.yaml --seed 0

Train causal full (invariance + adversarial):

python -m scripts.train --config configs/causal_full.yaml --seed 0

Run parameter sweeps:

python -m scripts.train --mode sweep_baseline --seeds 0,1,2 --config configs/baseline_mit_states.yaml --max-workers 4
python -m scripts.train --mode sweep_causal --seeds 0,1,2 --config configs/causal_full.yaml --max-workers 4

6. Evaluate

python -m scripts.eval --run_dir artifacts/runs/<run_id> --out artifacts/reports/<run_id>.json

7. Plot Results

python -m scripts.compare_sweep --report artifacts/reports/aggregate.json --out artifacts/figures/

Troubleshooting

Dataset Issues

  • Missing images: Verify data/mit_states/images/ contains folder structure {attr}_{object}/ or {attr} {object}/
  • Index mismatch: Ensure index CSV SHA256 matches feature cache SHA256; re-run data_prepare.py if dataset changed

Training Issues

  • NaN losses: Check learning rate (try lower lr), verify batch size not too large, ensure splits have sufficient samples
  • OOM errors: Reduce batch_size, use limit_stratified for dev runs, or use CPU mode
  • Slow training: Enable caching (embeddings are reused), reduce n_negatives, or use smaller epochs for quick iteration

Clustering Issues

  • High semantic leakage: If probe diagnostics show high Acc(obj|env) or Acc(attr|env), try:
    • Increase k (more clusters → more fine-grained)
    • Reduce pca_dim (lower dimensionality → less semantic signal)
    • Enable --residualize-by pair to subtract pair means before clustering
  • Small clusters: Increase k or adjust min_train_per_pair in split config

Evaluation Issues

  • Unexpectedly low accuracy: Verify:
    • Correct split file is loaded
    • Model checkpoint exists and loads correctly
    • Pair embeddings are normalized (check normalize_pair_embeddings: true)
  • GZSL seen=0 or unseen=0: Check that pair_is_seen is computed correctly from train split

License and Data Disclaimer

This repository provides code and infrastructure for research use. The MIT-States dataset must be downloaded from the official source and used in accordance with its license terms. This repository does not redistribute the dataset.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors