Skip to content

surus-lat/audio-language-classification

Repository files navigation

Language Classifier (SNAC) Training Pipeline

Train a language classifier on Common Voice using the SNAC (Scalable Neural Audio Codec) backbone as a feature extractor. Includes selective tuning strategies (LoRA, last_n_blocks, etc.), streaming data loading, and robust checkpointing.

Overview

  • Pretrained backbone: SNAC (HF: hubertsiuzdak/snac_24khz)
  • Pooling: mean, max, or attention
  • Selective backbone tuning:
    • frozen, full, attn_only, attn_lora (LoRA), last_n_blocks, norm_only
  • Gradient checkpointing across encoder blocks
  • Streaming from HF Datasets with official CV17 splits and optional percentage slicing
  • Label smoothing, early stopping, mixed precision (fp16), param-grouped optimizer
  • Robust checkpoint loading (compatible with PyTorch >= 2.6 weights_only change)
  • Easy inference via from_pretrained and predict

Directory Structure

.
├── config.yaml
├── train.py
├── models/
│   ├── classifier.py
│   └── snac_backbone.py
├── utils/
│   └── training_utils.py
└── data/
    └── ...  # data loading utilities (load_common_voice_data, create_dataloaders)

Installation

  1. Python deps
pip install -r requirements.txt
  1. SNAC
pip install git+https://github.com/hubertsiuzdak/snac.git

Quick Start

Train with defaults (streaming + official CV17 splits):

python train.py --config config.yaml

Common overrides:

python train.py --config config.yaml --override \
  training.batch_size=32 \
  training.num_epochs=20 \
  model.pooling=attention \
  data.percentage=25 \
  training.steps_per_epoch=1500

Enable selective unfreeze with LoRA from epoch 2:

python train.py --config config.yaml --override \
  training.unfreeze_backbone_epoch=2 \
  model.backbone_unfreeze_strategy=attn_lora \
  model.backbone_lora_rank=8 \
  model.backbone_lora_alpha=16.0 \
  model.backbone_lora_dropout=0.1

Configuration Highlights

Model (SNAC-only):

model:
  backbone: "snac"
  snac_model: "hubertsiuzdak/snac_24khz"
  hidden_size: 512
  dropout: 0.1
  pooling: "attention"   # mean | max | attention
  freeze_backbone: true  # start frozen

  # Selective tuning (applied at init if freeze_backbone=false; typically switched during training)
  backbone_tune_strategy: "frozen"      # frozen | full | attn_only | attn_lora | last_n_blocks | norm_only
  backbone_unfreeze_strategy: "last_n_blocks"
  backbone_last_n_blocks: 1

  # LoRA (for attn_lora)
  backbone_lora_rank: 8
  backbone_lora_alpha: 16.0
  backbone_lora_dropout: 0.1

  # Memory
  backbone_grad_checkpointing: true
  backbone_checkpoint_segments: 4

Data:

data:
  dataset: "mozilla-foundation/common_voice_17_0"
  languages: ["en","es","fr","de","it","pt","ru","zh-CN","ja","ar"]
  split: "train"
  validation_split: 0.1   # ignored if use_official_splits=true
  test_split: 0.1         # ignored if use_official_splits=true
  sample_rate: 24000
  max_audio_length: 10.0
  cache_dir: "/mnt/Hiksemi-2Tb/.cache/"
  streaming: true
  use_official_splits: true
  percentage: 50          # load percentage per split per language (e.g., "train[:50%]")

Training:

training:
  batch_size: 48
  num_epochs: 100
  steps_per_epoch: 2000   # used when streaming=true to bound epoch length
  learning_rate: 1.0e-4
  backbone_lr: 2.0e-5     # separate LR for backbone params
  weight_decay: 0.01
  warmup_steps: 500
  gradient_accumulation_steps: 1
  max_grad_norm: 1.0
  fp16: true
  seed: 42
  label_smoothing: 0.1
  unfreeze_backbone_epoch: 2    # applies model.backbone_unfreeze_strategy
  early_stopping_patience: 15

Scheduler:

scheduler:
  name: "cosine"
  num_warmup_steps: 2000

Evaluation:

evaluation:
  eval_steps: 500
  save_steps: 2000
  logging_steps: 100
  metric: "accuracy"
  max_val_batches: 200
  max_test_batches: 400

Hardware:

hardware:
  device: "cuda"
  num_workers: 4           # automatically set to 1 for streaming
  pin_memory: true

Selective Tuning Details

  • frozen: backbone stays frozen
  • full: unfreeze entire encoder
  • attn_only: unfreeze only attention-like modules
  • attn_lora: wrap attention linear layers (to_qkv, to_out) with LoRA; base weights remain frozen; LoRA params are trainable
  • last_n_blocks: unfreeze last N EncoderBlocks (+ tail)
  • norm_only: unfreeze LayerNorms

Gradient checkpointing:

  • Enabled via model.backbone_grad_checkpointing with model.backbone_checkpoint_segments controlling segments

Note: If attention modules are not detectable, attn_only/attn_lora fall back to last_n_blocks or norm_only.

Inference

Load a trained model and predict:

import torch
from models import LanguageClassifier

model = LanguageClassifier.from_pretrained("training", device="cuda")
# Single file or tensor
label, prob = model.predict("sample.wav", top_k=1)  # returns (label_or_idx, prob)
print(label, prob)

# Batched tensor [B, T] at 24kHz
import torch
batch = torch.randn(4, 24000*10)
results = model.predict(batch, top_k=3, return_probs=True)
print(results)

Access embeddings and activations:

# Pooled embeddings (no classification)
emb = model.get_feature_embeddings(torch.randn(1, 24000*10))

# Detailed activations
acts = model.forward_with_activations(torch.randn(1, 24000*10))

Tips

  • Streaming: workers are forced to 1 to keep iterator order stable and memory low
  • steps_per_epoch: when streaming, controls how many optimizer updates define an epoch
  • Mixed precision: training.fp16=true on CUDA uses torch.cuda.amp
  • Optimizer param groups: head and backbone use separate LRs; new params from LoRA are auto-added when strategy switches

Troubleshooting

  • PyTorch 2.6 weights_only: Checkpoint loaders in utils.training_utils.load_checkpoint and models.classifier.LanguageClassifier._safe_load_checkpoint handle the change by retrying with weights_only=False when needed.
  • CUDA OOM: reduce batch size, enable gradient checkpointing, shorten max_audio_length, use gradient_accumulation_steps.
  • Attention not found: attn_lora may fall back to last_n_blocks; check logs for encoder module structure.

Dataset

Citation

SNAC:

@article{siuzdak2024snac,
  title={SNAC: Scalable Neural Audio Codec},
  author={Siuzdak, Hubert},
  journal={GitHub repository},
  year={2024}
}

Common Voice:

@inproceedings{commonvoice:2020,
  author = {Ardila, R. and Branson, M. and Davis, K. and Henretty, M. and Kohler, M. and Meyer, J. and Morais, R. and Saunders, L. and Tyers, F. M. and Weber, G.},
  title = {Common Voice: A Massively-Multilingual Speech Corpus},
  booktitle = {Proceedings of the 12th Conference on Language Resources and Evaluation (LREC 2020)},
  pages = {4211--4215},
  year = {2020}
}

License

  • This project: research/educational use
  • SNAC: see SNAC repository license
  • Common Voice: see dataset license

About

A training pipeline to train neural encoders to be language classifiers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors