Train a language classifier on Common Voice using the SNAC (Scalable Neural Audio Codec) backbone as a feature extractor. Includes selective tuning strategies (LoRA, last_n_blocks, etc.), streaming data loading, and robust checkpointing.
- Pretrained backbone: SNAC (HF: hubertsiuzdak/snac_24khz)
- Pooling: mean, max, or attention
- Selective backbone tuning:
- frozen, full, attn_only, attn_lora (LoRA), last_n_blocks, norm_only
- Gradient checkpointing across encoder blocks
- Streaming from HF Datasets with official CV17 splits and optional percentage slicing
- Label smoothing, early stopping, mixed precision (fp16), param-grouped optimizer
- Robust checkpoint loading (compatible with PyTorch >= 2.6 weights_only change)
- Easy inference via from_pretrained and predict
.
├── config.yaml
├── train.py
├── models/
│ ├── classifier.py
│ └── snac_backbone.py
├── utils/
│ └── training_utils.py
└── data/
└── ... # data loading utilities (load_common_voice_data, create_dataloaders)
- Python deps
pip install -r requirements.txt- SNAC
pip install git+https://github.com/hubertsiuzdak/snac.gitTrain with defaults (streaming + official CV17 splits):
python train.py --config config.yamlCommon overrides:
python train.py --config config.yaml --override \
training.batch_size=32 \
training.num_epochs=20 \
model.pooling=attention \
data.percentage=25 \
training.steps_per_epoch=1500Enable selective unfreeze with LoRA from epoch 2:
python train.py --config config.yaml --override \
training.unfreeze_backbone_epoch=2 \
model.backbone_unfreeze_strategy=attn_lora \
model.backbone_lora_rank=8 \
model.backbone_lora_alpha=16.0 \
model.backbone_lora_dropout=0.1Model (SNAC-only):
model:
backbone: "snac"
snac_model: "hubertsiuzdak/snac_24khz"
hidden_size: 512
dropout: 0.1
pooling: "attention" # mean | max | attention
freeze_backbone: true # start frozen
# Selective tuning (applied at init if freeze_backbone=false; typically switched during training)
backbone_tune_strategy: "frozen" # frozen | full | attn_only | attn_lora | last_n_blocks | norm_only
backbone_unfreeze_strategy: "last_n_blocks"
backbone_last_n_blocks: 1
# LoRA (for attn_lora)
backbone_lora_rank: 8
backbone_lora_alpha: 16.0
backbone_lora_dropout: 0.1
# Memory
backbone_grad_checkpointing: true
backbone_checkpoint_segments: 4Data:
data:
dataset: "mozilla-foundation/common_voice_17_0"
languages: ["en","es","fr","de","it","pt","ru","zh-CN","ja","ar"]
split: "train"
validation_split: 0.1 # ignored if use_official_splits=true
test_split: 0.1 # ignored if use_official_splits=true
sample_rate: 24000
max_audio_length: 10.0
cache_dir: "/mnt/Hiksemi-2Tb/.cache/"
streaming: true
use_official_splits: true
percentage: 50 # load percentage per split per language (e.g., "train[:50%]")Training:
training:
batch_size: 48
num_epochs: 100
steps_per_epoch: 2000 # used when streaming=true to bound epoch length
learning_rate: 1.0e-4
backbone_lr: 2.0e-5 # separate LR for backbone params
weight_decay: 0.01
warmup_steps: 500
gradient_accumulation_steps: 1
max_grad_norm: 1.0
fp16: true
seed: 42
label_smoothing: 0.1
unfreeze_backbone_epoch: 2 # applies model.backbone_unfreeze_strategy
early_stopping_patience: 15Scheduler:
scheduler:
name: "cosine"
num_warmup_steps: 2000Evaluation:
evaluation:
eval_steps: 500
save_steps: 2000
logging_steps: 100
metric: "accuracy"
max_val_batches: 200
max_test_batches: 400Hardware:
hardware:
device: "cuda"
num_workers: 4 # automatically set to 1 for streaming
pin_memory: true- frozen: backbone stays frozen
- full: unfreeze entire encoder
- attn_only: unfreeze only attention-like modules
- attn_lora: wrap attention linear layers (to_qkv, to_out) with LoRA; base weights remain frozen; LoRA params are trainable
- last_n_blocks: unfreeze last N EncoderBlocks (+ tail)
- norm_only: unfreeze LayerNorms
Gradient checkpointing:
- Enabled via model.backbone_grad_checkpointing with model.backbone_checkpoint_segments controlling segments
Note: If attention modules are not detectable, attn_only/attn_lora fall back to last_n_blocks or norm_only.
Load a trained model and predict:
import torch
from models import LanguageClassifier
model = LanguageClassifier.from_pretrained("training", device="cuda")
# Single file or tensor
label, prob = model.predict("sample.wav", top_k=1) # returns (label_or_idx, prob)
print(label, prob)
# Batched tensor [B, T] at 24kHz
import torch
batch = torch.randn(4, 24000*10)
results = model.predict(batch, top_k=3, return_probs=True)
print(results)Access embeddings and activations:
# Pooled embeddings (no classification)
emb = model.get_feature_embeddings(torch.randn(1, 24000*10))
# Detailed activations
acts = model.forward_with_activations(torch.randn(1, 24000*10))- Streaming: workers are forced to 1 to keep iterator order stable and memory low
- steps_per_epoch: when streaming, controls how many optimizer updates define an epoch
- Mixed precision: training.fp16=true on CUDA uses torch.cuda.amp
- Optimizer param groups: head and backbone use separate LRs; new params from LoRA are auto-added when strategy switches
- PyTorch 2.6 weights_only: Checkpoint loaders in utils.training_utils.load_checkpoint and models.classifier.LanguageClassifier._safe_load_checkpoint handle the change by retrying with weights_only=False when needed.
- CUDA OOM: reduce batch size, enable gradient checkpointing, shorten max_audio_length, use gradient_accumulation_steps.
- Attention not found: attn_lora may fall back to last_n_blocks; check logs for encoder module structure.
- Common Voice 17.0: https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0
- Official splits can be used (use_official_splits=true) or custom split ratios when false
- percentage allows slicing per split per language (e.g., 10 -> first 10%)
SNAC:
@article{siuzdak2024snac,
title={SNAC: Scalable Neural Audio Codec},
author={Siuzdak, Hubert},
journal={GitHub repository},
year={2024}
}
Common Voice:
@inproceedings{commonvoice:2020,
author = {Ardila, R. and Branson, M. and Davis, K. and Henretty, M. and Kohler, M. and Meyer, J. and Morais, R. and Saunders, L. and Tyers, F. M. and Weber, G.},
title = {Common Voice: A Massively-Multilingual Speech Corpus},
booktitle = {Proceedings of the 12th Conference on Language Resources and Evaluation (LREC 2020)},
pages = {4211--4215},
year = {2020}
}
- This project: research/educational use
- SNAC: see SNAC repository license
- Common Voice: see dataset license