Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
from safetensors.torch import load_file
import IPython.display as ipd
from pathlib import Path

# Import F5-TTS modules
from f5_tts.model import CFM, UNetT, DiT
Expand Down Expand Up @@ -126,7 +126,16 @@ def _setup_mel_spec(self):

def _setup_vocoder(self):
"""Initialize vocoder."""
self.vocos = load_vocoder(is_local=False, local_path="")
# Provide path via ENV to a local copy of the model (otherwise it will be downloaded from HF):
# https://huggingface.co/charactr/vocos-mel-24khz
model_dir = Path(os.environ.get("MODEL_DIR_VOCODER", ""))
required_files = ["config.yaml", "pytorch_model.bin"]

if model_dir.is_dir() and all((model_dir / f).is_file() for f in required_files):
self.vocos = load_vocoder(is_local=True, local_path=model_dir)
else:
self.vocos = load_vocoder(is_local=False, local_path="")

self.vocos = self.vocos.to(self.device)

def _setup_duration_predictor(self, checkpoint_path):
Expand Down