From 2a5c241fdd82431ce76b8ec4fa254d49d002e80f Mon Sep 17 00:00:00 2001 From: Hunter Hogan Date: Thu, 14 Nov 2024 23:31:46 -0500 Subject: [PATCH] Fix tuple syntax for ConvTranspose2d kernel and stride in SCNet.py. Fix tqdm parameters in apply.py. --- scnet/SCNet.py | 7 ++++--- scnet/apply.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/scnet/SCNet.py b/scnet/SCNet.py index 7c3e687..654a399 100644 --- a/scnet/SCNet.py +++ b/scnet/SCNet.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from collections import deque from .separation import SeparationNet -import typing as tp +from typing import cast import math class Swish(nn.Module): @@ -147,7 +147,7 @@ def __init__(self, channels_in, channels_out, band_configs): # Initializing convolutional layers for each band self.convtrs = nn.ModuleList([ - nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1]) + nn.ConvTranspose2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1)) for _, config in band_configs.items() ]) @@ -351,7 +351,8 @@ def forward(self, x): x = self.separation_net(x) #decoder - for fusion_layer, su_layer in self.decoder: + for dec in self.decoder: + fusion_layer, su_layer = cast(nn.Sequential, dec) x = fusion_layer(x, save_skip.pop()) x = su_layer(x, save_lengths.pop(), save_original_lengths.pop()) diff --git a/scnet/apply.py b/scnet/apply.py index fcb9912..7103052 100644 --- a/scnet/apply.py +++ b/scnet/apply.py @@ -108,7 +108,7 @@ def apply_model(model, mix, shifts=1, split=True, segment=20, samplerate=44100, segment = int(samplerate * segment) stride = int((1 - overlap) * segment) offsets = range(0, length, stride) - scale = stride / samplerate + scale = stride // samplerate weight = th.cat([th.arange(1, segment // 2 + 1, device=device), th.arange(segment - segment // 2, 0, -1, device=device)]) assert len(weight) == segment @@ -122,7 +122,7 @@ def apply_model(model, mix, shifts=1, split=True, segment=20, samplerate=44100, futures.append((future, offset)) offset += segment if progress: - futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') + futures = tqdm.tqdm(futures, total=scale, ncols=120, unit='chunks') for future, offset in futures: chunk_out = future.result() chunk_length = chunk_out.shape[-1]