From db6ecfbb99ff5759ebeb945fb431da7099c71fdb Mon Sep 17 00:00:00 2001 From: darryllam Date: Mon, 27 Oct 2025 11:29:12 +0900 Subject: [PATCH 1/9] Changed venv installation to use uv --- README.md | 6 ++-- egs/TEMPLATE/path.sh | 2 +- .../path.sh | 2 +- .../path.sh | 2 +- egs/bvcc/path.sh | 2 +- egs/nisqa/path.sh | 2 +- egs/pstn/path.sh | 2 +- egs/singmos/path.sh | 2 +- egs/somos/path.sh | 2 +- egs/tencent/path.sh | 2 +- egs/tmhint-qi/path.sh | 2 +- egs/urgent2024-mos/path.sh | 2 +- hubconf.py | 7 +++- pyproject.toml | 33 +++++++++++++++++-- .../sheet}/bin/construct_datastore.py | 0 {sheet => src/sheet}/bin/inference.py | 0 .../sheet}/bin/nonparametric_inference.py | 0 {sheet => src/sheet}/bin/train.py | 0 {sheet => src/sheet}/bin/train_stack.py | 0 {sheet => src/sheet}/collaters/__init__.py | 0 .../sheet}/collaters/non_intrusive.py | 0 {sheet => src/sheet}/datasets/__init__.py | 0 .../sheet}/datasets/non_intrusive.py | 0 {sheet => src/sheet}/evaluation/metrics.py | 0 {sheet => src/sheet}/evaluation/plot.py | 0 {sheet => src/sheet}/losses/__init__.py | 0 {sheet => src/sheet}/losses/basic_losses.py | 0 {sheet => src/sheet}/losses/nll_losses.py | 0 {sheet => src/sheet}/models/__init__.py | 0 {sheet => src/sheet}/models/alignnet.py | 0 {sheet => src/sheet}/models/ldnet.py | 0 {sheet => src/sheet}/models/sslmos.py | 0 {sheet => src/sheet}/models/sslmos_u.py | 0 {sheet => src/sheet}/models/utmos.py | 1 + {sheet => src/sheet}/modules/__init__.py | 0 .../sheet}/modules/ldnet/__init__.py | 0 .../sheet}/modules/ldnet/mobilenetv2.py | 0 .../sheet}/modules/ldnet/mobilenetv3.py | 0 {sheet => src/sheet}/modules/ldnet/modules.py | 0 {sheet => src/sheet}/modules/utils.py | 0 .../sheet}/nonparametric/__init__.py | 0 .../sheet}/nonparametric/datastore.py | 0 {sheet => src/sheet}/schedulers/__init__.py | 0 {sheet => src/sheet}/schedulers/schedulers.py | 0 {sheet => src/sheet}/trainers/__init__.py | 0 {sheet => src/sheet}/trainers/base.py | 0 .../sheet}/trainers/non_intrusive.py | 0 {sheet => src/sheet}/utils/__init__.py | 0 {sheet => src/sheet}/utils/download.py | 0 {sheet => src/sheet}/utils/model_io.py | 0 {sheet => src/sheet}/utils/types.py | 0 {sheet => src/sheet}/utils/utils.py | 0 {sheet => src/sheet}/warmup_lr.py | 0 53 files changed, 53 insertions(+), 16 deletions(-) rename {sheet => src/sheet}/bin/construct_datastore.py (100%) rename {sheet => src/sheet}/bin/inference.py (100%) rename {sheet => src/sheet}/bin/nonparametric_inference.py (100%) rename {sheet => src/sheet}/bin/train.py (100%) rename {sheet => src/sheet}/bin/train_stack.py (100%) rename {sheet => src/sheet}/collaters/__init__.py (100%) rename {sheet => src/sheet}/collaters/non_intrusive.py (100%) rename {sheet => src/sheet}/datasets/__init__.py (100%) rename {sheet => src/sheet}/datasets/non_intrusive.py (100%) rename {sheet => src/sheet}/evaluation/metrics.py (100%) rename {sheet => src/sheet}/evaluation/plot.py (100%) rename {sheet => src/sheet}/losses/__init__.py (100%) rename {sheet => src/sheet}/losses/basic_losses.py (100%) rename {sheet => src/sheet}/losses/nll_losses.py (100%) rename {sheet => src/sheet}/models/__init__.py (100%) rename {sheet => src/sheet}/models/alignnet.py (100%) rename {sheet => src/sheet}/models/ldnet.py (100%) rename {sheet => src/sheet}/models/sslmos.py (100%) rename {sheet => src/sheet}/models/sslmos_u.py (100%) rename {sheet => src/sheet}/models/utmos.py (99%) rename {sheet => src/sheet}/modules/__init__.py (100%) rename {sheet => src/sheet}/modules/ldnet/__init__.py (100%) rename {sheet => src/sheet}/modules/ldnet/mobilenetv2.py (100%) rename {sheet => src/sheet}/modules/ldnet/mobilenetv3.py (100%) rename {sheet => src/sheet}/modules/ldnet/modules.py (100%) rename {sheet => src/sheet}/modules/utils.py (100%) rename {sheet => src/sheet}/nonparametric/__init__.py (100%) rename {sheet => src/sheet}/nonparametric/datastore.py (100%) rename {sheet => src/sheet}/schedulers/__init__.py (100%) rename {sheet => src/sheet}/schedulers/schedulers.py (100%) rename {sheet => src/sheet}/trainers/__init__.py (100%) rename {sheet => src/sheet}/trainers/base.py (100%) rename {sheet => src/sheet}/trainers/non_intrusive.py (100%) rename {sheet => src/sheet}/utils/__init__.py (100%) rename {sheet => src/sheet}/utils/download.py (100%) rename {sheet => src/sheet}/utils/model_io.py (100%) rename {sheet => src/sheet}/utils/types.py (100%) rename {sheet => src/sheet}/utils/utils.py (100%) rename {sheet => src/sheet}/warmup_lr.py (100%) diff --git a/README.md b/README.md index 18cd502..b1801db 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,10 @@ You don't need to prepare an environment (using conda, etc.) first. The followin ```bash git clone https://github.com/unilight/sheet.git -cd sheet/tools -make +cd sheet +uv venv tools/venv +source tools/venv/bin/activate +uv sync --active ``` ## Information diff --git a/egs/TEMPLATE/path.sh b/egs/TEMPLATE/path.sh index 9ddc626..4069294 100755 --- a/egs/TEMPLATE/path.sh +++ b/egs/TEMPLATE/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh b/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh index 9ddc626..4069294 100755 --- a/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh +++ b/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh b/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh index 9ddc626..4069294 100755 --- a/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh +++ b/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/bvcc/path.sh b/egs/bvcc/path.sh index 9ddc626..4069294 100755 --- a/egs/bvcc/path.sh +++ b/egs/bvcc/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/nisqa/path.sh b/egs/nisqa/path.sh index 9ddc626..4069294 100755 --- a/egs/nisqa/path.sh +++ b/egs/nisqa/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/pstn/path.sh b/egs/pstn/path.sh index 9ddc626..4069294 100755 --- a/egs/pstn/path.sh +++ b/egs/pstn/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/singmos/path.sh b/egs/singmos/path.sh index 9ddc626..4069294 100755 --- a/egs/singmos/path.sh +++ b/egs/singmos/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/somos/path.sh b/egs/somos/path.sh index 9ddc626..4069294 100755 --- a/egs/somos/path.sh +++ b/egs/somos/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/tencent/path.sh b/egs/tencent/path.sh index 9ddc626..4069294 100755 --- a/egs/tencent/path.sh +++ b/egs/tencent/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/tmhint-qi/path.sh b/egs/tmhint-qi/path.sh index 9ddc626..4069294 100755 --- a/egs/tmhint-qi/path.sh +++ b/egs/tmhint-qi/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/urgent2024-mos/path.sh b/egs/urgent2024-mos/path.sh index 9ddc626..4069294 100755 --- a/egs/urgent2024-mos/path.sh +++ b/egs/urgent2024-mos/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/hubconf.py b/hubconf.py index c02a9b7..48507b6 100644 --- a/hubconf.py +++ b/hubconf.py @@ -6,7 +6,7 @@ """torch.hub configuration.""" dependencies = ["yaml", "torch", "torchaudio", "sheet", "huggingface_hub"] - +import sys import os import torch import torch.nn.functional as F @@ -14,6 +14,11 @@ import yaml from huggingface_hub import hf_hub_download +repo_root = os.path.dirname(os.path.abspath(__file__)) +src_path = os.path.join(repo_root, "src") +if src_path not in sys.path: + sys.path.insert(0, 0, src_path) + FS = 16000 resamplers = {} MIN_REQUIRED_WAV_LENGTH = 1040 diff --git a/pyproject.toml b/pyproject.toml index fed528d..cfafc4e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,32 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +requires = ["uv_build"] +build-backend = "uv_build" + +[project] +name = "sheet" +version = "0.1.0" +description = "Speech Human Evaluation Estimation Toolkit (SHEET)" +requires-python = "==3.10.13" + +authors = [ + { name = "Your Name", email = "your.email@example.com" } +] + +dependencies = [ + "torch==2.0.1", + "torchaudio==2.0.2", + "numpy==1.26.4", + "tqdm>=4.67.1", + "h5py>=3.15.1", + "pyyaml>=6.0.3", + "transformers>=4.57.1", + "scipy>=1.15.3", + "soundfile>=0.13.1", + "soxr>=1.0.0", + "wheel>=0.45.1", + "prettytable>=3.16.0", + "matplotlib>=3.10.7", + "s3prl>=0.4.18", + "humanfriendly>=10.0", + "tensorboardx>=2.6.4", +] diff --git a/sheet/bin/construct_datastore.py b/src/sheet/bin/construct_datastore.py similarity index 100% rename from sheet/bin/construct_datastore.py rename to src/sheet/bin/construct_datastore.py diff --git a/sheet/bin/inference.py b/src/sheet/bin/inference.py similarity index 100% rename from sheet/bin/inference.py rename to src/sheet/bin/inference.py diff --git a/sheet/bin/nonparametric_inference.py b/src/sheet/bin/nonparametric_inference.py similarity index 100% rename from sheet/bin/nonparametric_inference.py rename to src/sheet/bin/nonparametric_inference.py diff --git a/sheet/bin/train.py b/src/sheet/bin/train.py similarity index 100% rename from sheet/bin/train.py rename to src/sheet/bin/train.py diff --git a/sheet/bin/train_stack.py b/src/sheet/bin/train_stack.py similarity index 100% rename from sheet/bin/train_stack.py rename to src/sheet/bin/train_stack.py diff --git a/sheet/collaters/__init__.py b/src/sheet/collaters/__init__.py similarity index 100% rename from sheet/collaters/__init__.py rename to src/sheet/collaters/__init__.py diff --git a/sheet/collaters/non_intrusive.py b/src/sheet/collaters/non_intrusive.py similarity index 100% rename from sheet/collaters/non_intrusive.py rename to src/sheet/collaters/non_intrusive.py diff --git a/sheet/datasets/__init__.py b/src/sheet/datasets/__init__.py similarity index 100% rename from sheet/datasets/__init__.py rename to src/sheet/datasets/__init__.py diff --git a/sheet/datasets/non_intrusive.py b/src/sheet/datasets/non_intrusive.py similarity index 100% rename from sheet/datasets/non_intrusive.py rename to src/sheet/datasets/non_intrusive.py diff --git a/sheet/evaluation/metrics.py b/src/sheet/evaluation/metrics.py similarity index 100% rename from sheet/evaluation/metrics.py rename to src/sheet/evaluation/metrics.py diff --git a/sheet/evaluation/plot.py b/src/sheet/evaluation/plot.py similarity index 100% rename from sheet/evaluation/plot.py rename to src/sheet/evaluation/plot.py diff --git a/sheet/losses/__init__.py b/src/sheet/losses/__init__.py similarity index 100% rename from sheet/losses/__init__.py rename to src/sheet/losses/__init__.py diff --git a/sheet/losses/basic_losses.py b/src/sheet/losses/basic_losses.py similarity index 100% rename from sheet/losses/basic_losses.py rename to src/sheet/losses/basic_losses.py diff --git a/sheet/losses/nll_losses.py b/src/sheet/losses/nll_losses.py similarity index 100% rename from sheet/losses/nll_losses.py rename to src/sheet/losses/nll_losses.py diff --git a/sheet/models/__init__.py b/src/sheet/models/__init__.py similarity index 100% rename from sheet/models/__init__.py rename to src/sheet/models/__init__.py diff --git a/sheet/models/alignnet.py b/src/sheet/models/alignnet.py similarity index 100% rename from sheet/models/alignnet.py rename to src/sheet/models/alignnet.py diff --git a/sheet/models/ldnet.py b/src/sheet/models/ldnet.py similarity index 100% rename from sheet/models/ldnet.py rename to src/sheet/models/ldnet.py diff --git a/sheet/models/sslmos.py b/src/sheet/models/sslmos.py similarity index 100% rename from sheet/models/sslmos.py rename to src/sheet/models/sslmos.py diff --git a/sheet/models/sslmos_u.py b/src/sheet/models/sslmos_u.py similarity index 100% rename from sheet/models/sslmos_u.py rename to src/sheet/models/sslmos_u.py diff --git a/sheet/models/utmos.py b/src/sheet/models/utmos.py similarity index 99% rename from sheet/models/utmos.py rename to src/sheet/models/utmos.py index b0facb3..44f7672 100644 --- a/sheet/models/utmos.py +++ b/src/sheet/models/utmos.py @@ -43,6 +43,7 @@ def __init__( decoder_activation: str = "ReLU", output_type: str = "scalar", range_clipping: bool = True, + num_domains: int = None, ): super().__init__() # this is needed! or else there will be an error. self.use_mean_listener = use_mean_listener diff --git a/sheet/modules/__init__.py b/src/sheet/modules/__init__.py similarity index 100% rename from sheet/modules/__init__.py rename to src/sheet/modules/__init__.py diff --git a/sheet/modules/ldnet/__init__.py b/src/sheet/modules/ldnet/__init__.py similarity index 100% rename from sheet/modules/ldnet/__init__.py rename to src/sheet/modules/ldnet/__init__.py diff --git a/sheet/modules/ldnet/mobilenetv2.py b/src/sheet/modules/ldnet/mobilenetv2.py similarity index 100% rename from sheet/modules/ldnet/mobilenetv2.py rename to src/sheet/modules/ldnet/mobilenetv2.py diff --git a/sheet/modules/ldnet/mobilenetv3.py b/src/sheet/modules/ldnet/mobilenetv3.py similarity index 100% rename from sheet/modules/ldnet/mobilenetv3.py rename to src/sheet/modules/ldnet/mobilenetv3.py diff --git a/sheet/modules/ldnet/modules.py b/src/sheet/modules/ldnet/modules.py similarity index 100% rename from sheet/modules/ldnet/modules.py rename to src/sheet/modules/ldnet/modules.py diff --git a/sheet/modules/utils.py b/src/sheet/modules/utils.py similarity index 100% rename from sheet/modules/utils.py rename to src/sheet/modules/utils.py diff --git a/sheet/nonparametric/__init__.py b/src/sheet/nonparametric/__init__.py similarity index 100% rename from sheet/nonparametric/__init__.py rename to src/sheet/nonparametric/__init__.py diff --git a/sheet/nonparametric/datastore.py b/src/sheet/nonparametric/datastore.py similarity index 100% rename from sheet/nonparametric/datastore.py rename to src/sheet/nonparametric/datastore.py diff --git a/sheet/schedulers/__init__.py b/src/sheet/schedulers/__init__.py similarity index 100% rename from sheet/schedulers/__init__.py rename to src/sheet/schedulers/__init__.py diff --git a/sheet/schedulers/schedulers.py b/src/sheet/schedulers/schedulers.py similarity index 100% rename from sheet/schedulers/schedulers.py rename to src/sheet/schedulers/schedulers.py diff --git a/sheet/trainers/__init__.py b/src/sheet/trainers/__init__.py similarity index 100% rename from sheet/trainers/__init__.py rename to src/sheet/trainers/__init__.py diff --git a/sheet/trainers/base.py b/src/sheet/trainers/base.py similarity index 100% rename from sheet/trainers/base.py rename to src/sheet/trainers/base.py diff --git a/sheet/trainers/non_intrusive.py b/src/sheet/trainers/non_intrusive.py similarity index 100% rename from sheet/trainers/non_intrusive.py rename to src/sheet/trainers/non_intrusive.py diff --git a/sheet/utils/__init__.py b/src/sheet/utils/__init__.py similarity index 100% rename from sheet/utils/__init__.py rename to src/sheet/utils/__init__.py diff --git a/sheet/utils/download.py b/src/sheet/utils/download.py similarity index 100% rename from sheet/utils/download.py rename to src/sheet/utils/download.py diff --git a/sheet/utils/model_io.py b/src/sheet/utils/model_io.py similarity index 100% rename from sheet/utils/model_io.py rename to src/sheet/utils/model_io.py diff --git a/sheet/utils/types.py b/src/sheet/utils/types.py similarity index 100% rename from sheet/utils/types.py rename to src/sheet/utils/types.py diff --git a/sheet/utils/utils.py b/src/sheet/utils/utils.py similarity index 100% rename from sheet/utils/utils.py rename to src/sheet/utils/utils.py diff --git a/sheet/warmup_lr.py b/src/sheet/warmup_lr.py similarity index 100% rename from sheet/warmup_lr.py rename to src/sheet/warmup_lr.py From a4f3b89130fcd84abd15c7d185baa8e2ee89abe6 Mon Sep 17 00:00:00 2001 From: darryllam Date: Wed, 29 Oct 2025 11:18:01 +0900 Subject: [PATCH 2/9] FiMerged stray sheet folder in root to src --- {sheet => src/sheet}/__init__.py | 0 {sheet => src/sheet}/losses/contrastive_loss.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {sheet => src/sheet}/__init__.py (100%) rename {sheet => src/sheet}/losses/contrastive_loss.py (100%) diff --git a/sheet/__init__.py b/src/sheet/__init__.py similarity index 100% rename from sheet/__init__.py rename to src/sheet/__init__.py diff --git a/sheet/losses/contrastive_loss.py b/src/sheet/losses/contrastive_loss.py similarity index 100% rename from sheet/losses/contrastive_loss.py rename to src/sheet/losses/contrastive_loss.py From 5b8faa9a34ba4075dc845c7475bc3fe4dae61c2b Mon Sep 17 00:00:00 2001 From: darryllam Date: Wed, 5 Nov 2025 10:30:46 +0900 Subject: [PATCH 3/9] Changed venv installation and path.sh to use '.venv'. Changed some depdendencies to optional in pyproject.toml and added metadata --- README.md | 10 ++++------ egs/TEMPLATE/path.sh | 4 ++-- .../path.sh | 4 ++-- .../path.sh | 4 ++-- egs/bvcc/path.sh | 4 ++-- egs/nisqa/path.sh | 4 ++-- egs/pstn/path.sh | 4 ++-- egs/singmos/path.sh | 4 ++-- egs/somos/path.sh | 4 ++-- egs/tencent/path.sh | 4 ++-- egs/tmhint-qi/path.sh | 4 ++-- egs/urgent2024-mos/path.sh | 4 ++-- pyproject.toml | 18 ++++++++++++------ 13 files changed, 38 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index b1801db..c206e24 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ You can use the `_id` argument to specify which pre-trained model to use. If not > Since SHEET is a on-going project, if you use our pre-trained model in you paper, it is suggested to specify the version. For instance: `SHEET SSL-MOS v0.1.0`, `SHEET SSL-MOS v0.2.5`, etc. > [!TIP] -> You don't need to install sheet following the [installation instructions](#instsallation). However, you might need to install the following: +> You don't need to install sheet following the [installation instructions](#installation). However, you might need to install the following: > ``` > sheet-sqa > huggingface_hub @@ -63,20 +63,18 @@ You can use the `_id` argument to specify which pre-trained model to use. If not 1.5806346 ``` -## Instsallation +## installation Full installation is needed if your goal is to do **training**. ### Editable installation with virtualenv -You don't need to prepare an environment (using conda, etc.) first. The following commands will automatically construct a virtual environment in `tools/`. When you run the recipes, the scripts will automatically activate the virtual environment. +First install the uv package manager [here](https://docs.astral.sh/uv/getting-started/installation/). Then, use the following commands to automatically construct a virtual environment in `.venv`. When you run the recipes, the scripts will automatically activate the virtual environment. ```bash git clone https://github.com/unilight/sheet.git cd sheet -uv venv tools/venv -source tools/venv/bin/activate -uv sync --active +uv sync --extras train ``` ## Information diff --git a/egs/TEMPLATE/path.sh b/egs/TEMPLATE/path.sh index 4069294..cca7225 100755 --- a/egs/TEMPLATE/path.sh +++ b/egs/TEMPLATE/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh b/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh index 4069294..cca7225 100755 --- a/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh +++ b/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh b/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh index 4069294..cca7225 100755 --- a/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh +++ b/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/bvcc/path.sh b/egs/bvcc/path.sh index 4069294..cca7225 100755 --- a/egs/bvcc/path.sh +++ b/egs/bvcc/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/nisqa/path.sh b/egs/nisqa/path.sh index 4069294..cca7225 100755 --- a/egs/nisqa/path.sh +++ b/egs/nisqa/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/pstn/path.sh b/egs/pstn/path.sh index 4069294..cca7225 100755 --- a/egs/pstn/path.sh +++ b/egs/pstn/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/singmos/path.sh b/egs/singmos/path.sh index 4069294..cca7225 100755 --- a/egs/singmos/path.sh +++ b/egs/singmos/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/somos/path.sh b/egs/somos/path.sh index 4069294..cca7225 100755 --- a/egs/somos/path.sh +++ b/egs/somos/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/tencent/path.sh b/egs/tencent/path.sh index 4069294..cca7225 100755 --- a/egs/tencent/path.sh +++ b/egs/tencent/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/tmhint-qi/path.sh b/egs/tmhint-qi/path.sh index 4069294..cca7225 100755 --- a/egs/tmhint-qi/path.sh +++ b/egs/tmhint-qi/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/egs/urgent2024-mos/path.sh b/egs/urgent2024-mos/path.sh index 4069294..cca7225 100755 --- a/egs/urgent2024-mos/path.sh +++ b/egs/urgent2024-mos/path.sh @@ -1,8 +1,8 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/tools/venv/bin/activate" + . "${PRJ_ROOT}/.venv/bin/activate" fi MAIN_ROOT=$PWD/../.. diff --git a/pyproject.toml b/pyproject.toml index cfafc4e..ced6198 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,14 +9,13 @@ description = "Speech Human Evaluation Estimation Toolkit (SHEET)" requires-python = "==3.10.13" authors = [ - { name = "Your Name", email = "your.email@example.com" } + { name = "Wen-Chin Huang", email = "wen.chinhuang@g.sp.m.is.nagoya-u.ac.jp" } ] dependencies = [ "torch==2.0.1", "torchaudio==2.0.2", "numpy==1.26.4", - "tqdm>=4.67.1", "h5py>=3.15.1", "pyyaml>=6.0.3", "transformers>=4.57.1", @@ -24,9 +23,16 @@ dependencies = [ "soundfile>=0.13.1", "soxr>=1.0.0", "wheel>=0.45.1", - "prettytable>=3.16.0", - "matplotlib>=3.10.7", "s3prl>=0.4.18", - "humanfriendly>=10.0", - "tensorboardx>=2.6.4", ] + +[project.optional-dependencies] +train = [ + "matplotlib>=3.10.7", + "tqdm>=4.67.1", + "gdown", + "tensorboardx>=2.6.4", + "kaldiio>=2.14.1", + "humanfriendly>=10.0", + "prettytable>=3.16.0", +] \ No newline at end of file From 9c116e22e769b828f11d264b5498d23dd9d81501 Mon Sep 17 00:00:00 2001 From: darryllam Date: Thu, 6 Nov 2025 18:32:21 +0900 Subject: [PATCH 4/9] Moved back to flat structure. venv relocated back to /tools/venv --- .gitignore | 4 +- README.md | 2 +- egs/TEMPLATE/path.sh | 6 +- .../path.sh | 6 +- .../path.sh | 6 +- egs/bvcc/path.sh | 6 +- egs/nisqa/path.sh | 6 +- egs/pstn/path.sh | 6 +- egs/singmos/path.sh | 6 +- egs/somos/path.sh | 6 +- egs/tencent/path.sh | 6 +- egs/tmhint-qi/path.sh | 6 +- egs/urgent2024-mos/path.sh | 6 +- sheet/__init__.py | 3 + sheet/bin/construct_datastore.py | 176 +++++++ sheet/bin/inference.py | 434 ++++++++++++++++ sheet/bin/nonparametric_inference.py | 400 +++++++++++++++ sheet/bin/train.py | 396 +++++++++++++++ sheet/bin/train_stack.py | 188 +++++++ sheet/collaters/__init__.py | 1 + sheet/collaters/non_intrusive.py | 108 ++++ sheet/datasets/__init__.py | 1 + sheet/datasets/non_intrusive.py | 340 +++++++++++++ sheet/evaluation/metrics.py | 34 ++ sheet/evaluation/plot.py | 108 ++++ sheet/losses/__init__.py | 3 + sheet/losses/basic_losses.py | 91 ++++ sheet/losses/contrastive_loss.py | 39 ++ sheet/losses/nll_losses.py | 109 ++++ sheet/models/__init__.py | 9 + sheet/models/alignnet.py | 400 +++++++++++++++ sheet/models/ldnet.py | 288 +++++++++++ sheet/models/sslmos.py | 467 ++++++++++++++++++ sheet/models/sslmos_u.py | 256 ++++++++++ sheet/models/utmos.py | 299 +++++++++++ sheet/modules/__init__.py | 0 sheet/modules/ldnet/__init__.py | 0 sheet/modules/ldnet/mobilenetv2.py | 240 +++++++++ sheet/modules/ldnet/mobilenetv3.py | 341 +++++++++++++ sheet/modules/ldnet/modules.py | 181 +++++++ sheet/modules/utils.py | 222 +++++++++ sheet/nonparametric/__init__.py | 0 sheet/nonparametric/datastore.py | 77 +++ sheet/schedulers/__init__.py | 1 + sheet/schedulers/schedulers.py | 21 + sheet/trainers/__init__.py | 2 + sheet/trainers/base.py | 315 ++++++++++++ sheet/trainers/non_intrusive.py | 310 ++++++++++++ sheet/utils/__init__.py | 1 + sheet/utils/download.py | 213 ++++++++ sheet/utils/model_io.py | 166 +++++++ sheet/utils/types.py | 139 ++++++ sheet/utils/utils.py | 164 ++++++ sheet/warmup_lr.py | 62 +++ 54 files changed, 6641 insertions(+), 36 deletions(-) create mode 100644 sheet/__init__.py create mode 100755 sheet/bin/construct_datastore.py create mode 100755 sheet/bin/inference.py create mode 100755 sheet/bin/nonparametric_inference.py create mode 100755 sheet/bin/train.py create mode 100755 sheet/bin/train_stack.py create mode 100644 sheet/collaters/__init__.py create mode 100644 sheet/collaters/non_intrusive.py create mode 100644 sheet/datasets/__init__.py create mode 100644 sheet/datasets/non_intrusive.py create mode 100644 sheet/evaluation/metrics.py create mode 100644 sheet/evaluation/plot.py create mode 100644 sheet/losses/__init__.py create mode 100644 sheet/losses/basic_losses.py create mode 100644 sheet/losses/contrastive_loss.py create mode 100644 sheet/losses/nll_losses.py create mode 100644 sheet/models/__init__.py create mode 100644 sheet/models/alignnet.py create mode 100644 sheet/models/ldnet.py create mode 100644 sheet/models/sslmos.py create mode 100644 sheet/models/sslmos_u.py create mode 100644 sheet/models/utmos.py create mode 100644 sheet/modules/__init__.py create mode 100644 sheet/modules/ldnet/__init__.py create mode 100644 sheet/modules/ldnet/mobilenetv2.py create mode 100644 sheet/modules/ldnet/mobilenetv3.py create mode 100644 sheet/modules/ldnet/modules.py create mode 100644 sheet/modules/utils.py create mode 100644 sheet/nonparametric/__init__.py create mode 100644 sheet/nonparametric/datastore.py create mode 100644 sheet/schedulers/__init__.py create mode 100644 sheet/schedulers/schedulers.py create mode 100644 sheet/trainers/__init__.py create mode 100644 sheet/trainers/base.py create mode 100644 sheet/trainers/non_intrusive.py create mode 100644 sheet/utils/__init__.py create mode 100644 sheet/utils/download.py create mode 100644 sheet/utils/model_io.py create mode 100644 sheet/utils/types.py create mode 100644 sheet/utils/utils.py create mode 100644 sheet/warmup_lr.py diff --git a/.gitignore b/.gitignore index ac2496a..89f6de2 100644 --- a/.gitignore +++ b/.gitignore @@ -103,7 +103,7 @@ celerybeat.pid # Environments .env -.venv +tools/venv env/ venv/ ENV/ @@ -127,7 +127,7 @@ dmypy.json # Pyre type checker .pyre/ - +*.lock exp/ downloads/ data/ diff --git a/README.md b/README.md index c206e24..0400de9 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ Full installation is needed if your goal is to do **training**. ### Editable installation with virtualenv -First install the uv package manager [here](https://docs.astral.sh/uv/getting-started/installation/). Then, use the following commands to automatically construct a virtual environment in `.venv`. When you run the recipes, the scripts will automatically activate the virtual environment. +First install the uv package manager [here](https://docs.astral.sh/uv/getting-started/installation/). Then, use the following commands to automatically construct a virtual environment in `tools/venv`. When you run the recipes, the scripts will automatically activate the virtual environment. ```bash git clone https://github.com/unilight/sheet.git diff --git a/egs/TEMPLATE/path.sh b/egs/TEMPLATE/path.sh index cca7225..9ddc626 100755 --- a/egs/TEMPLATE/path.sh +++ b/egs/TEMPLATE/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh b/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh index cca7225..9ddc626 100755 --- a/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh +++ b/egs/bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh b/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh index cca7225..9ddc626 100755 --- a/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh +++ b/egs/bvcc+somos+singmos+nisqa+tmhint-qi+tencent+pstn+urgent2024-mos/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/bvcc/path.sh b/egs/bvcc/path.sh index cca7225..9ddc626 100755 --- a/egs/bvcc/path.sh +++ b/egs/bvcc/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/nisqa/path.sh b/egs/nisqa/path.sh index cca7225..9ddc626 100755 --- a/egs/nisqa/path.sh +++ b/egs/nisqa/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/pstn/path.sh b/egs/pstn/path.sh index cca7225..9ddc626 100755 --- a/egs/pstn/path.sh +++ b/egs/pstn/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/singmos/path.sh b/egs/singmos/path.sh index cca7225..9ddc626 100755 --- a/egs/singmos/path.sh +++ b/egs/singmos/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/somos/path.sh b/egs/somos/path.sh index cca7225..9ddc626 100755 --- a/egs/somos/path.sh +++ b/egs/somos/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/tencent/path.sh b/egs/tencent/path.sh index cca7225..9ddc626 100755 --- a/egs/tencent/path.sh +++ b/egs/tencent/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/tmhint-qi/path.sh b/egs/tmhint-qi/path.sh index cca7225..9ddc626 100755 --- a/egs/tmhint-qi/path.sh +++ b/egs/tmhint-qi/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/egs/urgent2024-mos/path.sh b/egs/urgent2024-mos/path.sh index cca7225..9ddc626 100755 --- a/egs/urgent2024-mos/path.sh +++ b/egs/urgent2024-mos/path.sh @@ -1,12 +1,12 @@ # path related export PRJ_ROOT="${PWD}/../.." -if [ -e "${PRJ_ROOT}/.venv/bin/activate" ]; then +if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then # shellcheck disable=SC1090 - . "${PRJ_ROOT}/.venv/bin/activate" + . "${PRJ_ROOT}/tools/venv/bin/activate" fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/src/sheet/bin:$PATH +export PATH=$MAIN_ROOT/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/sheet/__init__.py b/sheet/__init__.py new file mode 100644 index 0000000..20fada3 --- /dev/null +++ b/sheet/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +__version__ = "0.2.5" diff --git a/sheet/bin/construct_datastore.py b/sheet/bin/construct_datastore.py new file mode 100755 index 0000000..06a5867 --- /dev/null +++ b/sheet/bin/construct_datastore.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Construct datastore .""" + +import argparse +import logging +import os + +import h5py +import numpy as np +import sheet +import sheet.datasets +import sheet.models +import torch +import yaml +from s3prl.nn import S3PRLUpstream +from tqdm import tqdm + + +def main(): + """Construct datastore.""" + parser = argparse.ArgumentParser( + description=( + "Construct datastore with ssl_model in trained model " + "(See detail in bin/construct_datastore.py)." + ) + ) + parser.add_argument( + "--csv-path", + required=True, + type=str, + help=("csv file path to construct datastore."), + ) + parser.add_argument( + "--out", + type=str, + required=True, + help="out path to save datastore.", + ) + parser.add_argument( + "--checkpoint", + type=str, + help="checkpoint file to be loaded.", + ) + parser.add_argument( + "--config", + default=None, + type=str, + help=( + "yaml format configuration file. if not explicitly provided, " + "it will be searched in the checkpoint directory. (default=None)" + ), + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check directory existence + # if not os.path.exists(args.outdir): + # os.makedirs(args.outdir) + + # load config + if args.config is None: + dirname = os.path.dirname(args.checkpoint) + args.config = os.path.join(dirname, "config.yml") + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + + args_dict = vars(args) + + config.update(args_dict) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # get dataset + dataset_class = getattr( + sheet.datasets, + config.get("dataset_type", "NonIntrusiveDataset"), + ) + dataset = dataset_class( + csv_path=args.csv_path, + target_sample_rate=config["sampling_rate"], + model_input=config["model_input"], + use_phoneme=config.get("use_phoneme", False), + symbols=config.get("symbols", None), + wav_only=True, + allow_cache=False, + ) + logging.info(f"Number of samples = {len(dataset)}.") + + # setup device + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + # get ssl model + s3prl_name = config["model_params"]["s3prl_name"] + ssl_model = S3PRLUpstream(s3prl_name) + + # load pre-trained model + pt_ckpt = torch.load(os.readlink(args.checkpoint), map_location="cpu")["model"] + state_dict = { + k.replace("ssl_model.", ""): v + for k, v in pt_ckpt.items() + if k.startswith("ssl_model") + } + ssl_model.load_state_dict(state_dict) + logging.info(f"Loaded model parameters from {args.checkpoint}.") + ssl_model = ssl_model.eval().to(device) + + # start inference + if os.path.exists(args.out): + hdf5_file = h5py.File(args.out, "r+") + else: + hdf5_file = h5py.File(args.out, "w") + + with torch.no_grad(), tqdm(dataset, desc="[inference]") as pbar: + for batch in pbar: + # set up model input + model_input = batch[config["model_input"]].unsqueeze(0).to(device) + model_lengths = model_input.new_tensor([model_input.size(1)]).long() + + all_encoder_outputs, _ = ssl_model(model_input, model_lengths) + embed = ( + torch.mean( + all_encoder_outputs[ + config["model_params"]["ssl_model_layer_idx"] + ].squeeze(0), + dim=0, + ) + .detach() + .cpu() + .numpy() + ) + + system_id = batch["system_id"] + sample_id = batch["sample_id"] + hdf5_path = system_id + "_" + sample_id + score = batch["avg_score"] + + hdf5_file.create_dataset("embeds/" + hdf5_path, data=embed) + hdf5_file.create_dataset("scores/" + hdf5_path, data=score) + + hdf5_file.flush() + hdf5_file.close() + + +if __name__ == "__main__": + main() diff --git a/sheet/bin/inference.py b/sheet/bin/inference.py new file mode 100755 index 0000000..18a4cf3 --- /dev/null +++ b/sheet/bin/inference.py @@ -0,0 +1,434 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Inference .""" + +import argparse +import csv +import logging +import os +import pickle +import time +from collections import defaultdict + +import numpy as np +import sheet +import sheet.datasets +import sheet.models +import torch +import yaml +from prettytable import MARKDOWN, PrettyTable +from sheet.evaluation.metrics import calculate +from sheet.evaluation.plot import ( + plot_sys_level_scatter, + plot_utt_level_hist, + plot_utt_level_scatter, +) +from sheet.utils import read_csv +from sheet.utils.model_io import model_average +from sheet.utils.types import str2bool +from tqdm import tqdm + + +def main(): + """Run inference process.""" + parser = argparse.ArgumentParser( + description=( + "Inference with trained model " "(See detail in bin/inference.py)." + ) + ) + parser.add_argument( + "--csv-path", + required=True, + type=str, + help=("csv file path to do inference."), + ) + parser.add_argument( + "--outdir", + type=str, + required=True, + help="directory to save generated figures.", + ) + parser.add_argument( + "--checkpoint", + type=str, + help="checkpoint file to be loaded.", + ) + parser.add_argument( + "--config", + default=None, + type=str, + help=( + "yaml format configuration file. if not explicitly provided, " + "it will be searched in the checkpoint directory. (default=None)" + ), + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + parser.add_argument( + "--inference-mode", + type=str, + help="inference mode. if not specified, use the default setting in config", + ) + parser.add_argument( + "--model-averaging", + type=str2bool, + default="False", + help="if true, average all model checkpoints in the exp directory", + ) + parser.add_argument( + "--use-stacking", + type=str2bool, + default="False", + help="if true, use the stack model in the exp directory", + ) + parser.add_argument( + "--meta-model-checkpoint", + type=str, + help="checkpoint file of meta model.", + ) + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check directory existence + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + # load config + if args.config is None: + dirname = os.path.dirname(args.checkpoint) + args.config = os.path.join(dirname, "config.yml") + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + + args_dict = vars(args) + # do not override if inference mode not specified + if args_dict["inference_mode"] is None: + del args_dict["inference_mode"] + + # get expdir first + expdir = config["outdir"] + + config.update(args_dict) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # get dataset + dataset_class = getattr( + sheet.datasets, + config.get("dataset_type", "NonIntrusiveDataset"), + ) + dataset = dataset_class( + csv_path=args.csv_path, + target_sample_rate=config["sampling_rate"], + model_input=config["model_input"], + use_phoneme=config.get("use_phoneme", False), + symbols=config.get("symbols", None), + wav_only=True, + allow_cache=False, + ) + logging.info(f"Number of inference samples = {len(dataset)}.") + + # setup device + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + # get model + model_class = getattr(sheet.models, config["model_type"]) + model = model_class( + config["model_input"], + num_listeners=config.get("num_listeners", None), + num_domains=config.get("num_domains", None), + **config["model_params"], + ) + + # set placeholders + eval_results = defaultdict(list) + eval_sys_results = defaultdict(lambda: defaultdict(list)) + logvars = [] + + # stacking model inference + if args.use_stacking: + # load meta model + with open(args.meta_model_checkpoint, "rb") as f: + meta_model = pickle.load(f) + + # run inference on all models + checkpoint_paths = sorted( + [ + os.path.join(expdir, p) + for p in os.listdir(expdir) + if os.path.isfile(os.path.join(expdir, p)) and p.endswith("steps.pkl") + ] + ) + xs = np.empty((len(dataset), len(checkpoint_paths))) + for i, checkpoint_path in enumerate(checkpoint_paths): + # load model + model.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["model"] + ) + logging.info(f"Loaded model parameters from {checkpoint_path}.") + model = model.eval().to(device) + + # start inference + start_time = time.time() + logging.info("Running inference...") + with torch.no_grad(): + for j, batch in enumerate(dataset): + # set up model input + model_input = batch[config["model_input"]].unsqueeze(0).to(device) + model_lengths = model_input.new_tensor([model_input.size(1)]).long() + inputs = { + config["model_input"]: model_input, + config["model_input"] + "_lengths": model_lengths, + } + if "phoneme_idxs" in batch: + inputs["phoneme_idxs"] = ( + batch["phoneme_idxs"].unsqueeze(0).to(device) + ) + inputs["phoneme_lengths"] = batch["phoneme_lengths"] + if "reference_idxs" in batch: + inputs["reference_idxs"] = ( + batch["reference_idxs"].unsqueeze(0).to(device) + ) + inputs["reference_lengths"] = batch["reference_lengths"] + + # model forward + if config["inference_mode"] == "mean_listener": + outputs = model.mean_listener_inference(inputs) + elif config["inference_mode"] == "mean_net": + outputs = model.mean_net_inference(inputs) + else: + raise NotImplementedError + + # store results + pred_score = outputs["scores"].cpu().detach().numpy()[0] + xs[j][i] = pred_score + + total_inference_time = time.time() - start_time + logging.info("Total inference time = {} secs.".format(total_inference_time)) + logging.info( + "Average inference speed = {:.3f} sec / sample.".format( + total_inference_time / len(dataset) + ) + ) + + # run inference on meta model + pred_mean_scores = meta_model.predict(xs) + + # rerun dataset to get system level scores + for i, batch in enumerate(dataset): + true_mean_scores = batch["avg_score"] + eval_results["pred_mean_scores"].append(pred_mean_scores[i]) + eval_results["true_mean_scores"].append(true_mean_scores) + sys_name = batch["system_id"] + eval_sys_results["pred_mean_scores"][sys_name].append(pred_mean_scores[i]) + eval_sys_results["true_mean_scores"][sys_name].append(true_mean_scores) + + # not using stacking + else: + # load parameter, or take average + assert (args.checkpoint == "" and args.model_averaging) or ( + args.checkpoint != "" and not args.model_averaging + ) + if args.checkpoint != "": + if os.path.islink(args.checkpoint): + model.load_state_dict( + torch.load(os.readlink(args.checkpoint), map_location="cpu")[ + "model" + ] + ) + else: + model.load_state_dict( + torch.load(args.checkpoint, map_location="cpu")["model"] + ) + logging.info(f"Loaded model parameters from {args.checkpoint}.") + else: + model, checkpoint_paths = model_average(model, expdir) + logging.info(f"Loaded model parameters from: {', '.join(checkpoint_paths)}") + model = model.eval().to(device) + + # start inference + start_time = time.time() + with torch.no_grad(), tqdm(dataset, desc="[inference]") as pbar: + for batch in pbar: + # set up model input + model_input = batch[config["model_input"]].unsqueeze(0).to(device) + model_lengths = model_input.new_tensor([model_input.size(1)]).long() + inputs = { + config["model_input"]: model_input, + config["model_input"] + "_lengths": model_lengths, + } + if "phoneme_idxs" in batch: + inputs["phoneme_idxs"] = ( + torch.tensor(batch["phoneme_idxs"], dtype=torch.long) + .unsqueeze(0) + .to(device) + ) + inputs["phoneme_lengths"] = torch.tensor( + [len(batch["phoneme_idxs"])], dtype=torch.long + ) + if "reference_idxs" in batch: + inputs["reference_idxs"] = ( + torch.tensor(batch["reference_idxs"], dtype=torch.long) + .unsqueeze(0) + .to(device) + ) + inputs["reference_lengths"] = torch.tensor( + [len(batch["reference_idxs"])], dtype=torch.long + ) + if "domain_idx" in batch: + inputs["domain_idxs"] = ( + torch.tensor(batch["domain_idx"], dtype=torch.long) + .unsqueeze(0) + .to(device) + ) + + # model forward + if config["inference_mode"] == "mean_listener": + outputs = model.mean_listener_inference(inputs) + elif config["inference_mode"] == "mean_net": + outputs = model.mean_net_inference(inputs) + else: + raise NotImplementedError + + # store results + answer = outputs["scores"].cpu().detach().numpy()[0] + if "logvars" in outputs: + logvar = outputs["logvars"].cpu().detach().numpy()[0] + logvars.append(logvar) + else: + logvar = None + dataset.fill_answer(batch["sample_id"], answer, logvar) + pred_mean_scores = answer + true_mean_scores = batch["avg_score"] + eval_results["pred_mean_scores"].append(pred_mean_scores) + eval_results["true_mean_scores"].append(true_mean_scores) + sys_name = batch["system_id"] + eval_sys_results["pred_mean_scores"][sys_name].append(pred_mean_scores) + eval_sys_results["true_mean_scores"][sys_name].append(true_mean_scores) + + total_inference_time = time.time() - start_time + logging.info("Total inference time = {} secs.".format(total_inference_time)) + logging.info( + "Average inference speed = {:.3f} sec / sample.".format( + total_inference_time / len(dataset) + ) + ) + eval_results["true_mean_scores"] = np.array(eval_results["true_mean_scores"]) + eval_results["pred_mean_scores"] = np.array(eval_results["pred_mean_scores"]) + eval_sys_results["true_mean_scores"] = np.array( + [np.mean(scores) for scores in eval_sys_results["true_mean_scores"].values()] + ) + eval_sys_results["pred_mean_scores"] = np.array( + [np.mean(scores) for scores in eval_sys_results["pred_mean_scores"].values()] + ) + + # calculate metrics + results = calculate( + eval_results["true_mean_scores"], + eval_results["pred_mean_scores"], + eval_sys_results["true_mean_scores"], + eval_sys_results["pred_mean_scores"], + ) + logging.info( + f'[UTT][ MSE = {results["utt_MSE"]:.3f} | LCC = {results["utt_LCC"]:.3f} | SRCC = {results["utt_SRCC"]:.3f} ] [SYS][ MSE = {results["sys_MSE"]:.3f} | LCC = {results["sys_LCC"]:.4f} | SRCC = {results["sys_SRCC"]:.4f} ]\n' + ) + if len(logvars) != 0: + logging.info(f'Mean log variance: {np.mean(logvars):.3f}') + + table = PrettyTable() + table.set_style(MARKDOWN) + table.field_names = [ + "Utt MSE", + "Utt LCC", + "Utt SRCC", + "Utt KTAU", + "Sys MSE", + "Sys LCC", + "Sys SRCC", + "Sys KTAU", + ] + table.add_row( + [ + round(results["utt_MSE"], 3), + round(results["utt_LCC"], 3), + round(results["utt_SRCC"], 3), + round(results["utt_KTAU"], 3), + round(results["sys_MSE"], 3), + round(results["sys_LCC"], 3), + round(results["sys_SRCC"], 3), + round(results["sys_KTAU"], 3), + ] + ) + print(table) + + # check directory + dirname = args.outdir + if not os.path.exists(dirname): + os.makedirs(dirname) + + # plot + plot_utt_level_hist( + eval_results["true_mean_scores"], + eval_results["pred_mean_scores"], + os.path.join(dirname, "distribution.png"), + ) + plot_utt_level_scatter( + eval_results["true_mean_scores"], + eval_results["pred_mean_scores"], + os.path.join(dirname, "utt_scatter_plot.png"), + results["utt_LCC"], + results["utt_SRCC"], + results["utt_MSE"], + results["utt_KTAU"], + ) + plot_sys_level_scatter( + eval_sys_results["true_mean_scores"], + eval_sys_results["pred_mean_scores"], + os.path.join(dirname, "sys_scatter_plot.png"), + results["sys_LCC"], + results["sys_SRCC"], + results["sys_MSE"], + results["sys_KTAU"], + ) + + # write results + results = dataset.return_results() + results_path = os.path.join(args.outdir, "results.csv") + fieldnames = list(results[0].keys()) + with open(results_path, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for line in results: + writer.writerow(line) + + +if __name__ == "__main__": + main() diff --git a/sheet/bin/nonparametric_inference.py b/sheet/bin/nonparametric_inference.py new file mode 100755 index 0000000..8a25ae1 --- /dev/null +++ b/sheet/bin/nonparametric_inference.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Non-parametric inference .""" + +import argparse +import csv +import logging +import os +import pickle +import time +from collections import defaultdict + +import faiss +import h5py +import numpy as np +import sheet +import sheet.datasets +import sheet.models +import torch +import yaml +from prettytable import MARKDOWN, PrettyTable +from scipy.special import softmax +from sheet.evaluation.metrics import calculate +from sheet.evaluation.plot import ( + plot_sys_level_scatter, + plot_utt_level_hist, + plot_utt_level_scatter, +) +from sheet.nonparametric.datastore import Datastore +from sheet.utils.model_io import model_average +from sheet.utils.types import str2bool +from sheet.utils import write_csv +from tqdm import tqdm + + +def main(): + """Run inference process.""" + parser = argparse.ArgumentParser( + description=( + "Inference with trained model " "(See detail in bin/inference.py)." + ) + ) + parser.add_argument( + "--csv-path", + required=True, + type=str, + help=("csv file path to do inference."), + ) + parser.add_argument( + "--datastore", + required=True, + type=str, + help=("h5 file path of the datastore."), + ) + parser.add_argument( + "--outdir", + type=str, + required=True, + help="directory to save generated figures.", + ) + parser.add_argument( + "--checkpoint", + type=str, + help="checkpoint file to be loaded.", + ) + parser.add_argument( + "--config", + default=None, + type=str, + help=( + "yaml format configuration file. if not explicitly provided, " + "it will be searched in the checkpoint directory. (default=None)" + ), + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + parser.add_argument( + "--inference-mode", + type=str, + help="inference mode. if not specified, use the default setting in config", + ) + parser.add_argument( + "--k", + type=int, + default=60, + help="number of neighbors", + ) + parser.add_argument( + "--np-inference-mode", + type=str, + required=True, + choices=["naive_knn", "domain_id_knn_1", "fusion"], + help="non-parametric inference mode.", + ) + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check directory existence + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + # load config + if args.config is None: + dirname = os.path.dirname(args.checkpoint) + args.config = os.path.join(dirname, "config.yml") + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + + args_dict = vars(args) + # do not override if inference mode not specified + if args_dict["inference_mode"] is None: + del args_dict["inference_mode"] + + # get expdir first + expdir = config["outdir"] + + config.update(args_dict) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # get dataset + dataset_class = getattr( + sheet.datasets, + config.get("dataset_type", "NonIntrusiveDataset"), + ) + dataset = dataset_class( + csv_path=args.csv_path, + target_sample_rate=config["sampling_rate"], + model_input=config["model_input"], + use_phoneme=config.get("use_phoneme", False), + symbols=config.get("symbols", None), + wav_only=True, + allow_cache=False, + ) + logging.info(f"Number of inference samples = {len(dataset)}.") + + # setup device + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + # get model + model_class = getattr(sheet.models, config["model_type"]) + is_ramp = ("RAMP" in config["model_type"]) + if is_ramp: + datastore = Datastore( + args.datastore, + config["model_params"]["parametric_model_params"]["ssl_model_output_dim"], + device=device, + ) + model = model_class( + config["model_input"], + num_listeners=config.get("num_listeners", None), + num_domains=config.get("num_domains", None), + datastore=datastore, + **config["model_params"], + ) + else: + datastore = Datastore( + args.datastore, + config["model_params"]["ssl_model_output_dim"], + device=device, + ) + model = model_class( + config["model_input"], + num_listeners=config.get("num_listeners", None), + num_domains=config.get("num_domains", None), + **config["model_params"], + ) + + # set placeholders + eval_results = defaultdict(list) + eval_sys_results = defaultdict(lambda: defaultdict(list)) + retrieval_results = {} + ramp_results = [] + + # load parameter + if os.path.islink(args.checkpoint): + checkpoint_path = os.readlink(args.checkpoint) + else: + checkpoint_path = os.path.realpath(args.checkpoint) + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) + logging.info(f"Loaded model parameters from {args.checkpoint}.") + model = model.eval().to(device) + + # start inference + start_time = time.time() + with torch.no_grad(), tqdm(dataset, desc="[inference]") as pbar: + for batch in pbar: + # set up model input + model_input = batch[config["model_input"]].unsqueeze(0).to(device) + model_lengths = model_input.new_tensor([model_input.size(1)]).long() + inputs = { + config["model_input"]: model_input, + config["model_input"] + "_lengths": model_lengths, + } + if "domain_idx" in batch: + inputs["domain_idxs"] = ( + torch.tensor(batch["domain_idx"], dtype=torch.long) + .unsqueeze(0) + .to(device) + ) + + # nonparametric part + if config["np_inference_mode"] == "naive_knn": + ssl_embed = ( + torch.mean(model.get_ssl_embeddings(inputs), dim=1) + .detach() + .cpu() + .numpy() + ) + outputs = datastore.knn(ssl_embed, args.k)["final_score"] + elif config["np_inference_mode"] == "domain_id_knn_1": + # retreive domain ID + ssl_embed = ( + torch.mean(model.get_ssl_embeddings(inputs), dim=1) + .detach() + .cpu() + .numpy() + ) + retrieved_id = int(datastore.knn(ssl_embed, 1)["ids"][0][0][0]) + inputs["domain_idxs"] = ( + torch.tensor(retrieved_id, dtype=torch.long).unsqueeze(0).to(device) + ) + retrieval_results[batch["sample_id"]] = {"retrieved_id": retrieved_id} + + # parametric path + if config["inference_mode"] == "mean_listener": + outputs = model.mean_listener_inference(inputs) + elif config["inference_mode"] == "mean_net": + outputs = model.mean_net_inference(inputs) + else: + raise NotImplementedError + + outputs = outputs["scores"].cpu().detach().numpy()[0] + elif config["np_inference_mode"] == "fusion": + model_outputs = model.inference(inputs, config["np_inference_mode"]) + outputs = ( + model_outputs["scores"] + .cpu() + .detach() + .numpy()[0] + ) + ramp_results.append( + {"sample_id": batch["sample_id"]} | + {k: v.cpu().detach().numpy()[0] for k, v in model_outputs.items() if not k == "scores"} + ) + else: + raise NotImplementedError + + # store results + answer = outputs + dataset.fill_answer(batch["sample_id"], answer) + pred_mean_scores = answer + true_mean_scores = batch["avg_score"] + eval_results["pred_mean_scores"].append(pred_mean_scores) + eval_results["true_mean_scores"].append(true_mean_scores) + sys_name = batch["system_id"] + eval_sys_results["pred_mean_scores"][sys_name].append(pred_mean_scores) + eval_sys_results["true_mean_scores"][sys_name].append(true_mean_scores) + + total_inference_time = time.time() - start_time + logging.info("Total inference time = {} secs.".format(total_inference_time)) + logging.info( + "Average inference speed = {:.3f} sec / sample.".format( + total_inference_time / len(dataset) + ) + ) + + # print retrieval results + for k, v in retrieval_results.items(): + print(k, v) + + # calculate metrics + eval_results["true_mean_scores"] = np.array(eval_results["true_mean_scores"]) + eval_results["pred_mean_scores"] = np.array(eval_results["pred_mean_scores"]) + eval_sys_results["true_mean_scores"] = np.array( + [np.mean(scores) for scores in eval_sys_results["true_mean_scores"].values()] + ) + eval_sys_results["pred_mean_scores"] = np.array( + [np.mean(scores) for scores in eval_sys_results["pred_mean_scores"].values()] + ) + + # calculate metrics + results = calculate( + eval_results["true_mean_scores"], + eval_results["pred_mean_scores"], + eval_sys_results["true_mean_scores"], + eval_sys_results["pred_mean_scores"], + ) + logging.info( + f'[UTT][ MSE = {results["utt_MSE"]:.3f} | LCC = {results["utt_LCC"]:.3f} | SRCC = {results["utt_SRCC"]:.3f} ] [SYS][ MSE = {results["sys_MSE"]:.3f} | LCC = {results["sys_LCC"]:.4f} | SRCC = {results["sys_SRCC"]:.4f} ]\n' + ) + + table = PrettyTable() + table.set_style(MARKDOWN) + table.field_names = [ + "Utt MSE", + "Utt LCC", + "Utt SRCC", + "Utt KTAU", + "Sys MSE", + "Sys LCC", + "Sys SRCC", + "Sys KTAU", + ] + table.add_row( + [ + round(results["utt_MSE"], 3), + round(results["utt_LCC"], 3), + round(results["utt_SRCC"], 3), + round(results["utt_KTAU"], 3), + round(results["sys_MSE"], 3), + round(results["sys_LCC"], 3), + round(results["sys_SRCC"], 3), + round(results["sys_KTAU"], 3), + ] + ) + print(table) + + # check directory + dirname = args.outdir + if not os.path.exists(dirname): + os.makedirs(dirname) + + # plot + plot_utt_level_hist( + eval_results["true_mean_scores"], + eval_results["pred_mean_scores"], + os.path.join(dirname, "distribution.png"), + ) + plot_utt_level_scatter( + eval_results["true_mean_scores"], + eval_results["pred_mean_scores"], + os.path.join(dirname, "utt_scatter_plot.png"), + results["utt_LCC"], + results["utt_SRCC"], + results["utt_MSE"], + results["utt_KTAU"], + ) + plot_sys_level_scatter( + eval_sys_results["true_mean_scores"], + eval_sys_results["pred_mean_scores"], + os.path.join(dirname, "sys_scatter_plot.png"), + results["sys_LCC"], + results["sys_SRCC"], + results["sys_MSE"], + results["sys_KTAU"], + ) + + # get results + results = dataset.return_results() + + # insert retrieval results + for i in range(len(results)): + sample_id = results[i]["sample_id"] + for k, v in retrieval_results[sample_id].items(): + results[i][k] = v + + # write results + results_path = os.path.join(args.outdir, "results.csv") + fieldnames = list(results[0].keys()) + with open(results_path, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for line in results: + writer.writerow(line) + + # write RAMP results if model is RAMP + if config["np_inference_mode"] == "fusion": + write_csv(ramp_results, os.path.join(args.outdir, "ramp_results.csv")) + +if __name__ == "__main__": + main() diff --git a/sheet/bin/train.py b/sheet/bin/train.py new file mode 100755 index 0000000..572f83a --- /dev/null +++ b/sheet/bin/train.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Train model.""" + +import argparse +import logging +import os +import sys + +import humanfriendly +import numpy as np +import sheet +import sheet.collaters +import sheet.datasets +import sheet.losses +import sheet.models +import sheet.trainers +import torch +import yaml +from sheet.schedulers import get_scheduler +from torch.utils.data import DataLoader + +# scheduler_classes = dict(warmuplr=WarmupLR) + + +def main(): + """Run training process.""" + parser = argparse.ArgumentParser( + description=( + "Train speech human evaluation estimation model (See detail in bin/train.py)." + ) + ) + parser.add_argument( + "--train-csv-path", + required=True, + type=str, + help=("training data csv file path."), + ) + parser.add_argument( + "--dev-csv-path", + required=True, + type=str, + help=("training data csv file path."), + ) + parser.add_argument( + "--outdir", + type=str, + required=True, + help="directory to save checkpoints.", + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="yaml format configuration file.", + ) + parser.add_argument( + "--additional-config", + type=str, + default=None, + help="yaml format configuration file (additional; for second-stage pretraining).", + ) + parser.add_argument( + "--init-checkpoint", + default="", + type=str, + nargs="?", + help='checkpoint file path to initialize pretrained params. (default="")', + ) + parser.add_argument( + "--resume", + default="", + type=str, + nargs="?", + help='checkpoint file path to resume training. (default="")', + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + parser.add_argument( + "--rank", + "--local_rank", + default=0, + type=int, + help="rank for distributed training. no need to explictly specify.", + ) + parser.add_argument("--seed", default=1337, type=int) + args = parser.parse_args() + + args.distributed = False + if not torch.cuda.is_available(): + device = torch.device("cpu") + else: + device = torch.device("cuda") + # effective when using fixed size inputs + # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 + torch.backends.cudnn.benchmark = False # because we have dynamic input size + torch.cuda.set_device(args.rank) + # setup for distributed training + # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed + if "WORLD_SIZE" in os.environ: + args.world_size = int(os.environ["WORLD_SIZE"]) + args.distributed = args.world_size > 1 + if args.distributed: + torch.distributed.init_process_group(backend="nccl", init_method="env://") + + # suppress logging for distributed training + if args.rank != 0: + sys.stdout = open(os.devnull, "w") + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # Fix seed and make backends deterministic + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.deterministic = True + + # fix issue of too many opened files + # https://github.com/pytorch/pytorch/issues/11201 + torch.multiprocessing.set_sharing_strategy("file_system") + + # check directory existence + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + # load main config + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + config.update(vars(args)) + + # load additional config + if args.additional_config is not None: + with open(args.additional_config) as f: + additional_config = yaml.load(f, Loader=yaml.Loader) + config.update(additional_config) + + # get dataset + dataset_class = getattr( + sheet.datasets, + config.get("dataset_type", "NonIntrusiveDataset"), + ) + logging.info(f"Loading training set from {args.train_csv_path}.") + train_dataset = dataset_class( + csv_path=args.train_csv_path, + target_sample_rate=config["sampling_rate"], + model_input=config["model_input"], + wav_only=config.get("wav_only", False), + use_phoneme=config.get("use_phoneme", False), + symbols=config.get("symbols", None), + use_mean_listener=config["model_params"].get("use_mean_listener", None), + categorical=config.get("categorical", False), + categorical_step=config.get("categorical_step", 1.0), + allow_cache=config["allow_cache"], + load_wav_cache=config.get("load_wav_cache", False), + ) + logging.info(f"The number of training files = {len(train_dataset)}.") + logging.info(f"Loading development set from {args.dev_csv_path}.") + dev_dataset = dataset_class( + csv_path=args.dev_csv_path, + target_sample_rate=config["sampling_rate"], + model_input=config["model_input"], + wav_only=True, + use_phoneme=config.get("use_phoneme", False), + symbols=config.get("symbols", None), + allow_cache=False, + # allow_cache=config["allow_cache"], + # load_wav_cache=config.get("load_wav_cache", False), + ) + logging.info(f"The number of development files = {len(dev_dataset)}.") + dataset = { + "train": train_dataset, + "dev": dev_dataset, + } + + # update number of listeners + if hasattr(train_dataset, "num_listeners"): + config["num_listeners"] = train_dataset.num_listeners + + # update number of domains + if config.get("num_domains", None) is None: + if hasattr(train_dataset, "num_domains"): + config["num_domains"] = train_dataset.num_domains + + # get data loader + collater_class = getattr( + sheet.collaters, + config.get("collater_type", "NonIntrusiveCollater"), + ) + collater = collater_class(config["padding_mode"]) + sampler = {"train": None, "dev": None} + if args.distributed: + # setup sampler for distributed training + from torch.utils.data.distributed import DistributedSampler + + sampler["train"] = DistributedSampler( + dataset=dataset["train"], + num_replicas=args.world_size, + rank=args.rank, + shuffle=True, + ) + sampler["dev"] = DistributedSampler( + dataset=dataset["dev"], + num_replicas=args.world_size, + rank=args.rank, + shuffle=False, + ) + data_loader = { + "train": DataLoader( + dataset=dataset["train"], + shuffle=False if args.distributed else True, + collate_fn=collater, + batch_size=config["train_batch_size"], + num_workers=config["num_workers"], + sampler=sampler["train"], + pin_memory=config["pin_memory"], + ), + "dev": DataLoader( + dataset=dataset["dev"], + shuffle=False, + collate_fn=collater, + batch_size=config["test_batch_size"], + num_workers=config["num_workers"], + sampler=sampler["dev"], + pin_memory=config["pin_memory"], + ), + } + + # define models + model_class = getattr( + sheet.models, + config["model_type"], + ) + model = model_class( + config["model_input"], + num_listeners=config.get("num_listeners", None), + num_domains=config.get("num_domains", None), + **config["model_params"], + ).to(device) + + # define criterions + criterion = {} + if config["mean_score_criterions"] is not None: + criterion["mean_score_criterions"] = [ + { + "type": criterion_dict["criterion_type"], + "criterion": getattr(sheet.losses, criterion_dict["criterion_type"])( + **criterion_dict["criterion_params"] + ), + "weight": criterion_dict["criterion_weight"], + } + for criterion_dict in config["mean_score_criterions"] + ] + if config.get("categorical_head_criterions", None) is not None: + criterion["categorical_head_criterions"] = [ + { + "type": criterion_dict["criterion_type"], + "criterion": getattr(sheet.losses, criterion_dict["criterion_type"])( + **criterion_dict["criterion_params"] + ), + "weight": criterion_dict["criterion_weight"], + } + for criterion_dict in config["categorical_head_criterions"] + ] + if config.get("listener_score_criterions", None) is not None: + criterion["listener_score_criterions"] = [ + { + "type": criterion_dict["criterion_type"], + "criterion": getattr(sheet.losses, criterion_dict["criterion_type"])( + **criterion_dict["criterion_params"] + ), + "weight": criterion_dict["criterion_weight"], + } + for criterion_dict in config["listener_score_criterions"] + ] + + # define optimizers and schedulers + optimizer_class = getattr( + torch.optim, + # keep compatibility + config.get("optimizer_type", "Adam"), + ) + optimizer = optimizer_class( + model.parameters(), + **config["optimizer_params"], + ) + if config["scheduler_type"] is not None: + scheduler = get_scheduler( + optimizer, + config["scheduler_type"], + config["train_max_steps"], + config["scheduler_params"], + ) + else: + scheduler = None + + if args.distributed: + # wrap model for distributed training + try: + from apex.parallel import DistributedDataParallel + except ImportError: + raise ImportError( + "apex is not installed. please check https://github.com/NVIDIA/apex." + ) + model = DistributedDataParallel(model) + + # show settings + logging.info( + "Model parameters: {}".format(humanfriendly.format_size(model.get_num_params())) + ) + logging.info(model) + logging.info(optimizer) + logging.info(scheduler) + logging.info(criterion) + + # define trainer + trainer_class = getattr(sheet.trainers, config["trainer_type"]) + trainer = trainer_class( + steps=0, + epochs=0, + data_loader=data_loader, + sampler=sampler, + model=model, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + config=config, + device=device, + ) + + # load pretrained parameters from checkpoint + if len(args.init_checkpoint) != 0: + trainer.load_trained_modules( + args.init_checkpoint, init_mods=config["init-mods"] + ) + logging.info(f"Successfully load parameters from {args.init_checkpoint}.") + + # resume from checkpoint + if len(args.resume) != 0: + trainer.load_checkpoint(args.resume) + logging.info(f"Successfully resumed from {args.resume}.") + + # freeze modules if necessary + if config.get("freeze-mods", None) is not None: + assert type(config["freeze-mods"]) is list + trainer.freeze_modules(config["freeze-mods"]) + logging.info(f"Freeze modules with prefixes {config['freeze-mods']}.") + + # save config + config["version"] = sheet.__version__ # add version info + with open(os.path.join(args.outdir, "config.yml"), "w") as f: + yaml.dump(config, f, Dumper=yaml.Dumper) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # run training loop + # try: + # trainer.run() + # finally: + # trainer.save_checkpoint( + # os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl") + # ) + # logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") + # NOTE(unilight): I don't think we need to save again here + trainer.run() + + +if __name__ == "__main__": + main() diff --git a/sheet/bin/train_stack.py b/sheet/bin/train_stack.py new file mode 100755 index 0000000..0132b2f --- /dev/null +++ b/sheet/bin/train_stack.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Train meta-model for stacking .""" + +import argparse +import logging +import os +import pickle +import time + +import numpy as np +import sheet +import sheet.datasets +import sheet.models +import torch +import yaml + + +def main(): + """Run inference process.""" + parser = argparse.ArgumentParser( + description=( + "Inference with trained model " "(See detail in bin/inference.py)." + ) + ) + parser.add_argument( + "--csv-path", + required=True, + type=str, + help=("csv file path to train stacking model."), + ) + parser.add_argument( + "--expdir", + type=str, + required=True, + help="directory to save model.", + ) + parser.add_argument( + "--meta-model-config", + required=True, + type=str, + help=("yaml format configuration file for the meta model. "), + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # load original config + with open(os.path.join(args.expdir, "config.yml")) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # load meta model config + with open(args.meta_model_config) as f: + meta_model_config = yaml.load(f, Loader=yaml.Loader) + + config.update(meta_model_config) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # get dataset + dataset_class = getattr( + sheet.datasets, + config.get("dataset_type", "NonIntrusiveDataset"), + ) + dataset = dataset_class( + csv_path=args.csv_path, + target_sample_rate=config["sampling_rate"], + model_input=config["model_input"], + wav_only=True, + allow_cache=False, + ) + logging.info(f"Number of samples to train meta model = {len(dataset)}.") + + # setup device + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + # get model + model_class = getattr(sheet.models, config["model_type"]) + model = model_class( + config["model_input"], + num_listeners=config.get("num_listeners", None), + **config["model_params"], + ) + + # run inference on all models + checkpoint_paths = sorted( + [ + os.path.join(args.expdir, p) + for p in os.listdir(args.expdir) + if os.path.isfile(os.path.join(args.expdir, p)) and p.endswith("steps.pkl") + ] + ) + xs = np.empty((len(dataset), len(checkpoint_paths))) + for i, checkpoint_path in enumerate(checkpoint_paths): + # load model + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) + logging.info(f"Loaded model parameters from {checkpoint_path}.") + model = model.eval().to(device) + + # start inference + start_time = time.time() + logging.info("Running inference...") + with torch.no_grad(): + for j, batch in enumerate(dataset): + # set up model input + model_input = batch[config["model_input"]].unsqueeze(0).to(device) + model_input_lengths = model_input.new_tensor( + [model_input.size(1)] + ).long() + + # model forward + if config["inference_mode"] == "mean_listener": + outputs = model.mean_listener_inference( + model_input, model_input_lengths + ) + elif config["inference_mode"] == "mean_net": + outputs = model.mean_net_inference(model_input, model_input_lengths) + else: + raise NotImplementedError + + # store results + pred_score = outputs["scores"].cpu().detach().numpy()[0] + xs[j][i] = pred_score + + total_inference_time = time.time() - start_time + logging.info("Total inference time = {} secs.".format(total_inference_time)) + logging.info( + "Average inference speed = {:.3f} sec / sample.".format( + total_inference_time / len(dataset) + ) + ) + + ys = np.array([batch["avg_score"] for batch in dataset]) + + # define meta model + if config["meta_model_type"] == "Ridge": + from sklearn.linear_model import Ridge + + meta_model = Ridge(**config["meta_model_params"]) + else: + raise NotImplementedError + + # train meta model + start_time = time.time() + logging.info("Start training meta model...") + meta_model.fit(xs, ys) + total_train_time = time.time() - start_time + logging.info("Total training time = {} secs.".format(total_train_time)) + + # save + with open(os.path.join(args.expdir, "meta_model.pkl"), "wb") as f: + pickle.dump(meta_model, f) + + with open(os.path.join(args.expdir, "config.yml"), "w") as f: + yaml.dump(config, f, Dumper=yaml.Dumper) + + +if __name__ == "__main__": + main() diff --git a/sheet/collaters/__init__.py b/sheet/collaters/__init__.py new file mode 100644 index 0000000..ebfb777 --- /dev/null +++ b/sheet/collaters/__init__.py @@ -0,0 +1 @@ +from .non_intrusive import * # NOQA diff --git a/sheet/collaters/non_intrusive.py b/sheet/collaters/non_intrusive.py new file mode 100644 index 0000000..97fcfbc --- /dev/null +++ b/sheet/collaters/non_intrusive.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence + +FEAT_TYPES = ["waveform", "mag_sgram"] + + +class NonIntrusiveCollater(object): + """Customized collater for Pytorch DataLoader in the non-intrusive setting.""" + + def __init__(self, padding_mode): + """Initialize customized collater for PyTorch DataLoader.""" + self.padding_mode = padding_mode + + def __call__(self, batch): + """Convert into batch tensors.""" + + items = {} + sorted_batch = sorted(batch, key=lambda x: -x["waveform"].shape[0]) + bs = len(sorted_batch) # batch_size + all_keys = list(sorted_batch[0].keys()) + + # score & listener id + items["scores"] = torch.tensor( + [sorted_batch[i]["score"] for i in range(bs)], dtype=torch.float + ) + items["avg_scores"] = torch.tensor( + [sorted_batch[i]["avg_score"] for i in range(bs)], dtype=torch.float + ) + if "categorical_score" in all_keys: + items["categorical_scores"] = torch.tensor( + [sorted_batch[i]["categorical_score"] for i in range(bs)], + dtype=torch.float, + ) + if "categorical_avg_score" in all_keys: + items["categorical_avg_scores"] = torch.tensor( + [sorted_batch[i]["categorical_avg_score"] for i in range(bs)], + dtype=torch.float, + ) + if "listener_id" in all_keys: + items["listener_ids"] = [sorted_batch[i]["listener_id"] for i in range(bs)] + if "listener_idx" in all_keys: + items["listener_idxs"] = torch.tensor( + [sorted_batch[i]["listener_idx"] for i in range(bs)], dtype=torch.long + ) + if "domain_idx" in all_keys: + items["domain_idxs"] = torch.tensor( + [sorted_batch[i]["domain_idx"] for i in range(bs)], dtype=torch.long + ) + + # phoneme and reference + if "phoneme_idxs" in all_keys: + phonemes = [ + torch.LongTensor(sorted_batch[i]["phoneme_idxs"]) for i in range(bs) + ] + items["phoneme_lengths"] = torch.from_numpy( + np.array([phoneme.size(0) for phoneme in phonemes]) + ) + items["phoneme_idxs"] = pad_sequence(phonemes, batch_first=True) + if "reference_idxs" in all_keys: + references = [ + torch.LongTensor(sorted_batch[i]["reference_idxs"]) for i in range(bs) + ] + items["reference_lengths"] = torch.from_numpy( + np.array([reference.size(0) for reference in references]) + ) + items["reference_idxs"] = pad_sequence(references, batch_first=True) + + # ids + items["system_ids"] = [sorted_batch[i]["system_id"] for i in range(bs)] + items["sample_ids"] = [sorted_batch[i]["sample_id"] for i in range(bs)] + + # pad input features (only those in FEAT TYPES) + for feat_type in FEAT_TYPES: + if not feat_type in sorted_batch[0]: + continue + + feats = [sorted_batch[i][feat_type] for i in range(bs)] + feat_lengths = torch.from_numpy(np.array([feat.size(0) for feat in feats])) + + # padding + if self.padding_mode == "zero_padding": + feats_padded = pad_sequence(feats, batch_first=True) + elif self.padding_mode == "repetitive": + max_len = feat_lengths[0] + feats_padded = [] + for feat in feats: + this_len = feat.shape[0] + dup_times = max_len // this_len + remain = max_len - this_len * dup_times + to_dup = [feat for t in range(dup_times)] + to_dup.append(feat[:remain]) + duplicated_feat = torch.Tensor(np.concatenate(to_dup, axis=0)) + feats_padded.append(duplicated_feat) + feats_padded = torch.stack(feats_padded, dim=0) + else: + raise NotImplementedError + + items[feat_type] = feats_padded + items[feat_type + "_lengths"] = feat_lengths + + return items diff --git a/sheet/datasets/__init__.py b/sheet/datasets/__init__.py new file mode 100644 index 0000000..ebfb777 --- /dev/null +++ b/sheet/datasets/__init__.py @@ -0,0 +1 @@ +from .non_intrusive import * # NOQA diff --git a/sheet/datasets/non_intrusive.py b/sheet/datasets/non_intrusive.py new file mode 100644 index 0000000..fc4bb37 --- /dev/null +++ b/sheet/datasets/non_intrusive.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Non-intrusive dataset modules.""" + +from collections import defaultdict +import logging +from multiprocessing import Manager + +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio +from sheet.utils import read_csv +from torch.utils.data import Dataset + +MIN_REQUIRED_WAV_LENGTH = 1040 +MAX_WAV_LENGTH_SECS = 10 + + +def adjust_length(waveform, target_sample_rate): + """Adjust waveform length to a fixed length.""" + if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH: + to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2 + waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0) + if waveform.shape[0] > MAX_WAV_LENGTH_SECS * target_sample_rate: + waveform = waveform[: MAX_WAV_LENGTH_SECS * target_sample_rate] + return waveform + + +def read_waveform(wav_path, target_sample_rate): + try: + # read waveform + waveform, sample_rate = torchaudio.load( + wav_path, channels_first=False + ) # waveform: [T, 1] + # resample if needed + if sample_rate != target_sample_rate: + resampler = torchaudio.transforms.Resample( + sample_rate, target_sample_rate, dtype=waveform.dtype + ) + waveform = resampler(waveform) + # mono only + if waveform.shape[1] > 1: + waveform = torch.mean(waveform, dim=1, keepdim=True) + except Exception as e: + print(f"Failed to load or resample {wav_path}: {e}") + raise + + waveform = waveform.squeeze(-1) + + # adjust length + waveform = adjust_length(waveform, target_sample_rate) + + return waveform + + +def _read_waveform(arg_tuple): + hash_id, wav_path, target_sample_rate = arg_tuple + return hash_id, read_waveform(wav_path, target_sample_rate) + +class NonIntrusiveDataset(Dataset): + """PyTorch compatible dataset for non-intrusive SSQA.""" + + def __init__( + self, + csv_path, + target_sample_rate, + model_input="wav", + wav_only=False, + use_mean_listener=False, + use_phoneme=False, + symbols=None, + categorical=False, + categorical_step=1.0, + no_feat=False, + allow_cache=False, + load_wav_cache=False, + ): + """Initialize dataset. + + Args: + csv path (str): path to the csv file + target_sample_rate (int): resample to this seample rate if there is a mismatch. + model_input (str): defalut is wav. is this is mag_sgram, extract magnitute sgram. + wav_only (bool): whether to return only wavs. Basically this means inference mode. + use_mean_listener (bool): whether to use mean listener. (only for datasets with listener labels) + use_phoneme (bool): whether to use phoneme. (only for UTMOS training) + symbols (str): symbols for phoneme. (only for UTMOS training) + categorical (bool): whether to include categorical output. + categorical_step (float): step for the categorical output. defauly is 1.0. + no_feat (bool): Whether to skip loading features (waveforms, mag_sgrams ...) + allow_cache (bool): Whether to allow cache of the loaded files. + load_wav_cache (bool): Whether to load all waveforms first and store in cache (this might make initialization slower). + + """ + self.target_sample_rate = target_sample_rate + self.use_phoneme = use_phoneme + if self.use_phoneme: + self.symbols = symbols + self.resamplers = {} + assert csv_path != "" + self.categorical = categorical + self.categorical_step = categorical_step + self.no_feat = no_feat + + # set model input transform + self.model_input = model_input + if model_input == "mag_sgram": + self.mag_sgram_transform = torchaudio.transforms.Spectrogram( + n_fft=512, hop_length=256, win_length=512, power=1 + ) + + # read csv file + self.metadata, _ = read_csv(csv_path, dict_reader=True) + + # calculate average score for each sample and add to metadata + self.calculate_avg_score() + + if wav_only: + self.reduce_to_wav_only() + else: + # add mean listener to metadata + if use_mean_listener: + mean_listener_metadata = self.gen_mean_listener_metadata() + self.metadata = self.metadata + mean_listener_metadata + + # get num of listeners + self.num_listeners = self.get_num_listeners() + + # get num of domains if domain_idx exists + if "domain_idx" in self.metadata[0]: + self.num_domains = self.get_num_domains() + + # build hash + self.build_feat_hash() + + # set cache + self.allow_cache = allow_cache + if allow_cache: + # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 + self.manager = Manager() + # self.wav_caches = self.manager.list() + # self.wav_caches += [() for _ in range(self.num_wavs)] + self.wav_caches = [None for _ in range(self.num_wavs)] + if self.model_input == "mag_sgram": + self.mag_sgram_caches = self.manager.list() + self.mag_sgram_caches += [() for _ in range(self.num_wavs)] + if load_wav_cache: + logging.info("Loading waveform cache. This might take a while ...") + self.load_wav_cache() + + def load_wav_cache(self): + # put text into csv + # with ProcessPoolExecutor(max_workers=2) as executor: + with ThreadPoolExecutor() as executor: + arg_tuples = [(item["hash_id"], item["wav_path"], self.target_sample_rate) for item in self.metadata] + results = list( + tqdm(executor.map(_read_waveform, arg_tuples), total=len(arg_tuples)) + ) + + for item in results: + if item is None: + continue + hash_id, waveform = item + self.wav_caches[hash_id] = waveform + + def __len__(self): + """Return dataset length. + + Returns: + int: The length of dataset. + + """ + return len(self.metadata) + + def get_num_listeners(self): + """Get number of listeners by counting unique listener id""" + listener_ids = set() + for item in self.metadata: + listener_ids.add(item["listener_id"]) + return len(listener_ids) + + def get_num_domains(self): + """Get number of domains by counting unique domain idxs""" + domain_idxs = set() + for item in self.metadata: + domain_idxs.add(item["domain_idx"]) + return len(domain_idxs) + + def build_feat_hash(self): + sample_ids = {} + count = 0 + for i in range(len(self.metadata)): + item = self.metadata[i] + sample_id = item["sample_id"] + if not sample_id in sample_ids: + sample_ids[sample_id] = count + count += 1 + self.metadata[i]["hash_id"] = sample_ids[sample_id] + self.num_wavs = len(sample_ids) + + def __getitem__(self, idx): + item = self.metadata[idx] + + # handle score + item["score"] = float(item["score"]) # cast from str to int + if self.categorical: + # we assume the score always starts from 1 + item["categorical_score"] = int( + (item["score"] - 1) // self.categorical_step + ) + + if "listener_idx" in item: + item["listener_idx"] = int(item["listener_idx"]) # cast from str to int + if "domain_idx" in item: + item["domain_idx"] = int(item["domain_idx"]) # cast from str to int + hash_id = item["hash_id"] + + # process text + if self.use_phoneme: + if "phoneme" in item: + if "phoneme_idxs" not in item: + item["phoneme_idxs"] = [ + self.symbols.index(p) for p in item["phoneme"] + ] + if "reference" in item: + if "reference_idxs" not in item: + item["reference_idxs"] = [ + self.symbols.index(p) for p in item["reference"] + ] + + # fetch waveform. return cached item if exists + if not self.no_feat: + if self.allow_cache and self.wav_caches[hash_id] is not None: + item["waveform"] = self.wav_caches[hash_id] + else: + waveform = read_waveform( + item["wav_path"], self.target_sample_rate + ) + item["waveform"] = waveform + if self.allow_cache: + self.wav_caches[hash_id] = item["waveform"] + + # additional feature extraction + if not self.no_feat: + if self.model_input == "mag_sgram": + # fetch mag_sgram. return cached item if exists + if self.allow_cache and len(self.mag_sgram_caches[hash_id]) != 0: + item["mag_sgram"] = self.mag_sgram_caches[hash_id] + else: + # torchaudio requires waveform to be [..., T] + mag_sgram = self.mag_sgram_transform( + waveform.squeeze(-1) + ) # mag_sgram: [freq, T] + item["mag_sgram"] = mag_sgram.mT # [T, freq] + if self.allow_cache: + self.mag_sgram_caches[hash_id] = item["mag_sgram"] + + return item + + def calculate_avg_score(self): + sample_scores = defaultdict(list) + + # loop through metadata + for item in self.metadata: + sample_scores[item["sample_id"]].append(float(item["score"])) + + # take average + sample_avg_score = { + sample_id: np.mean(np.array(scores)) + for sample_id, scores in sample_scores.items() + } + self.sample_avg_score = sample_avg_score + + # fill back into metadata + for i, item in enumerate(self.metadata): + self.metadata[i]["avg_score"] = sample_avg_score[item["sample_id"]] + if self.categorical: + # we assume the score always starts from 1 + self.metadata[i]["categorical_avg_score"] = int( + (self.metadata[i]["avg_score"] - 1) // self.categorical_step + ) + + def gen_mean_listener_metadata(self): + mean_listener_metadata = [] + sample_ids = set() + for item in self.metadata: + sample_id = item["sample_id"] + if sample_id not in sample_ids: + new_item = {k: v for k, v in item.items()} + new_item["listener_id"] = "mean_listener" + mean_listener_metadata.append(new_item) + sample_ids.add(sample_id) + return mean_listener_metadata + + def reduce_to_wav_only(self): + new_metadata = {} # {sample_id: item} + for item in self.metadata: + sample_id = item["sample_id"] + if not sample_id in new_metadata: + new_metadata[sample_id] = { + k: v + for k, v in item.items() + if k not in ["listener_id", "listener_idx"] + } + + self.metadata = list(new_metadata.values()) + + # the following two functions are for writing results during inference + def fill_answer(self, sample_id, score, logvar=None): + for idx, item in enumerate(self.metadata): + if item["sample_id"] == sample_id: + break + self.metadata[idx]["answer"] = score + if logvar is not None: + self.metadata[idx]["logvar"] = logvar + + def return_results(self): + return [ + { + k: item[k] + for k in [ + "wav_path", + "system_id", + "sample_id", + "avg_score", + "answer", + "logvar", + ] + if k in item + } + for item in self.metadata + ] diff --git a/sheet/evaluation/metrics.py b/sheet/evaluation/metrics.py new file mode 100644 index 0000000..b2978cc --- /dev/null +++ b/sheet/evaluation/metrics.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Script to calculate metrics.""" + +import numpy as np +import scipy + + +def calculate( + true_mean_scores, predict_mean_scores, true_sys_mean_scores, predict_sys_mean_scores +): + + utt_MSE = np.mean((true_mean_scores - predict_mean_scores) ** 2) + utt_LCC = np.corrcoef(true_mean_scores, predict_mean_scores)[0][1] + utt_SRCC = scipy.stats.spearmanr(true_mean_scores, predict_mean_scores)[0] + utt_KTAU = scipy.stats.kendalltau(true_mean_scores, predict_mean_scores)[0] + sys_MSE = np.mean((true_sys_mean_scores - predict_sys_mean_scores) ** 2) + sys_LCC = np.corrcoef(true_sys_mean_scores, predict_sys_mean_scores)[0][1] + sys_SRCC = scipy.stats.spearmanr(true_sys_mean_scores, predict_sys_mean_scores)[0] + sys_KTAU = scipy.stats.kendalltau(true_sys_mean_scores, predict_sys_mean_scores)[0] + + return { + "utt_MSE": utt_MSE, + "utt_LCC": utt_LCC, + "utt_SRCC": utt_SRCC, + "utt_KTAU": utt_KTAU, + "sys_MSE": sys_MSE, + "sys_LCC": sys_LCC, + "sys_SRCC": sys_SRCC, + "sys_KTAU": sys_KTAU, + } diff --git a/sheet/evaluation/plot.py b/sheet/evaluation/plot.py new file mode 100644 index 0000000..2350d4b --- /dev/null +++ b/sheet/evaluation/plot.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Script to plot figures.""" + +import matplotlib +import numpy as np + +# Force matplotlib to not use any Xwindows backend. +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +STYLE = "seaborn-v0_8-deep" + + +def plot_utt_level_hist(true_mean_scores, predict_mean_scores, filename): + """Plot utterance-level histrogram. + + Args: + true_mean_scores: ndarray of true scores + predict_mean_scores: ndarray of predicted scores + filename: name of the saved figure + """ + plt.style.use(STYLE) + bins = np.linspace(1, 5, 40) + plt.figure(2) + plt.hist( + [true_mean_scores, predict_mean_scores], bins, label=["true_mos", "predict_mos"] + ) + plt.legend(loc="upper right") + plt.xlabel("MOS") + plt.ylabel("number") + plt.show() + plt.savefig(filename, dpi=150) + plt.close() + + +def plot_utt_level_scatter( + true_mean_scores, predict_mean_scores, filename, LCC, SRCC, MSE, KTAU +): + """Plot utterance-level scatter plot + + Args: + true_mean_scores: ndarray of true scores + predict_mean_scores: ndarray of predicted scores + filename: name of the saved figure + LCC, SRCC, MSE, KTAU: metrics to be shown on the figure + """ + M = np.max([np.max(predict_mean_scores), 5]) + plt.figure(3) + plt.scatter( + true_mean_scores, + predict_mean_scores, + s=15, + color="b", + marker="o", + edgecolors="b", + alpha=0.20, + ) + plt.xlim([0.5, M]) + plt.ylim([0.5, M]) + plt.xlabel("True MOS") + plt.ylabel("Predicted MOS") + plt.title( + "Utt level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}".format( + LCC, SRCC, MSE, KTAU + ) + ) + plt.show() + plt.savefig(filename, dpi=150) + plt.close() + + +def plot_sys_level_scatter( + true_sys_mean_scores, predict_sys_mean_scores, filename, LCC, SRCC, MSE, KTAU +): + """Plot system-level scatter plot + + Args: + true_sys_mean_scores: ndarray of true system level scores + predict_sys_mean_scores: ndarray of predicted system level scores + filename: name of the saved figure + LCC, SRCC, MSE, KTAU: metrics to be shown on the figure + """ + M = np.max([np.max(predict_sys_mean_scores), 5]) + plt.figure(3) + plt.scatter( + true_sys_mean_scores, + predict_sys_mean_scores, + s=15, + color="b", + marker="o", + edgecolors="b", + ) + plt.xlim([0.5, M]) + plt.ylim([0.5, M]) + plt.xlabel("True MOS") + plt.ylabel("Predicted MOS") + plt.title( + "Sys level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}".format( + LCC, SRCC, MSE, KTAU + ) + ) + plt.show() + plt.savefig(filename, dpi=150) + plt.close() diff --git a/sheet/losses/__init__.py b/sheet/losses/__init__.py new file mode 100644 index 0000000..250a1a1 --- /dev/null +++ b/sheet/losses/__init__.py @@ -0,0 +1,3 @@ +from .basic_losses import * # NOQA +from .contrastive_loss import * # NOQA +from .nll_losses import * # NOQA \ No newline at end of file diff --git a/sheet/losses/basic_losses.py b/sheet/losses/basic_losses.py new file mode 100644 index 0000000..c9b5589 --- /dev/null +++ b/sheet/losses/basic_losses.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Basic losses.""" + +import torch +import torch.nn as nn +from sheet.modules.utils import make_non_pad_mask + + +class ScalarLoss(nn.Module): + """ + Loss for scalar output (we use the clipped MSE loss) + """ + + def __init__(self, tau, order=2, masked_loss=False): + super(ScalarLoss, self).__init__() + self.tau = tau + self.masked_loss = masked_loss + if order == 2: + self.criterion = torch.nn.MSELoss(reduction="none") + elif order == 1: + self.criterion = torch.nn.L1Loss(reduction="none") + else: + raise NotImplementedError + + def forward_criterion(self, y_hat, label, criterion_module, masks=None): + # might investigate how to combine masked loss with categorical output + if masks is not None: + y_hat = y_hat.masked_select(masks) + label = label.masked_select(masks) + + y_hat = y_hat.squeeze(-1) + loss = criterion_module(y_hat, label) + threshold = torch.abs(y_hat - label) > self.tau + loss = torch.mean(threshold * loss) + return loss + + def forward(self, pred_score, gt_score, device, lens=None): + """ + Args: + pred_mean, pred_score: [batch, time, 1/5] + """ + # make mask + if self.masked_loss: + masks = make_non_pad_mask(lens).to(device) + else: + masks = None + + # repeat for frame level loss + time = pred_score.shape[1] + # gt_mean = gt_mean.unsqueeze(1).repeat(1, time) + gt_score = gt_score.unsqueeze(1).repeat(1, time) + + loss = self.forward_criterion(pred_score, gt_score, self.criterion, masks) + return loss + + +class CategoricalLoss(nn.Module): + def __init__(self, masked_loss=False): + super(CategoricalLoss, self).__init__() + self.masked_loss = masked_loss + self.criterion = nn.CrossEntropyLoss(reduction="none") + + def ce(self, y_hat, label, criterion, masks=None): + if masks is not None: + y_hat = y_hat.masked_select(masks) + label = label.masked_select(masks) + + # y_hat must have shape (batch_size, num_classes, ...) + y_hat = y_hat.permute(0, 2, 1) + + ce = criterion(y_hat, label) + return torch.mean(ce) + + def forward(self, pred_score, gt_score, device, lens=None): + # make mask + if self.masked_loss: + masks = make_non_pad_mask(lens).to(device) + else: + masks = None + + # repeat for frame level loss + time = pred_score.shape[1] + gt_score = gt_score.unsqueeze(1).repeat(1, time).type(torch.long) + + score_ce = self.ce(pred_score, gt_score, self.criterion, masks) + return score_ce diff --git a/sheet/losses/contrastive_loss.py b/sheet/losses/contrastive_loss.py new file mode 100644 index 0000000..d381a5f --- /dev/null +++ b/sheet/losses/contrastive_loss.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""Contrastive loss proposed in UTMOS.""" + +import torch +import torch.nn as nn + + +class ContrastiveLoss(nn.Module): + """ + Contrastive Loss + Args: + margin: non-neg value, the smaller the stricter the loss will be, default: 0.2 + + """ + + def __init__(self, margin=0.2): + super(ContrastiveLoss, self).__init__() + self.margin = margin + + def forward(self, pred_score, gt_score, lens, device): + if pred_score.dim() > 2: + pred_score = pred_score.mean(dim=1).squeeze(1) + # pred_score, gt_score: tensor, [batch_size] + + gt_diff = gt_score.unsqueeze(1) - gt_score.unsqueeze(0) + pred_diff = pred_score.unsqueeze(1) - pred_score.unsqueeze(0) + + loss = torch.maximum( + torch.zeros(gt_diff.shape).to(gt_diff.device), + torch.abs(pred_diff - gt_diff) - self.margin, + ) + loss = loss.mean().div(2) + + return loss diff --git a/sheet/losses/nll_losses.py b/sheet/losses/nll_losses.py new file mode 100644 index 0000000..b0106c2 --- /dev/null +++ b/sheet/losses/nll_losses.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""NLL losses.""" + +import torch +import torch.nn as nn +from sheet.modules.utils import make_non_pad_mask + + +class GaussianNLLLoss(nn.Module): + """ + Gaussian NLL loss (for uncertainty modeling) + """ + + def __init__(self, tau, masked_loss=False): + super(GaussianNLLLoss, self).__init__() + self.tau = tau + self.masked_loss = masked_loss + + def forward_criterion(self, y_hat, logvar, label, masks=None): + """ + loss = 0.5 * (precision * (target - mean) ** 2 + log_var) + """ + + # might investigate how to combine masked loss with categorical output + if masks is not None: + y_hat = y_hat.masked_select(masks) + logvar = logvar.masked_select(masks) + label = label.masked_select(masks) + + y_hat = y_hat.squeeze(-1) + logvar = logvar.squeeze(-1) + precision = torch.exp(-logvar) + loss = 0.5 * (precision * (y_hat - label) ** 2 + logvar) + threshold = torch.abs(y_hat - label) > self.tau + loss = torch.mean(threshold * loss) + return loss + + def forward(self, pred_score, pred_logvar, gt_score, device, lens=None): + """ + Args: + pred_mean, pred_score: [batch, time, 1/5] + """ + # make mask + if self.masked_loss: + masks = make_non_pad_mask(lens).to(device) + else: + masks = None + + # repeat for frame level loss + time = pred_score.shape[1] + # gt_mean = gt_mean.unsqueeze(1).repeat(1, time) + gt_score = gt_score.unsqueeze(1).repeat(1, time) + + loss = self.forward_criterion(pred_score, pred_logvar, gt_score, masks) + return loss + + +class LaplaceNLLLoss(nn.Module): + """ + Laplace NLL loss (for uncertainty modeling) + """ + + def __init__(self, tau, masked_loss=False): + super(LaplaceNLLLoss, self).__init__() + self.tau = tau + self.masked_loss = masked_loss + + def forward_criterion(self, y_hat, logvar, label, masks=None): + """ + loss = 0.5 * (precision * (target - mean) ** 2 + log_var) + """ + + # might investigate how to combine masked loss with categorical output + if masks is not None: + y_hat = y_hat.masked_select(masks) + logvar = logvar.masked_select(masks) + label = label.masked_select(masks) + + y_hat = y_hat.squeeze(-1) + logvar = logvar.squeeze(-1) + b = torch.exp(logvar) + 1e-6 + loss = torch.abs(y_hat - label) / b + logvar + threshold = torch.abs(y_hat - label) > self.tau + loss = torch.mean(threshold * loss) + return loss + + def forward(self, pred_score, pred_logvar, gt_score, device, lens=None): + """ + Args: + pred_mean, pred_score: [batch, time, 1/5] + """ + # make mask + if self.masked_loss: + masks = make_non_pad_mask(lens).to(device) + else: + masks = None + + # repeat for frame level loss + time = pred_score.shape[1] + # gt_mean = gt_mean.unsqueeze(1).repeat(1, time) + gt_score = gt_score.unsqueeze(1).repeat(1, time) + + loss = self.forward_criterion(pred_score, pred_logvar, gt_score, masks) + return loss diff --git a/sheet/models/__init__.py b/sheet/models/__init__.py new file mode 100644 index 0000000..c773ca8 --- /dev/null +++ b/sheet/models/__init__.py @@ -0,0 +1,9 @@ +from .alignnet import * # NOQA +from .ldnet import * # NOQA + +# from .ramp_simple import * # NOQA +# from .ramp import * # NOQA +from .sslmos import * # NOQA +from .utmos import * # NOQA + +from .sslmos_u import * # NOQA \ No newline at end of file diff --git a/sheet/models/alignnet.py b/sheet/models/alignnet.py new file mode 100644 index 0000000..d194669 --- /dev/null +++ b/sheet/models/alignnet.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +# Modified AlignNet model + +import torch +import torch.nn as nn +from sheet.modules.ldnet.modules import Projection, ProjectionWithUncertainty + + +class AlignNet(torch.nn.Module): + def __init__( + self, + # dummy, for signature need + model_input: str, + # model related + ssl_module: str = "s3prl", + s3prl_name: str = "wav2vec2", + ssl_model_output_dim: int = 768, + ssl_model_layer_idx: int = -1, + # listener related + use_listener_modeling: bool = False, + num_listeners: int = None, + listener_emb_dim: int = None, + use_mean_listener: bool = True, + # domain related + use_domain_modeling: bool = False, + num_domains: int = None, + domain_emb_dim: int = None, + # decoder related + use_decoder_rnn: bool = True, + decoder_rnn_dim: int = 512, + decoder_dnn_dim: int = 2048, + decoder_activation: str = "ReLU", + output_type: str = "scalar", + range_clipping: bool = True, + ): + super().__init__() # this is needed! or else there will be an error. + self.use_mean_listener = use_mean_listener + self.output_type = output_type + self.decoder_dnn_dim = decoder_dnn_dim + self.range_clipping = range_clipping + + # define ssl model + if ssl_module == "s3prl": + from s3prl.nn import S3PRLUpstream + + if s3prl_name in S3PRLUpstream.available_names(): + self.ssl_model = S3PRLUpstream(s3prl_name) + self.ssl_model_layer_idx = ssl_model_layer_idx + else: + raise NotImplementedError + decoder_input_dim = ssl_model_output_dim + + # listener modeling related + self.use_listener_modeling = use_listener_modeling + if use_listener_modeling: + self.num_listeners = num_listeners + self.listener_embeddings = nn.Embedding( + num_embeddings=num_listeners, embedding_dim=listener_emb_dim + ) + decoder_input_dim += listener_emb_dim + + # domain modeling related + self.use_domain_modeling = use_domain_modeling + if use_domain_modeling: + self.num_domains = num_domains + self.domain_embeddings = nn.Embedding( + num_embeddings=num_domains, embedding_dim=domain_emb_dim + ) + decoder_input_dim += domain_emb_dim + + # define decoder rnn + self.use_decoder_rnn = use_decoder_rnn + if self.use_decoder_rnn: + self.decoder_rnn = nn.LSTM( + input_size=decoder_input_dim, + hidden_size=decoder_rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + self.decoder_dnn_input_dim = decoder_rnn_dim * 2 + else: + self.decoder_dnn_input_dim = decoder_input_dim + + # define activation + if decoder_activation == "ReLU": + self.decoder_activation = nn.ReLU + else: + raise NotImplementedError + + # there is always decoder dnn + self.decoder_dnn = Projection( + self.decoder_dnn_input_dim, + self.decoder_dnn_dim, + self.decoder_activation, + self.output_type, + self.range_clipping, + ) + + def get_num_params(self): + return sum(p.numel() for n, p in self.named_parameters()) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + inputs: dict, which has the following keys: + - waveform has shape (batch, time) + - waveform_lengths has shape (batch) + - listener_ids has shape (batch) + - domain_ids has shape (batch) + """ + waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] + + # ssl model forward + ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( + waveform, waveform_lengths + ) + to_concat = [ssl_model_outputs] + time = ssl_model_outputs.size(1) + + # get listener embedding + if self.use_listener_modeling: + listener_ids = inputs["listener_idxs"] + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(listener_embs) + + # get domain embedding + if self.use_domain_modeling: + domain_ids = inputs["domain_idxs"] + domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) + domain_embs = torch.stack( + [domain_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(domain_embs) + + decoder_inputs = torch.cat(to_concat, dim=2) + + # decoder rnn + if self.use_decoder_rnn: + decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) + + # decoder dnn + decoder_outputs = self.decoder_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # set outputs + # return lengths for masked loss calculation + ret = { + "waveform_lengths": waveform_lengths, + "frame_lengths": ssl_model_output_lengths, + } + if self.use_listener_modeling: + ret["ld_scores"] = decoder_outputs + else: + ret["mean_scores"] = decoder_outputs + + return ret + + def mean_listener_inference(self, inputs): + waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] + batch = waveform.size(0) + + # ssl model forward + ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( + waveform, waveform_lengths + ) + to_concat = [ssl_model_outputs] + time = ssl_model_outputs.size(1) + + # get listener embedding + if self.use_listener_modeling: + device = waveform.device + listener_ids = ( + torch.ones(batch, dtype=torch.long) * self.num_listeners - 1 + ).to( + device + ) # (bs) + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(listener_embs) + + # get domain embedding + if self.use_domain_modeling: + device = waveform.device + assert "domain_idxs" in inputs, "Must specify domain ID even in inference." + domain_ids = inputs["domain_idxs"] + domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) + domain_embs = torch.stack( + [domain_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(domain_embs) + + decoder_inputs = torch.cat(to_concat, dim=2) + + # decoder rnn + if self.use_decoder_rnn: + decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) + + # decoder dnn + decoder_outputs = self.decoder_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + scores = torch.mean(decoder_outputs.squeeze(-1), dim=1) + return {"scores": scores} + + def ssl_model_forward(self, waveform, waveform_lengths): + all_ssl_model_outputs, all_ssl_model_output_lengths = self.ssl_model( + waveform, waveform_lengths + ) + ssl_model_outputs = all_ssl_model_outputs[self.ssl_model_layer_idx] + ssl_model_output_lengths = all_ssl_model_output_lengths[ + self.ssl_model_layer_idx + ] + return ssl_model_outputs, ssl_model_output_lengths + + def get_ssl_embeddings(self, inputs): + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + return encoder_outputs + + +class AlignNet_U(AlignNet): + def __init__(self, model_input, *args, **kwargs): + super().__init__(model_input, *args, **kwargs) + + self.decoder_dnn = ProjectionWithUncertainty( + self.decoder_dnn_input_dim, + self.decoder_dnn_dim, + self.decoder_activation, + self.output_type, + 5 # fix this if one day we want to use categorical output + ) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + inputs: dict, which has the following keys: + - waveform has shape (batch, time) + - waveform_lengths has shape (batch) + - listener_ids has shape (batch) + - domain_ids has shape (batch) + """ + waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] + + # ssl model forward + ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( + waveform, waveform_lengths + ) + to_concat = [ssl_model_outputs] + time = ssl_model_outputs.size(1) + + # get listener embedding + if self.use_listener_modeling: + listener_ids = inputs["listener_idxs"] + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(listener_embs) + + # get domain embedding + if self.use_domain_modeling: + domain_ids = inputs["domain_idxs"] + domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) + domain_embs = torch.stack( + [domain_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(domain_embs) + + decoder_inputs = torch.cat(to_concat, dim=2) + + # decoder rnn + if self.use_decoder_rnn: + decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) + + # decoder dnn + decoder_outputs_mean, decoder_outputs_logvar = self.decoder_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # set outputs + # return lengths for masked loss calculation + ret = { + "waveform_lengths": waveform_lengths, + "frame_lengths": ssl_model_output_lengths, + } + if self.use_listener_modeling: + ret["ld_scores"] = decoder_outputs + else: + ret["mean_scores"] = decoder_outputs_mean + ret["mean_scores_logvar"] = decoder_outputs_logvar + + return ret + + def mean_listener_inference(self, inputs): + waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] + batch = waveform.size(0) + + # ssl model forward + ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( + waveform, waveform_lengths + ) + to_concat = [ssl_model_outputs] + time = ssl_model_outputs.size(1) + + # get listener embedding + if self.use_listener_modeling: + device = waveform.device + listener_ids = ( + torch.ones(batch, dtype=torch.long) * self.num_listeners - 1 + ).to( + device + ) # (bs) + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(listener_embs) + + # get domain embedding + if self.use_domain_modeling: + device = waveform.device + assert "domain_idxs" in inputs, "Must specify domain ID even in inference." + domain_ids = inputs["domain_idxs"] + domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) + domain_embs = torch.stack( + [domain_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(domain_embs) + + decoder_inputs = torch.cat(to_concat, dim=2) + + # decoder rnn + if self.use_decoder_rnn: + decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) + + # decoder dnn + decoder_outputs_mean, decoder_outputs_logvar = self.decoder_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + scores = torch.mean(decoder_outputs_mean.squeeze(-1), dim=1) + logvars = torch.mean(decoder_outputs_logvar.squeeze(-1), dim=1) + return {"scores": scores, "logvars": logvars} diff --git a/sheet/models/ldnet.py b/sheet/models/ldnet.py new file mode 100644 index 0000000..017193f --- /dev/null +++ b/sheet/models/ldnet.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +# LDNet model +# taken from: https://github.com/unilight/LDNet/blob/main/models/LDNet.py (written by myself) + +import math + +import torch +import torch.nn as nn +from sheet.modules.ldnet.modules import STRIDE, MobileNetV3ConvBlocks, Projection + + +class LDNet(torch.nn.Module): + def __init__( + self, + model_input: str, + # listener related + num_listeners: int, + listener_emb_dim: int, + use_mean_listener: bool, + # model related + activation: str, + encoder_type: str, + encoder_bneck_configs: list, + encoder_output_dim: int, + decoder_type: str, + decoder_dnn_dim: int, + output_type: str, + range_clipping: bool, + # mean net related + use_mean_net: bool = False, + mean_net_type: str = "ffn", + mean_net_dnn_dim: int = 64, + mean_net_range_clipping: bool = True, + ): + super().__init__() # this is needed! or else there will be an error. + self.use_mean_listener = use_mean_listener + self.output_type = output_type + + # only accept mag_sgram as input + assert model_input == "mag_sgram" + + # define listener embedding + self.num_listeners = num_listeners + self.listener_embeddings = nn.Embedding( + num_embeddings=num_listeners, embedding_dim=listener_emb_dim + ) + + # define activation + if activation == "ReLU": + self.activation = nn.ReLU + else: + raise NotImplementedError + + # define encoder + if encoder_type == "mobilenetv3": + self.encoder = MobileNetV3ConvBlocks( + encoder_bneck_configs, encoder_output_dim + ) + else: + raise NotImplementedError + + # define decoder + self.decoder_type = decoder_type + if decoder_type == "ffn": + decoder_dnn_input_dim = encoder_output_dim + listener_emb_dim + else: + raise NotImplementedError + # there is always dnn + self.decoder_dnn = Projection( + decoder_dnn_input_dim, + decoder_dnn_dim, + self.activation, + output_type, + range_clipping, + ) + + # define mean net + self.use_mean_net = use_mean_net + self.mean_net_type = mean_net_type + if use_mean_net: + if mean_net_type == "ffn": + mean_net_dnn_input_dim = encoder_output_dim + else: + raise NotImplementedError + # there is always dnn + self.mean_net_dnn = Projection( + mean_net_dnn_input_dim, + mean_net_dnn_dim, + self.activation, + output_type, + mean_net_range_clipping, + ) + + def _get_output_dim(self, input_size, num_layers, stride=STRIDE): + """ + calculate the final ouptut width (dim) of a CNN using the following formula + w_i = |_ (w_i-1 - 1) / stride + 1 _| + """ + output_dim = input_size + for _ in range(num_layers): + output_dim = math.floor((output_dim - 1) / STRIDE + 1) + return output_dim + + def get_num_params(self): + return sum(p.numel() for n, p in self.named_parameters()) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + mag_sgram has shape (batch, time, dim) + listener_ids has shape (batch) + """ + mag_sgram = inputs["mag_sgram"] + mag_sgram_lengths = inputs["mag_sgram_lengths"] + listener_ids = inputs["listener_idxs"] + + batch, time, _ = mag_sgram.shape + + # get listener embedding + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # encoder and inject listener embedding + mag_sgram = mag_sgram.unsqueeze(1) + encoder_outputs = self.encoder(mag_sgram) # (batch, ch, time, feat_dim) + encoder_outputs = encoder_outputs.view( + (batch, time, -1) + ) # (batch, time, feat_dim) + decoder_inputs = torch.cat( + [encoder_outputs, listener_embs], dim=-1 + ) # concat along feature dimension + + # mean net + if self.use_mean_net: + mean_net_inputs = encoder_outputs + if self.mean_net_type == "rnn": + mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_inputs) + else: + mean_net_outputs = mean_net_inputs + mean_net_outputs = self.mean_net_dnn( + mean_net_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # decoder + if self.decoder_type == "rnn": + decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) + else: + decoder_outputs = decoder_inputs + decoder_outputs = self.decoder_dnn( + decoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # define scores + ret = { + "frame_lengths": mag_sgram_lengths, + "mean_scores": mean_net_outputs if self.use_mean_net else None, + "ld_scores": decoder_outputs, + } + + return ret + + def mean_listener_inference(self, inputs): + """Mean listener inference. + Args: + mag_sgram has shape (batch, time, dim) + """ + + assert self.use_mean_listener + mag_sgram = inputs["mag_sgram"] + batch, time, dim = mag_sgram.shape + device = mag_sgram.device + + # get listener embedding + listener_id = (torch.ones(batch, dtype=torch.long) * self.num_listeners - 1).to( + device + ) # (bs) + listener_embs = self.listener_embeddings(listener_id) # (bs, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # encoder and inject listener embedding + mag_sgram = mag_sgram.unsqueeze(1) + encoder_outputs = self.encoder(mag_sgram) # (batch, ch, time, feat_dim) + encoder_outputs = encoder_outputs.view( + (batch, time, -1) + ) # (batch, time, feat_dim) + decoder_inputs = torch.cat( + [encoder_outputs, listener_embs], dim=-1 + ) # concat along feature dimension + + # decoder + if self.decoder_type == "rnn": + decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) + else: + decoder_outputs = decoder_inputs + decoder_outputs = self.decoder_dnn( + decoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # define scores + decoder_outputs = decoder_outputs.squeeze(-1) + scores = torch.mean(decoder_outputs, dim=1) + return {"scores": scores} + + def average_inference(self, mag_sgram, include_meanspk=False): + """Average listener inference. + Args: + mag_sgram has shape (batch, time, dim) + """ + + bs, time, _ = mag_sgram.shape + device = mag_sgram.device + if self.use_mean_listener and not include_meanspk: + actual_num_listeners = self.num_listeners - 1 + else: + actual_num_listeners = self.num_listeners + + # all listener ids + listener_id = ( + torch.arange(actual_num_listeners, dtype=torch.long) + .repeat(bs, 1) + .to(device) + ) # (bs, nj) + listener_embs = self.listener_embedding(listener_id) # (bs, nj, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=2 + ) # (bs, nj, time, feat_dim) + + # encoder and inject listener embedding + mag_sgram = mag_sgram.unsqueeze(1) + encoder_outputs = self.encoder(mag_sgram) # (batch, ch, time, feat_dim) + encoder_outputs = encoder_outputs.view( + (bs, time, -1) + ) # (batch, time, feat_dim) + decoder_inputs = torch.stack( + [encoder_outputs for i in range(actual_num_listeners)], dim=1 + ) # (bs, nj, time, feat_dim) + decoder_inputs = torch.cat( + [decoder_inputs, listener_embs], dim=-1 + ) # concat along feature dimension + + # mean net + if self.use_mean_net: + mean_net_inputs = encoder_outputs + if self.mean_net_type == "rnn": + mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_inputs) + else: + mean_net_outputs = mean_net_inputs + mean_net_outputs = self.mean_net_dnn( + mean_net_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # decoder + if self.decoder_type == "rnn": + decoder_outputs = decoder_inputs.view((bs * actual_num_listeners, time, -1)) + decoder_outputs, (h, c) = self.decoder_rnn(decoder_outputs) + else: + decoder_outputs = decoder_inputs + decoder_outputs = self.decoder_dnn( + decoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + decoder_outputs = decoder_outputs.view( + (bs, actual_num_listeners, time, -1) + ) # (bs, nj, time, 1/5) + + if self.output_type == "scalar": + decoder_outputs = decoder_outputs.squeeze(-1) # (bs, nj, time) + posterior_scores = torch.mean(decoder_outputs, dim=2) + ld_scores = torch.mean(decoder_outputs, dim=1) # (bs, time) + elif self.output_type == "categorical": + ld_posterior = torch.nn.functional.softmax(decoder_outputs, dim=-1) + ld_scores = torch.inner( + ld_posterior, torch.Tensor([1, 2, 3, 4, 5]).to(device) + ) + posterior_scores = torch.mean(ld_scores, dim=2) + ld_scores = torch.mean(ld_scores, dim=1) # (bs, time) + + # define scores + scores = torch.mean(ld_scores, dim=1) + + return {"scores": scores, "posterior_scores": posterior_scores} diff --git a/sheet/models/sslmos.py b/sheet/models/sslmos.py new file mode 100644 index 0000000..453485d --- /dev/null +++ b/sheet/models/sslmos.py @@ -0,0 +1,467 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +# SSLMOS model +# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper) + +import math + +import torch +import torch.nn as nn +from sheet.modules.ldnet.modules import Projection, ProjectionWithUncertainty + + +class SSLMOS(torch.nn.Module): + def __init__( + self, + # dummy, for signature need + model_input: str, + # model related + ssl_module: str = "s3prl", + s3prl_name: str = "wav2vec2", + ssl_model_output_dim: int = 768, + ssl_model_layer_idx: int = -1, + # mean net related + mean_net_dnn_dim: int = 64, + mean_net_output_type: str = "scalar", + mean_net_output_dim: int = 5, + mean_net_output_step: float = 0.25, + mean_net_range_clipping: bool = True, + # listener related + use_listener_modeling: bool = False, + num_listeners: int = None, + listener_emb_dim: int = None, + use_mean_listener: bool = True, + # decoder related + decoder_type: str = "ffn", + decoder_dnn_dim: int = 64, + output_type: str = "scalar", + range_clipping: bool = True, + # additional head (for RAMP) + use_additional_categorical_head: bool = False, + categorical_head_dnn_dim: int = 64, + categorical_head_output_dim: int = 17, + categorical_head_output_step: float = 0.25, + categorical_head_range_clipping: bool = True, + # dummy, for signature need + num_domains: int = None, + ): + super().__init__() # this is needed! or else there will be an error. + self.use_mean_listener = use_mean_listener + self.output_type = output_type + self.use_additional_categorical_head = use_additional_categorical_head + + # define listener embedding + self.use_listener_modeling = use_listener_modeling + + # define ssl model + if ssl_module == "s3prl": + from s3prl.nn import S3PRLUpstream + + if s3prl_name in S3PRLUpstream.available_names(): + self.ssl_model = S3PRLUpstream(s3prl_name) + self.ssl_model_layer_idx = ssl_model_layer_idx + else: + raise NotImplementedError + + # default uses ffn type mean net + self.mean_net_dnn = Projection( + ssl_model_output_dim, + mean_net_dnn_dim, + nn.ReLU, + mean_net_output_type, + mean_net_output_dim, + mean_net_output_step, + mean_net_range_clipping, + ) + + # additional categorical head (for RAMP) + if use_additional_categorical_head: + # make sure mean net is not categorical + assert ( + mean_net_output_type != "categorical" + ), "mean net cannot be categorical if additional categorical head is used" + self.categorical_head = Projection( + ssl_model_output_dim, + mean_net_dnn_dim, + nn.ReLU, + "categorical", + categorical_head_output_dim, + categorical_head_output_step, + categorical_head_range_clipping, + ) + + # listener modeling related + self.use_listener_modeling = use_listener_modeling + if use_listener_modeling: + self.num_listeners = num_listeners + self.listener_embeddings = nn.Embedding( + num_embeddings=num_listeners, embedding_dim=listener_emb_dim + ) + # define decoder + self.decoder_type = decoder_type + if decoder_type == "ffn": + decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim + else: + raise NotImplementedError + # there is always dnn + self.decoder_dnn = Projection( + decoder_dnn_input_dim, + decoder_dnn_dim, + self.activation, + output_type, + range_clipping, + ) + + def get_num_params(self): + return sum(p.numel() for n, p in self.named_parameters()) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + waveform has shape (batch, time) + waveform_lengths has shape (batch) + listener_ids has shape (batch) + """ + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + batch, time = waveform.shape + + # get listener embedding + if self.use_listener_modeling: + listener_ids = inputs["listener_idxs"] + # NOTE(unlight): not tested yet + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # ssl model forward + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] + + # inject listener embedding + if self.use_listener_modeling: + # NOTE(unlight): not tested yet + encoder_outputs = encoder_outputs.view( + (batch, time, -1) + ) # (batch, time, feat_dim) + decoder_inputs = torch.cat( + [encoder_outputs, listener_embs], dim=-1 + ) # concat along feature dimension + else: + decoder_inputs = encoder_outputs + + # masked mean pooling + # masks = make_non_pad_mask(encoder_outputs_lens) + # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] + # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) + + # mean net + mean_net_outputs = self.mean_net_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # additional categorical head + if self.use_additional_categorical_head: + categorical_head_outputs = self.categorical_head( + decoder_inputs + ) # [batch, time, categorical steps] + + # decoder + if self.use_listener_modeling: + if self.decoder_type == "rnn": + decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) + else: + decoder_outputs = decoder_inputs + decoder_outputs = self.decoder_dnn( + decoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # set outputs + # return lengths for masked loss calculation + ret = { + "waveform_lengths": waveform_lengths, + "frame_lengths": encoder_outputs_lens, + } + + # define scores + ret["mean_scores"] = mean_net_outputs + ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None + if self.use_additional_categorical_head: + ret["categorical_head_scores"] = categorical_head_outputs + + return ret + + def mean_net_inference(self, inputs): + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + # ssl model forward + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + + # mean net + decoder_inputs = encoder_outputs + mean_net_outputs = self.mean_net_dnn( + decoder_inputs, inference=True + ) # [batch, time, 1 (scalar) / 5 (categorical)] + mean_net_outputs = mean_net_outputs.squeeze(-1) + scores = torch.mean(mean_net_outputs.to(torch.float), dim=1) # [batch] + + ret = {"ssl_embeddings": encoder_outputs, "scores": scores} + + if self.use_additional_categorical_head: + ret["confidences"] = self.categorical_head( + decoder_inputs + ) # [batch, time, categorical steps] + + return ret + + def mean_net_inference_p1(self, waveform, waveform_lengths): + # ssl model forward + all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + return encoder_outputs + + def mean_net_inference_p2(self, encoder_outputs): + # mean net + mean_net_outputs = self.mean_net_dnn( + encoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + mean_net_outputs = mean_net_outputs.squeeze(-1) + scores = torch.mean(mean_net_outputs, dim=1) + + return scores + + def get_ssl_embeddings(self, inputs): + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + return encoder_outputs + + +class SSLMOS_U(SSLMOS): + def __init__( + self, + # dummy, for signature need + model_input: str, + # model related + ssl_module: str = "s3prl", + s3prl_name: str = "wav2vec2", + ssl_model_output_dim: int = 768, + ssl_model_layer_idx: int = -1, + # mean net related + mean_net_dnn_dim: int = 64, + mean_net_output_type: str = "scalar", + mean_net_output_dim: int = 5, + mean_net_output_step: float = 0.25, + mean_net_range_clipping: bool = True, + # listener related + use_listener_modeling: bool = False, + num_listeners: int = None, + listener_emb_dim: int = None, + use_mean_listener: bool = True, + # decoder related + decoder_type: str = "ffn", + decoder_dnn_dim: int = 64, + output_type: str = "scalar", + range_clipping: bool = True, + # additional head (for RAMP) + use_additional_categorical_head: bool = False, + categorical_head_dnn_dim: int = 64, + categorical_head_output_dim: int = 17, + categorical_head_output_step: float = 0.25, + categorical_head_range_clipping: bool = True, + # dummy, for signature need + num_domains: int = None, + ): + super().__init__() # this is needed! or else there will be an error. + self.use_mean_listener = use_mean_listener + self.output_type = output_type + self.use_additional_categorical_head = use_additional_categorical_head + + # define listener embedding + self.use_listener_modeling = use_listener_modeling + + # define ssl model + if ssl_module == "s3prl": + from s3prl.nn import S3PRLUpstream + + if s3prl_name in S3PRLUpstream.available_names(): + self.ssl_model = S3PRLUpstream(s3prl_name) + self.ssl_model_layer_idx = ssl_model_layer_idx + else: + raise NotImplementedError + + # default uses ffn type mean net + self.mean_net_dnn = ProjectionWithUncertainty( + ssl_model_output_dim, + mean_net_dnn_dim, + nn.ReLU, + mean_net_output_type, + mean_net_output_dim, + mean_net_output_step, + mean_net_range_clipping, + ) + + # additional categorical head (for RAMP) + if use_additional_categorical_head: + # make sure mean net is not categorical + assert ( + mean_net_output_type != "categorical" + ), "mean net cannot be categorical if additional categorical head is used" + self.categorical_head = Projection( + ssl_model_output_dim, + mean_net_dnn_dim, + nn.ReLU, + "categorical", + categorical_head_output_dim, + categorical_head_output_step, + categorical_head_range_clipping, + ) + + # listener modeling related + self.use_listener_modeling = use_listener_modeling + if use_listener_modeling: + self.num_listeners = num_listeners + self.listener_embeddings = nn.Embedding( + num_embeddings=num_listeners, embedding_dim=listener_emb_dim + ) + # define decoder + self.decoder_type = decoder_type + if decoder_type == "ffn": + decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim + else: + raise NotImplementedError + # there is always dnn + self.decoder_dnn = Projection( + decoder_dnn_input_dim, + decoder_dnn_dim, + self.activation, + output_type, + range_clipping, + ) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + waveform has shape (batch, time) + waveform_lengths has shape (batch) + listener_ids has shape (batch) + """ + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + batch, time = waveform.shape + + # get listener embedding + if self.use_listener_modeling: + listener_ids = inputs["listener_idxs"] + # NOTE(unlight): not tested yet + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # ssl model forward + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] + + # inject listener embedding + if self.use_listener_modeling: + # NOTE(unlight): not tested yet + encoder_outputs = encoder_outputs.view( + (batch, time, -1) + ) # (batch, time, feat_dim) + decoder_inputs = torch.cat( + [encoder_outputs, listener_embs], dim=-1 + ) # concat along feature dimension + else: + decoder_inputs = encoder_outputs + + # masked mean pooling + # masks = make_non_pad_mask(encoder_outputs_lens) + # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] + # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) + + # mean net + mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # additional categorical head + if self.use_additional_categorical_head: + categorical_head_outputs = self.categorical_head( + decoder_inputs + ) # [batch, time, categorical steps] + + # decoder + if self.use_listener_modeling: + if self.decoder_type == "rnn": + decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) + else: + decoder_outputs = decoder_inputs + decoder_outputs = self.decoder_dnn( + decoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # set outputs + # return lengths for masked loss calculation + ret = { + "waveform_lengths": waveform_lengths, + "frame_lengths": encoder_outputs_lens, + } + + # define scores + ret["mean_scores"] = mean_net_outputs_mean + ret["mean_scores_logvar"] = mean_net_outputs_logvar + ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None + if self.use_additional_categorical_head: + ret["categorical_head_scores"] = categorical_head_outputs + + return ret + + def mean_net_inference(self, inputs): + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + # ssl model forward + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + + # mean net + decoder_inputs = encoder_outputs + mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( + decoder_inputs, inference=True + ) # [batch, time, 1 (scalar) / 5 (categorical)] + mean_net_outputs_mean = mean_net_outputs_mean.squeeze(-1) + mean_net_outputs_logvar = mean_net_outputs_logvar.squeeze(-1) + scores = torch.mean(mean_net_outputs_mean, dim=1) # [batch] + logvars = torch.mean(mean_net_outputs_logvar, dim=1) # [batch] + + ret = {"ssl_embeddings": encoder_outputs, "scores": scores, "logvars": logvars} + + if self.use_additional_categorical_head: + ret["confidences"] = self.categorical_head( + decoder_inputs + ) # [batch, time, categorical steps] + + return ret diff --git a/sheet/models/sslmos_u.py b/sheet/models/sslmos_u.py new file mode 100644 index 0000000..5e799f2 --- /dev/null +++ b/sheet/models/sslmos_u.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- + +# Copyright 2025 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +# SSLMOS model which can output uncertainty +# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper) + +import math + +import torch +import torch.nn as nn +from sheet.modules.ldnet.modules import ProjectionWithUncertainty + + +class SSLMOS_U(torch.nn.Module): + def __init__( + self, + # dummy, for signature need + model_input: str, + # model related + ssl_module: str = "s3prl", + s3prl_name: str = "wav2vec2", + ssl_model_output_dim: int = 768, + ssl_model_layer_idx: int = -1, + # mean net related + mean_net_dnn_dim: int = 64, + mean_net_output_type: str = "scalar", + mean_net_output_dim: int = 5, + mean_net_output_step: float = 0.25, + mean_net_range_clipping: bool = True, + # listener related + use_listener_modeling: bool = False, + num_listeners: int = None, + listener_emb_dim: int = None, + use_mean_listener: bool = True, + # decoder related + decoder_type: str = "ffn", + decoder_dnn_dim: int = 64, + output_type: str = "scalar", + range_clipping: bool = True, + # additional head (for RAMP) + use_additional_categorical_head: bool = False, + categorical_head_dnn_dim: int = 64, + categorical_head_output_dim: int = 17, + categorical_head_output_step: float = 0.25, + categorical_head_range_clipping: bool = True, + # dummy, for signature need + num_domains: int = None, + ): + super().__init__() # this is needed! or else there will be an error. + self.use_mean_listener = use_mean_listener + self.output_type = output_type + self.use_additional_categorical_head = use_additional_categorical_head + + # define listener embedding + self.use_listener_modeling = use_listener_modeling + + # define ssl model + if ssl_module == "s3prl": + from s3prl.nn import S3PRLUpstream + + if s3prl_name in S3PRLUpstream.available_names(): + self.ssl_model = S3PRLUpstream(s3prl_name) + self.ssl_model_layer_idx = ssl_model_layer_idx + else: + raise NotImplementedError + + # default uses ffn type mean net + self.mean_net_dnn = ProjectionWithUncertainty( + ssl_model_output_dim, + mean_net_dnn_dim, + nn.ReLU, + mean_net_output_type, + mean_net_output_dim, + mean_net_output_step, + mean_net_range_clipping, + ) + + # additional categorical head (for RAMP) + if use_additional_categorical_head: + # make sure mean net is not categorical + assert ( + mean_net_output_type != "categorical" + ), "mean net cannot be categorical if additional categorical head is used" + self.categorical_head = Projection( + ssl_model_output_dim, + mean_net_dnn_dim, + nn.ReLU, + "categorical", + categorical_head_output_dim, + categorical_head_output_step, + categorical_head_range_clipping, + ) + + # listener modeling related + self.use_listener_modeling = use_listener_modeling + if use_listener_modeling: + self.num_listeners = num_listeners + self.listener_embeddings = nn.Embedding( + num_embeddings=num_listeners, embedding_dim=listener_emb_dim + ) + # define decoder + self.decoder_type = decoder_type + if decoder_type == "ffn": + decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim + else: + raise NotImplementedError + # there is always dnn + self.decoder_dnn = Projection( + decoder_dnn_input_dim, + decoder_dnn_dim, + self.activation, + output_type, + range_clipping, + ) + + def get_num_params(self): + return sum(p.numel() for n, p in self.named_parameters()) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + waveform has shape (batch, time) + waveform_lengths has shape (batch) + listener_ids has shape (batch) + """ + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + batch, time = waveform.shape + + # get listener embedding + if self.use_listener_modeling: + listener_ids = inputs["listener_idxs"] + # NOTE(unlight): not tested yet + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # ssl model forward + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] + + # inject listener embedding + if self.use_listener_modeling: + # NOTE(unlight): not tested yet + encoder_outputs = encoder_outputs.view( + (batch, time, -1) + ) # (batch, time, feat_dim) + decoder_inputs = torch.cat( + [encoder_outputs, listener_embs], dim=-1 + ) # concat along feature dimension + else: + decoder_inputs = encoder_outputs + + # masked mean pooling + # masks = make_non_pad_mask(encoder_outputs_lens) + # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] + # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) + + # mean net + mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # additional categorical head + if self.use_additional_categorical_head: + categorical_head_outputs = self.categorical_head( + decoder_inputs + ) # [batch, time, categorical steps] + + # decoder + if self.use_listener_modeling: + if self.decoder_type == "rnn": + decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) + else: + decoder_outputs = decoder_inputs + decoder_outputs = self.decoder_dnn( + decoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # set outputs + # return lengths for masked loss calculation + ret = { + "waveform_lengths": waveform_lengths, + "frame_lengths": encoder_outputs_lens, + } + + # define scores + ret["mean_scores"] = mean_net_outputs_mean + ret["mean_scores_logvar"] = mean_net_outputs_logvar + ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None + if self.use_additional_categorical_head: + ret["categorical_head_scores"] = categorical_head_outputs + + return ret + + def mean_net_inference(self, inputs): + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + # ssl model forward + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + + # mean net + decoder_inputs = encoder_outputs + mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( + decoder_inputs, inference=True + ) # [batch, time, 1 (scalar) / 5 (categorical)] + mean_net_outputs_mean = mean_net_outputs_mean.squeeze(-1) + mean_net_outputs_logvar = mean_net_outputs_logvar.squeeze(-1) + scores = torch.mean(mean_net_outputs_mean, dim=1) # [batch] + logvars = torch.mean(mean_net_outputs_logvar, dim=1) # [batch] + + ret = {"ssl_embeddings": encoder_outputs, "scores": scores, "logvars": logvars} + + if self.use_additional_categorical_head: + ret["confidences"] = self.categorical_head( + decoder_inputs + ) # [batch, time, categorical steps] + + return ret + + def mean_net_inference_p1(self, waveform, waveform_lengths): + # ssl model forward + all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + return encoder_outputs + + def mean_net_inference_p2(self, encoder_outputs): + # mean net + mean_net_outputs = self.mean_net_dnn( + encoder_outputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + mean_net_outputs = mean_net_outputs.squeeze(-1) + scores = torch.mean(mean_net_outputs, dim=1) + + return scores + + def get_ssl_embeddings(self, inputs): + waveform = inputs["waveform"] + waveform_lengths = inputs["waveform_lengths"] + + all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( + waveform, waveform_lengths + ) + encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] + return encoder_outputs \ No newline at end of file diff --git a/sheet/models/utmos.py b/sheet/models/utmos.py new file mode 100644 index 0000000..44f7672 --- /dev/null +++ b/sheet/models/utmos.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +# UTMOS model +# modified from: https://github.com/sarulab-speech/UTMOS22/tree/master/strong + +import math + +import torch +import torch.nn as nn +from sheet.modules.ldnet.modules import Projection +from sheet.modules.utils import make_non_pad_mask + + +class UTMOS(torch.nn.Module): + def __init__( + self, + model_input: str, + # model related + ssl_module: str, + s3prl_name: str, + ssl_model_output_dim: int, + ssl_model_layer_idx: int, + # phoneme and reference related + use_phoneme: bool = True, + phoneme_encoder_dim: int = 256, + phoneme_encoder_emb_dim: int = 256, + phoneme_encoder_out_dim: int = 256, + phoneme_encoder_n_lstm_layers: int = 3, + phoneme_encoder_vocab_size: int = 300, + use_reference: bool = True, + # listener related + use_listener_modeling: bool = False, + num_listeners: int = None, + listener_emb_dim: int = None, + use_mean_listener: bool = True, + # decoder related + use_decoder_rnn: bool = True, + decoder_rnn_dim: int = 512, + decoder_dnn_dim: int = 2048, + decoder_activation: str = "ReLU", + output_type: str = "scalar", + range_clipping: bool = True, + num_domains: int = None, + ): + super().__init__() # this is needed! or else there will be an error. + self.use_mean_listener = use_mean_listener + self.output_type = output_type + + # define listener embedding + self.use_listener_modeling = use_listener_modeling + + # define ssl model + if ssl_module == "s3prl": + from s3prl.nn import S3PRLUpstream + + if s3prl_name in S3PRLUpstream.available_names(): + self.ssl_model = S3PRLUpstream(s3prl_name) + self.ssl_model_layer_idx = ssl_model_layer_idx + else: + raise NotImplementedError + decoder_input_dim = ssl_model_output_dim + + # define phoneme encoder + self.use_phoneme = use_phoneme + self.use_reference = use_reference + if self.use_phoneme: + self.phoneme_embedding = nn.Embedding( + phoneme_encoder_vocab_size, phoneme_encoder_emb_dim + ) + self.phoneme_encoder_lstm = nn.LSTM( + phoneme_encoder_emb_dim, + phoneme_encoder_dim, + num_layers=phoneme_encoder_n_lstm_layers, + dropout=0.1, + bidirectional=True, + ) + if self.use_reference: + + phoneme_encoder_linear_input_dim = ( + phoneme_encoder_dim + phoneme_encoder_dim + ) + else: + phoneme_encoder_linear_input_dim = phoneme_encoder_dim + self.phoneme_encoder_linear = nn.Sequential( + nn.Linear(phoneme_encoder_linear_input_dim, phoneme_encoder_out_dim), + nn.ReLU(), + ) + decoder_input_dim += phoneme_encoder_out_dim + + # NOTE(unlight): ignore domain embedding right now + + # listener modeling related + self.use_listener_modeling = use_listener_modeling + if use_listener_modeling: + self.num_listeners = num_listeners + self.listener_embeddings = nn.Embedding( + num_embeddings=num_listeners, embedding_dim=listener_emb_dim + ) + decoder_input_dim += listener_emb_dim + + # define decoder rnn + self.use_decoder_rnn = use_decoder_rnn + if self.use_decoder_rnn: + self.decoder_rnn = nn.LSTM( + input_size=decoder_input_dim, + hidden_size=decoder_rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + decoder_dnn_input_dim = decoder_rnn_dim * 2 + else: + decoder_dnn_input_dim = decoder_input_dim + + # define activation + if decoder_activation == "ReLU": + self.decoder_activation = nn.ReLU + else: + raise NotImplementedError + + # there is always decoder dnn + self.decoder_dnn = Projection( + decoder_dnn_input_dim, + decoder_dnn_dim, + self.decoder_activation, + output_type, + range_clipping, + ) + + def get_num_params(self): + return sum(p.numel() for n, p in self.named_parameters()) + + def forward(self, inputs): + """Calculate forward propagation. + Args: + inputs: dict, which has the following keys: + - waveform has shape (batch, time) + - waveform_lengths has shape (batch) + - listener_ids has shape (batch) + """ + waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] + + # ssl model forward + ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( + waveform, waveform_lengths + ) + to_concat = [ssl_model_outputs] + time = ssl_model_outputs.size(1) + + # phoneme encoder forward + if self.use_phoneme: + phoneme_encoder_outputs = self.phoneme_encoder_forward(inputs, time) + to_concat.append(phoneme_encoder_outputs) + + # get listener embedding + if self.use_listener_modeling: + listener_ids = inputs["listener_idxs"] + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(listener_embs) + + decoder_inputs = torch.cat(to_concat, dim=2) + + # decoder rnn + if self.use_decoder_rnn: + decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) + + # decoder dnn + decoder_outputs = self.decoder_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + # set outputs + # return lengths for masked loss calculation + ret = { + "waveform_lengths": waveform_lengths, + "frame_lengths": ssl_model_output_lengths, + } + if self.use_listener_modeling: + ret["ld_scores"] = decoder_outputs + else: + ret["mean_scores"] = decoder_outputs + + return ret + + def mean_listener_inference(self, inputs): + waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] + batch = waveform.size(0) + + # ssl model forward + ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( + waveform, waveform_lengths + ) + to_concat = [ssl_model_outputs] + time = ssl_model_outputs.size(1) + + # phoneme encoder forward + if self.use_phoneme: + phoneme_encoder_outputs = self.phoneme_encoder_forward(inputs, time) + to_concat.append(phoneme_encoder_outputs) + + # get listener embedding + if self.use_listener_modeling: + device = waveform.device + listener_ids = ( + torch.ones(batch, dtype=torch.long) * self.num_listeners - 1 + ).to( + device + ) # (bs) + listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) + listener_embs = torch.stack( + [listener_embs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + # NOTE(unilight): is this needed? + # encoder_outputs = encoder_outputs.view( + # (batch, time, -1) + # ) # (batch, time, feat_dim) + to_concat.append(listener_embs) + + decoder_inputs = torch.cat(to_concat, dim=2) + + # decoder rnn + if self.use_decoder_rnn: + decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) + + # decoder dnn + decoder_outputs = self.decoder_dnn( + decoder_inputs + ) # [batch, time, 1 (scalar) / 5 (categorical)] + + scores = torch.mean(decoder_outputs.squeeze(-1), dim=1) + return {"scores": scores} + + def ssl_model_forward(self, waveform, waveform_lengths): + all_ssl_model_outputs, all_ssl_model_output_lengths = self.ssl_model( + waveform, waveform_lengths + ) + ssl_model_outputs = all_ssl_model_outputs[self.ssl_model_layer_idx] + ssl_model_output_lengths = all_ssl_model_output_lengths[ + self.ssl_model_layer_idx + ] + return ssl_model_outputs, ssl_model_output_lengths + + def phoneme_encoder_forward(self, inputs, time): + phoneme, phoneme_lengths = inputs["phoneme_idxs"], inputs["phoneme_lengths"] + phoneme_embeddings = self.phoneme_embedding(phoneme) + phoneme_embeddings = torch.nn.utils.rnn.pack_padded_sequence( + phoneme_embeddings, phoneme_lengths, batch_first=True, enforce_sorted=False + ) + _, (phoneme_encoder_outputs, _) = self.phoneme_encoder_lstm(phoneme_embeddings) + phoneme_encoder_outputs = ( + phoneme_encoder_outputs[-1] + phoneme_encoder_outputs[0] + ) + if self.use_reference: + assert ( + "reference_idxs" in inputs and "reference_lengths" in inputs + ), "reference and reference_lenghts should not be None when use_reference is True" + reference, reference_lengths = ( + inputs["reference_idxs"], + inputs["reference_lengths"], + ) + reference_embeddings = self.phoneme_embedding(reference) + reference_embeddings = torch.nn.utils.rnn.pack_padded_sequence( + reference_embeddings, + reference_lengths, + batch_first=True, + enforce_sorted=False, + ) + _, (reference_encoder_outputs, _) = self.phoneme_encoder_lstm( + reference_embeddings + ) + reference_encoder_outputs = ( + reference_encoder_outputs[-1] + reference_encoder_outputs[0] + ) + phoneme_encoder_outputs = self.phoneme_encoder_linear( + torch.cat([phoneme_encoder_outputs, reference_encoder_outputs], 1) + ) + else: + phoneme_encoder_outputs = self.phoneme_encoder_linear( + phoneme_encoder_outputs + ) + + # expand + phoneme_encoder_outputs = torch.stack( + [phoneme_encoder_outputs for i in range(time)], dim=1 + ) # (batch, time, feat_dim) + + return phoneme_encoder_outputs diff --git a/sheet/modules/__init__.py b/sheet/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sheet/modules/ldnet/__init__.py b/sheet/modules/ldnet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sheet/modules/ldnet/mobilenetv2.py b/sheet/modules/ldnet/mobilenetv2.py new file mode 100644 index 0000000..87af658 --- /dev/null +++ b/sheet/modules/ldnet/mobilenetv2.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- + +from typing import Any, Callable, List, Optional + +import torch +from torch import Tensor, nn + +__all__ = ["MobileNetV2", "mobilenet_v2"] + + +model_urls = { + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", +} + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNActivation(nn.Sequential): + def __init__( + self, + in_planes: int, + out_planes: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + dilation: int = 1, + ) -> None: + padding = (kernel_size - 1) // 2 * dilation + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if activation_layer is None: + activation_layer = nn.ReLU6 + super().__init__( + # NOTE(unilight): stride only operates on the last axis + nn.Conv2d( + in_planes, + out_planes, + kernel_size, + (1, stride), + padding, + dilation=dilation, + groups=groups, + bias=False, + ), + norm_layer(out_planes), + activation_layer(inplace=True), + ) + self.out_channels = out_planes + + +# necessary for backwards compatibility +ConvBNReLU = ConvBNActivation + + +class InvertedResidual(nn.Module): + def __init__( + self, + inp: int, + oup: int, + stride: int, + expand_ratio: int, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2, 3] + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers: List[nn.Module] = [] + if expand_ratio != 1: + # pw + layers.append( + ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer) + ) + layers.extend( + [ + # dw + ConvBNReLU( + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + norm_layer=norm_layer, + ), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ] + ) + self.conv = nn.Sequential(*layers) + self.out_channels = oup + self._is_cn = stride > 1 + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__( + self, + num_classes: int = 1000, + width_mult: float = 1.0, + inverted_residual_setting: Optional[List[List[int]]] = None, + round_nearest: int = 8, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + """ + MobileNet V2 main class + + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + norm_layer: Module specifying the normalization layer to use + + """ + super(MobileNetV2, self).__init__() + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if ( + len(inverted_residual_setting) == 0 + or len(inverted_residual_setting[0]) != 4 + ): + raise ValueError( + "inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting) + ) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible( + last_channel * max(1.0, width_mult), round_nearest + ) + features: List[nn.Module] = [ + ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer) + ] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block( + input_channel, + output_channel, + stride, + expand_ratio=t, + norm_layer=norm_layer, + ) + ) + input_channel = output_channel + # building last several layers + features.append( + ConvBNReLU( + input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer + ) + ) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + x = self.features(x) + # Cannot use "squeeze" as batch-size can be 1 + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) diff --git a/sheet/modules/ldnet/mobilenetv3.py b/sheet/modules/ldnet/mobilenetv3.py new file mode 100644 index 0000000..5759712 --- /dev/null +++ b/sheet/modules/ldnet/mobilenetv3.py @@ -0,0 +1,341 @@ +# -*- coding: utf-8 -*- + +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from .mobilenetv2 import ConvBNActivation, _make_divisible + +__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] + + +model_urls = { + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", +} + + +class SqueezeExcitation(nn.Module): + # Implemented as described at Figure 4 of the MobileNetV3 paper + def __init__(self, input_channels: int, squeeze_factor: int = 4): + super().__init__() + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) + self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + + def _scale(self, input: Tensor, inplace: bool) -> Tensor: + scale = F.adaptive_avg_pool2d(input, 1) + scale = self.fc1(scale) + scale = self.relu(scale) + scale = self.fc2(scale) + return F.hardsigmoid(scale, inplace=inplace) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input, True) + return scale * input + + +class InvertedResidualConfig: + # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper + def __init__( + self, + input_channels: int, + kernel: int, + expanded_channels: int, + out_channels: int, + use_se: bool, + activation: str, + stride: int, + dilation: int, + width_mult: float, + ): + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.kernel = kernel + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) + self.use_se = use_se + self.use_hs = activation == "HS" + self.stride = stride + self.dilation = dilation + + @staticmethod + def adjust_channels(channels: int, width_mult: float): + return _make_divisible(channels * width_mult, 8) + + +class InvertedResidual(nn.Module): + # Implemented as described at section 5 of MobileNetV3 paper + def __init__( + self, + cnf: InvertedResidualConfig, + norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation, + ): + super().__init__() + if not (1 <= cnf.stride <= 3): + raise ValueError("illegal stride value") + + self.use_res_connect = ( + cnf.stride == 1 and cnf.input_channels == cnf.out_channels + ) + + layers: List[nn.Module] = [] + activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU + + # expand + if cnf.expanded_channels != cnf.input_channels: + layers.append( + ConvBNActivation( + cnf.input_channels, + cnf.expanded_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + # depthwise + stride = 1 if cnf.dilation > 1 else cnf.stride + layers.append( + ConvBNActivation( + cnf.expanded_channels, + cnf.expanded_channels, + kernel_size=cnf.kernel, + stride=stride, + dilation=cnf.dilation, + groups=cnf.expanded_channels, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + if cnf.use_se: + layers.append(se_layer(cnf.expanded_channels)) + + # project + layers.append( + ConvBNActivation( + cnf.expanded_channels, + cnf.out_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Identity, + ) + ) + + self.block = nn.Sequential(*layers) + self.out_channels = cnf.out_channels + self._is_cn = cnf.stride > 1 + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result += input + return result + + +class MobileNetV3(nn.Module): + + def __init__( + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any + ) -> None: + """ + MobileNet V3 main class + + Args: + inverted_residual_setting (List[InvertedResidualConfig]): Network structure + last_channel (int): The number of channels on the penultimate layer + num_classes (int): Number of classes + block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + """ + super().__init__() + + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + elif not ( + isinstance(inverted_residual_setting, Sequence) + and all( + [ + isinstance(s, InvertedResidualConfig) + for s in inverted_residual_setting + ] + ) + ): + raise TypeError( + "The inverted_residual_setting should be List[InvertedResidualConfig]" + ) + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers.append( + ConvBNActivation( + 3, + firstconv_output_channels, + kernel_size=3, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) + + # building inverted residual blocks + for cnf in inverted_residual_setting: + layers.append(block(cnf, norm_layer)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].out_channels + lastconv_output_channels = 6 * lastconv_input_channels + layers.append( + ConvBNActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Sequential( + nn.Linear(lastconv_output_channels, last_channel), + nn.Hardswish(inplace=True), + nn.Dropout(p=0.2, inplace=True), + nn.Linear(last_channel, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _mobilenet_v3_conf( + arch: str, + width_mult: float = 1.0, + reduced_tail: bool = False, + dilated: bool = False, + **kwargs: Any +): + reduce_divider = 2 if reduced_tail else 1 + dilation = 2 if dilated else 1 + + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial( + InvertedResidualConfig.adjust_channels, width_mult=width_mult + ) + + if arch == "mobilenet_v3_large": + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 + bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 + bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 + bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), + bneck_conf( + 112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation + ), # C4 + bneck_conf( + 160 // reduce_divider, + 5, + 960 // reduce_divider, + 160 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + bneck_conf( + 160 // reduce_divider, + 5, + 960 // reduce_divider, + 160 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + ] + last_channel = adjust_channels(1280 // reduce_divider) # C5 + elif arch == "mobilenet_v3_small": + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 + bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 + bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 + bneck_conf( + 96 // reduce_divider, + 5, + 576 // reduce_divider, + 96 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + bneck_conf( + 96 // reduce_divider, + 5, + 576 // reduce_divider, + 96 // reduce_divider, + True, + "HS", + 1, + dilation, + ), + ] + last_channel = adjust_channels(1024 // reduce_divider) # C5 + else: + raise ValueError("Unsupported model type {}".format(arch)) + + return inverted_residual_setting, last_channel diff --git a/sheet/modules/ldnet/modules.py b/sheet/modules/ldnet/modules.py new file mode 100644 index 0000000..1aa64ba --- /dev/null +++ b/sheet/modules/ldnet/modules.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +# LDNet modules +# taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself) + +from functools import partial +from typing import List + +import torch +from sheet.modules.ldnet.mobilenetv2 import ConvBNActivation +from sheet.modules.ldnet.mobilenetv3 import InvertedResidual as InvertedResidualV3 +from sheet.modules.ldnet.mobilenetv3 import InvertedResidualConfig +from torch import nn + +STRIDE = 3 + + +class Projection(nn.Module): + def __init__( + self, + in_dim, + hidden_dim, + activation, + output_type, + _output_dim, + output_step=1.0, + range_clipping=False, + ): + super(Projection, self).__init__() + self.output_type = output_type + self.range_clipping = range_clipping + if output_type == "scalar": + output_dim = 1 + if range_clipping: + self.proj = nn.Tanh() + elif output_type == "categorical": + output_dim = _output_dim + self.output_step = output_step + else: + raise NotImplementedError("wrong output_type: {}".format(output_type)) + + self.net = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + activation(), + nn.Dropout(0.3), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x, inference=False): + output = self.net(x) + + # scalar / categorical + if self.output_type == "scalar": + # range clipping + if self.range_clipping: + return self.proj(output) * 2.0 + 3 + else: + return output + else: + if inference: + return torch.argmax(output, dim=-1) * self.output_step + 1 + else: + return output + + +class ProjectionWithUncertainty(nn.Module): + def __init__( + self, + in_dim, + hidden_dim, + activation, + output_type, + _output_dim, + output_step=1.0, + range_clipping=False, + ): + super(ProjectionWithUncertainty, self).__init__() + self.output_type = output_type + self.range_clipping = range_clipping + if output_type == "scalar": + output_dim = 2 + if range_clipping: + self.proj = nn.Tanh() + elif output_type == "categorical": + output_dim = _output_dim + self.output_step = output_step + else: + raise NotImplementedError("wrong output_type: {}".format(output_type)) + + self.net = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + activation(), + nn.Dropout(0.3), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x, inference=False): + output = self.net(x) # output shape: [B, T, d] + + # scalar / categorical + if self.output_type == "scalar": + mean, logvar = output[:, :, 0], output[:, :, 1] + # range clipping + if self.range_clipping: + return self.proj(mean) * 2.0 + 3, logvar + else: + return mean, logvar + else: + if inference: + return torch.argmax(output, dim=-1) * self.output_step + 1 + else: + return output + + +class MobileNetV3ConvBlocks(nn.Module): + def __init__(self, bneck_confs, output_dim): + super(MobileNetV3ConvBlocks, self).__init__() + + bneck_conf = partial(InvertedResidualConfig, width_mult=1) + inverted_residual_setting = [bneck_conf(*b_conf) for b_conf in bneck_confs] + + block = InvertedResidualV3 + + # Never tested if a different eps and momentum is needed + # norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + norm_layer = nn.BatchNorm2d + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers.append( + ConvBNActivation( + 1, + firstconv_output_channels, + kernel_size=3, + stride=STRIDE, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) + + # building inverted residual blocks + for cnf in inverted_residual_setting: + layers.append(block(cnf, norm_layer)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].out_channels + lastconv_output_channels = output_dim + layers.append( + ConvBNActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) + self.features = nn.Sequential(*layers) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def forward(self, x): + time = x.shape[2] + x = self.features(x) + x = nn.functional.adaptive_avg_pool2d(x, (time, 1)) + x = x.squeeze(-1).transpose(1, 2) + return x diff --git a/sheet/modules/utils.py b/sheet/modules/utils.py new file mode 100644 index 0000000..2e9786a --- /dev/null +++ b/sheet/modules/utils.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +""" Model utilities. + + Some functions are based on: + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/nets_utils.py +""" + +import torch + + +def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.long().tolist() + + bs = int(len(lengths)) + if maxlen is None: + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + else: + assert xs is None + assert maxlen >= int(max(lengths)) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) diff --git a/sheet/nonparametric/__init__.py b/sheet/nonparametric/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sheet/nonparametric/datastore.py b/sheet/nonparametric/datastore.py new file mode 100644 index 0000000..17574eb --- /dev/null +++ b/sheet/nonparametric/datastore.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +"""datastore related""" + +import faiss +import h5py +import numpy as np +from scipy.special import softmax + + +class Datastore: + def __init__( + self, + datastore_path, + embed_dim, + device, + ): + """ + Args: + datastore_path (str): path to the datastore. + embed_dim (int): dimension of the embed in the datastore + """ + embeds = [] + scores = [] + paths = [] + with h5py.File(datastore_path, "r") as f: + for hdf5_path in list(f["scores"].keys()): + paths.append(hdf5_path) + embeds.append(f["embeds"][hdf5_path][()]) + scores.append(f["scores"][hdf5_path][()]) + embeds = np.stack(embeds, axis=0) + scores = np.array(scores) + + # build index + index = faiss.IndexFlatL2(embed_dim) + if device.type == "cuda": + # index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index) + index = faiss.index_cpu_to_all_gpus(index, ngpu=1) + # else: + # embeds = torch.tensor(embeds, device=device) + index.add(embeds) + + self.embeds = embeds + self.scores = scores + self.paths = paths + self.index = index + + def knn(self, query, k, search_only=False): + # search + distances, I = self.index.search(query, k) + scores = np.stack([self.scores[row] for row in I]) + ret = {"distances": distances, "scores": scores} + + if search_only: + return ret + + # NOTE(unilight) 20250205: change to negative + # inv_dist = 1 / (distances + 1e-8) + inv_dist = -distances + + norm_dist = softmax(inv_dist, axis=1) + + mult = np.multiply(norm_dist, scores) + + final_score = np.sum(mult, axis=1)[0] + + # retrieve IDs + ids = [[self.paths[e] for e in row] for row in I] + + ret["final_score"] = final_score + ret["ids"] = ids + + return ret diff --git a/sheet/schedulers/__init__.py b/sheet/schedulers/__init__.py new file mode 100644 index 0000000..ebb6cb2 --- /dev/null +++ b/sheet/schedulers/__init__.py @@ -0,0 +1 @@ +from .schedulers import get_scheduler # NOQA diff --git a/sheet/schedulers/schedulers.py b/sheet/schedulers/schedulers.py new file mode 100644 index 0000000..36db3c8 --- /dev/null +++ b/sheet/schedulers/schedulers.py @@ -0,0 +1,21 @@ +import copy + +from torch.optim.lr_scheduler import MultiStepLR, StepLR + +# Reference: https://github.com/s3prl/s3prl/blob/master/s3prl/schedulers.py + + +def get_scheduler(optimizer, scheduler_name, total_steps, scheduler_config): + scheduler_config = copy.deepcopy(scheduler_config) + scheduler = eval(f"get_{scheduler_name}")( + optimizer, num_training_steps=total_steps, **scheduler_config + ) + return scheduler + + +def get_multistep(optimizer, num_training_steps, milestones, gamma): + return MultiStepLR(optimizer, milestones, gamma) + + +def get_stepLR(optimizer, num_training_steps, step_size, gamma): + return StepLR(optimizer, step_size, gamma) diff --git a/sheet/trainers/__init__.py b/sheet/trainers/__init__.py new file mode 100644 index 0000000..aab6c7a --- /dev/null +++ b/sheet/trainers/__init__.py @@ -0,0 +1,2 @@ +from .non_intrusive import * # NOQA +# from .ramp import * # NOQA diff --git a/sheet/trainers/base.py b/sheet/trainers/base.py new file mode 100644 index 0000000..d66fea9 --- /dev/null +++ b/sheet/trainers/base.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +import logging +import os +import time +from collections import defaultdict + +import torch +from sheet.utils.model_io import freeze_modules +from tensorboardX import SummaryWriter +from tqdm import tqdm + + +class Trainer(object): + """Customized trainer module.""" + + def __init__( + self, + steps, + epochs, + data_loader, + sampler, + model, + criterion, + optimizer, + scheduler, + config, + device=torch.device("cpu"), + ): + """Initialize trainer. + + Args: + steps (int): Initial global steps. + epochs (int): Initial global epochs. + data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. + model (dict): Dict of models. It must contrain "generator" and "discriminator" models. + criterion (dict): Dict of criterions. It must contrain "stft" and "mse" criterions. + optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers. + scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers. + config (dict): Config dict loaded from yaml format configuration file. + device (torch.deive): Pytorch device instance. + + """ + self.steps = steps + self.epochs = epochs + self.data_loader = data_loader + self.sampler = sampler + self.model = model + self.criterion = criterion + self.optimizer = optimizer + self.scheduler = scheduler + self.config = config + self.device = device + self.writer = SummaryWriter(config["outdir"]) + self.finish_train = False + + self.total_train_loss = defaultdict(float) + self.total_eval_loss = defaultdict(float) + self.reset_eval_results() + + self.gradient_accumulate_steps = self.config.get("gradient_accumulate_steps", 1) + + self.reporter = list() # each element is [steps: int, results: dict] + self.original_patience = self.config.get("patience", None) + self.current_patience = self.config.get("patience", None) + + def run(self): + """Run training.""" + self.backward_steps = 0 + self.all_loss = 0.0 + self.tqdm = tqdm( + initial=self.steps, total=self.config["train_max_steps"], desc="[train]", mininterval=5, maxinterval=5, + ) + while True: + # train one epoch + self._train_epoch() + + # check whether training is finished + if self.finish_train: + break + + self.tqdm.close() + logging.info("Finished training.") + + def save_checkpoint(self, checkpoint_path): + """Save checkpoint. + + Args: + checkpoint_path (str): Checkpoint path to be saved. + + """ + state_dict = { + "optimizer": self.optimizer.state_dict(), + "steps": self.steps, + "epochs": self.epochs, + } + if self.scheduler is not None: + state_dict["scheduler"] = self.scheduler.state_dict() + + if self.config["distributed"]: + state_dict["model"] = self.model.module.state_dict() + else: + state_dict["model"] = self.model.state_dict() + + if not os.path.exists(os.path.dirname(checkpoint_path)): + os.makedirs(os.path.dirname(checkpoint_path)) + torch.save(state_dict, checkpoint_path) + + def load_checkpoint(self, checkpoint_path, load_only_params=False): + """Load checkpoint. + + Args: + checkpoint_path (str): Checkpoint path to be loaded. + load_only_params (bool): Whether to load only model parameters. + + """ + state_dict = torch.load(checkpoint_path, map_location="cpu") + if self.config["distributed"]: + self.model.module.load_state_dict(state_dict["model"]) + else: + self.model.load_state_dict(state_dict["model"]) + if not load_only_params: + self.steps = state_dict["steps"] + self.epochs = state_dict["epochs"] + self.optimizer.load_state_dict(state_dict["optimizer"]) + if self.scheduler is not None: + self.scheduler.load_state_dict(state_dict["scheduler"]) + + def _train_step(self, batch): + """Train model one step.""" + pass + + def _train_epoch(self): + """Train model one epoch.""" + for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): + # train one step + self._train_step(batch) + if self.backward_steps % self.gradient_accumulate_steps > 0: + continue + + # check interval + if self.config["rank"] == 0: + self._check_log_interval() + self._check_eval_and_save_interval() + + # check whether training is finished + if self.finish_train: + return + + # update + self.epochs += 1 + self.train_steps_per_epoch = train_steps_per_epoch + logging.info( + f"(Steps: {self.steps}) Finished {self.epochs} epoch training " + f"({self.train_steps_per_epoch} steps per epoch)." + ) + + # needed for shuffle in distributed training + if self.config["distributed"]: + self.sampler["train"].set_epoch(self.epochs) + + @torch.no_grad() + def _eval_step(self, batch): + """Evaluate model one step.""" + pass + + def _eval(self): + """Evaluate model with dev set.""" + logging.info(f"(Steps: {self.steps}) Start evaluation.") + # change mode + self.model.eval() + start_time = time.time() + + # loop through dev set + for count, batch in enumerate(self.data_loader["dev"], 1): + self._eval_step(batch) + if "dev_samples_per_eval_loop" in self.config: + if count > self.config["dev_samples_per_eval_loop"]: + break + + logging.info( + f"(Steps: {self.steps}) Finished evaluation " + f"({time.time() - start_time} secs)." + ) + + @torch.no_grad() + def _log_metrics_and_save_figures(self): + """Log metrics and save figures.""" + pass + + def _write_to_tensorboard(self, loss): + """Write to tensorboard.""" + for key, value in loss.items(): + self.writer.add_scalar(key, value, self.steps) + + def _check_eval_and_save_interval(self): + if self.steps % self.config["eval_and_save_interval_steps"] == 0: + # run evaluation on dev set + self._eval() + + # get metrics and save figures + self._log_metrics_and_save_figures() + + # get best n steps + best_n_steps = self.get_and_show_best_n_models() + + # save current if in best n + if self.steps in best_n_steps: + current_checkpoint_path = os.path.join( + self.config["outdir"], f"checkpoint-{self.steps}steps.pkl" + ) + self.save_checkpoint(current_checkpoint_path) + logging.info( + f"Saved checkpoint @ {self.steps} steps because it is in best {self.config['keep_nbest_models']}." + ) + + # retstore patience + if self.original_patience is not None: + self.current_patience = self.original_patience + logging.info(f"Restoring patience to {self.original_patience}.") + else: + # minus patience + if self.current_patience is not None: + self.current_patience -= 1 + logging.info(f"Reducing patience to {self.current_patience}.") + + # if current is best, link to best + if self.steps == best_n_steps[0]: + best_checkpoint_path = os.path.join( + self.config["outdir"], f"checkpoint-best.pkl" + ) + if os.path.islink(best_checkpoint_path) or os.path.exists( + best_checkpoint_path + ): + os.remove(best_checkpoint_path) + os.symlink(current_checkpoint_path, best_checkpoint_path) + logging.info(f"Updated best checkpoint to {self.steps} steps.") + + # delete those not in best n + existing_checkpoint_paths = [ + fname + for fname in os.listdir(self.config["outdir"]) + if os.path.isfile(os.path.join(self.config["outdir"], fname)) + and fname.endswith("steps.pkl") + and not fname.startswith("original") + ] + for checkpoint_path in existing_checkpoint_paths: + steps = int( + checkpoint_path.replace("steps.pkl", "").replace("checkpoint-", "") + ) + if steps not in best_n_steps: + os.remove(os.path.join(self.config["outdir"], checkpoint_path)) + logging.info(f"Deleting checkpoint @ {steps} steps.") + + # reset + self.reset_eval_results() + + # restore mode + self.model.train() + + def _check_log_interval(self): + if self.steps % self.config["log_interval_steps"] == 0: + for key in self.total_train_loss.keys(): + self.total_train_loss[key] /= self.config["log_interval_steps"] + logging.info( + f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." + ) + self._write_to_tensorboard(self.total_train_loss) + + # reset + self.total_train_loss = defaultdict(float) + + def _check_train_finish(self): + if self.steps >= self.config["train_max_steps"]: + self.finish_train = True + + if self.current_patience is not None and self.current_patience <= 0: + self.finish_train = True + + def freeze_modules(self, modules): + freeze_modules(self.model, modules) + + def reset_eval_results(self): + self.eval_results = defaultdict(list) + self.eval_sys_results = defaultdict(lambda: defaultdict(list)) + + def get_and_show_best_n_models(self): + # sort according to key + best_n = sorted( + self.reporter, + key=lambda x: x[1][self.config["best_model_criterion"]["key"]], + ) + if ( + self.config["best_model_criterion"]["order"] == "highest" + ): # reverse if highest + best_n.reverse() + + # log the results + logging.info( + f"Best {self.config['keep_nbest_models']} models at step {self.steps}:" + ) + log_string = "; ".join( + f"{i+1}. {steps} steps: {self.config['best_model_criterion']['key']}={results[self.config['best_model_criterion']['key']]:.4f}" + for i, (steps, results) in enumerate( + best_n[: self.config["keep_nbest_models"]] + ) + ) + logging.info(log_string) + + # only return the steps + return [steps for steps, _ in best_n[: self.config["keep_nbest_models"]]] diff --git a/sheet/trainers/non_intrusive.py b/sheet/trainers/non_intrusive.py new file mode 100644 index 0000000..fbbbf90 --- /dev/null +++ b/sheet/trainers/non_intrusive.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +import logging +import os +import time + +# set to avoid matplotlib error in CLI environment +import matplotlib +import numpy as np +import soundfile as sf +import torch +from sheet.evaluation.metrics import calculate +from sheet.evaluation.plot import ( + plot_sys_level_scatter, + plot_utt_level_hist, + plot_utt_level_scatter, +) +from sheet.trainers.base import Trainer +from sheet.utils.model_io import ( + filter_modules, + get_partial_state_dict, + print_new_keys, + transfer_verification, +) + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +class NonIntrusiveEstimatorTrainer(Trainer): + """Customized trainer module for non-intrusive estimator.""" + + def load_trained_modules(self, checkpoint_path, init_mods): + if self.config["distributed"]: + main_state_dict = self.model.module.state_dict() + else: + main_state_dict = self.model.state_dict() + + if os.path.isfile(checkpoint_path): + model_state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] + + # first make sure that all modules in `init_mods` are in `checkpoint_path` + modules = filter_modules(model_state_dict, init_mods) + + # then, actually get the partial state_dict + partial_state_dict = get_partial_state_dict(model_state_dict, modules) + + if partial_state_dict: + if transfer_verification(main_state_dict, partial_state_dict, modules): + print_new_keys(partial_state_dict, modules, checkpoint_path) + main_state_dict.update(partial_state_dict) + else: + logging.error(f"Specified model was not found: {checkpoint_path}") + exit(1) + + if self.config["distributed"]: + self.model.module.load_state_dict(main_state_dict) + else: + self.model.load_state_dict(main_state_dict) + + def _train_step(self, batch): + """Train model one step.""" + + # set inputs + gen_loss = 0.0 + inputs = { + self.config["model_input"]: batch[self.config["model_input"]].to( + self.device + ), + self.config["model_input"] + + "_lengths": batch[self.config["model_input"] + "_lengths"].to( + self.device + ), + } + if "listener_idxs" in batch: + inputs["listener_idxs"] = batch["listener_idxs"].to(self.device) + if "domain_idxs" in batch: + inputs["domain_idxs"] = batch["domain_idxs"].to(self.device) + if "phoneme_idxs" in batch: + inputs["phoneme_idxs"] = batch["phoneme_idxs"].to(self.device) + inputs["phoneme_lengths"] = batch["phoneme_lengths"] + if "reference_idxs" in batch: + inputs["reference_idxs"] = batch["reference_idxs"].to(self.device) + inputs["reference_lengths"] = batch["reference_lengths"] + + # model forward + outputs = self.model(inputs) + + # get frame lengths if exist + if "frame_lengths" in outputs: + output_frame_lengths = outputs["frame_lengths"] + else: + output_frame_lengths = None + + # get ground truth scores + gt_scores = batch["scores"].to(self.device) + gt_avg_scores = batch["avg_scores"].to(self.device) + if "categorical_scores" in batch: + categorical_gt_scores = batch["categorical_scores"].to(self.device) + if "categorical_avg_scores" in batch: + categorical_gt_avg_scores = batch["categorical_avg_scores"].to(self.device) + + # mean loss + if "mean_score_criterions" in self.criterion: + for criterion_dict in self.criterion["mean_score_criterions"]: + if criterion_dict["type"] in ["GaussianNLLLoss", "LaplaceNLLLoss"]: + loss = criterion_dict["criterion"]( + outputs["mean_scores"], + outputs["mean_scores_logvar"], + gt_avg_scores, + self.device, + lens=output_frame_lengths, + ) + else: + # always pass the following arguments + loss = criterion_dict["criterion"]( + outputs["mean_scores"], + ( + categorical_gt_avg_scores + if criterion_dict["type"] == "CategoricalLoss" + else gt_avg_scores + ), + self.device, + lens=output_frame_lengths, + ) + gen_loss += loss * criterion_dict["weight"] + self.total_train_loss["train/mean_" + criterion_dict["type"]] += ( + loss.item() / self.gradient_accumulate_steps + ) + + # categorical head loss (for RAMP only) + if "categorical_head_criterions" in self.criterion: + for criterion_dict in self.criterion["categorical_head_criterions"]: + # always pass the following arguments + loss = criterion_dict["criterion"]( + outputs["categorical_head_scores"], + categorical_gt_avg_scores, + self.device, + lens=output_frame_lengths, + ) + gen_loss += loss * criterion_dict["weight"] + self.total_train_loss["train/categorical_head_loss"] += ( + loss.item() / self.gradient_accumulate_steps + ) + + # listener loss + if "listener_score_criterions" in self.criterion: + for criterion_dict in self.criterion["listener_score_criterions"]: + # always pass the following arguments + loss = criterion_dict["criterion"]( + outputs["ld_scores"], + ( + categorical_gt_scores + if criterion_dict["type"] == "CategoricalLoss" + else gt_scores + ), + self.device, + lens=output_frame_lengths, + ) + gen_loss += loss * criterion_dict["weight"] + self.total_train_loss["train/listener_" + criterion_dict["type"]] += ( + loss.item() / self.gradient_accumulate_steps + ) + + self.total_train_loss["train/loss"] += ( + gen_loss.item() / self.gradient_accumulate_steps + ) + + # update model + if self.gradient_accumulate_steps > 1: + gen_loss = gen_loss / self.gradient_accumulate_steps + gen_loss.backward() + self.all_loss += loss.item() + + self.backward_steps += 1 + if self.backward_steps % self.gradient_accumulate_steps > 0: + return + + if self.config["grad_norm"] > 0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.config["grad_norm"], + ) + self.optimizer.step() + self.optimizer.zero_grad() + if self.scheduler is not None: + self.scheduler.step() + + # update counts + self.steps += 1 + self.tqdm.update(1) + self._check_train_finish() + + @torch.no_grad() + def _eval_step(self, batch): + """Evaluate model one step.""" + + # set up model input + inputs = { + self.config["model_input"]: batch[self.config["model_input"]].to( + self.device + ), + self.config["model_input"] + + "_lengths": batch[self.config["model_input"] + "_lengths"].to( + self.device + ), + } + if "domain_idxs" in batch: + inputs["domain_idxs"] = batch["domain_idxs"].to(self.device) + if "phoneme_idxs" in batch: + inputs["phoneme_idxs"] = batch["phoneme_idxs"].to(self.device) + inputs["phoneme_lengths"] = batch["phoneme_lengths"] + if "reference_idxs" in batch: + inputs["reference_idxs"] = batch["reference_idxs"].to(self.device) + inputs["reference_lengths"] = batch["reference_lengths"] + + # model forward + if self.config["inference_mode"] == "mean_listener": + outputs = self.model.mean_listener_inference(inputs) + elif self.config["inference_mode"] == "mean_net": + outputs = self.model.mean_net_inference(inputs) + + # construct the eval_results dict + pred_mean_scores = outputs["scores"].cpu().detach().numpy() + true_mean_scores = batch["avg_scores"].numpy() + self.eval_results["pred_mean_scores"].extend(pred_mean_scores.tolist()) + self.eval_results["true_mean_scores"].extend(true_mean_scores.tolist()) + sys_names = batch["system_ids"] + for j, sys_name in enumerate(sys_names): + self.eval_sys_results["pred_mean_scores"][sys_name].append( + pred_mean_scores[j] + ) + self.eval_sys_results["true_mean_scores"][sys_name].append( + true_mean_scores[j] + ) + + @torch.no_grad() + def _log_metrics_and_save_figures(self): + """Log metrics and save figures.""" + + self.eval_results["true_mean_scores"] = np.array( + self.eval_results["true_mean_scores"] + ) + self.eval_results["pred_mean_scores"] = np.array( + self.eval_results["pred_mean_scores"] + ) + self.eval_sys_results["true_mean_scores"] = np.array( + [ + np.mean(scores) + for scores in self.eval_sys_results["true_mean_scores"].values() + ] + ) + self.eval_sys_results["pred_mean_scores"] = np.array( + [ + np.mean(scores) + for scores in self.eval_sys_results["pred_mean_scores"].values() + ] + ) + + # calculate metrics + results = calculate( + self.eval_results["true_mean_scores"], + self.eval_results["pred_mean_scores"], + self.eval_sys_results["true_mean_scores"], + self.eval_sys_results["pred_mean_scores"], + ) + + # log metrics + logging.info( + f'[{self.steps} steps][UTT][ MSE = {results["utt_MSE"]:.3f} | LCC = {results["utt_LCC"]:.3f} | SRCC = {results["utt_SRCC"]:.3f} ] [SYS][ MSE = {results["sys_MSE"]:.3f} | LCC = {results["sys_LCC"]:.4f} | SRCC = {results["sys_SRCC"]:.4f} ]\n' + ) + + # register metrics to reporter + self.reporter.append([self.steps, results]) + + # check directory + dirname = os.path.join( + self.config["outdir"], f"intermediate_results/{self.steps}steps" + ) + if not os.path.exists(dirname): + os.makedirs(dirname) + + # plot + plot_utt_level_hist( + self.eval_results["true_mean_scores"], + self.eval_results["pred_mean_scores"], + os.path.join(dirname, "distribution.png"), + ) + plot_utt_level_scatter( + self.eval_results["true_mean_scores"], + self.eval_results["pred_mean_scores"], + os.path.join(dirname, "utt_scatter_plot.png"), + results["utt_LCC"], + results["utt_SRCC"], + results["utt_MSE"], + results["utt_KTAU"], + ) + plot_sys_level_scatter( + self.eval_sys_results["true_mean_scores"], + self.eval_sys_results["pred_mean_scores"], + os.path.join(dirname, "sys_scatter_plot.png"), + results["sys_LCC"], + results["sys_SRCC"], + results["sys_MSE"], + results["sys_KTAU"], + ) diff --git a/sheet/utils/__init__.py b/sheet/utils/__init__.py new file mode 100644 index 0000000..e8fa95a --- /dev/null +++ b/sheet/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * # NOQA diff --git a/sheet/utils/download.py b/sheet/utils/download.py new file mode 100644 index 0000000..0eb6bd7 --- /dev/null +++ b/sheet/utils/download.py @@ -0,0 +1,213 @@ +""" +Thread-safe file downloading and cacheing + +Authors + * Leo 2022 + * Cheng Liang 2022 +""" + +import hashlib +import logging +import os +import shutil +import sys +import tempfile +import time +from pathlib import Path +from urllib.request import Request, urlopen + +import requests +from filelock import FileLock +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +_download_dir = Path.home() / ".cache" / "sheet" / "download" + +__all__ = [ + "get_dir", + "set_dir", + "download", + "urls_to_filepaths", +] + + +def get_dir(): + _download_dir.mkdir(exist_ok=True, parents=True) + return _download_dir + + +def set_dir(d): + global _download_dir + _download_dir = Path(d) + + +def _download_url_to_file(url, dst, hash_prefix=None, progress=True): + """ + This function is not thread-safe. Please ensure only a single + thread or process can enter this block at the same time + """ + + file_size = None + req = Request(url, headers={"User-Agent": "torch.hub"}) + u = urlopen(req) + meta = u.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + + tqdm.write(f"Downloading: {url}", file=sys.stderr) + tqdm.write(f"Destination: {dst}", file=sys.stderr) + with tqdm( + total=file_size, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[: len(hash_prefix)] != hash_prefix: + raise RuntimeError( + 'invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest + ) + ) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def _download_url_to_file_requests(url, dst, hash_prefix=None, progress=True): + """ + Alternative download when urllib.Request fails. + """ + + req = requests.get(url, stream=True, headers={"User-Agent": "torch.hub"}) + file_size = int(req.headers["Content-Length"]) + + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + + tqdm.write( + f"urllib.Request method failed. Trying using another method...", + file=sys.stderr, + ) + tqdm.write(f"Downloading: {url}", file=sys.stderr) + tqdm.write(f"Destination: {dst}", file=sys.stderr) + with tqdm( + total=file_size, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + for chunk in req.iter_content(chunk_size=1024 * 1024 * 10): + if chunk: + f.write(chunk) + f.flush() + os.fsync(f.fileno()) + if hash_prefix is not None: + sha256.update(chunk) + pbar.update(len(chunk)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[: len(hash_prefix)] != hash_prefix: + raise RuntimeError( + 'invalid hash value (expected "{}", got "{}")'.format( + hash_prefix, digest + ) + ) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def _download(filepath: Path, url, refresh: bool, new_enough_secs: float = 2.0): + """ + If refresh is True, check the latest modfieid time of the filepath. + If the file is new enough (no older than `new_enough_secs`), than directly use it. + If the file is older than `new_enough_secs`, than re-download the file. + This function is useful when multi-processes are all downloading the same large file + """ + + Path(filepath).parent.mkdir(exist_ok=True, parents=True) + + lock_file = Path(str(filepath) + ".lock") + logger.info(f"Requesting URL: {url}") + + with FileLock(str(lock_file)): + if not filepath.is_file() or ( + refresh and (time.time() - os.path.getmtime(filepath)) > new_enough_secs + ): + try: + _download_url_to_file(url, filepath) + except: + _download_url_to_file_requests(url, filepath) + + logger.info(f"Using URL's local file: {filepath}") + try: + lock_file.unlink() + except FileNotFoundError: + pass + + +def _urls_to_filepaths(*args, refresh=False, download: bool = True): + """ + Preprocess the URL specified in *args into local file paths after downloading + + Args: + Any number of URLs (1 ~ any) + + Return: + Same number of downloaded file paths + """ + + def _url_to_filepath(url): + assert isinstance(url, str) + m = hashlib.sha256() + m.update(str.encode(url)) + filepath = get_dir() / f"{str(m.hexdigest())}.{Path(url).name}" + if download: + _download(filepath, url, refresh=refresh) + return str(filepath.resolve()) + + paths = [_url_to_filepath(url) for url in args] + return paths if len(paths) > 1 else paths[0] + + +download = _download +urls_to_filepaths = _urls_to_filepaths diff --git a/sheet/utils/model_io.py b/sheet/utils/model_io.py new file mode 100644 index 0000000..bde38b5 --- /dev/null +++ b/sheet/utils/model_io.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2024 Wen-Chin Huang +# MIT License (https://opensource.org/licenses/MIT) + +import logging +import os +from collections import OrderedDict + +import torch + + +def print_new_keys(state_dict, modules, model_path): + logging.info(f"Loading {modules} from model: {model_path}") + + for k in state_dict.keys(): + logging.warning(f"Overriding module {k}") + + +def filter_modules(model_state_dict, modules): + """Filter non-matched modules in model state dict. + Args: + model_state_dict (Dict): Pre-trained model state dict. + modules (List): Specified module(s) to transfer. + Return: + new_mods (List): Filtered module list. + """ + new_mods = [] + incorrect_mods = [] + + mods_model = list(model_state_dict.keys()) + for mod in modules: + if any(key.startswith(mod) for key in mods_model): + new_mods += [mod] + else: + incorrect_mods += [mod] + + if incorrect_mods: + logging.error( + "Specified module(s) don't match or (partially match) " + f"available modules in model. You specified: {incorrect_mods}." + ) + logging.error("The existing modules in model are:") + logging.error(f"{mods_model}") + exit(1) + + return new_mods + + +def get_partial_state_dict(model_state_dict, modules): + """Create state dict with specified modules matching input model modules. + Args: + model_state_dict (Dict): Pre-trained model state dict. + modules (Dict): Specified module(s) to transfer. + Return: + new_state_dict (Dict): State dict with specified modules weights. + """ + new_state_dict = OrderedDict() + + for key, value in model_state_dict.items(): + if any(key.startswith(m) for m in modules): + new_state_dict[key] = value + + return new_state_dict + + +def transfer_verification(model_state_dict, partial_state_dict, modules): + """Verify tuples (key, shape) for input model modules match specified modules. + Args: + model_state_dict (Dict) : Main model state dict. + partial_state_dict (Dict): Pre-trained model state dict. + modules (List): Specified module(s) to transfer. + Return: + (bool): Whether transfer learning is allowed. + """ + model_modules = [] + partial_modules = [] + + for key_m, value_m in model_state_dict.items(): + if any(key_m.startswith(m) for m in modules): + model_modules += [(key_m, value_m.shape)] + model_modules = sorted(model_modules, key=lambda x: (x[0], x[1])) + + for key_p, value_p in partial_state_dict.items(): + if any(key_p.startswith(m) for m in modules): + partial_modules += [(key_p, value_p.shape)] + partial_modules = sorted(partial_modules, key=lambda x: (x[0], x[1])) + + module_match = model_modules == partial_modules + + if not module_match: + logging.error( + "Some specified modules from the pre-trained model " + "don't match with the new model modules:" + ) + logging.error(f"Pre-trained: {set(partial_modules) - set(model_modules)}") + logging.error(f"New model: {set(model_modules) - set(partial_modules)}") + exit(1) + + return module_match + + +def freeze_modules(model, modules): + """Freeze model parameters according to modules list. + Args: + model (torch.nn.Module): Main model. + modules (List): Specified module(s) to freeze. + Return: + model (torch.nn.Module) : Updated main model. + model_params (filter): Filtered model parameters. + """ + for mod, param in model.named_parameters(): + if any(mod.startswith(m) for m in modules): + logging.warning(f"Freezing {mod}. It will not be updated during training.") + param.requires_grad = False + + model_params = filter(lambda x: x.requires_grad, model.parameters()) + + return model, model_params + + +@torch.no_grad() +def model_average(model, outdir): + """Generate averaged model from existing models + + Args: + model: the model instance + outdir: the directory contains the model files + """ + # get model checkpoints + checkpoint_paths = [ + os.path.join(outdir, p) + for p in os.listdir(outdir) + if os.path.isfile(os.path.join(outdir, p)) and p.endswith("steps.pkl") + ] + n = len(checkpoint_paths) + + # load the checkpoints + avg = None + for checkpoint_path in checkpoint_paths: + states = torch.load(checkpoint_path, map_location="cpu")["model"] + if avg is None: + avg = states + else: + # Accumulated + for k in avg: + avg[k] = avg[k] + states[k] + + # take average + for k in avg: + if str(avg[k].dtype).startswith("torch.int"): + # For int type, not averaged, but only accumulated. + # e.g. BatchNorm.num_batches_tracked + # (If there are any cases that requires averaging + # or the other reducing method, e.g. max/min, for integer type, + # please report.) + logging.info(f"Accumulating {k} instead of averaging") + pass + else: + avg[k] = avg[k] / n + + # load into model + model.load_state_dict(avg) + + return model, checkpoint_paths diff --git a/sheet/utils/types.py b/sheet/utils/types.py new file mode 100644 index 0000000..fd43b9c --- /dev/null +++ b/sheet/utils/types.py @@ -0,0 +1,139 @@ +from distutils.util import strtobool +from typing import Optional, Tuple, Union + + +def str2bool(value: str) -> bool: + return bool(strtobool(value)) + + +def remove_parenthesis(value: str): + value = value.strip() + if value.startswith("(") and value.endswith(")"): + value = value[1:-1] + elif value.startswith("[") and value.endswith("]"): + value = value[1:-1] + return value + + +def remove_quotes(value: str): + value = value.strip() + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.startswith("'") and value.endswith("'"): + value = value[1:-1] + return value + + +def int_or_none(value: str) -> Optional[int]: + """int_or_none. + + Examples: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=int_or_none) + >>> parser.parse_args(['--foo', '456']) + Namespace(foo=456) + >>> parser.parse_args(['--foo', 'none']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'null']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'nil']) + Namespace(foo=None) + + """ + if value.strip().lower() in ("none", "null", "nil"): + return None + return int(value) + + +def float_or_none(value: str) -> Optional[float]: + """float_or_none. + + Examples: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=float_or_none) + >>> parser.parse_args(['--foo', '4.5']) + Namespace(foo=4.5) + >>> parser.parse_args(['--foo', 'none']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'null']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'nil']) + Namespace(foo=None) + + """ + if value.strip().lower() in ("none", "null", "nil"): + return None + return float(value) + + +def str_or_int(value: str) -> Union[str, int]: + try: + return int(value) + except ValueError: + return value + + +def str_or_none(value: str) -> Optional[str]: + """str_or_none. + + Examples: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=str_or_none) + >>> parser.parse_args(['--foo', 'aaa']) + Namespace(foo='aaa') + >>> parser.parse_args(['--foo', 'none']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'null']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'nil']) + Namespace(foo=None) + + """ + if value.strip().lower() in ("none", "null", "nil"): + return None + return value + + +def str2pair_str(value: str) -> Tuple[str, str]: + """str2pair_str. + + Examples: + >>> import argparse + >>> str2pair_str('abc,def ') + ('abc', 'def') + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=str2pair_str) + >>> parser.parse_args(['--foo', 'abc,def']) + Namespace(foo=('abc', 'def')) + + """ + value = remove_parenthesis(value) + a, b = value.split(",") + + # Workaround for configargparse issues: + # If the list values are given from yaml file, + # the value givent to type() is shaped as python-list, + # e.g. ['a', 'b', 'c'], + # so we need to remove double quotes from it. + return remove_quotes(a), remove_quotes(b) + + +def str2triple_str(value: str) -> Tuple[str, str, str]: + """str2triple_str. + + Examples: + >>> str2triple_str('abc,def ,ghi') + ('abc', 'def', 'ghi') + """ + value = remove_parenthesis(value) + a, b, c = value.split(",") + + # Workaround for configargparse issues: + # If the list values are given from yaml file, + # the value givent to type() is shaped as python-list, + # e.g. ['a', 'b', 'c'], + # so we need to remove quotes from it. + return remove_quotes(a), remove_quotes(b), remove_quotes(c) diff --git a/sheet/utils/utils.py b/sheet/utils/utils.py new file mode 100644 index 0000000..e039370 --- /dev/null +++ b/sheet/utils/utils.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Utility functions.""" + +import csv +import fnmatch +import logging +import os +import sys + +import h5py +import numpy as np + + +def get_basename(path): + return os.path.splitext(os.path.split(path)[-1])[0] + + +def read_csv(path, dict_reader=False, lazy=False, encoding=None): + """ + + If `dict_reader` is set to True, then return . + If `dict_reader` is set to False, then return . + """ + + """Read the csv file. + + Args: + path (str): path to the csv file + dict_reader (bool): whether to use dict reader. This should be set to true when the csv file has header. + lazy (bool): whether to read the file in this funcion. + + Return: + contents: reader or line of contents + fieldnames (list): header. If dict_reader is False, then return None. + + """ + + with open(path, newline="", encoding=encoding) as csvfile: + if dict_reader: + reader = csv.DictReader(csvfile) + fieldnames = reader.fieldnames + else: + reader = csv.reader(csvfile) + fieldnames = None + + if lazy: + contents = reader + else: + contents = [line for line in reader] + + return contents, fieldnames + +def write_csv(data, path): + """Write data to the output path. + + Args: + path (str): path to the output csv file + data (list): a list of dicts + + """ + fieldnames = list(data[0].keys()) + with open(path, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for line in data: + writer.writerow(line) + +def find_files(root_dir, query="*.wav", include_root_dir=True): + """Find files recursively. + + Args: + root_dir (str): Root root_dir to find. + query (str): Query to find. + include_root_dir (bool): If False, root_dir name is not included. + + Returns: + list: List of found filenames. + + """ + files = [] + for root, dirnames, filenames in os.walk(root_dir, followlinks=True): + for filename in fnmatch.filter(filenames, query): + files.append(os.path.join(root, filename)) + if not include_root_dir: + files = [file_.replace(root_dir + "/", "") for file_ in files] + + return files + + +def read_hdf5(hdf5_name, hdf5_path): + """Read hdf5 dataset. + + Args: + hdf5_name (str): Filename of hdf5 file. + hdf5_path (str): Dataset name in hdf5 file. + + Return: + any: Dataset values. + + """ + if not os.path.exists(hdf5_name): + logging.error(f"There is no such a hdf5 file ({hdf5_name}).") + sys.exit(1) + + hdf5_file = h5py.File(hdf5_name, "r") + + if hdf5_path not in hdf5_file: + logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") + sys.exit(1) + + hdf5_data = hdf5_file[hdf5_path][()] + hdf5_file.close() + + return hdf5_data + + +def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): + """Write dataset to hdf5. + + Args: + hdf5_name (str): Hdf5 dataset filename. + hdf5_path (str): Dataset path in hdf5. + write_data (ndarray): Data to write. + is_overwrite (bool): Whether to overwrite dataset. + + """ + # convert to numpy array + write_data = np.array(write_data) + + # check folder existence + folder_name, _ = os.path.split(hdf5_name) + if not os.path.exists(folder_name) and len(folder_name) != 0: + os.makedirs(folder_name) + + # check hdf5 existence + if os.path.exists(hdf5_name): + # if already exists, open with r+ mode + hdf5_file = h5py.File(hdf5_name, "r+") + # check dataset existence + if hdf5_path in hdf5_file: + if is_overwrite: + logging.warning( + "Dataset in hdf5 file already exists. recreate dataset in hdf5." + ) + hdf5_file.__delitem__(hdf5_path) + else: + logging.error( + "Dataset in hdf5 file already exists. " + "if you want to overwrite, please set is_overwrite = True." + ) + hdf5_file.close() + sys.exit(1) + else: + # if not exists, open with w mode + hdf5_file = h5py.File(hdf5_name, "w") + + # write data to hdf5 + hdf5_file.create_dataset(hdf5_path, data=write_data) + hdf5_file.flush() + hdf5_file.close() diff --git a/sheet/warmup_lr.py b/sheet/warmup_lr.py new file mode 100644 index 0000000..8406894 --- /dev/null +++ b/sheet/warmup_lr.py @@ -0,0 +1,62 @@ +"""Warm up learning rate scheduler module.""" + +from abc import ABC, abstractmethod +from typing import Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class AbsBatchStepScheduler(ABC): + @abstractmethod + def step(self, epoch: int = None): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state): + pass + + +class WarmupLR(_LRScheduler, AbsBatchStepScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 4000, + last_epoch: int = -1, + ): + self.warmup_steps = warmup_steps + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + return [ + lr + * self.warmup_steps**0.5 + * min(step_num**-0.5, step_num * self.warmup_steps**-1.5) + for lr in self.base_lrs + ] From dd748525f97a3716bf666f9bc036eb106eef5f42 Mon Sep 17 00:00:00 2001 From: darryllam Date: Thu, 6 Nov 2025 18:40:09 +0900 Subject: [PATCH 5/9] Added install bash script --- README.md | 2 +- egs/bvcc/path.sh | 2 +- install.sh | 13 +++++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 install.sh diff --git a/README.md b/README.md index 0400de9..4d06114 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ First install the uv package manager [here](https://docs.astral.sh/uv/getting-st ```bash git clone https://github.com/unilight/sheet.git cd sheet -uv sync --extras train +bash install.sh train ``` ## Information diff --git a/egs/bvcc/path.sh b/egs/bvcc/path.sh index 9ddc626..4069294 100755 --- a/egs/bvcc/path.sh +++ b/egs/bvcc/path.sh @@ -6,7 +6,7 @@ if [ -e "${PRJ_ROOT}/tools/venv/bin/activate" ]; then fi MAIN_ROOT=$PWD/../.. -export PATH=$MAIN_ROOT/sheet/bin:$PATH +export PATH=$MAIN_ROOT/src/sheet/bin:$PATH # python related export OMP_NUM_THREADS=1 diff --git a/install.sh b/install.sh new file mode 100644 index 0000000..c4f5580 --- /dev/null +++ b/install.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +uv venv tools/venv +source tools/venv/bin/activate + +if [[ "$1" == "train" ]]; then + echo "Including 'train' extras..." + uv sync --extra train --active +else + echo "Syncing without 'train' extras..." + uv sync --active +fi \ No newline at end of file From 59a4b2f571ec086f22a7e40e5c680d69aee4b379 Mon Sep 17 00:00:00 2001 From: darryllam Date: Thu, 6 Nov 2025 18:44:06 +0900 Subject: [PATCH 6/9] Fix to version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ced6198..67ebf4a 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "uv_build" [project] name = "sheet" -version = "0.1.0" +version = "0.2.5" description = "Speech Human Evaluation Estimation Toolkit (SHEET)" requires-python = "==3.10.13" From 565758955e32f94bbf456d20ec499f5422d3e5c4 Mon Sep 17 00:00:00 2001 From: darryllam Date: Mon, 10 Nov 2025 12:00:23 +0900 Subject: [PATCH 7/9] Removed setup.cfg. Updated pyproject.toml. Removed stray src folder --- pyproject.toml | 11 +- setup.cfg | 41 -- src/sheet/__init__.py | 3 - src/sheet/bin/construct_datastore.py | 176 --------- src/sheet/bin/inference.py | 434 --------------------- src/sheet/bin/nonparametric_inference.py | 400 ------------------- src/sheet/bin/train.py | 396 ------------------- src/sheet/bin/train_stack.py | 188 --------- src/sheet/collaters/__init__.py | 1 - src/sheet/collaters/non_intrusive.py | 108 ------ src/sheet/datasets/__init__.py | 1 - src/sheet/datasets/non_intrusive.py | 340 ----------------- src/sheet/evaluation/metrics.py | 34 -- src/sheet/evaluation/plot.py | 108 ------ src/sheet/losses/__init__.py | 3 - src/sheet/losses/basic_losses.py | 91 ----- src/sheet/losses/contrastive_loss.py | 39 -- src/sheet/losses/nll_losses.py | 109 ------ src/sheet/models/__init__.py | 9 - src/sheet/models/alignnet.py | 400 ------------------- src/sheet/models/ldnet.py | 288 -------------- src/sheet/models/sslmos.py | 467 ----------------------- src/sheet/models/sslmos_u.py | 256 ------------- src/sheet/models/utmos.py | 299 --------------- src/sheet/modules/__init__.py | 0 src/sheet/modules/ldnet/__init__.py | 0 src/sheet/modules/ldnet/mobilenetv2.py | 240 ------------ src/sheet/modules/ldnet/mobilenetv3.py | 341 ----------------- src/sheet/modules/ldnet/modules.py | 181 --------- src/sheet/modules/utils.py | 222 ----------- src/sheet/nonparametric/__init__.py | 0 src/sheet/nonparametric/datastore.py | 77 ---- src/sheet/schedulers/__init__.py | 1 - src/sheet/schedulers/schedulers.py | 21 - src/sheet/trainers/__init__.py | 2 - src/sheet/trainers/base.py | 315 --------------- src/sheet/trainers/non_intrusive.py | 310 --------------- src/sheet/utils/__init__.py | 1 - src/sheet/utils/download.py | 213 ----------- src/sheet/utils/model_io.py | 166 -------- src/sheet/utils/types.py | 139 ------- src/sheet/utils/utils.py | 164 -------- src/sheet/warmup_lr.py | 62 --- 43 files changed, 8 insertions(+), 6649 deletions(-) delete mode 100755 setup.cfg delete mode 100644 src/sheet/__init__.py delete mode 100755 src/sheet/bin/construct_datastore.py delete mode 100755 src/sheet/bin/inference.py delete mode 100755 src/sheet/bin/nonparametric_inference.py delete mode 100755 src/sheet/bin/train.py delete mode 100755 src/sheet/bin/train_stack.py delete mode 100644 src/sheet/collaters/__init__.py delete mode 100644 src/sheet/collaters/non_intrusive.py delete mode 100644 src/sheet/datasets/__init__.py delete mode 100644 src/sheet/datasets/non_intrusive.py delete mode 100644 src/sheet/evaluation/metrics.py delete mode 100644 src/sheet/evaluation/plot.py delete mode 100644 src/sheet/losses/__init__.py delete mode 100644 src/sheet/losses/basic_losses.py delete mode 100644 src/sheet/losses/contrastive_loss.py delete mode 100644 src/sheet/losses/nll_losses.py delete mode 100644 src/sheet/models/__init__.py delete mode 100644 src/sheet/models/alignnet.py delete mode 100644 src/sheet/models/ldnet.py delete mode 100644 src/sheet/models/sslmos.py delete mode 100644 src/sheet/models/sslmos_u.py delete mode 100644 src/sheet/models/utmos.py delete mode 100644 src/sheet/modules/__init__.py delete mode 100644 src/sheet/modules/ldnet/__init__.py delete mode 100644 src/sheet/modules/ldnet/mobilenetv2.py delete mode 100644 src/sheet/modules/ldnet/mobilenetv3.py delete mode 100644 src/sheet/modules/ldnet/modules.py delete mode 100644 src/sheet/modules/utils.py delete mode 100644 src/sheet/nonparametric/__init__.py delete mode 100644 src/sheet/nonparametric/datastore.py delete mode 100644 src/sheet/schedulers/__init__.py delete mode 100644 src/sheet/schedulers/schedulers.py delete mode 100644 src/sheet/trainers/__init__.py delete mode 100644 src/sheet/trainers/base.py delete mode 100644 src/sheet/trainers/non_intrusive.py delete mode 100644 src/sheet/utils/__init__.py delete mode 100644 src/sheet/utils/download.py delete mode 100644 src/sheet/utils/model_io.py delete mode 100644 src/sheet/utils/types.py delete mode 100644 src/sheet/utils/utils.py delete mode 100644 src/sheet/warmup_lr.py diff --git a/pyproject.toml b/pyproject.toml index 67ebf4a..f61dc76 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["uv_build"] -build-backend = "uv_build" +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" [project] name = "sheet" @@ -35,4 +35,9 @@ train = [ "kaldiio>=2.14.1", "humanfriendly>=10.0", "prettytable>=3.16.0", -] \ No newline at end of file +] + +[tool.setuptools] +packages = [ + "sheet" +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100755 index 99c581a..0000000 --- a/setup.cfg +++ /dev/null @@ -1,41 +0,0 @@ -[options] -packages = find: -install_requires = - librosa >= 0.8.0 - soundfile>=0.10.2 - pyyaml - h5py>=2.9.0 - filelock - protobuf<=3.20.1 - scipy - s3prl - faiss-cpu - -[options.extras_require] -train = - matplotlib>=3.1.0 - tqdm>=4.26.1 - gdown - tensorboardX - kaldiio>=2.14.1 - humanfriendly - prettytable - -[metadata] -name = sheet_sqa -version = 0.2.5 -author = Wen-Chin Huang -author_email = wen.chinhuang@g.sp.m.is.nagoya-u.ac.jp -description = Speech Human Evaluation Estimation Toolkit (SHEET) -keywords = speech quality assessment -license = MIT -url = https://github.com/unilight/sheet -project_urls = - Source = https://github.com/unilight/sheet - Tracker = https://github.com/unilight/sheet/issues -long_description=README.md -long_description_content_type=text/markdown -classifiers = - License :: OSI Approved :: MIT License - Programming Language :: Python :: 3 - diff --git a/src/sheet/__init__.py b/src/sheet/__init__.py deleted file mode 100644 index 20fada3..0000000 --- a/src/sheet/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# -*- coding: utf-8 -*- - -__version__ = "0.2.5" diff --git a/src/sheet/bin/construct_datastore.py b/src/sheet/bin/construct_datastore.py deleted file mode 100755 index 06a5867..0000000 --- a/src/sheet/bin/construct_datastore.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Construct datastore .""" - -import argparse -import logging -import os - -import h5py -import numpy as np -import sheet -import sheet.datasets -import sheet.models -import torch -import yaml -from s3prl.nn import S3PRLUpstream -from tqdm import tqdm - - -def main(): - """Construct datastore.""" - parser = argparse.ArgumentParser( - description=( - "Construct datastore with ssl_model in trained model " - "(See detail in bin/construct_datastore.py)." - ) - ) - parser.add_argument( - "--csv-path", - required=True, - type=str, - help=("csv file path to construct datastore."), - ) - parser.add_argument( - "--out", - type=str, - required=True, - help="out path to save datastore.", - ) - parser.add_argument( - "--checkpoint", - type=str, - help="checkpoint file to be loaded.", - ) - parser.add_argument( - "--config", - default=None, - type=str, - help=( - "yaml format configuration file. if not explicitly provided, " - "it will be searched in the checkpoint directory. (default=None)" - ), - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="logging level. higher is more logging. (default=1)", - ) - args = parser.parse_args() - - # set logger - if args.verbose > 1: - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose > 0: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - - # check directory existence - # if not os.path.exists(args.outdir): - # os.makedirs(args.outdir) - - # load config - if args.config is None: - dirname = os.path.dirname(args.checkpoint) - args.config = os.path.join(dirname, "config.yml") - with open(args.config) as f: - config = yaml.load(f, Loader=yaml.Loader) - - args_dict = vars(args) - - config.update(args_dict) - for key, value in config.items(): - logging.info(f"{key} = {value}") - - # get dataset - dataset_class = getattr( - sheet.datasets, - config.get("dataset_type", "NonIntrusiveDataset"), - ) - dataset = dataset_class( - csv_path=args.csv_path, - target_sample_rate=config["sampling_rate"], - model_input=config["model_input"], - use_phoneme=config.get("use_phoneme", False), - symbols=config.get("symbols", None), - wav_only=True, - allow_cache=False, - ) - logging.info(f"Number of samples = {len(dataset)}.") - - # setup device - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - - # get ssl model - s3prl_name = config["model_params"]["s3prl_name"] - ssl_model = S3PRLUpstream(s3prl_name) - - # load pre-trained model - pt_ckpt = torch.load(os.readlink(args.checkpoint), map_location="cpu")["model"] - state_dict = { - k.replace("ssl_model.", ""): v - for k, v in pt_ckpt.items() - if k.startswith("ssl_model") - } - ssl_model.load_state_dict(state_dict) - logging.info(f"Loaded model parameters from {args.checkpoint}.") - ssl_model = ssl_model.eval().to(device) - - # start inference - if os.path.exists(args.out): - hdf5_file = h5py.File(args.out, "r+") - else: - hdf5_file = h5py.File(args.out, "w") - - with torch.no_grad(), tqdm(dataset, desc="[inference]") as pbar: - for batch in pbar: - # set up model input - model_input = batch[config["model_input"]].unsqueeze(0).to(device) - model_lengths = model_input.new_tensor([model_input.size(1)]).long() - - all_encoder_outputs, _ = ssl_model(model_input, model_lengths) - embed = ( - torch.mean( - all_encoder_outputs[ - config["model_params"]["ssl_model_layer_idx"] - ].squeeze(0), - dim=0, - ) - .detach() - .cpu() - .numpy() - ) - - system_id = batch["system_id"] - sample_id = batch["sample_id"] - hdf5_path = system_id + "_" + sample_id - score = batch["avg_score"] - - hdf5_file.create_dataset("embeds/" + hdf5_path, data=embed) - hdf5_file.create_dataset("scores/" + hdf5_path, data=score) - - hdf5_file.flush() - hdf5_file.close() - - -if __name__ == "__main__": - main() diff --git a/src/sheet/bin/inference.py b/src/sheet/bin/inference.py deleted file mode 100755 index 18a4cf3..0000000 --- a/src/sheet/bin/inference.py +++ /dev/null @@ -1,434 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Inference .""" - -import argparse -import csv -import logging -import os -import pickle -import time -from collections import defaultdict - -import numpy as np -import sheet -import sheet.datasets -import sheet.models -import torch -import yaml -from prettytable import MARKDOWN, PrettyTable -from sheet.evaluation.metrics import calculate -from sheet.evaluation.plot import ( - plot_sys_level_scatter, - plot_utt_level_hist, - plot_utt_level_scatter, -) -from sheet.utils import read_csv -from sheet.utils.model_io import model_average -from sheet.utils.types import str2bool -from tqdm import tqdm - - -def main(): - """Run inference process.""" - parser = argparse.ArgumentParser( - description=( - "Inference with trained model " "(See detail in bin/inference.py)." - ) - ) - parser.add_argument( - "--csv-path", - required=True, - type=str, - help=("csv file path to do inference."), - ) - parser.add_argument( - "--outdir", - type=str, - required=True, - help="directory to save generated figures.", - ) - parser.add_argument( - "--checkpoint", - type=str, - help="checkpoint file to be loaded.", - ) - parser.add_argument( - "--config", - default=None, - type=str, - help=( - "yaml format configuration file. if not explicitly provided, " - "it will be searched in the checkpoint directory. (default=None)" - ), - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="logging level. higher is more logging. (default=1)", - ) - parser.add_argument( - "--inference-mode", - type=str, - help="inference mode. if not specified, use the default setting in config", - ) - parser.add_argument( - "--model-averaging", - type=str2bool, - default="False", - help="if true, average all model checkpoints in the exp directory", - ) - parser.add_argument( - "--use-stacking", - type=str2bool, - default="False", - help="if true, use the stack model in the exp directory", - ) - parser.add_argument( - "--meta-model-checkpoint", - type=str, - help="checkpoint file of meta model.", - ) - args = parser.parse_args() - - # set logger - if args.verbose > 1: - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose > 0: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - - # check directory existence - if not os.path.exists(args.outdir): - os.makedirs(args.outdir) - - # load config - if args.config is None: - dirname = os.path.dirname(args.checkpoint) - args.config = os.path.join(dirname, "config.yml") - with open(args.config) as f: - config = yaml.load(f, Loader=yaml.Loader) - - args_dict = vars(args) - # do not override if inference mode not specified - if args_dict["inference_mode"] is None: - del args_dict["inference_mode"] - - # get expdir first - expdir = config["outdir"] - - config.update(args_dict) - for key, value in config.items(): - logging.info(f"{key} = {value}") - - # get dataset - dataset_class = getattr( - sheet.datasets, - config.get("dataset_type", "NonIntrusiveDataset"), - ) - dataset = dataset_class( - csv_path=args.csv_path, - target_sample_rate=config["sampling_rate"], - model_input=config["model_input"], - use_phoneme=config.get("use_phoneme", False), - symbols=config.get("symbols", None), - wav_only=True, - allow_cache=False, - ) - logging.info(f"Number of inference samples = {len(dataset)}.") - - # setup device - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - - # get model - model_class = getattr(sheet.models, config["model_type"]) - model = model_class( - config["model_input"], - num_listeners=config.get("num_listeners", None), - num_domains=config.get("num_domains", None), - **config["model_params"], - ) - - # set placeholders - eval_results = defaultdict(list) - eval_sys_results = defaultdict(lambda: defaultdict(list)) - logvars = [] - - # stacking model inference - if args.use_stacking: - # load meta model - with open(args.meta_model_checkpoint, "rb") as f: - meta_model = pickle.load(f) - - # run inference on all models - checkpoint_paths = sorted( - [ - os.path.join(expdir, p) - for p in os.listdir(expdir) - if os.path.isfile(os.path.join(expdir, p)) and p.endswith("steps.pkl") - ] - ) - xs = np.empty((len(dataset), len(checkpoint_paths))) - for i, checkpoint_path in enumerate(checkpoint_paths): - # load model - model.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["model"] - ) - logging.info(f"Loaded model parameters from {checkpoint_path}.") - model = model.eval().to(device) - - # start inference - start_time = time.time() - logging.info("Running inference...") - with torch.no_grad(): - for j, batch in enumerate(dataset): - # set up model input - model_input = batch[config["model_input"]].unsqueeze(0).to(device) - model_lengths = model_input.new_tensor([model_input.size(1)]).long() - inputs = { - config["model_input"]: model_input, - config["model_input"] + "_lengths": model_lengths, - } - if "phoneme_idxs" in batch: - inputs["phoneme_idxs"] = ( - batch["phoneme_idxs"].unsqueeze(0).to(device) - ) - inputs["phoneme_lengths"] = batch["phoneme_lengths"] - if "reference_idxs" in batch: - inputs["reference_idxs"] = ( - batch["reference_idxs"].unsqueeze(0).to(device) - ) - inputs["reference_lengths"] = batch["reference_lengths"] - - # model forward - if config["inference_mode"] == "mean_listener": - outputs = model.mean_listener_inference(inputs) - elif config["inference_mode"] == "mean_net": - outputs = model.mean_net_inference(inputs) - else: - raise NotImplementedError - - # store results - pred_score = outputs["scores"].cpu().detach().numpy()[0] - xs[j][i] = pred_score - - total_inference_time = time.time() - start_time - logging.info("Total inference time = {} secs.".format(total_inference_time)) - logging.info( - "Average inference speed = {:.3f} sec / sample.".format( - total_inference_time / len(dataset) - ) - ) - - # run inference on meta model - pred_mean_scores = meta_model.predict(xs) - - # rerun dataset to get system level scores - for i, batch in enumerate(dataset): - true_mean_scores = batch["avg_score"] - eval_results["pred_mean_scores"].append(pred_mean_scores[i]) - eval_results["true_mean_scores"].append(true_mean_scores) - sys_name = batch["system_id"] - eval_sys_results["pred_mean_scores"][sys_name].append(pred_mean_scores[i]) - eval_sys_results["true_mean_scores"][sys_name].append(true_mean_scores) - - # not using stacking - else: - # load parameter, or take average - assert (args.checkpoint == "" and args.model_averaging) or ( - args.checkpoint != "" and not args.model_averaging - ) - if args.checkpoint != "": - if os.path.islink(args.checkpoint): - model.load_state_dict( - torch.load(os.readlink(args.checkpoint), map_location="cpu")[ - "model" - ] - ) - else: - model.load_state_dict( - torch.load(args.checkpoint, map_location="cpu")["model"] - ) - logging.info(f"Loaded model parameters from {args.checkpoint}.") - else: - model, checkpoint_paths = model_average(model, expdir) - logging.info(f"Loaded model parameters from: {', '.join(checkpoint_paths)}") - model = model.eval().to(device) - - # start inference - start_time = time.time() - with torch.no_grad(), tqdm(dataset, desc="[inference]") as pbar: - for batch in pbar: - # set up model input - model_input = batch[config["model_input"]].unsqueeze(0).to(device) - model_lengths = model_input.new_tensor([model_input.size(1)]).long() - inputs = { - config["model_input"]: model_input, - config["model_input"] + "_lengths": model_lengths, - } - if "phoneme_idxs" in batch: - inputs["phoneme_idxs"] = ( - torch.tensor(batch["phoneme_idxs"], dtype=torch.long) - .unsqueeze(0) - .to(device) - ) - inputs["phoneme_lengths"] = torch.tensor( - [len(batch["phoneme_idxs"])], dtype=torch.long - ) - if "reference_idxs" in batch: - inputs["reference_idxs"] = ( - torch.tensor(batch["reference_idxs"], dtype=torch.long) - .unsqueeze(0) - .to(device) - ) - inputs["reference_lengths"] = torch.tensor( - [len(batch["reference_idxs"])], dtype=torch.long - ) - if "domain_idx" in batch: - inputs["domain_idxs"] = ( - torch.tensor(batch["domain_idx"], dtype=torch.long) - .unsqueeze(0) - .to(device) - ) - - # model forward - if config["inference_mode"] == "mean_listener": - outputs = model.mean_listener_inference(inputs) - elif config["inference_mode"] == "mean_net": - outputs = model.mean_net_inference(inputs) - else: - raise NotImplementedError - - # store results - answer = outputs["scores"].cpu().detach().numpy()[0] - if "logvars" in outputs: - logvar = outputs["logvars"].cpu().detach().numpy()[0] - logvars.append(logvar) - else: - logvar = None - dataset.fill_answer(batch["sample_id"], answer, logvar) - pred_mean_scores = answer - true_mean_scores = batch["avg_score"] - eval_results["pred_mean_scores"].append(pred_mean_scores) - eval_results["true_mean_scores"].append(true_mean_scores) - sys_name = batch["system_id"] - eval_sys_results["pred_mean_scores"][sys_name].append(pred_mean_scores) - eval_sys_results["true_mean_scores"][sys_name].append(true_mean_scores) - - total_inference_time = time.time() - start_time - logging.info("Total inference time = {} secs.".format(total_inference_time)) - logging.info( - "Average inference speed = {:.3f} sec / sample.".format( - total_inference_time / len(dataset) - ) - ) - eval_results["true_mean_scores"] = np.array(eval_results["true_mean_scores"]) - eval_results["pred_mean_scores"] = np.array(eval_results["pred_mean_scores"]) - eval_sys_results["true_mean_scores"] = np.array( - [np.mean(scores) for scores in eval_sys_results["true_mean_scores"].values()] - ) - eval_sys_results["pred_mean_scores"] = np.array( - [np.mean(scores) for scores in eval_sys_results["pred_mean_scores"].values()] - ) - - # calculate metrics - results = calculate( - eval_results["true_mean_scores"], - eval_results["pred_mean_scores"], - eval_sys_results["true_mean_scores"], - eval_sys_results["pred_mean_scores"], - ) - logging.info( - f'[UTT][ MSE = {results["utt_MSE"]:.3f} | LCC = {results["utt_LCC"]:.3f} | SRCC = {results["utt_SRCC"]:.3f} ] [SYS][ MSE = {results["sys_MSE"]:.3f} | LCC = {results["sys_LCC"]:.4f} | SRCC = {results["sys_SRCC"]:.4f} ]\n' - ) - if len(logvars) != 0: - logging.info(f'Mean log variance: {np.mean(logvars):.3f}') - - table = PrettyTable() - table.set_style(MARKDOWN) - table.field_names = [ - "Utt MSE", - "Utt LCC", - "Utt SRCC", - "Utt KTAU", - "Sys MSE", - "Sys LCC", - "Sys SRCC", - "Sys KTAU", - ] - table.add_row( - [ - round(results["utt_MSE"], 3), - round(results["utt_LCC"], 3), - round(results["utt_SRCC"], 3), - round(results["utt_KTAU"], 3), - round(results["sys_MSE"], 3), - round(results["sys_LCC"], 3), - round(results["sys_SRCC"], 3), - round(results["sys_KTAU"], 3), - ] - ) - print(table) - - # check directory - dirname = args.outdir - if not os.path.exists(dirname): - os.makedirs(dirname) - - # plot - plot_utt_level_hist( - eval_results["true_mean_scores"], - eval_results["pred_mean_scores"], - os.path.join(dirname, "distribution.png"), - ) - plot_utt_level_scatter( - eval_results["true_mean_scores"], - eval_results["pred_mean_scores"], - os.path.join(dirname, "utt_scatter_plot.png"), - results["utt_LCC"], - results["utt_SRCC"], - results["utt_MSE"], - results["utt_KTAU"], - ) - plot_sys_level_scatter( - eval_sys_results["true_mean_scores"], - eval_sys_results["pred_mean_scores"], - os.path.join(dirname, "sys_scatter_plot.png"), - results["sys_LCC"], - results["sys_SRCC"], - results["sys_MSE"], - results["sys_KTAU"], - ) - - # write results - results = dataset.return_results() - results_path = os.path.join(args.outdir, "results.csv") - fieldnames = list(results[0].keys()) - with open(results_path, "w", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - for line in results: - writer.writerow(line) - - -if __name__ == "__main__": - main() diff --git a/src/sheet/bin/nonparametric_inference.py b/src/sheet/bin/nonparametric_inference.py deleted file mode 100755 index 8a25ae1..0000000 --- a/src/sheet/bin/nonparametric_inference.py +++ /dev/null @@ -1,400 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Non-parametric inference .""" - -import argparse -import csv -import logging -import os -import pickle -import time -from collections import defaultdict - -import faiss -import h5py -import numpy as np -import sheet -import sheet.datasets -import sheet.models -import torch -import yaml -from prettytable import MARKDOWN, PrettyTable -from scipy.special import softmax -from sheet.evaluation.metrics import calculate -from sheet.evaluation.plot import ( - plot_sys_level_scatter, - plot_utt_level_hist, - plot_utt_level_scatter, -) -from sheet.nonparametric.datastore import Datastore -from sheet.utils.model_io import model_average -from sheet.utils.types import str2bool -from sheet.utils import write_csv -from tqdm import tqdm - - -def main(): - """Run inference process.""" - parser = argparse.ArgumentParser( - description=( - "Inference with trained model " "(See detail in bin/inference.py)." - ) - ) - parser.add_argument( - "--csv-path", - required=True, - type=str, - help=("csv file path to do inference."), - ) - parser.add_argument( - "--datastore", - required=True, - type=str, - help=("h5 file path of the datastore."), - ) - parser.add_argument( - "--outdir", - type=str, - required=True, - help="directory to save generated figures.", - ) - parser.add_argument( - "--checkpoint", - type=str, - help="checkpoint file to be loaded.", - ) - parser.add_argument( - "--config", - default=None, - type=str, - help=( - "yaml format configuration file. if not explicitly provided, " - "it will be searched in the checkpoint directory. (default=None)" - ), - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="logging level. higher is more logging. (default=1)", - ) - parser.add_argument( - "--inference-mode", - type=str, - help="inference mode. if not specified, use the default setting in config", - ) - parser.add_argument( - "--k", - type=int, - default=60, - help="number of neighbors", - ) - parser.add_argument( - "--np-inference-mode", - type=str, - required=True, - choices=["naive_knn", "domain_id_knn_1", "fusion"], - help="non-parametric inference mode.", - ) - args = parser.parse_args() - - # set logger - if args.verbose > 1: - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose > 0: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - - # check directory existence - if not os.path.exists(args.outdir): - os.makedirs(args.outdir) - - # load config - if args.config is None: - dirname = os.path.dirname(args.checkpoint) - args.config = os.path.join(dirname, "config.yml") - with open(args.config) as f: - config = yaml.load(f, Loader=yaml.Loader) - - args_dict = vars(args) - # do not override if inference mode not specified - if args_dict["inference_mode"] is None: - del args_dict["inference_mode"] - - # get expdir first - expdir = config["outdir"] - - config.update(args_dict) - for key, value in config.items(): - logging.info(f"{key} = {value}") - - # get dataset - dataset_class = getattr( - sheet.datasets, - config.get("dataset_type", "NonIntrusiveDataset"), - ) - dataset = dataset_class( - csv_path=args.csv_path, - target_sample_rate=config["sampling_rate"], - model_input=config["model_input"], - use_phoneme=config.get("use_phoneme", False), - symbols=config.get("symbols", None), - wav_only=True, - allow_cache=False, - ) - logging.info(f"Number of inference samples = {len(dataset)}.") - - # setup device - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - - # get model - model_class = getattr(sheet.models, config["model_type"]) - is_ramp = ("RAMP" in config["model_type"]) - if is_ramp: - datastore = Datastore( - args.datastore, - config["model_params"]["parametric_model_params"]["ssl_model_output_dim"], - device=device, - ) - model = model_class( - config["model_input"], - num_listeners=config.get("num_listeners", None), - num_domains=config.get("num_domains", None), - datastore=datastore, - **config["model_params"], - ) - else: - datastore = Datastore( - args.datastore, - config["model_params"]["ssl_model_output_dim"], - device=device, - ) - model = model_class( - config["model_input"], - num_listeners=config.get("num_listeners", None), - num_domains=config.get("num_domains", None), - **config["model_params"], - ) - - # set placeholders - eval_results = defaultdict(list) - eval_sys_results = defaultdict(lambda: defaultdict(list)) - retrieval_results = {} - ramp_results = [] - - # load parameter - if os.path.islink(args.checkpoint): - checkpoint_path = os.readlink(args.checkpoint) - else: - checkpoint_path = os.path.realpath(args.checkpoint) - model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) - logging.info(f"Loaded model parameters from {args.checkpoint}.") - model = model.eval().to(device) - - # start inference - start_time = time.time() - with torch.no_grad(), tqdm(dataset, desc="[inference]") as pbar: - for batch in pbar: - # set up model input - model_input = batch[config["model_input"]].unsqueeze(0).to(device) - model_lengths = model_input.new_tensor([model_input.size(1)]).long() - inputs = { - config["model_input"]: model_input, - config["model_input"] + "_lengths": model_lengths, - } - if "domain_idx" in batch: - inputs["domain_idxs"] = ( - torch.tensor(batch["domain_idx"], dtype=torch.long) - .unsqueeze(0) - .to(device) - ) - - # nonparametric part - if config["np_inference_mode"] == "naive_knn": - ssl_embed = ( - torch.mean(model.get_ssl_embeddings(inputs), dim=1) - .detach() - .cpu() - .numpy() - ) - outputs = datastore.knn(ssl_embed, args.k)["final_score"] - elif config["np_inference_mode"] == "domain_id_knn_1": - # retreive domain ID - ssl_embed = ( - torch.mean(model.get_ssl_embeddings(inputs), dim=1) - .detach() - .cpu() - .numpy() - ) - retrieved_id = int(datastore.knn(ssl_embed, 1)["ids"][0][0][0]) - inputs["domain_idxs"] = ( - torch.tensor(retrieved_id, dtype=torch.long).unsqueeze(0).to(device) - ) - retrieval_results[batch["sample_id"]] = {"retrieved_id": retrieved_id} - - # parametric path - if config["inference_mode"] == "mean_listener": - outputs = model.mean_listener_inference(inputs) - elif config["inference_mode"] == "mean_net": - outputs = model.mean_net_inference(inputs) - else: - raise NotImplementedError - - outputs = outputs["scores"].cpu().detach().numpy()[0] - elif config["np_inference_mode"] == "fusion": - model_outputs = model.inference(inputs, config["np_inference_mode"]) - outputs = ( - model_outputs["scores"] - .cpu() - .detach() - .numpy()[0] - ) - ramp_results.append( - {"sample_id": batch["sample_id"]} | - {k: v.cpu().detach().numpy()[0] for k, v in model_outputs.items() if not k == "scores"} - ) - else: - raise NotImplementedError - - # store results - answer = outputs - dataset.fill_answer(batch["sample_id"], answer) - pred_mean_scores = answer - true_mean_scores = batch["avg_score"] - eval_results["pred_mean_scores"].append(pred_mean_scores) - eval_results["true_mean_scores"].append(true_mean_scores) - sys_name = batch["system_id"] - eval_sys_results["pred_mean_scores"][sys_name].append(pred_mean_scores) - eval_sys_results["true_mean_scores"][sys_name].append(true_mean_scores) - - total_inference_time = time.time() - start_time - logging.info("Total inference time = {} secs.".format(total_inference_time)) - logging.info( - "Average inference speed = {:.3f} sec / sample.".format( - total_inference_time / len(dataset) - ) - ) - - # print retrieval results - for k, v in retrieval_results.items(): - print(k, v) - - # calculate metrics - eval_results["true_mean_scores"] = np.array(eval_results["true_mean_scores"]) - eval_results["pred_mean_scores"] = np.array(eval_results["pred_mean_scores"]) - eval_sys_results["true_mean_scores"] = np.array( - [np.mean(scores) for scores in eval_sys_results["true_mean_scores"].values()] - ) - eval_sys_results["pred_mean_scores"] = np.array( - [np.mean(scores) for scores in eval_sys_results["pred_mean_scores"].values()] - ) - - # calculate metrics - results = calculate( - eval_results["true_mean_scores"], - eval_results["pred_mean_scores"], - eval_sys_results["true_mean_scores"], - eval_sys_results["pred_mean_scores"], - ) - logging.info( - f'[UTT][ MSE = {results["utt_MSE"]:.3f} | LCC = {results["utt_LCC"]:.3f} | SRCC = {results["utt_SRCC"]:.3f} ] [SYS][ MSE = {results["sys_MSE"]:.3f} | LCC = {results["sys_LCC"]:.4f} | SRCC = {results["sys_SRCC"]:.4f} ]\n' - ) - - table = PrettyTable() - table.set_style(MARKDOWN) - table.field_names = [ - "Utt MSE", - "Utt LCC", - "Utt SRCC", - "Utt KTAU", - "Sys MSE", - "Sys LCC", - "Sys SRCC", - "Sys KTAU", - ] - table.add_row( - [ - round(results["utt_MSE"], 3), - round(results["utt_LCC"], 3), - round(results["utt_SRCC"], 3), - round(results["utt_KTAU"], 3), - round(results["sys_MSE"], 3), - round(results["sys_LCC"], 3), - round(results["sys_SRCC"], 3), - round(results["sys_KTAU"], 3), - ] - ) - print(table) - - # check directory - dirname = args.outdir - if not os.path.exists(dirname): - os.makedirs(dirname) - - # plot - plot_utt_level_hist( - eval_results["true_mean_scores"], - eval_results["pred_mean_scores"], - os.path.join(dirname, "distribution.png"), - ) - plot_utt_level_scatter( - eval_results["true_mean_scores"], - eval_results["pred_mean_scores"], - os.path.join(dirname, "utt_scatter_plot.png"), - results["utt_LCC"], - results["utt_SRCC"], - results["utt_MSE"], - results["utt_KTAU"], - ) - plot_sys_level_scatter( - eval_sys_results["true_mean_scores"], - eval_sys_results["pred_mean_scores"], - os.path.join(dirname, "sys_scatter_plot.png"), - results["sys_LCC"], - results["sys_SRCC"], - results["sys_MSE"], - results["sys_KTAU"], - ) - - # get results - results = dataset.return_results() - - # insert retrieval results - for i in range(len(results)): - sample_id = results[i]["sample_id"] - for k, v in retrieval_results[sample_id].items(): - results[i][k] = v - - # write results - results_path = os.path.join(args.outdir, "results.csv") - fieldnames = list(results[0].keys()) - with open(results_path, "w", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - for line in results: - writer.writerow(line) - - # write RAMP results if model is RAMP - if config["np_inference_mode"] == "fusion": - write_csv(ramp_results, os.path.join(args.outdir, "ramp_results.csv")) - -if __name__ == "__main__": - main() diff --git a/src/sheet/bin/train.py b/src/sheet/bin/train.py deleted file mode 100755 index 572f83a..0000000 --- a/src/sheet/bin/train.py +++ /dev/null @@ -1,396 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Train model.""" - -import argparse -import logging -import os -import sys - -import humanfriendly -import numpy as np -import sheet -import sheet.collaters -import sheet.datasets -import sheet.losses -import sheet.models -import sheet.trainers -import torch -import yaml -from sheet.schedulers import get_scheduler -from torch.utils.data import DataLoader - -# scheduler_classes = dict(warmuplr=WarmupLR) - - -def main(): - """Run training process.""" - parser = argparse.ArgumentParser( - description=( - "Train speech human evaluation estimation model (See detail in bin/train.py)." - ) - ) - parser.add_argument( - "--train-csv-path", - required=True, - type=str, - help=("training data csv file path."), - ) - parser.add_argument( - "--dev-csv-path", - required=True, - type=str, - help=("training data csv file path."), - ) - parser.add_argument( - "--outdir", - type=str, - required=True, - help="directory to save checkpoints.", - ) - parser.add_argument( - "--config", - type=str, - required=True, - help="yaml format configuration file.", - ) - parser.add_argument( - "--additional-config", - type=str, - default=None, - help="yaml format configuration file (additional; for second-stage pretraining).", - ) - parser.add_argument( - "--init-checkpoint", - default="", - type=str, - nargs="?", - help='checkpoint file path to initialize pretrained params. (default="")', - ) - parser.add_argument( - "--resume", - default="", - type=str, - nargs="?", - help='checkpoint file path to resume training. (default="")', - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="logging level. higher is more logging. (default=1)", - ) - parser.add_argument( - "--rank", - "--local_rank", - default=0, - type=int, - help="rank for distributed training. no need to explictly specify.", - ) - parser.add_argument("--seed", default=1337, type=int) - args = parser.parse_args() - - args.distributed = False - if not torch.cuda.is_available(): - device = torch.device("cpu") - else: - device = torch.device("cuda") - # effective when using fixed size inputs - # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 - torch.backends.cudnn.benchmark = False # because we have dynamic input size - torch.cuda.set_device(args.rank) - # setup for distributed training - # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed - if "WORLD_SIZE" in os.environ: - args.world_size = int(os.environ["WORLD_SIZE"]) - args.distributed = args.world_size > 1 - if args.distributed: - torch.distributed.init_process_group(backend="nccl", init_method="env://") - - # suppress logging for distributed training - if args.rank != 0: - sys.stdout = open(os.devnull, "w") - - # set logger - if args.verbose > 1: - logging.basicConfig( - level=logging.DEBUG, - stream=sys.stdout, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose > 0: - logging.basicConfig( - level=logging.INFO, - stream=sys.stdout, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - stream=sys.stdout, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - - # Fix seed and make backends deterministic - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(args.seed) - torch.backends.cudnn.deterministic = True - - # fix issue of too many opened files - # https://github.com/pytorch/pytorch/issues/11201 - torch.multiprocessing.set_sharing_strategy("file_system") - - # check directory existence - if not os.path.exists(args.outdir): - os.makedirs(args.outdir) - - # load main config - with open(args.config) as f: - config = yaml.load(f, Loader=yaml.Loader) - config.update(vars(args)) - - # load additional config - if args.additional_config is not None: - with open(args.additional_config) as f: - additional_config = yaml.load(f, Loader=yaml.Loader) - config.update(additional_config) - - # get dataset - dataset_class = getattr( - sheet.datasets, - config.get("dataset_type", "NonIntrusiveDataset"), - ) - logging.info(f"Loading training set from {args.train_csv_path}.") - train_dataset = dataset_class( - csv_path=args.train_csv_path, - target_sample_rate=config["sampling_rate"], - model_input=config["model_input"], - wav_only=config.get("wav_only", False), - use_phoneme=config.get("use_phoneme", False), - symbols=config.get("symbols", None), - use_mean_listener=config["model_params"].get("use_mean_listener", None), - categorical=config.get("categorical", False), - categorical_step=config.get("categorical_step", 1.0), - allow_cache=config["allow_cache"], - load_wav_cache=config.get("load_wav_cache", False), - ) - logging.info(f"The number of training files = {len(train_dataset)}.") - logging.info(f"Loading development set from {args.dev_csv_path}.") - dev_dataset = dataset_class( - csv_path=args.dev_csv_path, - target_sample_rate=config["sampling_rate"], - model_input=config["model_input"], - wav_only=True, - use_phoneme=config.get("use_phoneme", False), - symbols=config.get("symbols", None), - allow_cache=False, - # allow_cache=config["allow_cache"], - # load_wav_cache=config.get("load_wav_cache", False), - ) - logging.info(f"The number of development files = {len(dev_dataset)}.") - dataset = { - "train": train_dataset, - "dev": dev_dataset, - } - - # update number of listeners - if hasattr(train_dataset, "num_listeners"): - config["num_listeners"] = train_dataset.num_listeners - - # update number of domains - if config.get("num_domains", None) is None: - if hasattr(train_dataset, "num_domains"): - config["num_domains"] = train_dataset.num_domains - - # get data loader - collater_class = getattr( - sheet.collaters, - config.get("collater_type", "NonIntrusiveCollater"), - ) - collater = collater_class(config["padding_mode"]) - sampler = {"train": None, "dev": None} - if args.distributed: - # setup sampler for distributed training - from torch.utils.data.distributed import DistributedSampler - - sampler["train"] = DistributedSampler( - dataset=dataset["train"], - num_replicas=args.world_size, - rank=args.rank, - shuffle=True, - ) - sampler["dev"] = DistributedSampler( - dataset=dataset["dev"], - num_replicas=args.world_size, - rank=args.rank, - shuffle=False, - ) - data_loader = { - "train": DataLoader( - dataset=dataset["train"], - shuffle=False if args.distributed else True, - collate_fn=collater, - batch_size=config["train_batch_size"], - num_workers=config["num_workers"], - sampler=sampler["train"], - pin_memory=config["pin_memory"], - ), - "dev": DataLoader( - dataset=dataset["dev"], - shuffle=False, - collate_fn=collater, - batch_size=config["test_batch_size"], - num_workers=config["num_workers"], - sampler=sampler["dev"], - pin_memory=config["pin_memory"], - ), - } - - # define models - model_class = getattr( - sheet.models, - config["model_type"], - ) - model = model_class( - config["model_input"], - num_listeners=config.get("num_listeners", None), - num_domains=config.get("num_domains", None), - **config["model_params"], - ).to(device) - - # define criterions - criterion = {} - if config["mean_score_criterions"] is not None: - criterion["mean_score_criterions"] = [ - { - "type": criterion_dict["criterion_type"], - "criterion": getattr(sheet.losses, criterion_dict["criterion_type"])( - **criterion_dict["criterion_params"] - ), - "weight": criterion_dict["criterion_weight"], - } - for criterion_dict in config["mean_score_criterions"] - ] - if config.get("categorical_head_criterions", None) is not None: - criterion["categorical_head_criterions"] = [ - { - "type": criterion_dict["criterion_type"], - "criterion": getattr(sheet.losses, criterion_dict["criterion_type"])( - **criterion_dict["criterion_params"] - ), - "weight": criterion_dict["criterion_weight"], - } - for criterion_dict in config["categorical_head_criterions"] - ] - if config.get("listener_score_criterions", None) is not None: - criterion["listener_score_criterions"] = [ - { - "type": criterion_dict["criterion_type"], - "criterion": getattr(sheet.losses, criterion_dict["criterion_type"])( - **criterion_dict["criterion_params"] - ), - "weight": criterion_dict["criterion_weight"], - } - for criterion_dict in config["listener_score_criterions"] - ] - - # define optimizers and schedulers - optimizer_class = getattr( - torch.optim, - # keep compatibility - config.get("optimizer_type", "Adam"), - ) - optimizer = optimizer_class( - model.parameters(), - **config["optimizer_params"], - ) - if config["scheduler_type"] is not None: - scheduler = get_scheduler( - optimizer, - config["scheduler_type"], - config["train_max_steps"], - config["scheduler_params"], - ) - else: - scheduler = None - - if args.distributed: - # wrap model for distributed training - try: - from apex.parallel import DistributedDataParallel - except ImportError: - raise ImportError( - "apex is not installed. please check https://github.com/NVIDIA/apex." - ) - model = DistributedDataParallel(model) - - # show settings - logging.info( - "Model parameters: {}".format(humanfriendly.format_size(model.get_num_params())) - ) - logging.info(model) - logging.info(optimizer) - logging.info(scheduler) - logging.info(criterion) - - # define trainer - trainer_class = getattr(sheet.trainers, config["trainer_type"]) - trainer = trainer_class( - steps=0, - epochs=0, - data_loader=data_loader, - sampler=sampler, - model=model, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - config=config, - device=device, - ) - - # load pretrained parameters from checkpoint - if len(args.init_checkpoint) != 0: - trainer.load_trained_modules( - args.init_checkpoint, init_mods=config["init-mods"] - ) - logging.info(f"Successfully load parameters from {args.init_checkpoint}.") - - # resume from checkpoint - if len(args.resume) != 0: - trainer.load_checkpoint(args.resume) - logging.info(f"Successfully resumed from {args.resume}.") - - # freeze modules if necessary - if config.get("freeze-mods", None) is not None: - assert type(config["freeze-mods"]) is list - trainer.freeze_modules(config["freeze-mods"]) - logging.info(f"Freeze modules with prefixes {config['freeze-mods']}.") - - # save config - config["version"] = sheet.__version__ # add version info - with open(os.path.join(args.outdir, "config.yml"), "w") as f: - yaml.dump(config, f, Dumper=yaml.Dumper) - for key, value in config.items(): - logging.info(f"{key} = {value}") - - # run training loop - # try: - # trainer.run() - # finally: - # trainer.save_checkpoint( - # os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl") - # ) - # logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") - # NOTE(unilight): I don't think we need to save again here - trainer.run() - - -if __name__ == "__main__": - main() diff --git a/src/sheet/bin/train_stack.py b/src/sheet/bin/train_stack.py deleted file mode 100755 index 0132b2f..0000000 --- a/src/sheet/bin/train_stack.py +++ /dev/null @@ -1,188 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Train meta-model for stacking .""" - -import argparse -import logging -import os -import pickle -import time - -import numpy as np -import sheet -import sheet.datasets -import sheet.models -import torch -import yaml - - -def main(): - """Run inference process.""" - parser = argparse.ArgumentParser( - description=( - "Inference with trained model " "(See detail in bin/inference.py)." - ) - ) - parser.add_argument( - "--csv-path", - required=True, - type=str, - help=("csv file path to train stacking model."), - ) - parser.add_argument( - "--expdir", - type=str, - required=True, - help="directory to save model.", - ) - parser.add_argument( - "--meta-model-config", - required=True, - type=str, - help=("yaml format configuration file for the meta model. "), - ) - parser.add_argument( - "--verbose", - type=int, - default=1, - help="logging level. higher is more logging. (default=1)", - ) - args = parser.parse_args() - - # set logger - if args.verbose > 1: - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose > 0: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - - # load original config - with open(os.path.join(args.expdir, "config.yml")) as f: - config = yaml.load(f, Loader=yaml.Loader) - - # load meta model config - with open(args.meta_model_config) as f: - meta_model_config = yaml.load(f, Loader=yaml.Loader) - - config.update(meta_model_config) - for key, value in config.items(): - logging.info(f"{key} = {value}") - - # get dataset - dataset_class = getattr( - sheet.datasets, - config.get("dataset_type", "NonIntrusiveDataset"), - ) - dataset = dataset_class( - csv_path=args.csv_path, - target_sample_rate=config["sampling_rate"], - model_input=config["model_input"], - wav_only=True, - allow_cache=False, - ) - logging.info(f"Number of samples to train meta model = {len(dataset)}.") - - # setup device - if torch.cuda.is_available(): - device = torch.device("cuda") - else: - device = torch.device("cpu") - - # get model - model_class = getattr(sheet.models, config["model_type"]) - model = model_class( - config["model_input"], - num_listeners=config.get("num_listeners", None), - **config["model_params"], - ) - - # run inference on all models - checkpoint_paths = sorted( - [ - os.path.join(args.expdir, p) - for p in os.listdir(args.expdir) - if os.path.isfile(os.path.join(args.expdir, p)) and p.endswith("steps.pkl") - ] - ) - xs = np.empty((len(dataset), len(checkpoint_paths))) - for i, checkpoint_path in enumerate(checkpoint_paths): - # load model - model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) - logging.info(f"Loaded model parameters from {checkpoint_path}.") - model = model.eval().to(device) - - # start inference - start_time = time.time() - logging.info("Running inference...") - with torch.no_grad(): - for j, batch in enumerate(dataset): - # set up model input - model_input = batch[config["model_input"]].unsqueeze(0).to(device) - model_input_lengths = model_input.new_tensor( - [model_input.size(1)] - ).long() - - # model forward - if config["inference_mode"] == "mean_listener": - outputs = model.mean_listener_inference( - model_input, model_input_lengths - ) - elif config["inference_mode"] == "mean_net": - outputs = model.mean_net_inference(model_input, model_input_lengths) - else: - raise NotImplementedError - - # store results - pred_score = outputs["scores"].cpu().detach().numpy()[0] - xs[j][i] = pred_score - - total_inference_time = time.time() - start_time - logging.info("Total inference time = {} secs.".format(total_inference_time)) - logging.info( - "Average inference speed = {:.3f} sec / sample.".format( - total_inference_time / len(dataset) - ) - ) - - ys = np.array([batch["avg_score"] for batch in dataset]) - - # define meta model - if config["meta_model_type"] == "Ridge": - from sklearn.linear_model import Ridge - - meta_model = Ridge(**config["meta_model_params"]) - else: - raise NotImplementedError - - # train meta model - start_time = time.time() - logging.info("Start training meta model...") - meta_model.fit(xs, ys) - total_train_time = time.time() - start_time - logging.info("Total training time = {} secs.".format(total_train_time)) - - # save - with open(os.path.join(args.expdir, "meta_model.pkl"), "wb") as f: - pickle.dump(meta_model, f) - - with open(os.path.join(args.expdir, "config.yml"), "w") as f: - yaml.dump(config, f, Dumper=yaml.Dumper) - - -if __name__ == "__main__": - main() diff --git a/src/sheet/collaters/__init__.py b/src/sheet/collaters/__init__.py deleted file mode 100644 index ebfb777..0000000 --- a/src/sheet/collaters/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .non_intrusive import * # NOQA diff --git a/src/sheet/collaters/non_intrusive.py b/src/sheet/collaters/non_intrusive.py deleted file mode 100644 index 97fcfbc..0000000 --- a/src/sheet/collaters/non_intrusive.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -import numpy as np -import torch -from torch.nn.utils.rnn import pad_sequence - -FEAT_TYPES = ["waveform", "mag_sgram"] - - -class NonIntrusiveCollater(object): - """Customized collater for Pytorch DataLoader in the non-intrusive setting.""" - - def __init__(self, padding_mode): - """Initialize customized collater for PyTorch DataLoader.""" - self.padding_mode = padding_mode - - def __call__(self, batch): - """Convert into batch tensors.""" - - items = {} - sorted_batch = sorted(batch, key=lambda x: -x["waveform"].shape[0]) - bs = len(sorted_batch) # batch_size - all_keys = list(sorted_batch[0].keys()) - - # score & listener id - items["scores"] = torch.tensor( - [sorted_batch[i]["score"] for i in range(bs)], dtype=torch.float - ) - items["avg_scores"] = torch.tensor( - [sorted_batch[i]["avg_score"] for i in range(bs)], dtype=torch.float - ) - if "categorical_score" in all_keys: - items["categorical_scores"] = torch.tensor( - [sorted_batch[i]["categorical_score"] for i in range(bs)], - dtype=torch.float, - ) - if "categorical_avg_score" in all_keys: - items["categorical_avg_scores"] = torch.tensor( - [sorted_batch[i]["categorical_avg_score"] for i in range(bs)], - dtype=torch.float, - ) - if "listener_id" in all_keys: - items["listener_ids"] = [sorted_batch[i]["listener_id"] for i in range(bs)] - if "listener_idx" in all_keys: - items["listener_idxs"] = torch.tensor( - [sorted_batch[i]["listener_idx"] for i in range(bs)], dtype=torch.long - ) - if "domain_idx" in all_keys: - items["domain_idxs"] = torch.tensor( - [sorted_batch[i]["domain_idx"] for i in range(bs)], dtype=torch.long - ) - - # phoneme and reference - if "phoneme_idxs" in all_keys: - phonemes = [ - torch.LongTensor(sorted_batch[i]["phoneme_idxs"]) for i in range(bs) - ] - items["phoneme_lengths"] = torch.from_numpy( - np.array([phoneme.size(0) for phoneme in phonemes]) - ) - items["phoneme_idxs"] = pad_sequence(phonemes, batch_first=True) - if "reference_idxs" in all_keys: - references = [ - torch.LongTensor(sorted_batch[i]["reference_idxs"]) for i in range(bs) - ] - items["reference_lengths"] = torch.from_numpy( - np.array([reference.size(0) for reference in references]) - ) - items["reference_idxs"] = pad_sequence(references, batch_first=True) - - # ids - items["system_ids"] = [sorted_batch[i]["system_id"] for i in range(bs)] - items["sample_ids"] = [sorted_batch[i]["sample_id"] for i in range(bs)] - - # pad input features (only those in FEAT TYPES) - for feat_type in FEAT_TYPES: - if not feat_type in sorted_batch[0]: - continue - - feats = [sorted_batch[i][feat_type] for i in range(bs)] - feat_lengths = torch.from_numpy(np.array([feat.size(0) for feat in feats])) - - # padding - if self.padding_mode == "zero_padding": - feats_padded = pad_sequence(feats, batch_first=True) - elif self.padding_mode == "repetitive": - max_len = feat_lengths[0] - feats_padded = [] - for feat in feats: - this_len = feat.shape[0] - dup_times = max_len // this_len - remain = max_len - this_len * dup_times - to_dup = [feat for t in range(dup_times)] - to_dup.append(feat[:remain]) - duplicated_feat = torch.Tensor(np.concatenate(to_dup, axis=0)) - feats_padded.append(duplicated_feat) - feats_padded = torch.stack(feats_padded, dim=0) - else: - raise NotImplementedError - - items[feat_type] = feats_padded - items[feat_type + "_lengths"] = feat_lengths - - return items diff --git a/src/sheet/datasets/__init__.py b/src/sheet/datasets/__init__.py deleted file mode 100644 index ebfb777..0000000 --- a/src/sheet/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .non_intrusive import * # NOQA diff --git a/src/sheet/datasets/non_intrusive.py b/src/sheet/datasets/non_intrusive.py deleted file mode 100644 index fc4bb37..0000000 --- a/src/sheet/datasets/non_intrusive.py +++ /dev/null @@ -1,340 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Non-intrusive dataset modules.""" - -from collections import defaultdict -import logging -from multiprocessing import Manager - -from tqdm import tqdm -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from functools import partial - -import numpy as np -import torch -import torch.nn.functional as F -import torchaudio -from sheet.utils import read_csv -from torch.utils.data import Dataset - -MIN_REQUIRED_WAV_LENGTH = 1040 -MAX_WAV_LENGTH_SECS = 10 - - -def adjust_length(waveform, target_sample_rate): - """Adjust waveform length to a fixed length.""" - if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH: - to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2 - waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0) - if waveform.shape[0] > MAX_WAV_LENGTH_SECS * target_sample_rate: - waveform = waveform[: MAX_WAV_LENGTH_SECS * target_sample_rate] - return waveform - - -def read_waveform(wav_path, target_sample_rate): - try: - # read waveform - waveform, sample_rate = torchaudio.load( - wav_path, channels_first=False - ) # waveform: [T, 1] - # resample if needed - if sample_rate != target_sample_rate: - resampler = torchaudio.transforms.Resample( - sample_rate, target_sample_rate, dtype=waveform.dtype - ) - waveform = resampler(waveform) - # mono only - if waveform.shape[1] > 1: - waveform = torch.mean(waveform, dim=1, keepdim=True) - except Exception as e: - print(f"Failed to load or resample {wav_path}: {e}") - raise - - waveform = waveform.squeeze(-1) - - # adjust length - waveform = adjust_length(waveform, target_sample_rate) - - return waveform - - -def _read_waveform(arg_tuple): - hash_id, wav_path, target_sample_rate = arg_tuple - return hash_id, read_waveform(wav_path, target_sample_rate) - -class NonIntrusiveDataset(Dataset): - """PyTorch compatible dataset for non-intrusive SSQA.""" - - def __init__( - self, - csv_path, - target_sample_rate, - model_input="wav", - wav_only=False, - use_mean_listener=False, - use_phoneme=False, - symbols=None, - categorical=False, - categorical_step=1.0, - no_feat=False, - allow_cache=False, - load_wav_cache=False, - ): - """Initialize dataset. - - Args: - csv path (str): path to the csv file - target_sample_rate (int): resample to this seample rate if there is a mismatch. - model_input (str): defalut is wav. is this is mag_sgram, extract magnitute sgram. - wav_only (bool): whether to return only wavs. Basically this means inference mode. - use_mean_listener (bool): whether to use mean listener. (only for datasets with listener labels) - use_phoneme (bool): whether to use phoneme. (only for UTMOS training) - symbols (str): symbols for phoneme. (only for UTMOS training) - categorical (bool): whether to include categorical output. - categorical_step (float): step for the categorical output. defauly is 1.0. - no_feat (bool): Whether to skip loading features (waveforms, mag_sgrams ...) - allow_cache (bool): Whether to allow cache of the loaded files. - load_wav_cache (bool): Whether to load all waveforms first and store in cache (this might make initialization slower). - - """ - self.target_sample_rate = target_sample_rate - self.use_phoneme = use_phoneme - if self.use_phoneme: - self.symbols = symbols - self.resamplers = {} - assert csv_path != "" - self.categorical = categorical - self.categorical_step = categorical_step - self.no_feat = no_feat - - # set model input transform - self.model_input = model_input - if model_input == "mag_sgram": - self.mag_sgram_transform = torchaudio.transforms.Spectrogram( - n_fft=512, hop_length=256, win_length=512, power=1 - ) - - # read csv file - self.metadata, _ = read_csv(csv_path, dict_reader=True) - - # calculate average score for each sample and add to metadata - self.calculate_avg_score() - - if wav_only: - self.reduce_to_wav_only() - else: - # add mean listener to metadata - if use_mean_listener: - mean_listener_metadata = self.gen_mean_listener_metadata() - self.metadata = self.metadata + mean_listener_metadata - - # get num of listeners - self.num_listeners = self.get_num_listeners() - - # get num of domains if domain_idx exists - if "domain_idx" in self.metadata[0]: - self.num_domains = self.get_num_domains() - - # build hash - self.build_feat_hash() - - # set cache - self.allow_cache = allow_cache - if allow_cache: - # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 - self.manager = Manager() - # self.wav_caches = self.manager.list() - # self.wav_caches += [() for _ in range(self.num_wavs)] - self.wav_caches = [None for _ in range(self.num_wavs)] - if self.model_input == "mag_sgram": - self.mag_sgram_caches = self.manager.list() - self.mag_sgram_caches += [() for _ in range(self.num_wavs)] - if load_wav_cache: - logging.info("Loading waveform cache. This might take a while ...") - self.load_wav_cache() - - def load_wav_cache(self): - # put text into csv - # with ProcessPoolExecutor(max_workers=2) as executor: - with ThreadPoolExecutor() as executor: - arg_tuples = [(item["hash_id"], item["wav_path"], self.target_sample_rate) for item in self.metadata] - results = list( - tqdm(executor.map(_read_waveform, arg_tuples), total=len(arg_tuples)) - ) - - for item in results: - if item is None: - continue - hash_id, waveform = item - self.wav_caches[hash_id] = waveform - - def __len__(self): - """Return dataset length. - - Returns: - int: The length of dataset. - - """ - return len(self.metadata) - - def get_num_listeners(self): - """Get number of listeners by counting unique listener id""" - listener_ids = set() - for item in self.metadata: - listener_ids.add(item["listener_id"]) - return len(listener_ids) - - def get_num_domains(self): - """Get number of domains by counting unique domain idxs""" - domain_idxs = set() - for item in self.metadata: - domain_idxs.add(item["domain_idx"]) - return len(domain_idxs) - - def build_feat_hash(self): - sample_ids = {} - count = 0 - for i in range(len(self.metadata)): - item = self.metadata[i] - sample_id = item["sample_id"] - if not sample_id in sample_ids: - sample_ids[sample_id] = count - count += 1 - self.metadata[i]["hash_id"] = sample_ids[sample_id] - self.num_wavs = len(sample_ids) - - def __getitem__(self, idx): - item = self.metadata[idx] - - # handle score - item["score"] = float(item["score"]) # cast from str to int - if self.categorical: - # we assume the score always starts from 1 - item["categorical_score"] = int( - (item["score"] - 1) // self.categorical_step - ) - - if "listener_idx" in item: - item["listener_idx"] = int(item["listener_idx"]) # cast from str to int - if "domain_idx" in item: - item["domain_idx"] = int(item["domain_idx"]) # cast from str to int - hash_id = item["hash_id"] - - # process text - if self.use_phoneme: - if "phoneme" in item: - if "phoneme_idxs" not in item: - item["phoneme_idxs"] = [ - self.symbols.index(p) for p in item["phoneme"] - ] - if "reference" in item: - if "reference_idxs" not in item: - item["reference_idxs"] = [ - self.symbols.index(p) for p in item["reference"] - ] - - # fetch waveform. return cached item if exists - if not self.no_feat: - if self.allow_cache and self.wav_caches[hash_id] is not None: - item["waveform"] = self.wav_caches[hash_id] - else: - waveform = read_waveform( - item["wav_path"], self.target_sample_rate - ) - item["waveform"] = waveform - if self.allow_cache: - self.wav_caches[hash_id] = item["waveform"] - - # additional feature extraction - if not self.no_feat: - if self.model_input == "mag_sgram": - # fetch mag_sgram. return cached item if exists - if self.allow_cache and len(self.mag_sgram_caches[hash_id]) != 0: - item["mag_sgram"] = self.mag_sgram_caches[hash_id] - else: - # torchaudio requires waveform to be [..., T] - mag_sgram = self.mag_sgram_transform( - waveform.squeeze(-1) - ) # mag_sgram: [freq, T] - item["mag_sgram"] = mag_sgram.mT # [T, freq] - if self.allow_cache: - self.mag_sgram_caches[hash_id] = item["mag_sgram"] - - return item - - def calculate_avg_score(self): - sample_scores = defaultdict(list) - - # loop through metadata - for item in self.metadata: - sample_scores[item["sample_id"]].append(float(item["score"])) - - # take average - sample_avg_score = { - sample_id: np.mean(np.array(scores)) - for sample_id, scores in sample_scores.items() - } - self.sample_avg_score = sample_avg_score - - # fill back into metadata - for i, item in enumerate(self.metadata): - self.metadata[i]["avg_score"] = sample_avg_score[item["sample_id"]] - if self.categorical: - # we assume the score always starts from 1 - self.metadata[i]["categorical_avg_score"] = int( - (self.metadata[i]["avg_score"] - 1) // self.categorical_step - ) - - def gen_mean_listener_metadata(self): - mean_listener_metadata = [] - sample_ids = set() - for item in self.metadata: - sample_id = item["sample_id"] - if sample_id not in sample_ids: - new_item = {k: v for k, v in item.items()} - new_item["listener_id"] = "mean_listener" - mean_listener_metadata.append(new_item) - sample_ids.add(sample_id) - return mean_listener_metadata - - def reduce_to_wav_only(self): - new_metadata = {} # {sample_id: item} - for item in self.metadata: - sample_id = item["sample_id"] - if not sample_id in new_metadata: - new_metadata[sample_id] = { - k: v - for k, v in item.items() - if k not in ["listener_id", "listener_idx"] - } - - self.metadata = list(new_metadata.values()) - - # the following two functions are for writing results during inference - def fill_answer(self, sample_id, score, logvar=None): - for idx, item in enumerate(self.metadata): - if item["sample_id"] == sample_id: - break - self.metadata[idx]["answer"] = score - if logvar is not None: - self.metadata[idx]["logvar"] = logvar - - def return_results(self): - return [ - { - k: item[k] - for k in [ - "wav_path", - "system_id", - "sample_id", - "avg_score", - "answer", - "logvar", - ] - if k in item - } - for item in self.metadata - ] diff --git a/src/sheet/evaluation/metrics.py b/src/sheet/evaluation/metrics.py deleted file mode 100644 index b2978cc..0000000 --- a/src/sheet/evaluation/metrics.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Script to calculate metrics.""" - -import numpy as np -import scipy - - -def calculate( - true_mean_scores, predict_mean_scores, true_sys_mean_scores, predict_sys_mean_scores -): - - utt_MSE = np.mean((true_mean_scores - predict_mean_scores) ** 2) - utt_LCC = np.corrcoef(true_mean_scores, predict_mean_scores)[0][1] - utt_SRCC = scipy.stats.spearmanr(true_mean_scores, predict_mean_scores)[0] - utt_KTAU = scipy.stats.kendalltau(true_mean_scores, predict_mean_scores)[0] - sys_MSE = np.mean((true_sys_mean_scores - predict_sys_mean_scores) ** 2) - sys_LCC = np.corrcoef(true_sys_mean_scores, predict_sys_mean_scores)[0][1] - sys_SRCC = scipy.stats.spearmanr(true_sys_mean_scores, predict_sys_mean_scores)[0] - sys_KTAU = scipy.stats.kendalltau(true_sys_mean_scores, predict_sys_mean_scores)[0] - - return { - "utt_MSE": utt_MSE, - "utt_LCC": utt_LCC, - "utt_SRCC": utt_SRCC, - "utt_KTAU": utt_KTAU, - "sys_MSE": sys_MSE, - "sys_LCC": sys_LCC, - "sys_SRCC": sys_SRCC, - "sys_KTAU": sys_KTAU, - } diff --git a/src/sheet/evaluation/plot.py b/src/sheet/evaluation/plot.py deleted file mode 100644 index 2350d4b..0000000 --- a/src/sheet/evaluation/plot.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Script to plot figures.""" - -import matplotlib -import numpy as np - -# Force matplotlib to not use any Xwindows backend. -matplotlib.use("Agg") -import matplotlib.pyplot as plt - -STYLE = "seaborn-v0_8-deep" - - -def plot_utt_level_hist(true_mean_scores, predict_mean_scores, filename): - """Plot utterance-level histrogram. - - Args: - true_mean_scores: ndarray of true scores - predict_mean_scores: ndarray of predicted scores - filename: name of the saved figure - """ - plt.style.use(STYLE) - bins = np.linspace(1, 5, 40) - plt.figure(2) - plt.hist( - [true_mean_scores, predict_mean_scores], bins, label=["true_mos", "predict_mos"] - ) - plt.legend(loc="upper right") - plt.xlabel("MOS") - plt.ylabel("number") - plt.show() - plt.savefig(filename, dpi=150) - plt.close() - - -def plot_utt_level_scatter( - true_mean_scores, predict_mean_scores, filename, LCC, SRCC, MSE, KTAU -): - """Plot utterance-level scatter plot - - Args: - true_mean_scores: ndarray of true scores - predict_mean_scores: ndarray of predicted scores - filename: name of the saved figure - LCC, SRCC, MSE, KTAU: metrics to be shown on the figure - """ - M = np.max([np.max(predict_mean_scores), 5]) - plt.figure(3) - plt.scatter( - true_mean_scores, - predict_mean_scores, - s=15, - color="b", - marker="o", - edgecolors="b", - alpha=0.20, - ) - plt.xlim([0.5, M]) - plt.ylim([0.5, M]) - plt.xlabel("True MOS") - plt.ylabel("Predicted MOS") - plt.title( - "Utt level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}".format( - LCC, SRCC, MSE, KTAU - ) - ) - plt.show() - plt.savefig(filename, dpi=150) - plt.close() - - -def plot_sys_level_scatter( - true_sys_mean_scores, predict_sys_mean_scores, filename, LCC, SRCC, MSE, KTAU -): - """Plot system-level scatter plot - - Args: - true_sys_mean_scores: ndarray of true system level scores - predict_sys_mean_scores: ndarray of predicted system level scores - filename: name of the saved figure - LCC, SRCC, MSE, KTAU: metrics to be shown on the figure - """ - M = np.max([np.max(predict_sys_mean_scores), 5]) - plt.figure(3) - plt.scatter( - true_sys_mean_scores, - predict_sys_mean_scores, - s=15, - color="b", - marker="o", - edgecolors="b", - ) - plt.xlim([0.5, M]) - plt.ylim([0.5, M]) - plt.xlabel("True MOS") - plt.ylabel("Predicted MOS") - plt.title( - "Sys level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}".format( - LCC, SRCC, MSE, KTAU - ) - ) - plt.show() - plt.savefig(filename, dpi=150) - plt.close() diff --git a/src/sheet/losses/__init__.py b/src/sheet/losses/__init__.py deleted file mode 100644 index 250a1a1..0000000 --- a/src/sheet/losses/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .basic_losses import * # NOQA -from .contrastive_loss import * # NOQA -from .nll_losses import * # NOQA \ No newline at end of file diff --git a/src/sheet/losses/basic_losses.py b/src/sheet/losses/basic_losses.py deleted file mode 100644 index c9b5589..0000000 --- a/src/sheet/losses/basic_losses.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Basic losses.""" - -import torch -import torch.nn as nn -from sheet.modules.utils import make_non_pad_mask - - -class ScalarLoss(nn.Module): - """ - Loss for scalar output (we use the clipped MSE loss) - """ - - def __init__(self, tau, order=2, masked_loss=False): - super(ScalarLoss, self).__init__() - self.tau = tau - self.masked_loss = masked_loss - if order == 2: - self.criterion = torch.nn.MSELoss(reduction="none") - elif order == 1: - self.criterion = torch.nn.L1Loss(reduction="none") - else: - raise NotImplementedError - - def forward_criterion(self, y_hat, label, criterion_module, masks=None): - # might investigate how to combine masked loss with categorical output - if masks is not None: - y_hat = y_hat.masked_select(masks) - label = label.masked_select(masks) - - y_hat = y_hat.squeeze(-1) - loss = criterion_module(y_hat, label) - threshold = torch.abs(y_hat - label) > self.tau - loss = torch.mean(threshold * loss) - return loss - - def forward(self, pred_score, gt_score, device, lens=None): - """ - Args: - pred_mean, pred_score: [batch, time, 1/5] - """ - # make mask - if self.masked_loss: - masks = make_non_pad_mask(lens).to(device) - else: - masks = None - - # repeat for frame level loss - time = pred_score.shape[1] - # gt_mean = gt_mean.unsqueeze(1).repeat(1, time) - gt_score = gt_score.unsqueeze(1).repeat(1, time) - - loss = self.forward_criterion(pred_score, gt_score, self.criterion, masks) - return loss - - -class CategoricalLoss(nn.Module): - def __init__(self, masked_loss=False): - super(CategoricalLoss, self).__init__() - self.masked_loss = masked_loss - self.criterion = nn.CrossEntropyLoss(reduction="none") - - def ce(self, y_hat, label, criterion, masks=None): - if masks is not None: - y_hat = y_hat.masked_select(masks) - label = label.masked_select(masks) - - # y_hat must have shape (batch_size, num_classes, ...) - y_hat = y_hat.permute(0, 2, 1) - - ce = criterion(y_hat, label) - return torch.mean(ce) - - def forward(self, pred_score, gt_score, device, lens=None): - # make mask - if self.masked_loss: - masks = make_non_pad_mask(lens).to(device) - else: - masks = None - - # repeat for frame level loss - time = pred_score.shape[1] - gt_score = gt_score.unsqueeze(1).repeat(1, time).type(torch.long) - - score_ce = self.ce(pred_score, gt_score, self.criterion, masks) - return score_ce diff --git a/src/sheet/losses/contrastive_loss.py b/src/sheet/losses/contrastive_loss.py deleted file mode 100644 index d381a5f..0000000 --- a/src/sheet/losses/contrastive_loss.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""Contrastive loss proposed in UTMOS.""" - -import torch -import torch.nn as nn - - -class ContrastiveLoss(nn.Module): - """ - Contrastive Loss - Args: - margin: non-neg value, the smaller the stricter the loss will be, default: 0.2 - - """ - - def __init__(self, margin=0.2): - super(ContrastiveLoss, self).__init__() - self.margin = margin - - def forward(self, pred_score, gt_score, lens, device): - if pred_score.dim() > 2: - pred_score = pred_score.mean(dim=1).squeeze(1) - # pred_score, gt_score: tensor, [batch_size] - - gt_diff = gt_score.unsqueeze(1) - gt_score.unsqueeze(0) - pred_diff = pred_score.unsqueeze(1) - pred_score.unsqueeze(0) - - loss = torch.maximum( - torch.zeros(gt_diff.shape).to(gt_diff.device), - torch.abs(pred_diff - gt_diff) - self.margin, - ) - loss = loss.mean().div(2) - - return loss diff --git a/src/sheet/losses/nll_losses.py b/src/sheet/losses/nll_losses.py deleted file mode 100644 index b0106c2..0000000 --- a/src/sheet/losses/nll_losses.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""NLL losses.""" - -import torch -import torch.nn as nn -from sheet.modules.utils import make_non_pad_mask - - -class GaussianNLLLoss(nn.Module): - """ - Gaussian NLL loss (for uncertainty modeling) - """ - - def __init__(self, tau, masked_loss=False): - super(GaussianNLLLoss, self).__init__() - self.tau = tau - self.masked_loss = masked_loss - - def forward_criterion(self, y_hat, logvar, label, masks=None): - """ - loss = 0.5 * (precision * (target - mean) ** 2 + log_var) - """ - - # might investigate how to combine masked loss with categorical output - if masks is not None: - y_hat = y_hat.masked_select(masks) - logvar = logvar.masked_select(masks) - label = label.masked_select(masks) - - y_hat = y_hat.squeeze(-1) - logvar = logvar.squeeze(-1) - precision = torch.exp(-logvar) - loss = 0.5 * (precision * (y_hat - label) ** 2 + logvar) - threshold = torch.abs(y_hat - label) > self.tau - loss = torch.mean(threshold * loss) - return loss - - def forward(self, pred_score, pred_logvar, gt_score, device, lens=None): - """ - Args: - pred_mean, pred_score: [batch, time, 1/5] - """ - # make mask - if self.masked_loss: - masks = make_non_pad_mask(lens).to(device) - else: - masks = None - - # repeat for frame level loss - time = pred_score.shape[1] - # gt_mean = gt_mean.unsqueeze(1).repeat(1, time) - gt_score = gt_score.unsqueeze(1).repeat(1, time) - - loss = self.forward_criterion(pred_score, pred_logvar, gt_score, masks) - return loss - - -class LaplaceNLLLoss(nn.Module): - """ - Laplace NLL loss (for uncertainty modeling) - """ - - def __init__(self, tau, masked_loss=False): - super(LaplaceNLLLoss, self).__init__() - self.tau = tau - self.masked_loss = masked_loss - - def forward_criterion(self, y_hat, logvar, label, masks=None): - """ - loss = 0.5 * (precision * (target - mean) ** 2 + log_var) - """ - - # might investigate how to combine masked loss with categorical output - if masks is not None: - y_hat = y_hat.masked_select(masks) - logvar = logvar.masked_select(masks) - label = label.masked_select(masks) - - y_hat = y_hat.squeeze(-1) - logvar = logvar.squeeze(-1) - b = torch.exp(logvar) + 1e-6 - loss = torch.abs(y_hat - label) / b + logvar - threshold = torch.abs(y_hat - label) > self.tau - loss = torch.mean(threshold * loss) - return loss - - def forward(self, pred_score, pred_logvar, gt_score, device, lens=None): - """ - Args: - pred_mean, pred_score: [batch, time, 1/5] - """ - # make mask - if self.masked_loss: - masks = make_non_pad_mask(lens).to(device) - else: - masks = None - - # repeat for frame level loss - time = pred_score.shape[1] - # gt_mean = gt_mean.unsqueeze(1).repeat(1, time) - gt_score = gt_score.unsqueeze(1).repeat(1, time) - - loss = self.forward_criterion(pred_score, pred_logvar, gt_score, masks) - return loss diff --git a/src/sheet/models/__init__.py b/src/sheet/models/__init__.py deleted file mode 100644 index c773ca8..0000000 --- a/src/sheet/models/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .alignnet import * # NOQA -from .ldnet import * # NOQA - -# from .ramp_simple import * # NOQA -# from .ramp import * # NOQA -from .sslmos import * # NOQA -from .utmos import * # NOQA - -from .sslmos_u import * # NOQA \ No newline at end of file diff --git a/src/sheet/models/alignnet.py b/src/sheet/models/alignnet.py deleted file mode 100644 index d194669..0000000 --- a/src/sheet/models/alignnet.py +++ /dev/null @@ -1,400 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -# Modified AlignNet model - -import torch -import torch.nn as nn -from sheet.modules.ldnet.modules import Projection, ProjectionWithUncertainty - - -class AlignNet(torch.nn.Module): - def __init__( - self, - # dummy, for signature need - model_input: str, - # model related - ssl_module: str = "s3prl", - s3prl_name: str = "wav2vec2", - ssl_model_output_dim: int = 768, - ssl_model_layer_idx: int = -1, - # listener related - use_listener_modeling: bool = False, - num_listeners: int = None, - listener_emb_dim: int = None, - use_mean_listener: bool = True, - # domain related - use_domain_modeling: bool = False, - num_domains: int = None, - domain_emb_dim: int = None, - # decoder related - use_decoder_rnn: bool = True, - decoder_rnn_dim: int = 512, - decoder_dnn_dim: int = 2048, - decoder_activation: str = "ReLU", - output_type: str = "scalar", - range_clipping: bool = True, - ): - super().__init__() # this is needed! or else there will be an error. - self.use_mean_listener = use_mean_listener - self.output_type = output_type - self.decoder_dnn_dim = decoder_dnn_dim - self.range_clipping = range_clipping - - # define ssl model - if ssl_module == "s3prl": - from s3prl.nn import S3PRLUpstream - - if s3prl_name in S3PRLUpstream.available_names(): - self.ssl_model = S3PRLUpstream(s3prl_name) - self.ssl_model_layer_idx = ssl_model_layer_idx - else: - raise NotImplementedError - decoder_input_dim = ssl_model_output_dim - - # listener modeling related - self.use_listener_modeling = use_listener_modeling - if use_listener_modeling: - self.num_listeners = num_listeners - self.listener_embeddings = nn.Embedding( - num_embeddings=num_listeners, embedding_dim=listener_emb_dim - ) - decoder_input_dim += listener_emb_dim - - # domain modeling related - self.use_domain_modeling = use_domain_modeling - if use_domain_modeling: - self.num_domains = num_domains - self.domain_embeddings = nn.Embedding( - num_embeddings=num_domains, embedding_dim=domain_emb_dim - ) - decoder_input_dim += domain_emb_dim - - # define decoder rnn - self.use_decoder_rnn = use_decoder_rnn - if self.use_decoder_rnn: - self.decoder_rnn = nn.LSTM( - input_size=decoder_input_dim, - hidden_size=decoder_rnn_dim, - num_layers=1, - batch_first=True, - bidirectional=True, - ) - self.decoder_dnn_input_dim = decoder_rnn_dim * 2 - else: - self.decoder_dnn_input_dim = decoder_input_dim - - # define activation - if decoder_activation == "ReLU": - self.decoder_activation = nn.ReLU - else: - raise NotImplementedError - - # there is always decoder dnn - self.decoder_dnn = Projection( - self.decoder_dnn_input_dim, - self.decoder_dnn_dim, - self.decoder_activation, - self.output_type, - self.range_clipping, - ) - - def get_num_params(self): - return sum(p.numel() for n, p in self.named_parameters()) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - inputs: dict, which has the following keys: - - waveform has shape (batch, time) - - waveform_lengths has shape (batch) - - listener_ids has shape (batch) - - domain_ids has shape (batch) - """ - waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] - - # ssl model forward - ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( - waveform, waveform_lengths - ) - to_concat = [ssl_model_outputs] - time = ssl_model_outputs.size(1) - - # get listener embedding - if self.use_listener_modeling: - listener_ids = inputs["listener_idxs"] - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(listener_embs) - - # get domain embedding - if self.use_domain_modeling: - domain_ids = inputs["domain_idxs"] - domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) - domain_embs = torch.stack( - [domain_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(domain_embs) - - decoder_inputs = torch.cat(to_concat, dim=2) - - # decoder rnn - if self.use_decoder_rnn: - decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) - - # decoder dnn - decoder_outputs = self.decoder_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # set outputs - # return lengths for masked loss calculation - ret = { - "waveform_lengths": waveform_lengths, - "frame_lengths": ssl_model_output_lengths, - } - if self.use_listener_modeling: - ret["ld_scores"] = decoder_outputs - else: - ret["mean_scores"] = decoder_outputs - - return ret - - def mean_listener_inference(self, inputs): - waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] - batch = waveform.size(0) - - # ssl model forward - ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( - waveform, waveform_lengths - ) - to_concat = [ssl_model_outputs] - time = ssl_model_outputs.size(1) - - # get listener embedding - if self.use_listener_modeling: - device = waveform.device - listener_ids = ( - torch.ones(batch, dtype=torch.long) * self.num_listeners - 1 - ).to( - device - ) # (bs) - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(listener_embs) - - # get domain embedding - if self.use_domain_modeling: - device = waveform.device - assert "domain_idxs" in inputs, "Must specify domain ID even in inference." - domain_ids = inputs["domain_idxs"] - domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) - domain_embs = torch.stack( - [domain_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(domain_embs) - - decoder_inputs = torch.cat(to_concat, dim=2) - - # decoder rnn - if self.use_decoder_rnn: - decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) - - # decoder dnn - decoder_outputs = self.decoder_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - scores = torch.mean(decoder_outputs.squeeze(-1), dim=1) - return {"scores": scores} - - def ssl_model_forward(self, waveform, waveform_lengths): - all_ssl_model_outputs, all_ssl_model_output_lengths = self.ssl_model( - waveform, waveform_lengths - ) - ssl_model_outputs = all_ssl_model_outputs[self.ssl_model_layer_idx] - ssl_model_output_lengths = all_ssl_model_output_lengths[ - self.ssl_model_layer_idx - ] - return ssl_model_outputs, ssl_model_output_lengths - - def get_ssl_embeddings(self, inputs): - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - return encoder_outputs - - -class AlignNet_U(AlignNet): - def __init__(self, model_input, *args, **kwargs): - super().__init__(model_input, *args, **kwargs) - - self.decoder_dnn = ProjectionWithUncertainty( - self.decoder_dnn_input_dim, - self.decoder_dnn_dim, - self.decoder_activation, - self.output_type, - 5 # fix this if one day we want to use categorical output - ) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - inputs: dict, which has the following keys: - - waveform has shape (batch, time) - - waveform_lengths has shape (batch) - - listener_ids has shape (batch) - - domain_ids has shape (batch) - """ - waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] - - # ssl model forward - ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( - waveform, waveform_lengths - ) - to_concat = [ssl_model_outputs] - time = ssl_model_outputs.size(1) - - # get listener embedding - if self.use_listener_modeling: - listener_ids = inputs["listener_idxs"] - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(listener_embs) - - # get domain embedding - if self.use_domain_modeling: - domain_ids = inputs["domain_idxs"] - domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) - domain_embs = torch.stack( - [domain_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(domain_embs) - - decoder_inputs = torch.cat(to_concat, dim=2) - - # decoder rnn - if self.use_decoder_rnn: - decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) - - # decoder dnn - decoder_outputs_mean, decoder_outputs_logvar = self.decoder_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # set outputs - # return lengths for masked loss calculation - ret = { - "waveform_lengths": waveform_lengths, - "frame_lengths": ssl_model_output_lengths, - } - if self.use_listener_modeling: - ret["ld_scores"] = decoder_outputs - else: - ret["mean_scores"] = decoder_outputs_mean - ret["mean_scores_logvar"] = decoder_outputs_logvar - - return ret - - def mean_listener_inference(self, inputs): - waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] - batch = waveform.size(0) - - # ssl model forward - ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( - waveform, waveform_lengths - ) - to_concat = [ssl_model_outputs] - time = ssl_model_outputs.size(1) - - # get listener embedding - if self.use_listener_modeling: - device = waveform.device - listener_ids = ( - torch.ones(batch, dtype=torch.long) * self.num_listeners - 1 - ).to( - device - ) # (bs) - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(listener_embs) - - # get domain embedding - if self.use_domain_modeling: - device = waveform.device - assert "domain_idxs" in inputs, "Must specify domain ID even in inference." - domain_ids = inputs["domain_idxs"] - domain_embs = self.domain_embeddings(domain_ids) # (batch, emb_dim) - domain_embs = torch.stack( - [domain_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(domain_embs) - - decoder_inputs = torch.cat(to_concat, dim=2) - - # decoder rnn - if self.use_decoder_rnn: - decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) - - # decoder dnn - decoder_outputs_mean, decoder_outputs_logvar = self.decoder_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - scores = torch.mean(decoder_outputs_mean.squeeze(-1), dim=1) - logvars = torch.mean(decoder_outputs_logvar.squeeze(-1), dim=1) - return {"scores": scores, "logvars": logvars} diff --git a/src/sheet/models/ldnet.py b/src/sheet/models/ldnet.py deleted file mode 100644 index 017193f..0000000 --- a/src/sheet/models/ldnet.py +++ /dev/null @@ -1,288 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -# LDNet model -# taken from: https://github.com/unilight/LDNet/blob/main/models/LDNet.py (written by myself) - -import math - -import torch -import torch.nn as nn -from sheet.modules.ldnet.modules import STRIDE, MobileNetV3ConvBlocks, Projection - - -class LDNet(torch.nn.Module): - def __init__( - self, - model_input: str, - # listener related - num_listeners: int, - listener_emb_dim: int, - use_mean_listener: bool, - # model related - activation: str, - encoder_type: str, - encoder_bneck_configs: list, - encoder_output_dim: int, - decoder_type: str, - decoder_dnn_dim: int, - output_type: str, - range_clipping: bool, - # mean net related - use_mean_net: bool = False, - mean_net_type: str = "ffn", - mean_net_dnn_dim: int = 64, - mean_net_range_clipping: bool = True, - ): - super().__init__() # this is needed! or else there will be an error. - self.use_mean_listener = use_mean_listener - self.output_type = output_type - - # only accept mag_sgram as input - assert model_input == "mag_sgram" - - # define listener embedding - self.num_listeners = num_listeners - self.listener_embeddings = nn.Embedding( - num_embeddings=num_listeners, embedding_dim=listener_emb_dim - ) - - # define activation - if activation == "ReLU": - self.activation = nn.ReLU - else: - raise NotImplementedError - - # define encoder - if encoder_type == "mobilenetv3": - self.encoder = MobileNetV3ConvBlocks( - encoder_bneck_configs, encoder_output_dim - ) - else: - raise NotImplementedError - - # define decoder - self.decoder_type = decoder_type - if decoder_type == "ffn": - decoder_dnn_input_dim = encoder_output_dim + listener_emb_dim - else: - raise NotImplementedError - # there is always dnn - self.decoder_dnn = Projection( - decoder_dnn_input_dim, - decoder_dnn_dim, - self.activation, - output_type, - range_clipping, - ) - - # define mean net - self.use_mean_net = use_mean_net - self.mean_net_type = mean_net_type - if use_mean_net: - if mean_net_type == "ffn": - mean_net_dnn_input_dim = encoder_output_dim - else: - raise NotImplementedError - # there is always dnn - self.mean_net_dnn = Projection( - mean_net_dnn_input_dim, - mean_net_dnn_dim, - self.activation, - output_type, - mean_net_range_clipping, - ) - - def _get_output_dim(self, input_size, num_layers, stride=STRIDE): - """ - calculate the final ouptut width (dim) of a CNN using the following formula - w_i = |_ (w_i-1 - 1) / stride + 1 _| - """ - output_dim = input_size - for _ in range(num_layers): - output_dim = math.floor((output_dim - 1) / STRIDE + 1) - return output_dim - - def get_num_params(self): - return sum(p.numel() for n, p in self.named_parameters()) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - mag_sgram has shape (batch, time, dim) - listener_ids has shape (batch) - """ - mag_sgram = inputs["mag_sgram"] - mag_sgram_lengths = inputs["mag_sgram_lengths"] - listener_ids = inputs["listener_idxs"] - - batch, time, _ = mag_sgram.shape - - # get listener embedding - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # encoder and inject listener embedding - mag_sgram = mag_sgram.unsqueeze(1) - encoder_outputs = self.encoder(mag_sgram) # (batch, ch, time, feat_dim) - encoder_outputs = encoder_outputs.view( - (batch, time, -1) - ) # (batch, time, feat_dim) - decoder_inputs = torch.cat( - [encoder_outputs, listener_embs], dim=-1 - ) # concat along feature dimension - - # mean net - if self.use_mean_net: - mean_net_inputs = encoder_outputs - if self.mean_net_type == "rnn": - mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_inputs) - else: - mean_net_outputs = mean_net_inputs - mean_net_outputs = self.mean_net_dnn( - mean_net_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # decoder - if self.decoder_type == "rnn": - decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) - else: - decoder_outputs = decoder_inputs - decoder_outputs = self.decoder_dnn( - decoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # define scores - ret = { - "frame_lengths": mag_sgram_lengths, - "mean_scores": mean_net_outputs if self.use_mean_net else None, - "ld_scores": decoder_outputs, - } - - return ret - - def mean_listener_inference(self, inputs): - """Mean listener inference. - Args: - mag_sgram has shape (batch, time, dim) - """ - - assert self.use_mean_listener - mag_sgram = inputs["mag_sgram"] - batch, time, dim = mag_sgram.shape - device = mag_sgram.device - - # get listener embedding - listener_id = (torch.ones(batch, dtype=torch.long) * self.num_listeners - 1).to( - device - ) # (bs) - listener_embs = self.listener_embeddings(listener_id) # (bs, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # encoder and inject listener embedding - mag_sgram = mag_sgram.unsqueeze(1) - encoder_outputs = self.encoder(mag_sgram) # (batch, ch, time, feat_dim) - encoder_outputs = encoder_outputs.view( - (batch, time, -1) - ) # (batch, time, feat_dim) - decoder_inputs = torch.cat( - [encoder_outputs, listener_embs], dim=-1 - ) # concat along feature dimension - - # decoder - if self.decoder_type == "rnn": - decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) - else: - decoder_outputs = decoder_inputs - decoder_outputs = self.decoder_dnn( - decoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # define scores - decoder_outputs = decoder_outputs.squeeze(-1) - scores = torch.mean(decoder_outputs, dim=1) - return {"scores": scores} - - def average_inference(self, mag_sgram, include_meanspk=False): - """Average listener inference. - Args: - mag_sgram has shape (batch, time, dim) - """ - - bs, time, _ = mag_sgram.shape - device = mag_sgram.device - if self.use_mean_listener and not include_meanspk: - actual_num_listeners = self.num_listeners - 1 - else: - actual_num_listeners = self.num_listeners - - # all listener ids - listener_id = ( - torch.arange(actual_num_listeners, dtype=torch.long) - .repeat(bs, 1) - .to(device) - ) # (bs, nj) - listener_embs = self.listener_embedding(listener_id) # (bs, nj, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=2 - ) # (bs, nj, time, feat_dim) - - # encoder and inject listener embedding - mag_sgram = mag_sgram.unsqueeze(1) - encoder_outputs = self.encoder(mag_sgram) # (batch, ch, time, feat_dim) - encoder_outputs = encoder_outputs.view( - (bs, time, -1) - ) # (batch, time, feat_dim) - decoder_inputs = torch.stack( - [encoder_outputs for i in range(actual_num_listeners)], dim=1 - ) # (bs, nj, time, feat_dim) - decoder_inputs = torch.cat( - [decoder_inputs, listener_embs], dim=-1 - ) # concat along feature dimension - - # mean net - if self.use_mean_net: - mean_net_inputs = encoder_outputs - if self.mean_net_type == "rnn": - mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_inputs) - else: - mean_net_outputs = mean_net_inputs - mean_net_outputs = self.mean_net_dnn( - mean_net_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # decoder - if self.decoder_type == "rnn": - decoder_outputs = decoder_inputs.view((bs * actual_num_listeners, time, -1)) - decoder_outputs, (h, c) = self.decoder_rnn(decoder_outputs) - else: - decoder_outputs = decoder_inputs - decoder_outputs = self.decoder_dnn( - decoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - decoder_outputs = decoder_outputs.view( - (bs, actual_num_listeners, time, -1) - ) # (bs, nj, time, 1/5) - - if self.output_type == "scalar": - decoder_outputs = decoder_outputs.squeeze(-1) # (bs, nj, time) - posterior_scores = torch.mean(decoder_outputs, dim=2) - ld_scores = torch.mean(decoder_outputs, dim=1) # (bs, time) - elif self.output_type == "categorical": - ld_posterior = torch.nn.functional.softmax(decoder_outputs, dim=-1) - ld_scores = torch.inner( - ld_posterior, torch.Tensor([1, 2, 3, 4, 5]).to(device) - ) - posterior_scores = torch.mean(ld_scores, dim=2) - ld_scores = torch.mean(ld_scores, dim=1) # (bs, time) - - # define scores - scores = torch.mean(ld_scores, dim=1) - - return {"scores": scores, "posterior_scores": posterior_scores} diff --git a/src/sheet/models/sslmos.py b/src/sheet/models/sslmos.py deleted file mode 100644 index 453485d..0000000 --- a/src/sheet/models/sslmos.py +++ /dev/null @@ -1,467 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -# SSLMOS model -# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper) - -import math - -import torch -import torch.nn as nn -from sheet.modules.ldnet.modules import Projection, ProjectionWithUncertainty - - -class SSLMOS(torch.nn.Module): - def __init__( - self, - # dummy, for signature need - model_input: str, - # model related - ssl_module: str = "s3prl", - s3prl_name: str = "wav2vec2", - ssl_model_output_dim: int = 768, - ssl_model_layer_idx: int = -1, - # mean net related - mean_net_dnn_dim: int = 64, - mean_net_output_type: str = "scalar", - mean_net_output_dim: int = 5, - mean_net_output_step: float = 0.25, - mean_net_range_clipping: bool = True, - # listener related - use_listener_modeling: bool = False, - num_listeners: int = None, - listener_emb_dim: int = None, - use_mean_listener: bool = True, - # decoder related - decoder_type: str = "ffn", - decoder_dnn_dim: int = 64, - output_type: str = "scalar", - range_clipping: bool = True, - # additional head (for RAMP) - use_additional_categorical_head: bool = False, - categorical_head_dnn_dim: int = 64, - categorical_head_output_dim: int = 17, - categorical_head_output_step: float = 0.25, - categorical_head_range_clipping: bool = True, - # dummy, for signature need - num_domains: int = None, - ): - super().__init__() # this is needed! or else there will be an error. - self.use_mean_listener = use_mean_listener - self.output_type = output_type - self.use_additional_categorical_head = use_additional_categorical_head - - # define listener embedding - self.use_listener_modeling = use_listener_modeling - - # define ssl model - if ssl_module == "s3prl": - from s3prl.nn import S3PRLUpstream - - if s3prl_name in S3PRLUpstream.available_names(): - self.ssl_model = S3PRLUpstream(s3prl_name) - self.ssl_model_layer_idx = ssl_model_layer_idx - else: - raise NotImplementedError - - # default uses ffn type mean net - self.mean_net_dnn = Projection( - ssl_model_output_dim, - mean_net_dnn_dim, - nn.ReLU, - mean_net_output_type, - mean_net_output_dim, - mean_net_output_step, - mean_net_range_clipping, - ) - - # additional categorical head (for RAMP) - if use_additional_categorical_head: - # make sure mean net is not categorical - assert ( - mean_net_output_type != "categorical" - ), "mean net cannot be categorical if additional categorical head is used" - self.categorical_head = Projection( - ssl_model_output_dim, - mean_net_dnn_dim, - nn.ReLU, - "categorical", - categorical_head_output_dim, - categorical_head_output_step, - categorical_head_range_clipping, - ) - - # listener modeling related - self.use_listener_modeling = use_listener_modeling - if use_listener_modeling: - self.num_listeners = num_listeners - self.listener_embeddings = nn.Embedding( - num_embeddings=num_listeners, embedding_dim=listener_emb_dim - ) - # define decoder - self.decoder_type = decoder_type - if decoder_type == "ffn": - decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim - else: - raise NotImplementedError - # there is always dnn - self.decoder_dnn = Projection( - decoder_dnn_input_dim, - decoder_dnn_dim, - self.activation, - output_type, - range_clipping, - ) - - def get_num_params(self): - return sum(p.numel() for n, p in self.named_parameters()) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - waveform has shape (batch, time) - waveform_lengths has shape (batch) - listener_ids has shape (batch) - """ - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - batch, time = waveform.shape - - # get listener embedding - if self.use_listener_modeling: - listener_ids = inputs["listener_idxs"] - # NOTE(unlight): not tested yet - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # ssl model forward - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] - - # inject listener embedding - if self.use_listener_modeling: - # NOTE(unlight): not tested yet - encoder_outputs = encoder_outputs.view( - (batch, time, -1) - ) # (batch, time, feat_dim) - decoder_inputs = torch.cat( - [encoder_outputs, listener_embs], dim=-1 - ) # concat along feature dimension - else: - decoder_inputs = encoder_outputs - - # masked mean pooling - # masks = make_non_pad_mask(encoder_outputs_lens) - # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] - # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) - - # mean net - mean_net_outputs = self.mean_net_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # additional categorical head - if self.use_additional_categorical_head: - categorical_head_outputs = self.categorical_head( - decoder_inputs - ) # [batch, time, categorical steps] - - # decoder - if self.use_listener_modeling: - if self.decoder_type == "rnn": - decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) - else: - decoder_outputs = decoder_inputs - decoder_outputs = self.decoder_dnn( - decoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # set outputs - # return lengths for masked loss calculation - ret = { - "waveform_lengths": waveform_lengths, - "frame_lengths": encoder_outputs_lens, - } - - # define scores - ret["mean_scores"] = mean_net_outputs - ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None - if self.use_additional_categorical_head: - ret["categorical_head_scores"] = categorical_head_outputs - - return ret - - def mean_net_inference(self, inputs): - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - # ssl model forward - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - - # mean net - decoder_inputs = encoder_outputs - mean_net_outputs = self.mean_net_dnn( - decoder_inputs, inference=True - ) # [batch, time, 1 (scalar) / 5 (categorical)] - mean_net_outputs = mean_net_outputs.squeeze(-1) - scores = torch.mean(mean_net_outputs.to(torch.float), dim=1) # [batch] - - ret = {"ssl_embeddings": encoder_outputs, "scores": scores} - - if self.use_additional_categorical_head: - ret["confidences"] = self.categorical_head( - decoder_inputs - ) # [batch, time, categorical steps] - - return ret - - def mean_net_inference_p1(self, waveform, waveform_lengths): - # ssl model forward - all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - return encoder_outputs - - def mean_net_inference_p2(self, encoder_outputs): - # mean net - mean_net_outputs = self.mean_net_dnn( - encoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - mean_net_outputs = mean_net_outputs.squeeze(-1) - scores = torch.mean(mean_net_outputs, dim=1) - - return scores - - def get_ssl_embeddings(self, inputs): - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - return encoder_outputs - - -class SSLMOS_U(SSLMOS): - def __init__( - self, - # dummy, for signature need - model_input: str, - # model related - ssl_module: str = "s3prl", - s3prl_name: str = "wav2vec2", - ssl_model_output_dim: int = 768, - ssl_model_layer_idx: int = -1, - # mean net related - mean_net_dnn_dim: int = 64, - mean_net_output_type: str = "scalar", - mean_net_output_dim: int = 5, - mean_net_output_step: float = 0.25, - mean_net_range_clipping: bool = True, - # listener related - use_listener_modeling: bool = False, - num_listeners: int = None, - listener_emb_dim: int = None, - use_mean_listener: bool = True, - # decoder related - decoder_type: str = "ffn", - decoder_dnn_dim: int = 64, - output_type: str = "scalar", - range_clipping: bool = True, - # additional head (for RAMP) - use_additional_categorical_head: bool = False, - categorical_head_dnn_dim: int = 64, - categorical_head_output_dim: int = 17, - categorical_head_output_step: float = 0.25, - categorical_head_range_clipping: bool = True, - # dummy, for signature need - num_domains: int = None, - ): - super().__init__() # this is needed! or else there will be an error. - self.use_mean_listener = use_mean_listener - self.output_type = output_type - self.use_additional_categorical_head = use_additional_categorical_head - - # define listener embedding - self.use_listener_modeling = use_listener_modeling - - # define ssl model - if ssl_module == "s3prl": - from s3prl.nn import S3PRLUpstream - - if s3prl_name in S3PRLUpstream.available_names(): - self.ssl_model = S3PRLUpstream(s3prl_name) - self.ssl_model_layer_idx = ssl_model_layer_idx - else: - raise NotImplementedError - - # default uses ffn type mean net - self.mean_net_dnn = ProjectionWithUncertainty( - ssl_model_output_dim, - mean_net_dnn_dim, - nn.ReLU, - mean_net_output_type, - mean_net_output_dim, - mean_net_output_step, - mean_net_range_clipping, - ) - - # additional categorical head (for RAMP) - if use_additional_categorical_head: - # make sure mean net is not categorical - assert ( - mean_net_output_type != "categorical" - ), "mean net cannot be categorical if additional categorical head is used" - self.categorical_head = Projection( - ssl_model_output_dim, - mean_net_dnn_dim, - nn.ReLU, - "categorical", - categorical_head_output_dim, - categorical_head_output_step, - categorical_head_range_clipping, - ) - - # listener modeling related - self.use_listener_modeling = use_listener_modeling - if use_listener_modeling: - self.num_listeners = num_listeners - self.listener_embeddings = nn.Embedding( - num_embeddings=num_listeners, embedding_dim=listener_emb_dim - ) - # define decoder - self.decoder_type = decoder_type - if decoder_type == "ffn": - decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim - else: - raise NotImplementedError - # there is always dnn - self.decoder_dnn = Projection( - decoder_dnn_input_dim, - decoder_dnn_dim, - self.activation, - output_type, - range_clipping, - ) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - waveform has shape (batch, time) - waveform_lengths has shape (batch) - listener_ids has shape (batch) - """ - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - batch, time = waveform.shape - - # get listener embedding - if self.use_listener_modeling: - listener_ids = inputs["listener_idxs"] - # NOTE(unlight): not tested yet - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # ssl model forward - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] - - # inject listener embedding - if self.use_listener_modeling: - # NOTE(unlight): not tested yet - encoder_outputs = encoder_outputs.view( - (batch, time, -1) - ) # (batch, time, feat_dim) - decoder_inputs = torch.cat( - [encoder_outputs, listener_embs], dim=-1 - ) # concat along feature dimension - else: - decoder_inputs = encoder_outputs - - # masked mean pooling - # masks = make_non_pad_mask(encoder_outputs_lens) - # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] - # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) - - # mean net - mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # additional categorical head - if self.use_additional_categorical_head: - categorical_head_outputs = self.categorical_head( - decoder_inputs - ) # [batch, time, categorical steps] - - # decoder - if self.use_listener_modeling: - if self.decoder_type == "rnn": - decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) - else: - decoder_outputs = decoder_inputs - decoder_outputs = self.decoder_dnn( - decoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # set outputs - # return lengths for masked loss calculation - ret = { - "waveform_lengths": waveform_lengths, - "frame_lengths": encoder_outputs_lens, - } - - # define scores - ret["mean_scores"] = mean_net_outputs_mean - ret["mean_scores_logvar"] = mean_net_outputs_logvar - ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None - if self.use_additional_categorical_head: - ret["categorical_head_scores"] = categorical_head_outputs - - return ret - - def mean_net_inference(self, inputs): - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - # ssl model forward - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - - # mean net - decoder_inputs = encoder_outputs - mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( - decoder_inputs, inference=True - ) # [batch, time, 1 (scalar) / 5 (categorical)] - mean_net_outputs_mean = mean_net_outputs_mean.squeeze(-1) - mean_net_outputs_logvar = mean_net_outputs_logvar.squeeze(-1) - scores = torch.mean(mean_net_outputs_mean, dim=1) # [batch] - logvars = torch.mean(mean_net_outputs_logvar, dim=1) # [batch] - - ret = {"ssl_embeddings": encoder_outputs, "scores": scores, "logvars": logvars} - - if self.use_additional_categorical_head: - ret["confidences"] = self.categorical_head( - decoder_inputs - ) # [batch, time, categorical steps] - - return ret diff --git a/src/sheet/models/sslmos_u.py b/src/sheet/models/sslmos_u.py deleted file mode 100644 index 5e799f2..0000000 --- a/src/sheet/models/sslmos_u.py +++ /dev/null @@ -1,256 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2025 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -# SSLMOS model which can output uncertainty -# modified from: https://github.com/nii-yamagishilab/mos-finetune-ssl/blob/main/mos_fairseq.py (written by Erica Cooper) - -import math - -import torch -import torch.nn as nn -from sheet.modules.ldnet.modules import ProjectionWithUncertainty - - -class SSLMOS_U(torch.nn.Module): - def __init__( - self, - # dummy, for signature need - model_input: str, - # model related - ssl_module: str = "s3prl", - s3prl_name: str = "wav2vec2", - ssl_model_output_dim: int = 768, - ssl_model_layer_idx: int = -1, - # mean net related - mean_net_dnn_dim: int = 64, - mean_net_output_type: str = "scalar", - mean_net_output_dim: int = 5, - mean_net_output_step: float = 0.25, - mean_net_range_clipping: bool = True, - # listener related - use_listener_modeling: bool = False, - num_listeners: int = None, - listener_emb_dim: int = None, - use_mean_listener: bool = True, - # decoder related - decoder_type: str = "ffn", - decoder_dnn_dim: int = 64, - output_type: str = "scalar", - range_clipping: bool = True, - # additional head (for RAMP) - use_additional_categorical_head: bool = False, - categorical_head_dnn_dim: int = 64, - categorical_head_output_dim: int = 17, - categorical_head_output_step: float = 0.25, - categorical_head_range_clipping: bool = True, - # dummy, for signature need - num_domains: int = None, - ): - super().__init__() # this is needed! or else there will be an error. - self.use_mean_listener = use_mean_listener - self.output_type = output_type - self.use_additional_categorical_head = use_additional_categorical_head - - # define listener embedding - self.use_listener_modeling = use_listener_modeling - - # define ssl model - if ssl_module == "s3prl": - from s3prl.nn import S3PRLUpstream - - if s3prl_name in S3PRLUpstream.available_names(): - self.ssl_model = S3PRLUpstream(s3prl_name) - self.ssl_model_layer_idx = ssl_model_layer_idx - else: - raise NotImplementedError - - # default uses ffn type mean net - self.mean_net_dnn = ProjectionWithUncertainty( - ssl_model_output_dim, - mean_net_dnn_dim, - nn.ReLU, - mean_net_output_type, - mean_net_output_dim, - mean_net_output_step, - mean_net_range_clipping, - ) - - # additional categorical head (for RAMP) - if use_additional_categorical_head: - # make sure mean net is not categorical - assert ( - mean_net_output_type != "categorical" - ), "mean net cannot be categorical if additional categorical head is used" - self.categorical_head = Projection( - ssl_model_output_dim, - mean_net_dnn_dim, - nn.ReLU, - "categorical", - categorical_head_output_dim, - categorical_head_output_step, - categorical_head_range_clipping, - ) - - # listener modeling related - self.use_listener_modeling = use_listener_modeling - if use_listener_modeling: - self.num_listeners = num_listeners - self.listener_embeddings = nn.Embedding( - num_embeddings=num_listeners, embedding_dim=listener_emb_dim - ) - # define decoder - self.decoder_type = decoder_type - if decoder_type == "ffn": - decoder_dnn_input_dim = ssl_model_output_dim + listener_emb_dim - else: - raise NotImplementedError - # there is always dnn - self.decoder_dnn = Projection( - decoder_dnn_input_dim, - decoder_dnn_dim, - self.activation, - output_type, - range_clipping, - ) - - def get_num_params(self): - return sum(p.numel() for n, p in self.named_parameters()) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - waveform has shape (batch, time) - waveform_lengths has shape (batch) - listener_ids has shape (batch) - """ - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - batch, time = waveform.shape - - # get listener embedding - if self.use_listener_modeling: - listener_ids = inputs["listener_idxs"] - # NOTE(unlight): not tested yet - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # ssl model forward - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - encoder_outputs_lens = all_encoder_outputs_lens[self.ssl_model_layer_idx] - - # inject listener embedding - if self.use_listener_modeling: - # NOTE(unlight): not tested yet - encoder_outputs = encoder_outputs.view( - (batch, time, -1) - ) # (batch, time, feat_dim) - decoder_inputs = torch.cat( - [encoder_outputs, listener_embs], dim=-1 - ) # concat along feature dimension - else: - decoder_inputs = encoder_outputs - - # masked mean pooling - # masks = make_non_pad_mask(encoder_outputs_lens) - # masks = masks.unsqueeze(-1).to(decoder_inputs.device) # [B, max_time, 1] - # decoder_inputs = torch.sum(decoder_inputs * masks, dim=1) / encoder_outputs_lens.unsqueeze(-1) - - # mean net - mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # additional categorical head - if self.use_additional_categorical_head: - categorical_head_outputs = self.categorical_head( - decoder_inputs - ) # [batch, time, categorical steps] - - # decoder - if self.use_listener_modeling: - if self.decoder_type == "rnn": - decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) - else: - decoder_outputs = decoder_inputs - decoder_outputs = self.decoder_dnn( - decoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # set outputs - # return lengths for masked loss calculation - ret = { - "waveform_lengths": waveform_lengths, - "frame_lengths": encoder_outputs_lens, - } - - # define scores - ret["mean_scores"] = mean_net_outputs_mean - ret["mean_scores_logvar"] = mean_net_outputs_logvar - ret["ld_scores"] = decoder_outputs if self.use_listener_modeling else None - if self.use_additional_categorical_head: - ret["categorical_head_scores"] = categorical_head_outputs - - return ret - - def mean_net_inference(self, inputs): - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - # ssl model forward - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - - # mean net - decoder_inputs = encoder_outputs - mean_net_outputs_mean, mean_net_outputs_logvar = self.mean_net_dnn( - decoder_inputs, inference=True - ) # [batch, time, 1 (scalar) / 5 (categorical)] - mean_net_outputs_mean = mean_net_outputs_mean.squeeze(-1) - mean_net_outputs_logvar = mean_net_outputs_logvar.squeeze(-1) - scores = torch.mean(mean_net_outputs_mean, dim=1) # [batch] - logvars = torch.mean(mean_net_outputs_logvar, dim=1) # [batch] - - ret = {"ssl_embeddings": encoder_outputs, "scores": scores, "logvars": logvars} - - if self.use_additional_categorical_head: - ret["confidences"] = self.categorical_head( - decoder_inputs - ) # [batch, time, categorical steps] - - return ret - - def mean_net_inference_p1(self, waveform, waveform_lengths): - # ssl model forward - all_encoder_outputs, _ = self.ssl_model(waveform, waveform_lengths) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - return encoder_outputs - - def mean_net_inference_p2(self, encoder_outputs): - # mean net - mean_net_outputs = self.mean_net_dnn( - encoder_outputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - mean_net_outputs = mean_net_outputs.squeeze(-1) - scores = torch.mean(mean_net_outputs, dim=1) - - return scores - - def get_ssl_embeddings(self, inputs): - waveform = inputs["waveform"] - waveform_lengths = inputs["waveform_lengths"] - - all_encoder_outputs, all_encoder_outputs_lens = self.ssl_model( - waveform, waveform_lengths - ) - encoder_outputs = all_encoder_outputs[self.ssl_model_layer_idx] - return encoder_outputs \ No newline at end of file diff --git a/src/sheet/models/utmos.py b/src/sheet/models/utmos.py deleted file mode 100644 index 44f7672..0000000 --- a/src/sheet/models/utmos.py +++ /dev/null @@ -1,299 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -# UTMOS model -# modified from: https://github.com/sarulab-speech/UTMOS22/tree/master/strong - -import math - -import torch -import torch.nn as nn -from sheet.modules.ldnet.modules import Projection -from sheet.modules.utils import make_non_pad_mask - - -class UTMOS(torch.nn.Module): - def __init__( - self, - model_input: str, - # model related - ssl_module: str, - s3prl_name: str, - ssl_model_output_dim: int, - ssl_model_layer_idx: int, - # phoneme and reference related - use_phoneme: bool = True, - phoneme_encoder_dim: int = 256, - phoneme_encoder_emb_dim: int = 256, - phoneme_encoder_out_dim: int = 256, - phoneme_encoder_n_lstm_layers: int = 3, - phoneme_encoder_vocab_size: int = 300, - use_reference: bool = True, - # listener related - use_listener_modeling: bool = False, - num_listeners: int = None, - listener_emb_dim: int = None, - use_mean_listener: bool = True, - # decoder related - use_decoder_rnn: bool = True, - decoder_rnn_dim: int = 512, - decoder_dnn_dim: int = 2048, - decoder_activation: str = "ReLU", - output_type: str = "scalar", - range_clipping: bool = True, - num_domains: int = None, - ): - super().__init__() # this is needed! or else there will be an error. - self.use_mean_listener = use_mean_listener - self.output_type = output_type - - # define listener embedding - self.use_listener_modeling = use_listener_modeling - - # define ssl model - if ssl_module == "s3prl": - from s3prl.nn import S3PRLUpstream - - if s3prl_name in S3PRLUpstream.available_names(): - self.ssl_model = S3PRLUpstream(s3prl_name) - self.ssl_model_layer_idx = ssl_model_layer_idx - else: - raise NotImplementedError - decoder_input_dim = ssl_model_output_dim - - # define phoneme encoder - self.use_phoneme = use_phoneme - self.use_reference = use_reference - if self.use_phoneme: - self.phoneme_embedding = nn.Embedding( - phoneme_encoder_vocab_size, phoneme_encoder_emb_dim - ) - self.phoneme_encoder_lstm = nn.LSTM( - phoneme_encoder_emb_dim, - phoneme_encoder_dim, - num_layers=phoneme_encoder_n_lstm_layers, - dropout=0.1, - bidirectional=True, - ) - if self.use_reference: - - phoneme_encoder_linear_input_dim = ( - phoneme_encoder_dim + phoneme_encoder_dim - ) - else: - phoneme_encoder_linear_input_dim = phoneme_encoder_dim - self.phoneme_encoder_linear = nn.Sequential( - nn.Linear(phoneme_encoder_linear_input_dim, phoneme_encoder_out_dim), - nn.ReLU(), - ) - decoder_input_dim += phoneme_encoder_out_dim - - # NOTE(unlight): ignore domain embedding right now - - # listener modeling related - self.use_listener_modeling = use_listener_modeling - if use_listener_modeling: - self.num_listeners = num_listeners - self.listener_embeddings = nn.Embedding( - num_embeddings=num_listeners, embedding_dim=listener_emb_dim - ) - decoder_input_dim += listener_emb_dim - - # define decoder rnn - self.use_decoder_rnn = use_decoder_rnn - if self.use_decoder_rnn: - self.decoder_rnn = nn.LSTM( - input_size=decoder_input_dim, - hidden_size=decoder_rnn_dim, - num_layers=1, - batch_first=True, - bidirectional=True, - ) - decoder_dnn_input_dim = decoder_rnn_dim * 2 - else: - decoder_dnn_input_dim = decoder_input_dim - - # define activation - if decoder_activation == "ReLU": - self.decoder_activation = nn.ReLU - else: - raise NotImplementedError - - # there is always decoder dnn - self.decoder_dnn = Projection( - decoder_dnn_input_dim, - decoder_dnn_dim, - self.decoder_activation, - output_type, - range_clipping, - ) - - def get_num_params(self): - return sum(p.numel() for n, p in self.named_parameters()) - - def forward(self, inputs): - """Calculate forward propagation. - Args: - inputs: dict, which has the following keys: - - waveform has shape (batch, time) - - waveform_lengths has shape (batch) - - listener_ids has shape (batch) - """ - waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] - - # ssl model forward - ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( - waveform, waveform_lengths - ) - to_concat = [ssl_model_outputs] - time = ssl_model_outputs.size(1) - - # phoneme encoder forward - if self.use_phoneme: - phoneme_encoder_outputs = self.phoneme_encoder_forward(inputs, time) - to_concat.append(phoneme_encoder_outputs) - - # get listener embedding - if self.use_listener_modeling: - listener_ids = inputs["listener_idxs"] - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(listener_embs) - - decoder_inputs = torch.cat(to_concat, dim=2) - - # decoder rnn - if self.use_decoder_rnn: - decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) - - # decoder dnn - decoder_outputs = self.decoder_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - # set outputs - # return lengths for masked loss calculation - ret = { - "waveform_lengths": waveform_lengths, - "frame_lengths": ssl_model_output_lengths, - } - if self.use_listener_modeling: - ret["ld_scores"] = decoder_outputs - else: - ret["mean_scores"] = decoder_outputs - - return ret - - def mean_listener_inference(self, inputs): - waveform, waveform_lengths = inputs["waveform"], inputs["waveform_lengths"] - batch = waveform.size(0) - - # ssl model forward - ssl_model_outputs, ssl_model_output_lengths = self.ssl_model_forward( - waveform, waveform_lengths - ) - to_concat = [ssl_model_outputs] - time = ssl_model_outputs.size(1) - - # phoneme encoder forward - if self.use_phoneme: - phoneme_encoder_outputs = self.phoneme_encoder_forward(inputs, time) - to_concat.append(phoneme_encoder_outputs) - - # get listener embedding - if self.use_listener_modeling: - device = waveform.device - listener_ids = ( - torch.ones(batch, dtype=torch.long) * self.num_listeners - 1 - ).to( - device - ) # (bs) - listener_embs = self.listener_embeddings(listener_ids) # (batch, emb_dim) - listener_embs = torch.stack( - [listener_embs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - # NOTE(unilight): is this needed? - # encoder_outputs = encoder_outputs.view( - # (batch, time, -1) - # ) # (batch, time, feat_dim) - to_concat.append(listener_embs) - - decoder_inputs = torch.cat(to_concat, dim=2) - - # decoder rnn - if self.use_decoder_rnn: - decoder_inputs, (h, c) = self.decoder_rnn(decoder_inputs) - - # decoder dnn - decoder_outputs = self.decoder_dnn( - decoder_inputs - ) # [batch, time, 1 (scalar) / 5 (categorical)] - - scores = torch.mean(decoder_outputs.squeeze(-1), dim=1) - return {"scores": scores} - - def ssl_model_forward(self, waveform, waveform_lengths): - all_ssl_model_outputs, all_ssl_model_output_lengths = self.ssl_model( - waveform, waveform_lengths - ) - ssl_model_outputs = all_ssl_model_outputs[self.ssl_model_layer_idx] - ssl_model_output_lengths = all_ssl_model_output_lengths[ - self.ssl_model_layer_idx - ] - return ssl_model_outputs, ssl_model_output_lengths - - def phoneme_encoder_forward(self, inputs, time): - phoneme, phoneme_lengths = inputs["phoneme_idxs"], inputs["phoneme_lengths"] - phoneme_embeddings = self.phoneme_embedding(phoneme) - phoneme_embeddings = torch.nn.utils.rnn.pack_padded_sequence( - phoneme_embeddings, phoneme_lengths, batch_first=True, enforce_sorted=False - ) - _, (phoneme_encoder_outputs, _) = self.phoneme_encoder_lstm(phoneme_embeddings) - phoneme_encoder_outputs = ( - phoneme_encoder_outputs[-1] + phoneme_encoder_outputs[0] - ) - if self.use_reference: - assert ( - "reference_idxs" in inputs and "reference_lengths" in inputs - ), "reference and reference_lenghts should not be None when use_reference is True" - reference, reference_lengths = ( - inputs["reference_idxs"], - inputs["reference_lengths"], - ) - reference_embeddings = self.phoneme_embedding(reference) - reference_embeddings = torch.nn.utils.rnn.pack_padded_sequence( - reference_embeddings, - reference_lengths, - batch_first=True, - enforce_sorted=False, - ) - _, (reference_encoder_outputs, _) = self.phoneme_encoder_lstm( - reference_embeddings - ) - reference_encoder_outputs = ( - reference_encoder_outputs[-1] + reference_encoder_outputs[0] - ) - phoneme_encoder_outputs = self.phoneme_encoder_linear( - torch.cat([phoneme_encoder_outputs, reference_encoder_outputs], 1) - ) - else: - phoneme_encoder_outputs = self.phoneme_encoder_linear( - phoneme_encoder_outputs - ) - - # expand - phoneme_encoder_outputs = torch.stack( - [phoneme_encoder_outputs for i in range(time)], dim=1 - ) # (batch, time, feat_dim) - - return phoneme_encoder_outputs diff --git a/src/sheet/modules/__init__.py b/src/sheet/modules/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/sheet/modules/ldnet/__init__.py b/src/sheet/modules/ldnet/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/sheet/modules/ldnet/mobilenetv2.py b/src/sheet/modules/ldnet/mobilenetv2.py deleted file mode 100644 index 87af658..0000000 --- a/src/sheet/modules/ldnet/mobilenetv2.py +++ /dev/null @@ -1,240 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Any, Callable, List, Optional - -import torch -from torch import Tensor, nn - -__all__ = ["MobileNetV2", "mobilenet_v2"] - - -model_urls = { - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", -} - - -def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -class ConvBNActivation(nn.Sequential): - def __init__( - self, - in_planes: int, - out_planes: int, - kernel_size: int = 3, - stride: int = 1, - groups: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - activation_layer: Optional[Callable[..., nn.Module]] = None, - dilation: int = 1, - ) -> None: - padding = (kernel_size - 1) // 2 * dilation - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if activation_layer is None: - activation_layer = nn.ReLU6 - super().__init__( - # NOTE(unilight): stride only operates on the last axis - nn.Conv2d( - in_planes, - out_planes, - kernel_size, - (1, stride), - padding, - dilation=dilation, - groups=groups, - bias=False, - ), - norm_layer(out_planes), - activation_layer(inplace=True), - ) - self.out_channels = out_planes - - -# necessary for backwards compatibility -ConvBNReLU = ConvBNActivation - - -class InvertedResidual(nn.Module): - def __init__( - self, - inp: int, - oup: int, - stride: int, - expand_ratio: int, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2, 3] - - if norm_layer is None: - norm_layer = nn.BatchNorm2d - - hidden_dim = int(round(inp * expand_ratio)) - self.use_res_connect = self.stride == 1 and inp == oup - - layers: List[nn.Module] = [] - if expand_ratio != 1: - # pw - layers.append( - ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer) - ) - layers.extend( - [ - # dw - ConvBNReLU( - hidden_dim, - hidden_dim, - stride=stride, - groups=hidden_dim, - norm_layer=norm_layer, - ), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - norm_layer(oup), - ] - ) - self.conv = nn.Sequential(*layers) - self.out_channels = oup - self._is_cn = stride > 1 - - def forward(self, x: Tensor) -> Tensor: - if self.use_res_connect: - return x + self.conv(x) - else: - return self.conv(x) - - -class MobileNetV2(nn.Module): - def __init__( - self, - num_classes: int = 1000, - width_mult: float = 1.0, - inverted_residual_setting: Optional[List[List[int]]] = None, - round_nearest: int = 8, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - """ - MobileNet V2 main class - - Args: - num_classes (int): Number of classes - width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount - inverted_residual_setting: Network structure - round_nearest (int): Round the number of channels in each layer to be a multiple of this number - Set to 1 to turn off rounding - block: Module specifying inverted residual building block for mobilenet - norm_layer: Module specifying the normalization layer to use - - """ - super(MobileNetV2, self).__init__() - - if block is None: - block = InvertedResidual - - if norm_layer is None: - norm_layer = nn.BatchNorm2d - - input_channel = 32 - last_channel = 1280 - - if inverted_residual_setting is None: - inverted_residual_setting = [ - # t, c, n, s - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 320, 1, 1], - ] - - # only check the first element, assuming user knows t,c,n,s are required - if ( - len(inverted_residual_setting) == 0 - or len(inverted_residual_setting[0]) != 4 - ): - raise ValueError( - "inverted_residual_setting should be non-empty " - "or a 4-element list, got {}".format(inverted_residual_setting) - ) - - # building first layer - input_channel = _make_divisible(input_channel * width_mult, round_nearest) - self.last_channel = _make_divisible( - last_channel * max(1.0, width_mult), round_nearest - ) - features: List[nn.Module] = [ - ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer) - ] - # building inverted residual blocks - for t, c, n, s in inverted_residual_setting: - output_channel = _make_divisible(c * width_mult, round_nearest) - for i in range(n): - stride = s if i == 0 else 1 - features.append( - block( - input_channel, - output_channel, - stride, - expand_ratio=t, - norm_layer=norm_layer, - ) - ) - input_channel = output_channel - # building last several layers - features.append( - ConvBNReLU( - input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer - ) - ) - # make it nn.Sequential - self.features = nn.Sequential(*features) - - # building classifier - self.classifier = nn.Sequential( - nn.Dropout(0.2), - nn.Linear(self.last_channel, num_classes), - ) - - # weight initialization - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.zeros_(m.bias) - - def _forward_impl(self, x: Tensor) -> Tensor: - # This exists since TorchScript doesn't support inheritance, so the superclass method - # (this one) needs to have a name other than `forward` that can be accessed in a subclass - x = self.features(x) - # Cannot use "squeeze" as batch-size can be 1 - x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) - x = torch.flatten(x, 1) - x = self.classifier(x) - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) diff --git a/src/sheet/modules/ldnet/mobilenetv3.py b/src/sheet/modules/ldnet/mobilenetv3.py deleted file mode 100644 index 5759712..0000000 --- a/src/sheet/modules/ldnet/mobilenetv3.py +++ /dev/null @@ -1,341 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence - -import torch -from torch import Tensor, nn -from torch.nn import functional as F - -from .mobilenetv2 import ConvBNActivation, _make_divisible - -__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] - - -model_urls = { - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", -} - - -class SqueezeExcitation(nn.Module): - # Implemented as described at Figure 4 of the MobileNetV3 paper - def __init__(self, input_channels: int, squeeze_factor: int = 4): - super().__init__() - squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) - self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) - self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) - - def _scale(self, input: Tensor, inplace: bool) -> Tensor: - scale = F.adaptive_avg_pool2d(input, 1) - scale = self.fc1(scale) - scale = self.relu(scale) - scale = self.fc2(scale) - return F.hardsigmoid(scale, inplace=inplace) - - def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input, True) - return scale * input - - -class InvertedResidualConfig: - # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper - def __init__( - self, - input_channels: int, - kernel: int, - expanded_channels: int, - out_channels: int, - use_se: bool, - activation: str, - stride: int, - dilation: int, - width_mult: float, - ): - self.input_channels = self.adjust_channels(input_channels, width_mult) - self.kernel = kernel - self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) - self.out_channels = self.adjust_channels(out_channels, width_mult) - self.use_se = use_se - self.use_hs = activation == "HS" - self.stride = stride - self.dilation = dilation - - @staticmethod - def adjust_channels(channels: int, width_mult: float): - return _make_divisible(channels * width_mult, 8) - - -class InvertedResidual(nn.Module): - # Implemented as described at section 5 of MobileNetV3 paper - def __init__( - self, - cnf: InvertedResidualConfig, - norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation, - ): - super().__init__() - if not (1 <= cnf.stride <= 3): - raise ValueError("illegal stride value") - - self.use_res_connect = ( - cnf.stride == 1 and cnf.input_channels == cnf.out_channels - ) - - layers: List[nn.Module] = [] - activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU - - # expand - if cnf.expanded_channels != cnf.input_channels: - layers.append( - ConvBNActivation( - cnf.input_channels, - cnf.expanded_channels, - kernel_size=1, - norm_layer=norm_layer, - activation_layer=activation_layer, - ) - ) - - # depthwise - stride = 1 if cnf.dilation > 1 else cnf.stride - layers.append( - ConvBNActivation( - cnf.expanded_channels, - cnf.expanded_channels, - kernel_size=cnf.kernel, - stride=stride, - dilation=cnf.dilation, - groups=cnf.expanded_channels, - norm_layer=norm_layer, - activation_layer=activation_layer, - ) - ) - if cnf.use_se: - layers.append(se_layer(cnf.expanded_channels)) - - # project - layers.append( - ConvBNActivation( - cnf.expanded_channels, - cnf.out_channels, - kernel_size=1, - norm_layer=norm_layer, - activation_layer=nn.Identity, - ) - ) - - self.block = nn.Sequential(*layers) - self.out_channels = cnf.out_channels - self._is_cn = cnf.stride > 1 - - def forward(self, input: Tensor) -> Tensor: - result = self.block(input) - if self.use_res_connect: - result += input - return result - - -class MobileNetV3(nn.Module): - - def __init__( - self, - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any - ) -> None: - """ - MobileNet V3 main class - - Args: - inverted_residual_setting (List[InvertedResidualConfig]): Network structure - last_channel (int): The number of channels on the penultimate layer - num_classes (int): Number of classes - block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet - norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use - """ - super().__init__() - - if not inverted_residual_setting: - raise ValueError("The inverted_residual_setting should not be empty") - elif not ( - isinstance(inverted_residual_setting, Sequence) - and all( - [ - isinstance(s, InvertedResidualConfig) - for s in inverted_residual_setting - ] - ) - ): - raise TypeError( - "The inverted_residual_setting should be List[InvertedResidualConfig]" - ) - - if block is None: - block = InvertedResidual - - if norm_layer is None: - norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) - - layers: List[nn.Module] = [] - - # building first layer - firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append( - ConvBNActivation( - 3, - firstconv_output_channels, - kernel_size=3, - stride=2, - norm_layer=norm_layer, - activation_layer=nn.Hardswish, - ) - ) - - # building inverted residual blocks - for cnf in inverted_residual_setting: - layers.append(block(cnf, norm_layer)) - - # building last several layers - lastconv_input_channels = inverted_residual_setting[-1].out_channels - lastconv_output_channels = 6 * lastconv_input_channels - layers.append( - ConvBNActivation( - lastconv_input_channels, - lastconv_output_channels, - kernel_size=1, - norm_layer=norm_layer, - activation_layer=nn.Hardswish, - ) - ) - - self.features = nn.Sequential(*layers) - self.avgpool = nn.AdaptiveAvgPool2d(1) - self.classifier = nn.Sequential( - nn.Linear(lastconv_output_channels, last_channel), - nn.Hardswish(inplace=True), - nn.Dropout(p=0.2, inplace=True), - nn.Linear(last_channel, num_classes), - ) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.zeros_(m.bias) - - def _forward_impl(self, x: Tensor) -> Tensor: - x = self.features(x) - - x = self.avgpool(x) - x = torch.flatten(x, 1) - - x = self.classifier(x) - - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - -def _mobilenet_v3_conf( - arch: str, - width_mult: float = 1.0, - reduced_tail: bool = False, - dilated: bool = False, - **kwargs: Any -): - reduce_divider = 2 if reduced_tail else 1 - dilation = 2 if dilated else 1 - - bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) - adjust_channels = partial( - InvertedResidualConfig.adjust_channels, width_mult=width_mult - ) - - if arch == "mobilenet_v3_large": - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), - bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 - bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), - bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 - bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), - bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), - bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 - bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), - bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), - bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), - bneck_conf( - 112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation - ), # C4 - bneck_conf( - 160 // reduce_divider, - 5, - 960 // reduce_divider, - 160 // reduce_divider, - True, - "HS", - 1, - dilation, - ), - bneck_conf( - 160 // reduce_divider, - 5, - 960 // reduce_divider, - 160 // reduce_divider, - True, - "HS", - 1, - dilation, - ), - ] - last_channel = adjust_channels(1280 // reduce_divider) # C5 - elif arch == "mobilenet_v3_small": - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 - bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 - bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), - bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 - bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), - bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), - bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), - bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), - bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 - bneck_conf( - 96 // reduce_divider, - 5, - 576 // reduce_divider, - 96 // reduce_divider, - True, - "HS", - 1, - dilation, - ), - bneck_conf( - 96 // reduce_divider, - 5, - 576 // reduce_divider, - 96 // reduce_divider, - True, - "HS", - 1, - dilation, - ), - ] - last_channel = adjust_channels(1024 // reduce_divider) # C5 - else: - raise ValueError("Unsupported model type {}".format(arch)) - - return inverted_residual_setting, last_channel diff --git a/src/sheet/modules/ldnet/modules.py b/src/sheet/modules/ldnet/modules.py deleted file mode 100644 index 1aa64ba..0000000 --- a/src/sheet/modules/ldnet/modules.py +++ /dev/null @@ -1,181 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -# LDNet modules -# taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself) - -from functools import partial -from typing import List - -import torch -from sheet.modules.ldnet.mobilenetv2 import ConvBNActivation -from sheet.modules.ldnet.mobilenetv3 import InvertedResidual as InvertedResidualV3 -from sheet.modules.ldnet.mobilenetv3 import InvertedResidualConfig -from torch import nn - -STRIDE = 3 - - -class Projection(nn.Module): - def __init__( - self, - in_dim, - hidden_dim, - activation, - output_type, - _output_dim, - output_step=1.0, - range_clipping=False, - ): - super(Projection, self).__init__() - self.output_type = output_type - self.range_clipping = range_clipping - if output_type == "scalar": - output_dim = 1 - if range_clipping: - self.proj = nn.Tanh() - elif output_type == "categorical": - output_dim = _output_dim - self.output_step = output_step - else: - raise NotImplementedError("wrong output_type: {}".format(output_type)) - - self.net = nn.Sequential( - nn.Linear(in_dim, hidden_dim), - activation(), - nn.Dropout(0.3), - nn.Linear(hidden_dim, output_dim), - ) - - def forward(self, x, inference=False): - output = self.net(x) - - # scalar / categorical - if self.output_type == "scalar": - # range clipping - if self.range_clipping: - return self.proj(output) * 2.0 + 3 - else: - return output - else: - if inference: - return torch.argmax(output, dim=-1) * self.output_step + 1 - else: - return output - - -class ProjectionWithUncertainty(nn.Module): - def __init__( - self, - in_dim, - hidden_dim, - activation, - output_type, - _output_dim, - output_step=1.0, - range_clipping=False, - ): - super(ProjectionWithUncertainty, self).__init__() - self.output_type = output_type - self.range_clipping = range_clipping - if output_type == "scalar": - output_dim = 2 - if range_clipping: - self.proj = nn.Tanh() - elif output_type == "categorical": - output_dim = _output_dim - self.output_step = output_step - else: - raise NotImplementedError("wrong output_type: {}".format(output_type)) - - self.net = nn.Sequential( - nn.Linear(in_dim, hidden_dim), - activation(), - nn.Dropout(0.3), - nn.Linear(hidden_dim, output_dim), - ) - - def forward(self, x, inference=False): - output = self.net(x) # output shape: [B, T, d] - - # scalar / categorical - if self.output_type == "scalar": - mean, logvar = output[:, :, 0], output[:, :, 1] - # range clipping - if self.range_clipping: - return self.proj(mean) * 2.0 + 3, logvar - else: - return mean, logvar - else: - if inference: - return torch.argmax(output, dim=-1) * self.output_step + 1 - else: - return output - - -class MobileNetV3ConvBlocks(nn.Module): - def __init__(self, bneck_confs, output_dim): - super(MobileNetV3ConvBlocks, self).__init__() - - bneck_conf = partial(InvertedResidualConfig, width_mult=1) - inverted_residual_setting = [bneck_conf(*b_conf) for b_conf in bneck_confs] - - block = InvertedResidualV3 - - # Never tested if a different eps and momentum is needed - # norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) - norm_layer = nn.BatchNorm2d - - layers: List[nn.Module] = [] - - # building first layer - firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append( - ConvBNActivation( - 1, - firstconv_output_channels, - kernel_size=3, - stride=STRIDE, - norm_layer=norm_layer, - activation_layer=nn.Hardswish, - ) - ) - - # building inverted residual blocks - for cnf in inverted_residual_setting: - layers.append(block(cnf, norm_layer)) - - # building last several layers - lastconv_input_channels = inverted_residual_setting[-1].out_channels - lastconv_output_channels = output_dim - layers.append( - ConvBNActivation( - lastconv_input_channels, - lastconv_output_channels, - kernel_size=1, - norm_layer=norm_layer, - activation_layer=nn.Hardswish, - ) - ) - self.features = nn.Sequential(*layers) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.zeros_(m.bias) - - def forward(self, x): - time = x.shape[2] - x = self.features(x) - x = nn.functional.adaptive_avg_pool2d(x, (time, 1)) - x = x.squeeze(-1).transpose(1, 2) - return x diff --git a/src/sheet/modules/utils.py b/src/sheet/modules/utils.py deleted file mode 100644 index 2e9786a..0000000 --- a/src/sheet/modules/utils.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -""" Model utilities. - - Some functions are based on: - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/nets_utils.py -""" - -import torch - - -def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): - """Make mask tensor containing indices of padded part. - - Args: - lengths (LongTensor or List): Batch of lengths (B,). - xs (Tensor, optional): The reference tensor. - If set, masks will be the same shape as this tensor. - length_dim (int, optional): Dimension indicator of the above tensor. - See the example. - - Returns: - Tensor: Mask tensor containing indices of padded part. - dtype=torch.uint8 in PyTorch 1.2- - dtype=torch.bool in PyTorch 1.2+ (including 1.2) - - Examples: - With only lengths. - - >>> lengths = [5, 3, 2] - >>> make_pad_mask(lengths) - masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] - - With the reference tensor. - - >>> xs = torch.zeros((3, 2, 4)) - >>> make_pad_mask(lengths, xs) - tensor([[[0, 0, 0, 0], - [0, 0, 0, 0]], - [[0, 0, 0, 1], - [0, 0, 0, 1]], - [[0, 0, 1, 1], - [0, 0, 1, 1]]], dtype=torch.uint8) - >>> xs = torch.zeros((3, 2, 6)) - >>> make_pad_mask(lengths, xs) - tensor([[[0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]], - [[0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1]], - [[0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) - - With the reference tensor and dimension indicator. - - >>> xs = torch.zeros((3, 6, 6)) - >>> make_pad_mask(lengths, xs, 1) - tensor([[[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) - >>> make_pad_mask(lengths, xs, 2) - tensor([[[0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 1]], - [[0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1]], - [[0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) - - """ - if length_dim == 0: - raise ValueError("length_dim cannot be 0: {}".format(length_dim)) - - if not isinstance(lengths, list): - lengths = lengths.long().tolist() - - bs = int(len(lengths)) - if maxlen is None: - if xs is None: - maxlen = int(max(lengths)) - else: - maxlen = xs.size(length_dim) - else: - assert xs is None - assert maxlen >= int(max(lengths)) - - seq_range = torch.arange(0, maxlen, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) - seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - if xs is not None: - assert xs.size(0) == bs, (xs.size(0), bs) - - if length_dim < 0: - length_dim = xs.dim() + length_dim - # ind = (:, None, ..., None, :, , None, ..., None) - ind = tuple( - slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) - ) - mask = mask[ind].expand_as(xs).to(xs.device) - return mask - - -def make_non_pad_mask(lengths, xs=None, length_dim=-1): - """Make mask tensor containing indices of non-padded part. - - Args: - lengths (LongTensor or List): Batch of lengths (B,). - xs (Tensor, optional): The reference tensor. - If set, masks will be the same shape as this tensor. - length_dim (int, optional): Dimension indicator of the above tensor. - See the example. - - Returns: - ByteTensor: mask tensor containing indices of padded part. - dtype=torch.uint8 in PyTorch 1.2- - dtype=torch.bool in PyTorch 1.2+ (including 1.2) - - Examples: - With only lengths. - - >>> lengths = [5, 3, 2] - >>> make_non_pad_mask(lengths) - masks = [[1, 1, 1, 1 ,1], - [1, 1, 1, 0, 0], - [1, 1, 0, 0, 0]] - - With the reference tensor. - - >>> xs = torch.zeros((3, 2, 4)) - >>> make_non_pad_mask(lengths, xs) - tensor([[[1, 1, 1, 1], - [1, 1, 1, 1]], - [[1, 1, 1, 0], - [1, 1, 1, 0]], - [[1, 1, 0, 0], - [1, 1, 0, 0]]], dtype=torch.uint8) - >>> xs = torch.zeros((3, 2, 6)) - >>> make_non_pad_mask(lengths, xs) - tensor([[[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0]], - [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0]], - [[1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) - - With the reference tensor and dimension indicator. - - >>> xs = torch.zeros((3, 6, 6)) - >>> make_non_pad_mask(lengths, xs, 1) - tensor([[[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) - >>> make_non_pad_mask(lengths, xs, 2) - tensor([[[1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 0]], - [[1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0]], - [[1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) - - """ - return ~make_pad_mask(lengths, xs, length_dim) diff --git a/src/sheet/nonparametric/__init__.py b/src/sheet/nonparametric/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/sheet/nonparametric/datastore.py b/src/sheet/nonparametric/datastore.py deleted file mode 100644 index 17574eb..0000000 --- a/src/sheet/nonparametric/datastore.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -"""datastore related""" - -import faiss -import h5py -import numpy as np -from scipy.special import softmax - - -class Datastore: - def __init__( - self, - datastore_path, - embed_dim, - device, - ): - """ - Args: - datastore_path (str): path to the datastore. - embed_dim (int): dimension of the embed in the datastore - """ - embeds = [] - scores = [] - paths = [] - with h5py.File(datastore_path, "r") as f: - for hdf5_path in list(f["scores"].keys()): - paths.append(hdf5_path) - embeds.append(f["embeds"][hdf5_path][()]) - scores.append(f["scores"][hdf5_path][()]) - embeds = np.stack(embeds, axis=0) - scores = np.array(scores) - - # build index - index = faiss.IndexFlatL2(embed_dim) - if device.type == "cuda": - # index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index) - index = faiss.index_cpu_to_all_gpus(index, ngpu=1) - # else: - # embeds = torch.tensor(embeds, device=device) - index.add(embeds) - - self.embeds = embeds - self.scores = scores - self.paths = paths - self.index = index - - def knn(self, query, k, search_only=False): - # search - distances, I = self.index.search(query, k) - scores = np.stack([self.scores[row] for row in I]) - ret = {"distances": distances, "scores": scores} - - if search_only: - return ret - - # NOTE(unilight) 20250205: change to negative - # inv_dist = 1 / (distances + 1e-8) - inv_dist = -distances - - norm_dist = softmax(inv_dist, axis=1) - - mult = np.multiply(norm_dist, scores) - - final_score = np.sum(mult, axis=1)[0] - - # retrieve IDs - ids = [[self.paths[e] for e in row] for row in I] - - ret["final_score"] = final_score - ret["ids"] = ids - - return ret diff --git a/src/sheet/schedulers/__init__.py b/src/sheet/schedulers/__init__.py deleted file mode 100644 index ebb6cb2..0000000 --- a/src/sheet/schedulers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .schedulers import get_scheduler # NOQA diff --git a/src/sheet/schedulers/schedulers.py b/src/sheet/schedulers/schedulers.py deleted file mode 100644 index 36db3c8..0000000 --- a/src/sheet/schedulers/schedulers.py +++ /dev/null @@ -1,21 +0,0 @@ -import copy - -from torch.optim.lr_scheduler import MultiStepLR, StepLR - -# Reference: https://github.com/s3prl/s3prl/blob/master/s3prl/schedulers.py - - -def get_scheduler(optimizer, scheduler_name, total_steps, scheduler_config): - scheduler_config = copy.deepcopy(scheduler_config) - scheduler = eval(f"get_{scheduler_name}")( - optimizer, num_training_steps=total_steps, **scheduler_config - ) - return scheduler - - -def get_multistep(optimizer, num_training_steps, milestones, gamma): - return MultiStepLR(optimizer, milestones, gamma) - - -def get_stepLR(optimizer, num_training_steps, step_size, gamma): - return StepLR(optimizer, step_size, gamma) diff --git a/src/sheet/trainers/__init__.py b/src/sheet/trainers/__init__.py deleted file mode 100644 index aab6c7a..0000000 --- a/src/sheet/trainers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .non_intrusive import * # NOQA -# from .ramp import * # NOQA diff --git a/src/sheet/trainers/base.py b/src/sheet/trainers/base.py deleted file mode 100644 index d66fea9..0000000 --- a/src/sheet/trainers/base.py +++ /dev/null @@ -1,315 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -import logging -import os -import time -from collections import defaultdict - -import torch -from sheet.utils.model_io import freeze_modules -from tensorboardX import SummaryWriter -from tqdm import tqdm - - -class Trainer(object): - """Customized trainer module.""" - - def __init__( - self, - steps, - epochs, - data_loader, - sampler, - model, - criterion, - optimizer, - scheduler, - config, - device=torch.device("cpu"), - ): - """Initialize trainer. - - Args: - steps (int): Initial global steps. - epochs (int): Initial global epochs. - data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. - model (dict): Dict of models. It must contrain "generator" and "discriminator" models. - criterion (dict): Dict of criterions. It must contrain "stft" and "mse" criterions. - optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers. - scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers. - config (dict): Config dict loaded from yaml format configuration file. - device (torch.deive): Pytorch device instance. - - """ - self.steps = steps - self.epochs = epochs - self.data_loader = data_loader - self.sampler = sampler - self.model = model - self.criterion = criterion - self.optimizer = optimizer - self.scheduler = scheduler - self.config = config - self.device = device - self.writer = SummaryWriter(config["outdir"]) - self.finish_train = False - - self.total_train_loss = defaultdict(float) - self.total_eval_loss = defaultdict(float) - self.reset_eval_results() - - self.gradient_accumulate_steps = self.config.get("gradient_accumulate_steps", 1) - - self.reporter = list() # each element is [steps: int, results: dict] - self.original_patience = self.config.get("patience", None) - self.current_patience = self.config.get("patience", None) - - def run(self): - """Run training.""" - self.backward_steps = 0 - self.all_loss = 0.0 - self.tqdm = tqdm( - initial=self.steps, total=self.config["train_max_steps"], desc="[train]", mininterval=5, maxinterval=5, - ) - while True: - # train one epoch - self._train_epoch() - - # check whether training is finished - if self.finish_train: - break - - self.tqdm.close() - logging.info("Finished training.") - - def save_checkpoint(self, checkpoint_path): - """Save checkpoint. - - Args: - checkpoint_path (str): Checkpoint path to be saved. - - """ - state_dict = { - "optimizer": self.optimizer.state_dict(), - "steps": self.steps, - "epochs": self.epochs, - } - if self.scheduler is not None: - state_dict["scheduler"] = self.scheduler.state_dict() - - if self.config["distributed"]: - state_dict["model"] = self.model.module.state_dict() - else: - state_dict["model"] = self.model.state_dict() - - if not os.path.exists(os.path.dirname(checkpoint_path)): - os.makedirs(os.path.dirname(checkpoint_path)) - torch.save(state_dict, checkpoint_path) - - def load_checkpoint(self, checkpoint_path, load_only_params=False): - """Load checkpoint. - - Args: - checkpoint_path (str): Checkpoint path to be loaded. - load_only_params (bool): Whether to load only model parameters. - - """ - state_dict = torch.load(checkpoint_path, map_location="cpu") - if self.config["distributed"]: - self.model.module.load_state_dict(state_dict["model"]) - else: - self.model.load_state_dict(state_dict["model"]) - if not load_only_params: - self.steps = state_dict["steps"] - self.epochs = state_dict["epochs"] - self.optimizer.load_state_dict(state_dict["optimizer"]) - if self.scheduler is not None: - self.scheduler.load_state_dict(state_dict["scheduler"]) - - def _train_step(self, batch): - """Train model one step.""" - pass - - def _train_epoch(self): - """Train model one epoch.""" - for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): - # train one step - self._train_step(batch) - if self.backward_steps % self.gradient_accumulate_steps > 0: - continue - - # check interval - if self.config["rank"] == 0: - self._check_log_interval() - self._check_eval_and_save_interval() - - # check whether training is finished - if self.finish_train: - return - - # update - self.epochs += 1 - self.train_steps_per_epoch = train_steps_per_epoch - logging.info( - f"(Steps: {self.steps}) Finished {self.epochs} epoch training " - f"({self.train_steps_per_epoch} steps per epoch)." - ) - - # needed for shuffle in distributed training - if self.config["distributed"]: - self.sampler["train"].set_epoch(self.epochs) - - @torch.no_grad() - def _eval_step(self, batch): - """Evaluate model one step.""" - pass - - def _eval(self): - """Evaluate model with dev set.""" - logging.info(f"(Steps: {self.steps}) Start evaluation.") - # change mode - self.model.eval() - start_time = time.time() - - # loop through dev set - for count, batch in enumerate(self.data_loader["dev"], 1): - self._eval_step(batch) - if "dev_samples_per_eval_loop" in self.config: - if count > self.config["dev_samples_per_eval_loop"]: - break - - logging.info( - f"(Steps: {self.steps}) Finished evaluation " - f"({time.time() - start_time} secs)." - ) - - @torch.no_grad() - def _log_metrics_and_save_figures(self): - """Log metrics and save figures.""" - pass - - def _write_to_tensorboard(self, loss): - """Write to tensorboard.""" - for key, value in loss.items(): - self.writer.add_scalar(key, value, self.steps) - - def _check_eval_and_save_interval(self): - if self.steps % self.config["eval_and_save_interval_steps"] == 0: - # run evaluation on dev set - self._eval() - - # get metrics and save figures - self._log_metrics_and_save_figures() - - # get best n steps - best_n_steps = self.get_and_show_best_n_models() - - # save current if in best n - if self.steps in best_n_steps: - current_checkpoint_path = os.path.join( - self.config["outdir"], f"checkpoint-{self.steps}steps.pkl" - ) - self.save_checkpoint(current_checkpoint_path) - logging.info( - f"Saved checkpoint @ {self.steps} steps because it is in best {self.config['keep_nbest_models']}." - ) - - # retstore patience - if self.original_patience is not None: - self.current_patience = self.original_patience - logging.info(f"Restoring patience to {self.original_patience}.") - else: - # minus patience - if self.current_patience is not None: - self.current_patience -= 1 - logging.info(f"Reducing patience to {self.current_patience}.") - - # if current is best, link to best - if self.steps == best_n_steps[0]: - best_checkpoint_path = os.path.join( - self.config["outdir"], f"checkpoint-best.pkl" - ) - if os.path.islink(best_checkpoint_path) or os.path.exists( - best_checkpoint_path - ): - os.remove(best_checkpoint_path) - os.symlink(current_checkpoint_path, best_checkpoint_path) - logging.info(f"Updated best checkpoint to {self.steps} steps.") - - # delete those not in best n - existing_checkpoint_paths = [ - fname - for fname in os.listdir(self.config["outdir"]) - if os.path.isfile(os.path.join(self.config["outdir"], fname)) - and fname.endswith("steps.pkl") - and not fname.startswith("original") - ] - for checkpoint_path in existing_checkpoint_paths: - steps = int( - checkpoint_path.replace("steps.pkl", "").replace("checkpoint-", "") - ) - if steps not in best_n_steps: - os.remove(os.path.join(self.config["outdir"], checkpoint_path)) - logging.info(f"Deleting checkpoint @ {steps} steps.") - - # reset - self.reset_eval_results() - - # restore mode - self.model.train() - - def _check_log_interval(self): - if self.steps % self.config["log_interval_steps"] == 0: - for key in self.total_train_loss.keys(): - self.total_train_loss[key] /= self.config["log_interval_steps"] - logging.info( - f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." - ) - self._write_to_tensorboard(self.total_train_loss) - - # reset - self.total_train_loss = defaultdict(float) - - def _check_train_finish(self): - if self.steps >= self.config["train_max_steps"]: - self.finish_train = True - - if self.current_patience is not None and self.current_patience <= 0: - self.finish_train = True - - def freeze_modules(self, modules): - freeze_modules(self.model, modules) - - def reset_eval_results(self): - self.eval_results = defaultdict(list) - self.eval_sys_results = defaultdict(lambda: defaultdict(list)) - - def get_and_show_best_n_models(self): - # sort according to key - best_n = sorted( - self.reporter, - key=lambda x: x[1][self.config["best_model_criterion"]["key"]], - ) - if ( - self.config["best_model_criterion"]["order"] == "highest" - ): # reverse if highest - best_n.reverse() - - # log the results - logging.info( - f"Best {self.config['keep_nbest_models']} models at step {self.steps}:" - ) - log_string = "; ".join( - f"{i+1}. {steps} steps: {self.config['best_model_criterion']['key']}={results[self.config['best_model_criterion']['key']]:.4f}" - for i, (steps, results) in enumerate( - best_n[: self.config["keep_nbest_models"]] - ) - ) - logging.info(log_string) - - # only return the steps - return [steps for steps, _ in best_n[: self.config["keep_nbest_models"]]] diff --git a/src/sheet/trainers/non_intrusive.py b/src/sheet/trainers/non_intrusive.py deleted file mode 100644 index fbbbf90..0000000 --- a/src/sheet/trainers/non_intrusive.py +++ /dev/null @@ -1,310 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -import logging -import os -import time - -# set to avoid matplotlib error in CLI environment -import matplotlib -import numpy as np -import soundfile as sf -import torch -from sheet.evaluation.metrics import calculate -from sheet.evaluation.plot import ( - plot_sys_level_scatter, - plot_utt_level_hist, - plot_utt_level_scatter, -) -from sheet.trainers.base import Trainer -from sheet.utils.model_io import ( - filter_modules, - get_partial_state_dict, - print_new_keys, - transfer_verification, -) - -matplotlib.use("Agg") -import matplotlib.pyplot as plt - - -class NonIntrusiveEstimatorTrainer(Trainer): - """Customized trainer module for non-intrusive estimator.""" - - def load_trained_modules(self, checkpoint_path, init_mods): - if self.config["distributed"]: - main_state_dict = self.model.module.state_dict() - else: - main_state_dict = self.model.state_dict() - - if os.path.isfile(checkpoint_path): - model_state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] - - # first make sure that all modules in `init_mods` are in `checkpoint_path` - modules = filter_modules(model_state_dict, init_mods) - - # then, actually get the partial state_dict - partial_state_dict = get_partial_state_dict(model_state_dict, modules) - - if partial_state_dict: - if transfer_verification(main_state_dict, partial_state_dict, modules): - print_new_keys(partial_state_dict, modules, checkpoint_path) - main_state_dict.update(partial_state_dict) - else: - logging.error(f"Specified model was not found: {checkpoint_path}") - exit(1) - - if self.config["distributed"]: - self.model.module.load_state_dict(main_state_dict) - else: - self.model.load_state_dict(main_state_dict) - - def _train_step(self, batch): - """Train model one step.""" - - # set inputs - gen_loss = 0.0 - inputs = { - self.config["model_input"]: batch[self.config["model_input"]].to( - self.device - ), - self.config["model_input"] - + "_lengths": batch[self.config["model_input"] + "_lengths"].to( - self.device - ), - } - if "listener_idxs" in batch: - inputs["listener_idxs"] = batch["listener_idxs"].to(self.device) - if "domain_idxs" in batch: - inputs["domain_idxs"] = batch["domain_idxs"].to(self.device) - if "phoneme_idxs" in batch: - inputs["phoneme_idxs"] = batch["phoneme_idxs"].to(self.device) - inputs["phoneme_lengths"] = batch["phoneme_lengths"] - if "reference_idxs" in batch: - inputs["reference_idxs"] = batch["reference_idxs"].to(self.device) - inputs["reference_lengths"] = batch["reference_lengths"] - - # model forward - outputs = self.model(inputs) - - # get frame lengths if exist - if "frame_lengths" in outputs: - output_frame_lengths = outputs["frame_lengths"] - else: - output_frame_lengths = None - - # get ground truth scores - gt_scores = batch["scores"].to(self.device) - gt_avg_scores = batch["avg_scores"].to(self.device) - if "categorical_scores" in batch: - categorical_gt_scores = batch["categorical_scores"].to(self.device) - if "categorical_avg_scores" in batch: - categorical_gt_avg_scores = batch["categorical_avg_scores"].to(self.device) - - # mean loss - if "mean_score_criterions" in self.criterion: - for criterion_dict in self.criterion["mean_score_criterions"]: - if criterion_dict["type"] in ["GaussianNLLLoss", "LaplaceNLLLoss"]: - loss = criterion_dict["criterion"]( - outputs["mean_scores"], - outputs["mean_scores_logvar"], - gt_avg_scores, - self.device, - lens=output_frame_lengths, - ) - else: - # always pass the following arguments - loss = criterion_dict["criterion"]( - outputs["mean_scores"], - ( - categorical_gt_avg_scores - if criterion_dict["type"] == "CategoricalLoss" - else gt_avg_scores - ), - self.device, - lens=output_frame_lengths, - ) - gen_loss += loss * criterion_dict["weight"] - self.total_train_loss["train/mean_" + criterion_dict["type"]] += ( - loss.item() / self.gradient_accumulate_steps - ) - - # categorical head loss (for RAMP only) - if "categorical_head_criterions" in self.criterion: - for criterion_dict in self.criterion["categorical_head_criterions"]: - # always pass the following arguments - loss = criterion_dict["criterion"]( - outputs["categorical_head_scores"], - categorical_gt_avg_scores, - self.device, - lens=output_frame_lengths, - ) - gen_loss += loss * criterion_dict["weight"] - self.total_train_loss["train/categorical_head_loss"] += ( - loss.item() / self.gradient_accumulate_steps - ) - - # listener loss - if "listener_score_criterions" in self.criterion: - for criterion_dict in self.criterion["listener_score_criterions"]: - # always pass the following arguments - loss = criterion_dict["criterion"]( - outputs["ld_scores"], - ( - categorical_gt_scores - if criterion_dict["type"] == "CategoricalLoss" - else gt_scores - ), - self.device, - lens=output_frame_lengths, - ) - gen_loss += loss * criterion_dict["weight"] - self.total_train_loss["train/listener_" + criterion_dict["type"]] += ( - loss.item() / self.gradient_accumulate_steps - ) - - self.total_train_loss["train/loss"] += ( - gen_loss.item() / self.gradient_accumulate_steps - ) - - # update model - if self.gradient_accumulate_steps > 1: - gen_loss = gen_loss / self.gradient_accumulate_steps - gen_loss.backward() - self.all_loss += loss.item() - - self.backward_steps += 1 - if self.backward_steps % self.gradient_accumulate_steps > 0: - return - - if self.config["grad_norm"] > 0: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.config["grad_norm"], - ) - self.optimizer.step() - self.optimizer.zero_grad() - if self.scheduler is not None: - self.scheduler.step() - - # update counts - self.steps += 1 - self.tqdm.update(1) - self._check_train_finish() - - @torch.no_grad() - def _eval_step(self, batch): - """Evaluate model one step.""" - - # set up model input - inputs = { - self.config["model_input"]: batch[self.config["model_input"]].to( - self.device - ), - self.config["model_input"] - + "_lengths": batch[self.config["model_input"] + "_lengths"].to( - self.device - ), - } - if "domain_idxs" in batch: - inputs["domain_idxs"] = batch["domain_idxs"].to(self.device) - if "phoneme_idxs" in batch: - inputs["phoneme_idxs"] = batch["phoneme_idxs"].to(self.device) - inputs["phoneme_lengths"] = batch["phoneme_lengths"] - if "reference_idxs" in batch: - inputs["reference_idxs"] = batch["reference_idxs"].to(self.device) - inputs["reference_lengths"] = batch["reference_lengths"] - - # model forward - if self.config["inference_mode"] == "mean_listener": - outputs = self.model.mean_listener_inference(inputs) - elif self.config["inference_mode"] == "mean_net": - outputs = self.model.mean_net_inference(inputs) - - # construct the eval_results dict - pred_mean_scores = outputs["scores"].cpu().detach().numpy() - true_mean_scores = batch["avg_scores"].numpy() - self.eval_results["pred_mean_scores"].extend(pred_mean_scores.tolist()) - self.eval_results["true_mean_scores"].extend(true_mean_scores.tolist()) - sys_names = batch["system_ids"] - for j, sys_name in enumerate(sys_names): - self.eval_sys_results["pred_mean_scores"][sys_name].append( - pred_mean_scores[j] - ) - self.eval_sys_results["true_mean_scores"][sys_name].append( - true_mean_scores[j] - ) - - @torch.no_grad() - def _log_metrics_and_save_figures(self): - """Log metrics and save figures.""" - - self.eval_results["true_mean_scores"] = np.array( - self.eval_results["true_mean_scores"] - ) - self.eval_results["pred_mean_scores"] = np.array( - self.eval_results["pred_mean_scores"] - ) - self.eval_sys_results["true_mean_scores"] = np.array( - [ - np.mean(scores) - for scores in self.eval_sys_results["true_mean_scores"].values() - ] - ) - self.eval_sys_results["pred_mean_scores"] = np.array( - [ - np.mean(scores) - for scores in self.eval_sys_results["pred_mean_scores"].values() - ] - ) - - # calculate metrics - results = calculate( - self.eval_results["true_mean_scores"], - self.eval_results["pred_mean_scores"], - self.eval_sys_results["true_mean_scores"], - self.eval_sys_results["pred_mean_scores"], - ) - - # log metrics - logging.info( - f'[{self.steps} steps][UTT][ MSE = {results["utt_MSE"]:.3f} | LCC = {results["utt_LCC"]:.3f} | SRCC = {results["utt_SRCC"]:.3f} ] [SYS][ MSE = {results["sys_MSE"]:.3f} | LCC = {results["sys_LCC"]:.4f} | SRCC = {results["sys_SRCC"]:.4f} ]\n' - ) - - # register metrics to reporter - self.reporter.append([self.steps, results]) - - # check directory - dirname = os.path.join( - self.config["outdir"], f"intermediate_results/{self.steps}steps" - ) - if not os.path.exists(dirname): - os.makedirs(dirname) - - # plot - plot_utt_level_hist( - self.eval_results["true_mean_scores"], - self.eval_results["pred_mean_scores"], - os.path.join(dirname, "distribution.png"), - ) - plot_utt_level_scatter( - self.eval_results["true_mean_scores"], - self.eval_results["pred_mean_scores"], - os.path.join(dirname, "utt_scatter_plot.png"), - results["utt_LCC"], - results["utt_SRCC"], - results["utt_MSE"], - results["utt_KTAU"], - ) - plot_sys_level_scatter( - self.eval_sys_results["true_mean_scores"], - self.eval_sys_results["pred_mean_scores"], - os.path.join(dirname, "sys_scatter_plot.png"), - results["sys_LCC"], - results["sys_SRCC"], - results["sys_MSE"], - results["sys_KTAU"], - ) diff --git a/src/sheet/utils/__init__.py b/src/sheet/utils/__init__.py deleted file mode 100644 index e8fa95a..0000000 --- a/src/sheet/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import * # NOQA diff --git a/src/sheet/utils/download.py b/src/sheet/utils/download.py deleted file mode 100644 index 0eb6bd7..0000000 --- a/src/sheet/utils/download.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Thread-safe file downloading and cacheing - -Authors - * Leo 2022 - * Cheng Liang 2022 -""" - -import hashlib -import logging -import os -import shutil -import sys -import tempfile -import time -from pathlib import Path -from urllib.request import Request, urlopen - -import requests -from filelock import FileLock -from tqdm import tqdm - -logger = logging.getLogger(__name__) - - -_download_dir = Path.home() / ".cache" / "sheet" / "download" - -__all__ = [ - "get_dir", - "set_dir", - "download", - "urls_to_filepaths", -] - - -def get_dir(): - _download_dir.mkdir(exist_ok=True, parents=True) - return _download_dir - - -def set_dir(d): - global _download_dir - _download_dir = Path(d) - - -def _download_url_to_file(url, dst, hash_prefix=None, progress=True): - """ - This function is not thread-safe. Please ensure only a single - thread or process can enter this block at the same time - """ - - file_size = None - req = Request(url, headers={"User-Agent": "torch.hub"}) - u = urlopen(req) - meta = u.info() - if hasattr(meta, "getheaders"): - content_length = meta.getheaders("Content-Length") - else: - content_length = meta.get_all("Content-Length") - if content_length is not None and len(content_length) > 0: - file_size = int(content_length[0]) - - dst = os.path.expanduser(dst) - dst_dir = os.path.dirname(dst) - f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) - - try: - if hash_prefix is not None: - sha256 = hashlib.sha256() - - tqdm.write(f"Downloading: {url}", file=sys.stderr) - tqdm.write(f"Destination: {dst}", file=sys.stderr) - with tqdm( - total=file_size, - disable=not progress, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as pbar: - while True: - buffer = u.read(8192) - if len(buffer) == 0: - break - f.write(buffer) - if hash_prefix is not None: - sha256.update(buffer) - pbar.update(len(buffer)) - - f.close() - if hash_prefix is not None: - digest = sha256.hexdigest() - if digest[: len(hash_prefix)] != hash_prefix: - raise RuntimeError( - 'invalid hash value (expected "{}", got "{}")'.format( - hash_prefix, digest - ) - ) - shutil.move(f.name, dst) - finally: - f.close() - if os.path.exists(f.name): - os.remove(f.name) - - -def _download_url_to_file_requests(url, dst, hash_prefix=None, progress=True): - """ - Alternative download when urllib.Request fails. - """ - - req = requests.get(url, stream=True, headers={"User-Agent": "torch.hub"}) - file_size = int(req.headers["Content-Length"]) - - dst = os.path.expanduser(dst) - dst_dir = os.path.dirname(dst) - f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) - - try: - if hash_prefix is not None: - sha256 = hashlib.sha256() - - tqdm.write( - f"urllib.Request method failed. Trying using another method...", - file=sys.stderr, - ) - tqdm.write(f"Downloading: {url}", file=sys.stderr) - tqdm.write(f"Destination: {dst}", file=sys.stderr) - with tqdm( - total=file_size, - disable=not progress, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as pbar: - for chunk in req.iter_content(chunk_size=1024 * 1024 * 10): - if chunk: - f.write(chunk) - f.flush() - os.fsync(f.fileno()) - if hash_prefix is not None: - sha256.update(chunk) - pbar.update(len(chunk)) - - f.close() - if hash_prefix is not None: - digest = sha256.hexdigest() - if digest[: len(hash_prefix)] != hash_prefix: - raise RuntimeError( - 'invalid hash value (expected "{}", got "{}")'.format( - hash_prefix, digest - ) - ) - shutil.move(f.name, dst) - finally: - f.close() - if os.path.exists(f.name): - os.remove(f.name) - - -def _download(filepath: Path, url, refresh: bool, new_enough_secs: float = 2.0): - """ - If refresh is True, check the latest modfieid time of the filepath. - If the file is new enough (no older than `new_enough_secs`), than directly use it. - If the file is older than `new_enough_secs`, than re-download the file. - This function is useful when multi-processes are all downloading the same large file - """ - - Path(filepath).parent.mkdir(exist_ok=True, parents=True) - - lock_file = Path(str(filepath) + ".lock") - logger.info(f"Requesting URL: {url}") - - with FileLock(str(lock_file)): - if not filepath.is_file() or ( - refresh and (time.time() - os.path.getmtime(filepath)) > new_enough_secs - ): - try: - _download_url_to_file(url, filepath) - except: - _download_url_to_file_requests(url, filepath) - - logger.info(f"Using URL's local file: {filepath}") - try: - lock_file.unlink() - except FileNotFoundError: - pass - - -def _urls_to_filepaths(*args, refresh=False, download: bool = True): - """ - Preprocess the URL specified in *args into local file paths after downloading - - Args: - Any number of URLs (1 ~ any) - - Return: - Same number of downloaded file paths - """ - - def _url_to_filepath(url): - assert isinstance(url, str) - m = hashlib.sha256() - m.update(str.encode(url)) - filepath = get_dir() / f"{str(m.hexdigest())}.{Path(url).name}" - if download: - _download(filepath, url, refresh=refresh) - return str(filepath.resolve()) - - paths = [_url_to_filepath(url) for url in args] - return paths if len(paths) > 1 else paths[0] - - -download = _download -urls_to_filepaths = _urls_to_filepaths diff --git a/src/sheet/utils/model_io.py b/src/sheet/utils/model_io.py deleted file mode 100644 index bde38b5..0000000 --- a/src/sheet/utils/model_io.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2024 Wen-Chin Huang -# MIT License (https://opensource.org/licenses/MIT) - -import logging -import os -from collections import OrderedDict - -import torch - - -def print_new_keys(state_dict, modules, model_path): - logging.info(f"Loading {modules} from model: {model_path}") - - for k in state_dict.keys(): - logging.warning(f"Overriding module {k}") - - -def filter_modules(model_state_dict, modules): - """Filter non-matched modules in model state dict. - Args: - model_state_dict (Dict): Pre-trained model state dict. - modules (List): Specified module(s) to transfer. - Return: - new_mods (List): Filtered module list. - """ - new_mods = [] - incorrect_mods = [] - - mods_model = list(model_state_dict.keys()) - for mod in modules: - if any(key.startswith(mod) for key in mods_model): - new_mods += [mod] - else: - incorrect_mods += [mod] - - if incorrect_mods: - logging.error( - "Specified module(s) don't match or (partially match) " - f"available modules in model. You specified: {incorrect_mods}." - ) - logging.error("The existing modules in model are:") - logging.error(f"{mods_model}") - exit(1) - - return new_mods - - -def get_partial_state_dict(model_state_dict, modules): - """Create state dict with specified modules matching input model modules. - Args: - model_state_dict (Dict): Pre-trained model state dict. - modules (Dict): Specified module(s) to transfer. - Return: - new_state_dict (Dict): State dict with specified modules weights. - """ - new_state_dict = OrderedDict() - - for key, value in model_state_dict.items(): - if any(key.startswith(m) for m in modules): - new_state_dict[key] = value - - return new_state_dict - - -def transfer_verification(model_state_dict, partial_state_dict, modules): - """Verify tuples (key, shape) for input model modules match specified modules. - Args: - model_state_dict (Dict) : Main model state dict. - partial_state_dict (Dict): Pre-trained model state dict. - modules (List): Specified module(s) to transfer. - Return: - (bool): Whether transfer learning is allowed. - """ - model_modules = [] - partial_modules = [] - - for key_m, value_m in model_state_dict.items(): - if any(key_m.startswith(m) for m in modules): - model_modules += [(key_m, value_m.shape)] - model_modules = sorted(model_modules, key=lambda x: (x[0], x[1])) - - for key_p, value_p in partial_state_dict.items(): - if any(key_p.startswith(m) for m in modules): - partial_modules += [(key_p, value_p.shape)] - partial_modules = sorted(partial_modules, key=lambda x: (x[0], x[1])) - - module_match = model_modules == partial_modules - - if not module_match: - logging.error( - "Some specified modules from the pre-trained model " - "don't match with the new model modules:" - ) - logging.error(f"Pre-trained: {set(partial_modules) - set(model_modules)}") - logging.error(f"New model: {set(model_modules) - set(partial_modules)}") - exit(1) - - return module_match - - -def freeze_modules(model, modules): - """Freeze model parameters according to modules list. - Args: - model (torch.nn.Module): Main model. - modules (List): Specified module(s) to freeze. - Return: - model (torch.nn.Module) : Updated main model. - model_params (filter): Filtered model parameters. - """ - for mod, param in model.named_parameters(): - if any(mod.startswith(m) for m in modules): - logging.warning(f"Freezing {mod}. It will not be updated during training.") - param.requires_grad = False - - model_params = filter(lambda x: x.requires_grad, model.parameters()) - - return model, model_params - - -@torch.no_grad() -def model_average(model, outdir): - """Generate averaged model from existing models - - Args: - model: the model instance - outdir: the directory contains the model files - """ - # get model checkpoints - checkpoint_paths = [ - os.path.join(outdir, p) - for p in os.listdir(outdir) - if os.path.isfile(os.path.join(outdir, p)) and p.endswith("steps.pkl") - ] - n = len(checkpoint_paths) - - # load the checkpoints - avg = None - for checkpoint_path in checkpoint_paths: - states = torch.load(checkpoint_path, map_location="cpu")["model"] - if avg is None: - avg = states - else: - # Accumulated - for k in avg: - avg[k] = avg[k] + states[k] - - # take average - for k in avg: - if str(avg[k].dtype).startswith("torch.int"): - # For int type, not averaged, but only accumulated. - # e.g. BatchNorm.num_batches_tracked - # (If there are any cases that requires averaging - # or the other reducing method, e.g. max/min, for integer type, - # please report.) - logging.info(f"Accumulating {k} instead of averaging") - pass - else: - avg[k] = avg[k] / n - - # load into model - model.load_state_dict(avg) - - return model, checkpoint_paths diff --git a/src/sheet/utils/types.py b/src/sheet/utils/types.py deleted file mode 100644 index fd43b9c..0000000 --- a/src/sheet/utils/types.py +++ /dev/null @@ -1,139 +0,0 @@ -from distutils.util import strtobool -from typing import Optional, Tuple, Union - - -def str2bool(value: str) -> bool: - return bool(strtobool(value)) - - -def remove_parenthesis(value: str): - value = value.strip() - if value.startswith("(") and value.endswith(")"): - value = value[1:-1] - elif value.startswith("[") and value.endswith("]"): - value = value[1:-1] - return value - - -def remove_quotes(value: str): - value = value.strip() - if value.startswith('"') and value.endswith('"'): - value = value[1:-1] - elif value.startswith("'") and value.endswith("'"): - value = value[1:-1] - return value - - -def int_or_none(value: str) -> Optional[int]: - """int_or_none. - - Examples: - >>> import argparse - >>> parser = argparse.ArgumentParser() - >>> _ = parser.add_argument('--foo', type=int_or_none) - >>> parser.parse_args(['--foo', '456']) - Namespace(foo=456) - >>> parser.parse_args(['--foo', 'none']) - Namespace(foo=None) - >>> parser.parse_args(['--foo', 'null']) - Namespace(foo=None) - >>> parser.parse_args(['--foo', 'nil']) - Namespace(foo=None) - - """ - if value.strip().lower() in ("none", "null", "nil"): - return None - return int(value) - - -def float_or_none(value: str) -> Optional[float]: - """float_or_none. - - Examples: - >>> import argparse - >>> parser = argparse.ArgumentParser() - >>> _ = parser.add_argument('--foo', type=float_or_none) - >>> parser.parse_args(['--foo', '4.5']) - Namespace(foo=4.5) - >>> parser.parse_args(['--foo', 'none']) - Namespace(foo=None) - >>> parser.parse_args(['--foo', 'null']) - Namespace(foo=None) - >>> parser.parse_args(['--foo', 'nil']) - Namespace(foo=None) - - """ - if value.strip().lower() in ("none", "null", "nil"): - return None - return float(value) - - -def str_or_int(value: str) -> Union[str, int]: - try: - return int(value) - except ValueError: - return value - - -def str_or_none(value: str) -> Optional[str]: - """str_or_none. - - Examples: - >>> import argparse - >>> parser = argparse.ArgumentParser() - >>> _ = parser.add_argument('--foo', type=str_or_none) - >>> parser.parse_args(['--foo', 'aaa']) - Namespace(foo='aaa') - >>> parser.parse_args(['--foo', 'none']) - Namespace(foo=None) - >>> parser.parse_args(['--foo', 'null']) - Namespace(foo=None) - >>> parser.parse_args(['--foo', 'nil']) - Namespace(foo=None) - - """ - if value.strip().lower() in ("none", "null", "nil"): - return None - return value - - -def str2pair_str(value: str) -> Tuple[str, str]: - """str2pair_str. - - Examples: - >>> import argparse - >>> str2pair_str('abc,def ') - ('abc', 'def') - >>> parser = argparse.ArgumentParser() - >>> _ = parser.add_argument('--foo', type=str2pair_str) - >>> parser.parse_args(['--foo', 'abc,def']) - Namespace(foo=('abc', 'def')) - - """ - value = remove_parenthesis(value) - a, b = value.split(",") - - # Workaround for configargparse issues: - # If the list values are given from yaml file, - # the value givent to type() is shaped as python-list, - # e.g. ['a', 'b', 'c'], - # so we need to remove double quotes from it. - return remove_quotes(a), remove_quotes(b) - - -def str2triple_str(value: str) -> Tuple[str, str, str]: - """str2triple_str. - - Examples: - >>> str2triple_str('abc,def ,ghi') - ('abc', 'def', 'ghi') - """ - value = remove_parenthesis(value) - a, b, c = value.split(",") - - # Workaround for configargparse issues: - # If the list values are given from yaml file, - # the value givent to type() is shaped as python-list, - # e.g. ['a', 'b', 'c'], - # so we need to remove quotes from it. - return remove_quotes(a), remove_quotes(b), remove_quotes(c) diff --git a/src/sheet/utils/utils.py b/src/sheet/utils/utils.py deleted file mode 100644 index e039370..0000000 --- a/src/sheet/utils/utils.py +++ /dev/null @@ -1,164 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 Tomoki Hayashi -# MIT License (https://opensource.org/licenses/MIT) - -"""Utility functions.""" - -import csv -import fnmatch -import logging -import os -import sys - -import h5py -import numpy as np - - -def get_basename(path): - return os.path.splitext(os.path.split(path)[-1])[0] - - -def read_csv(path, dict_reader=False, lazy=False, encoding=None): - """ - - If `dict_reader` is set to True, then return . - If `dict_reader` is set to False, then return . - """ - - """Read the csv file. - - Args: - path (str): path to the csv file - dict_reader (bool): whether to use dict reader. This should be set to true when the csv file has header. - lazy (bool): whether to read the file in this funcion. - - Return: - contents: reader or line of contents - fieldnames (list): header. If dict_reader is False, then return None. - - """ - - with open(path, newline="", encoding=encoding) as csvfile: - if dict_reader: - reader = csv.DictReader(csvfile) - fieldnames = reader.fieldnames - else: - reader = csv.reader(csvfile) - fieldnames = None - - if lazy: - contents = reader - else: - contents = [line for line in reader] - - return contents, fieldnames - -def write_csv(data, path): - """Write data to the output path. - - Args: - path (str): path to the output csv file - data (list): a list of dicts - - """ - fieldnames = list(data[0].keys()) - with open(path, "w", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - for line in data: - writer.writerow(line) - -def find_files(root_dir, query="*.wav", include_root_dir=True): - """Find files recursively. - - Args: - root_dir (str): Root root_dir to find. - query (str): Query to find. - include_root_dir (bool): If False, root_dir name is not included. - - Returns: - list: List of found filenames. - - """ - files = [] - for root, dirnames, filenames in os.walk(root_dir, followlinks=True): - for filename in fnmatch.filter(filenames, query): - files.append(os.path.join(root, filename)) - if not include_root_dir: - files = [file_.replace(root_dir + "/", "") for file_ in files] - - return files - - -def read_hdf5(hdf5_name, hdf5_path): - """Read hdf5 dataset. - - Args: - hdf5_name (str): Filename of hdf5 file. - hdf5_path (str): Dataset name in hdf5 file. - - Return: - any: Dataset values. - - """ - if not os.path.exists(hdf5_name): - logging.error(f"There is no such a hdf5 file ({hdf5_name}).") - sys.exit(1) - - hdf5_file = h5py.File(hdf5_name, "r") - - if hdf5_path not in hdf5_file: - logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") - sys.exit(1) - - hdf5_data = hdf5_file[hdf5_path][()] - hdf5_file.close() - - return hdf5_data - - -def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): - """Write dataset to hdf5. - - Args: - hdf5_name (str): Hdf5 dataset filename. - hdf5_path (str): Dataset path in hdf5. - write_data (ndarray): Data to write. - is_overwrite (bool): Whether to overwrite dataset. - - """ - # convert to numpy array - write_data = np.array(write_data) - - # check folder existence - folder_name, _ = os.path.split(hdf5_name) - if not os.path.exists(folder_name) and len(folder_name) != 0: - os.makedirs(folder_name) - - # check hdf5 existence - if os.path.exists(hdf5_name): - # if already exists, open with r+ mode - hdf5_file = h5py.File(hdf5_name, "r+") - # check dataset existence - if hdf5_path in hdf5_file: - if is_overwrite: - logging.warning( - "Dataset in hdf5 file already exists. recreate dataset in hdf5." - ) - hdf5_file.__delitem__(hdf5_path) - else: - logging.error( - "Dataset in hdf5 file already exists. " - "if you want to overwrite, please set is_overwrite = True." - ) - hdf5_file.close() - sys.exit(1) - else: - # if not exists, open with w mode - hdf5_file = h5py.File(hdf5_name, "w") - - # write data to hdf5 - hdf5_file.create_dataset(hdf5_path, data=write_data) - hdf5_file.flush() - hdf5_file.close() diff --git a/src/sheet/warmup_lr.py b/src/sheet/warmup_lr.py deleted file mode 100644 index 8406894..0000000 --- a/src/sheet/warmup_lr.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Warm up learning rate scheduler module.""" - -from abc import ABC, abstractmethod -from typing import Union - -import torch -from torch.optim.lr_scheduler import _LRScheduler - - -class AbsBatchStepScheduler(ABC): - @abstractmethod - def step(self, epoch: int = None): - pass - - @abstractmethod - def state_dict(self): - pass - - @abstractmethod - def load_state_dict(self, state): - pass - - -class WarmupLR(_LRScheduler, AbsBatchStepScheduler): - """The WarmupLR scheduler - - This scheduler is almost same as NoamLR Scheduler except for following difference: - - NoamLR: - lr = optimizer.lr * model_size ** -0.5 - * min(step ** -0.5, step * warmup_step ** -1.5) - WarmupLR: - lr = optimizer.lr * warmup_step ** 0.5 - * min(step ** -0.5, step * warmup_step ** -1.5) - - Note that the maximum lr equals to optimizer.lr in this scheduler. - - """ - - def __init__( - self, - optimizer: torch.optim.Optimizer, - warmup_steps: Union[int, float] = 4000, - last_epoch: int = -1, - ): - self.warmup_steps = warmup_steps - - # __init__() must be invoked before setting field - # because step() is also invoked in __init__() - super().__init__(optimizer, last_epoch) - - def __repr__(self): - return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" - - def get_lr(self): - step_num = self.last_epoch + 1 - return [ - lr - * self.warmup_steps**0.5 - * min(step_num**-0.5, step_num * self.warmup_steps**-1.5) - for lr in self.base_lrs - ] From a6abed5dbd84879058d18fba9f98f03e9e45f134 Mon Sep 17 00:00:00 2001 From: darryllam Date: Mon, 10 Nov 2025 16:47:57 +0900 Subject: [PATCH 8/9] Changed name in project.toml metadata to sheet_sqa --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f61dc76..a6baf06 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=61", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "sheet" +name = "sheet_sqa" version = "0.2.5" description = "Speech Human Evaluation Estimation Toolkit (SHEET)" requires-python = "==3.10.13" From 54b37144598785efae2ea478379cf1e9c191ebd0 Mon Sep 17 00:00:00 2001 From: darryllam Date: Tue, 11 Nov 2025 11:31:40 +0900 Subject: [PATCH 9/9] Revert "Changed name in project.toml metadata to sheet_sqa" This reverts commit a6abed5dbd84879058d18fba9f98f03e9e45f134. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a6baf06..f61dc76 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=61", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "sheet_sqa" +name = "sheet" version = "0.2.5" description = "Speech Human Evaluation Estimation Toolkit (SHEET)" requires-python = "==3.10.13"