diff --git a/src/infer.py b/src/infer.py index 1726fd8..01b0969 100644 --- a/src/infer.py +++ b/src/infer.py @@ -5,7 +5,7 @@ from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint from safetensors.torch import load_file -import IPython.display as ipd +from pathlib import Path # Import F5-TTS modules from f5_tts.model import CFM, UNetT, DiT @@ -126,7 +126,16 @@ def _setup_mel_spec(self): def _setup_vocoder(self): """Initialize vocoder.""" - self.vocos = load_vocoder(is_local=False, local_path="") + # Provide path via ENV to a local copy of the model (otherwise it will be downloaded from HF): + # https://huggingface.co/charactr/vocos-mel-24khz + model_dir = Path(os.environ.get("MODEL_DIR_VOCODER", "")) + required_files = ["config.yaml", "pytorch_model.bin"] + + if model_dir.is_dir() and all((model_dir / f).is_file() for f in required_files): + self.vocos = load_vocoder(is_local=True, local_path=model_dir) + else: + self.vocos = load_vocoder(is_local=False, local_path="") + self.vocos = self.vocos.to(self.device) def _setup_duration_predictor(self, checkpoint_path):