Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,10 @@ def griffinlim(
and *Signal estimation from modified short-time Fourier transform* :cite:`1172092`.

Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)`
where freq is ``n_fft // 2 + 1``.
specgram (Tensor):
A STFT spectrogram of dimension `(..., freq, frames)` where freq is ``n_fft // 2 + 1``.
If magnitude-only (non-complex dtype), then either random init or zero init.
If magnitude and phase (complex dtype), then phase used as init.
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
hop_length (int): Length of hop between STFT windows. (
Expand Down Expand Up @@ -306,10 +308,19 @@ def griffinlim(
specgram = specgram.pow(1 / power)

# initialize the phase
if rand_init:
angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
if torch.is_complex(specgram) and rand_init:
raise ValueError("Cannot choose between given phase init and random phase init. Either give complex spectrogram input or rand_init True with non-complex spectrogram.")
elif torch.is_complex(specgram) and not rand_init:
# Use current phase as init
angles = torch.angle(specgram)
angles = torch.polar(abs=torch.ones_like(angles), angle=angles) # angles should be complex dtype
specgram = torch.abs(specgram)
else:
angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
# Either random phase init or zero phase init
if rand_init:
angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
else:
angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)

# And initialize the previous iterate to 0
tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device)
Expand Down
5 changes: 3 additions & 2 deletions src/torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def forward(self, specgram: Tensor) -> Tensor:
r"""
Args:
specgram (Tensor):
A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``.
A STFT spectrogram of dimension (..., freq, frames) where freq is ``n_fft // 2 + 1``.
If magnitude-only (non-complex dtype), then either random init or zero init.
If magnitude and phase (complex dtype), then phase used as init.
Returns:
Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
Expand Down