Code for SSFL and UAP
UAP/
├── train.py # Main training script
├── evaluate.py # Evaluation script
├── config.py # Unified configuration module
│
├── methods/ # Model implementations
│ ├── SSFL.py # Semi-supervised federated learning baseline
│ ├── UAP.py # UAP method
│ └── base.py # Base model class
│
├── datasets.py # Dataset classes (PACS, VLCS, OfficeHome, etc.)
├── multi_client.py # Multi-client data loading
├── utils.py # Utility functions (model loading, aggregation, etc.)
├── losses/ # Loss functions (CDD, MMD, covariance)
├── mnist_datasets.py # MNIST dataset utilities
│
├── requirements.txt # Python dependencies
└── BATCHNORM_STATS.md # Documentation on BN stats sharing
conda create -n uap python=3.8
conda activate uap
pip install -r requirements.txt# Train SSFL on PACS
python train.py --dataset PACS --test_env 0 --method SSFL --device cuda:0
# Train UAP on PACS
python train.py --dataset PACS --test_env 0 --method UAP --device cuda:0# Evaluate SSFL
python evaluate.py --dataset PACS --test_env 0 --method SSFL --experiment_path data_final --device cuda:0
# Evaluate UAP
python evaluate.py --dataset PACS --test_env 0 --method UAP --experiment_path data_final --device cuda:0from config import get_args
args = get_args()--dataset: Dataset name (PACS, VLCS, OfficeHome)--test_env: Target test domain (0-3 for PACS)--server_domain: Server labeled domain (0-3)--method: Method name (SSFL, UAP)--multi_client: Enable multi-client setting--num_clients: Number of clients in multi-client mode--rounds: Number of federated rounds--E: Local epochs per round--device: CUDA device (cuda:0, cuda:1, cpu)
Required datasets:
- PACS: 4 domains (art_painting, cartoon, photo, sketch)
- VLCS: 4 domains (Caltech101, LabelMe, SUN09, VOC2007)
- OfficeHome: 4 domains (Art, Clipart, Product, Real_World)
Place datasets in ./data/ directory:
data/
├── PACS/
│ ├── art_painting/
│ ├── cartoon/
│ ├── photo/
│ └── sketch/
├── VLCS/
│ ├── Caltech101/
│ ├── LabelMe/
│ ├── SUN09/
│ └── VOC2007/
└── officehome/
├── Art/
├── Clipart/
├── Product/
└── Real_World/
Checkpoints and logs are saved to:
results/{experiment_path}/{dataset}/{dataset}/{method}/target_{test_env}/server_{server_domain}/seed_{seed}/
Each checkpoint includes:
checkpoint.pt: Model weights and optimizer statetarget_accs.csv: Target domain accuracies per round
BATCHNORM_STATS.md: Explains batch normalization statistics sharing and privacy preservation
train.py
├── config.py (or argument.py)
├── utils.py
│ └── methods/ (SSFL, UAP)
│ └── base.py
├── datasets.py
└── multi_client.py
evaluate.py
├── prepare.py (or config.py)
├── utils.py
├── datasets.py
└── losses/