diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 75bd6c57eb..295e915c9b 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -275,8 +275,10 @@ def griffinlim( and *Signal estimation from modified short-time Fourier transform* :cite:`1172092`. Args: - specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)` - where freq is ``n_fft // 2 + 1``. + specgram (Tensor): + A STFT spectrogram of dimension `(..., freq, frames)` where freq is ``n_fft // 2 + 1``. + If magnitude-only (non-complex dtype), then either random init or zero init. + If magnitude and phase (complex dtype), then phase used as init. window (Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins hop_length (int): Length of hop between STFT windows. ( @@ -306,10 +308,19 @@ def griffinlim( specgram = specgram.pow(1 / power) # initialize the phase - if rand_init: - angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) + if torch.is_complex(specgram) and rand_init: + raise ValueError("Cannot choose between given phase init and random phase init. Either give complex spectrogram input or rand_init True with non-complex spectrogram.") + elif torch.is_complex(specgram) and not rand_init: + # Use current phase as init + angles = torch.angle(specgram) + angles = torch.polar(abs=torch.ones_like(angles), angle=angles) # angles should be complex dtype + specgram = torch.abs(specgram) else: - angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) + # Either random phase init or zero phase init + if rand_init: + angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) + else: + angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) # And initialize the previous iterate to 0 tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device) diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 36f7ecc1cc..c587fd3178 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -277,8 +277,9 @@ def forward(self, specgram: Tensor) -> Tensor: r""" Args: specgram (Tensor): - A magnitude-only STFT spectrogram of dimension (..., freq, frames) - where freq is ``n_fft // 2 + 1``. + A STFT spectrogram of dimension (..., freq, frames) where freq is ``n_fft // 2 + 1``. + If magnitude-only (non-complex dtype), then either random init or zero init. + If magnitude and phase (complex dtype), then phase used as init. Returns: Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.