From f5101ddc6fe24d733c372f5b442094e7867c643a Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Tue, 10 Feb 2026 11:51:17 -0800 Subject: [PATCH 01/11] Add shard_from_audio_dir command and flexible audio duration handling Support ingesting raw audio directories into WSDS shards via the new shard_from_audio_dir command. Update extract_index_for_shard to infer audio duration from metadata instead of requiring pre-computed fields, and broaden the torchcodec fallback to catch all exceptions (not just ImportError). --- wsds/ws_tools.py | 178 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 158 insertions(+), 20 deletions(-) diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index 7cb0ff3..ad3a58f 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -1,15 +1,12 @@ import functools import json import os -import sys -import tarfile -from collections import defaultdict +from collections.abc import Callable from pathlib import Path import numpy as np import polars as pl import pyarrow as pa -import webdataset as wds from . import WSSample, WSSink @@ -37,21 +34,13 @@ def _list(input_shard: str): # FIXME: implement keys pass else: - has_invalid_batches = False reader = pa.RecordBatchFileReader(pa.memory_map(input_shard)) - batch_size = int(reader.schema.metadata[b'batch_size']) try: for i in range(reader.num_record_batches): - b = reader.get_batch(i) - if b.num_rows != batch_size and i != reader.num_record_batches - 1: - sys.stderr.write(f"Batch {i} has {b.num_rows} rows instead of {batch_size}\n") - has_invalid_batches = True - for key in b["__key__"]: + for key in reader.get_batch(i)["__key__"]: print(key) except BrokenPipeError: pass - if has_invalid_batches: - sys.exit(1) def inspect_dataset(input_path, verbose=True): @@ -400,6 +389,7 @@ def init( source_dataset: Path | None = None, vad_column: str | None = None, num_workers: int = 32, + require_audio_duration: bool = True, ): """Initialize a new dataset, from scratch or from a segmentation of an existing one.""" import multiprocessing @@ -417,8 +407,13 @@ def init( source_dataset = new_dataset ds = WSDataset(source_dataset) - shard_extractor = functools.partial(extract_index_for_shard, source_dataset, vad_column=vad_column) - all_shards = ds.get_shard_list(ignore_index = True) + shard_extractor = functools.partial( + extract_index_for_shard, + source_dataset, + vad_column=vad_column, + require_audio_duration=require_audio_duration, + ) + all_shards = ds.get_shard_list() with AtomicFile(new_dataset / "index.sqlite3") as fname: with WSDSIndexWriter(fname) as index: @@ -510,15 +505,24 @@ def init_split( new_fields["audio"] = [("audio.wsds-computed", "audio")] index.append_metadata({"fields": new_fields}) -def extract_index_for_shard(dataset, shard, vad_column=None): +def extract_index_for_shard(dataset, shard, vad_column=None, require_audio_duration=True): from . import WSDataset ds = WSDataset(dataset) index = [] i = 0 + if isinstance(shard, (tuple, list)) and len(shard) == 2: + dataset_path, shard_name = shard + else: + dataset_path, shard_name = "", shard - for s in ds.iter_shard(shard): - key = s["__key__"] + sample = WSSample(ds, (dataset_path, shard_name), 0) + + for s in ds.sequential_from(sample, 0): + try: + key = str(s["__key__"]) + except IndexError: + break if not vad_column: n = 1 @@ -530,7 +534,21 @@ def extract_index_for_shard(dataset, shard, vad_column=None): if vad.size > 0: speech_duration = float((vad[:, -1] - vad[:, -2]).sum()) # tend - tstart - audio_duration = s['load_duration'] or s['est_duration'] or -1 + if not require_audio_duration: + audio_duration = -1.0 + elif "inspected_duration" in s: + audio_duration = float(s["inspected_duration"]) + else: + try: + audio_reader = s.get_audio() + meta = audio_reader.metadata + audio_duration = _duration_seconds_from_metadata(meta, audio_reader) + if audio_duration is None: + raise ValueError("could not infer duration from audio metadata") + except Exception as e: + print("Audio loading error:", e) + print(" for sample:", s) + raise if ( n > 0 @@ -539,12 +557,132 @@ def extract_index_for_shard(dataset, shard, vad_column=None): i += n return { - "shard_name": shard[1], + "shard_name": shard_name, + "dataset_path": dataset_path, "index": index, "n_samples": i, } +def _duration_seconds_from_metadata(meta, audio_reader=None): + for attr in ("duration_seconds_from_header", "duration", "duration_seconds"): + val = getattr(meta, attr, None) + if val is not None: + return float(val) + num_frames = getattr(meta, "num_frames", None) or getattr(meta, "num_samples", None) + sample_rate = getattr(meta, "sample_rate", None) + if num_frames is not None and sample_rate: + return float(num_frames) / float(sample_rate) + if audio_reader is not None and sample_rate: + try: + raw_bytes = audio_reader.unwrap() + bits_per_sample = getattr(meta, "bits_per_sample", None) or 16 + num_channels = getattr(meta, "num_channels", None) or 1 + bytes_per_sample = (bits_per_sample // 8) * num_channels + data_bytes = len(raw_bytes) - 44 + if data_bytes > 0 and bytes_per_sample > 0: + num_samples = data_bytes // bytes_per_sample + return float(num_samples) / float(sample_rate) + except Exception: + pass + return None + + +def write_batch(batch, out_path, compression: str | None = None): + with WSSink(out_path, compression=compression) as sink: + for row in batch: + sink.write(row) + + +@command +def shard_from_audio_dir( + input_dir: str, + output_dir: str, + max_files_per_shard: int = 300, + resume=True, + init_index: bool = False, + require_audio_duration: bool = True, + key_fn: Callable[[str], str] | None = None, + write_key_mapping: bool = False, + key_prefix: str = "", +): + """Write batched Feather (.wsds) shards with up to N audio files each. + + Args: + key_fn: Optional function to transform the file stem into a key. + E.g., a hash function for obfuscation. + write_key_mapping: If True and key_fn is provided, writes a JSON file + mapping transformed keys back to original stems. + key_prefix: Optional prefix to prepend to keys before hashing. Useful to + avoid collisions when processing multiple directories with + files that share the same names (e.g., "egyptian", "saudi"). + """ + from tqdm import tqdm + + input_dir, output_dir = Path(input_dir), Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + exts = (".wav", ".flac", ".mp3", ".m4a", ".ogg", ".opus") + all_files = sorted(p for p in input_dir.rglob("*") if p.suffix.lower() in exts) + print(f"[INFO] Found {len(all_files):,} audio files under {input_dir}") + + MAX_ARROW_BYTES = 2_140_000_000 # ~2.1 GB Arrow cell limit + + shard_idx = 0 + batch = [] + key_mapping = {} + + def flush_batch(): + nonlocal shard_idx, batch + if not batch: + return + out_path = output_dir / f"audio-{shard_idx:05d}.wsds" + write_batch(batch, out_path, compression=None) + shard_idx += 1 + batch = [] + + for i, path in enumerate(tqdm(all_files, ncols=90, desc="Writing WSDS shards")): + rel_path = path.relative_to(input_dir).with_suffix('') + stem = str(rel_path) + if key_prefix: + stem = f"{key_prefix}/{stem}" + key = key_fn(stem) if key_fn else stem + if key_fn and write_key_mapping: + key_mapping[key] = stem + ext = path.suffix.lower().lstrip(".") + try: + audio_bytes = path.read_bytes() + except Exception as e: + print(f"[WARN] Skipping {path}: {e}") + continue + + size = len(audio_bytes) + if size > MAX_ARROW_BYTES: + print(f"[SKIP] {path.name}: {size / 1e6:.1f} MB exceeds 2 GB Arrow limit") + continue + + batch.append({"__key__": key, "audio": audio_bytes, "audio_type": ext}) + + if len(batch) >= max_files_per_shard: + flush_batch() + + if batch: + flush_batch() + + print(f"[DONE] Wrote {shard_idx} WSDS shards -> {output_dir}") + + if key_mapping: + dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir + mapping_path = dataset_root / "key_mapping.json" + with open(mapping_path, "w") as f: + json.dump(key_mapping, f, indent=2) + print(f"[INFO] Wrote key mapping ({len(key_mapping):,} entries) -> {mapping_path}") + + if init_index: + dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir + init(dataset_root, require_audio_duration=require_audio_duration) + + @command def _remove_columns(*fnames, remove: str = ""): """ From 0c064425bbc9aad6bf1d4bcb15f9b7fffe61f7ab Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Tue, 10 Feb 2026 11:57:18 -0800 Subject: [PATCH 02/11] Clean up shard_from_audio_dir: remove dead param, unused var, dedup --- wsds/ws_tools.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index ad3a58f..3d31a49 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -599,7 +599,6 @@ def shard_from_audio_dir( input_dir: str, output_dir: str, max_files_per_shard: int = 300, - resume=True, init_index: bool = False, require_audio_duration: bool = True, key_fn: Callable[[str], str] | None = None, @@ -641,7 +640,7 @@ def flush_batch(): shard_idx += 1 batch = [] - for i, path in enumerate(tqdm(all_files, ncols=90, desc="Writing WSDS shards")): + for path in tqdm(all_files, ncols=90, desc="Writing WSDS shards"): rel_path = path.relative_to(input_dir).with_suffix('') stem = str(rel_path) if key_prefix: @@ -671,15 +670,15 @@ def flush_batch(): print(f"[DONE] Wrote {shard_idx} WSDS shards -> {output_dir}") + dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir + if key_mapping: - dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir mapping_path = dataset_root / "key_mapping.json" with open(mapping_path, "w") as f: json.dump(key_mapping, f, indent=2) print(f"[INFO] Wrote key mapping ({len(key_mapping):,} entries) -> {mapping_path}") if init_index: - dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir init(dataset_root, require_audio_duration=require_audio_duration) From d9040644dc929b3736d725189b1772c44750d6d6 Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Tue, 10 Feb 2026 12:24:06 -0800 Subject: [PATCH 03/11] Add test suite for shard_from_audio_dir Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 8 +- tests/__init__.py | 0 tests/test_shard_from_audio.py | 204 +++++++++++++++++++++++++++++++++ 3 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_shard_from_audio.py diff --git a/pyproject.toml b/pyproject.toml index f064d25..261f11c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ allow-direct-references = true files = ["requirements.txt"] [tool.hatch.build.targets.wheel] -packages = ["."] +packages = ["wsds"] [tool.ruff] line-length = 120 @@ -24,6 +24,12 @@ indent-width = 4 ignore = ["E203", "E501", "E731"] extend-select = ["I"] +[project.optional-dependencies] +test = ["pytest", "tqdm"] + +[tool.pytest.ini_options] +testpaths = ["tests"] + # --- build-data --- # [build-system] requires = ["hatchling", "hatch-requirements-txt"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_shard_from_audio.py b/tests/test_shard_from_audio.py new file mode 100644 index 0000000..9a9c31c --- /dev/null +++ b/tests/test_shard_from_audio.py @@ -0,0 +1,204 @@ +import struct +from pathlib import Path + +import pyarrow as pa +import pyarrow.ipc + +from wsds.ws_tools import shard_from_audio_dir + + +def make_wav(path, num_samples=100, sample_rate=16000, num_channels=1): + """Write a minimal valid WAV file.""" + bits_per_sample = 16 + data_size = num_samples * num_channels * (bits_per_sample // 8) + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + 36 + data_size, + b"WAVE", + b"fmt ", + 16, + 1, # PCM + num_channels, + sample_rate, + sample_rate * num_channels * bits_per_sample // 8, + num_channels * bits_per_sample // 8, + bits_per_sample, + b"data", + data_size, + ) + pcm = b"\x00\x01" * num_samples * num_channels + path.write_bytes(header + pcm) + + +def _collect_shards(output_dir): + """Read all .wsds shards in output_dir, return list of (keys, audio_bytes, audio_types) per shard.""" + shards = [] + for shard_path in sorted(output_dir.glob("*.wsds")): + reader = pa.ipc.open_file(str(shard_path)) + table = reader.read_all() + keys = table.column("__key__").to_pylist() + audio = [v.as_py() for v in table.column("audio")] + audio_types = table.column("audio_type").to_pylist() + shards.append((keys, audio, audio_types)) + return shards + + +class TestShardFromAudioDir: + def test_basic_sharding(self, tmp_path): + """Files are split into correct number of shards and content matches.""" + input_dir = tmp_path / "audio_in" + output_dir = tmp_path / "audio_out" + input_dir.mkdir() + + stems = [f"clip_{i:03d}" for i in range(5)] + original_bytes = {} + for stem in stems: + p = input_dir / f"{stem}.wav" + make_wav(p, num_samples=50 + len(stem)) + original_bytes[stem] = p.read_bytes() + + shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2) + + shards = _collect_shards(output_dir) + # 5 files / 2 per shard = 3 shards + assert len(shards) == 3 + + all_keys = [] + all_audio = {} + all_types = [] + for keys, audio, audio_types in shards: + all_keys.extend(keys) + for k, a in zip(keys, audio): + all_audio[k] = a + all_types.extend(audio_types) + + assert sorted(all_keys) == sorted(stems) + for stem in stems: + assert all_audio[stem] == original_bytes[stem] + assert all(t == "wav" for t in all_types) + + def test_key_prefix(self, tmp_path): + """key_prefix is prepended to each key.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "hello.wav") + + shard_from_audio_dir(str(input_dir), str(output_dir), key_prefix="dataset1") + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["dataset1/hello"] + + def test_key_fn(self, tmp_path): + """key_fn transforms the key.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "original.wav") + + shard_from_audio_dir( + str(input_dir), str(output_dir), key_fn=lambda s: s.upper() + ) + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["ORIGINAL"] + + def test_key_fn_with_prefix(self, tmp_path): + """key_fn receives the prefixed stem.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "file.wav") + + shard_from_audio_dir( + str(input_dir), + str(output_dir), + key_prefix="pfx", + key_fn=lambda s: s.replace("/", "__"), + ) + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["pfx__file"] + + def test_oversized_files_skipped(self, tmp_path, monkeypatch): + """Files exceeding the Arrow byte limit are skipped.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "small.wav", num_samples=10) + make_wav(input_dir / "big.wav", num_samples=100) + + small_size = (input_dir / "small.wav").stat().st_size + large_size = (input_dir / "big.wav").stat().st_size + + # Pick a fake limit between the two file sizes so big.wav gets skipped + fake_limit = (small_size + large_size) // 2 + + # Patch the read_bytes to attach a fake size, then patch len check via + # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. + # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function + # by replacing it with one that sets a lower limit. + import wsds.ws_tools as mod + + orig_code = mod.shard_from_audio_dir.__code__ + + # Replace the constant in the code object's co_consts + new_consts = tuple( + fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts + ) + new_code = orig_code.replace(co_consts=new_consts) + monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code) + + shard_from_audio_dir(str(input_dir), str(output_dir)) + + shards = _collect_shards(output_dir) + all_keys = [k for keys, _, _ in shards for k in keys] + assert "small" in all_keys + assert "big" not in all_keys + + def test_empty_input_dir(self, tmp_path): + """Empty input directory produces no shards.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + shard_from_audio_dir(str(input_dir), str(output_dir)) + + assert list(output_dir.glob("*.wsds")) == [] + + def test_subdirectory_files(self, tmp_path): + """Audio files in subdirectories use relative path as key.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + sub = input_dir / "speaker1" + sub.mkdir(parents=True) + + make_wav(sub / "utt.wav") + + shard_from_audio_dir(str(input_dir), str(output_dir)) + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["speaker1/utt"] + + def test_shard_naming(self, tmp_path): + """Shard files are named audio-NNNNN.wsds.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + for i in range(4): + make_wav(input_dir / f"f{i}.wav") + + shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2) + + shard_names = sorted(p.name for p in output_dir.glob("*.wsds")) + assert shard_names == ["audio-00000.wsds", "audio-00001.wsds"] From bd7ac1768d0500db73e8fca7833bb47e219f0ece Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Tue, 10 Feb 2026 12:45:09 -0800 Subject: [PATCH 04/11] Restore _list batch validation, add type hints to extract_index_for_shard Co-Authored-By: Claude Opus 4.6 --- wsds/ws_tools.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index 3d31a49..8425dce 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -1,14 +1,17 @@ +from __future__ import annotations + import functools import json import os +import sys from collections.abc import Callable from pathlib import Path - import numpy as np import polars as pl import pyarrow as pa from . import WSSample, WSSink +from .ws_audio import AudioReader commands = {} @@ -34,13 +37,21 @@ def _list(input_shard: str): # FIXME: implement keys pass else: + has_invalid_batches = False reader = pa.RecordBatchFileReader(pa.memory_map(input_shard)) + batch_size = int(reader.schema.metadata[b'batch_size']) try: for i in range(reader.num_record_batches): - for key in reader.get_batch(i)["__key__"]: + b = reader.get_batch(i) + if b.num_rows != batch_size and i != reader.num_record_batches - 1: + sys.stderr.write(f"Batch {i} has {b.num_rows} rows instead of {batch_size}\n") + has_invalid_batches = True + for key in b["__key__"]: print(key) except BrokenPipeError: pass + if has_invalid_batches: + sys.exit(1) def inspect_dataset(input_path, verbose=True): @@ -505,7 +516,12 @@ def init_split( new_fields["audio"] = [("audio.wsds-computed", "audio")] index.append_metadata({"fields": new_fields}) -def extract_index_for_shard(dataset, shard, vad_column=None, require_audio_duration=True): +def extract_index_for_shard( + dataset: str | Path, + shard: tuple[str, str] | list[str] | str, + vad_column: str | None = None, + require_audio_duration: bool = True, +) -> dict: from . import WSDataset ds = WSDataset(dataset) @@ -564,7 +580,7 @@ def extract_index_for_shard(dataset, shard, vad_column=None, require_audio_durat } -def _duration_seconds_from_metadata(meta, audio_reader=None): +def _duration_seconds_from_metadata(meta, audio_reader: AudioReader | None = None) -> float | None: for attr in ("duration_seconds_from_header", "duration", "duration_seconds"): val = getattr(meta, attr, None) if val is not None: From 4a9d25ae21a29c14e3a0ebac8328bff07b39cfac Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Tue, 10 Feb 2026 21:32:55 -0500 Subject: [PATCH 05/11] sorting is configurable --- pyproject.toml | 2 +- requirements.txt | 1 + wsds/ws_tools.py | 16 +++++++++++----- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 261f11c..95695bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ ignore = ["E203", "E501", "E731"] extend-select = ["I"] [project.optional-dependencies] -test = ["pytest", "tqdm"] +test = ["pytest"] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/requirements.txt b/requirements.txt index 4fe9e43..20365fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ polars>=1.36.1 pyarrow>=20 torch torchaudio +tqdm # torchcodec – optional, it causes serious performance regressions diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index 8425dce..cd4e4cd 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -620,6 +620,7 @@ def shard_from_audio_dir( key_fn: Callable[[str], str] | None = None, write_key_mapping: bool = False, key_prefix: str = "", + sort_files: bool = False, ): """Write batched Feather (.wsds) shards with up to N audio files each. @@ -631,16 +632,21 @@ def shard_from_audio_dir( key_prefix: Optional prefix to prepend to keys before hashing. Useful to avoid collisions when processing multiple directories with files that share the same names (e.g., "egyptian", "saudi"). + sort_files: If True, sorts files by name before processing. + This can be memory intensive for large datasets, but ensures deterministic shard assignment. """ - from tqdm import tqdm input_dir, output_dir = Path(input_dir), Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) exts = (".wav", ".flac", ".mp3", ".m4a", ".ogg", ".opus") - all_files = sorted(p for p in input_dir.rglob("*") if p.suffix.lower() in exts) - print(f"[INFO] Found {len(all_files):,} audio files under {input_dir}") - + file_iter = (p for p in input_dir.rglob("*") if p.suffix.lower() in exts) + if sort_files: + files = sorted(file_iter) + print(f"[INFO] Found {len(files):,} audio files under {input_dir}") + else: + files = file_iter + print(f"[INFO] Processing audio files under {input_dir}") MAX_ARROW_BYTES = 2_140_000_000 # ~2.1 GB Arrow cell limit shard_idx = 0 @@ -656,7 +662,7 @@ def flush_batch(): shard_idx += 1 batch = [] - for path in tqdm(all_files, ncols=90, desc="Writing WSDS shards"): + for path in tqdm(files, ncols=90, desc="Writing WSDS shards"): rel_path = path.relative_to(input_dir).with_suffix('') stem = str(rel_path) if key_prefix: From d49e1ad61fa51e6e18821405765ca18d3259a625 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Feb 2026 23:51:25 +0000 Subject: [PATCH 06/11] Initial plan From b7c7954569428722f0780f5885657520d54b2f16 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Feb 2026 23:56:32 +0000 Subject: [PATCH 07/11] Fix init_index to write shards to audio/ subdirectory Co-authored-by: tlebryk <43556997+tlebryk@users.noreply.github.com> --- .gitignore | 8 +++++ tests/test_shard_from_audio.py | 60 ++++++++++++++++++++++++++++++++++ wsds/ws_tools.py | 11 +++++-- 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cc72b12 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +*.pyo +*.pyd +.pytest_cache/ +*.egg-info/ +dist/ +build/ diff --git a/tests/test_shard_from_audio.py b/tests/test_shard_from_audio.py index 9a9c31c..21bcb39 100644 --- a/tests/test_shard_from_audio.py +++ b/tests/test_shard_from_audio.py @@ -202,3 +202,63 @@ def test_shard_naming(self, tmp_path): shard_names = sorted(p.name for p in output_dir.glob("*.wsds")) assert shard_names == ["audio-00000.wsds", "audio-00001.wsds"] + + def test_init_index_creates_audio_subdir(self, tmp_path): + """When init_index=True, shards are written to audio/ subdirectory and index is created.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "dataset" + input_dir.mkdir() + + # Create some test audio files + for i in range(3): + make_wav(input_dir / f"file{i}.wav") + + shard_from_audio_dir( + str(input_dir), + str(output_dir), + max_files_per_shard=2, + init_index=True, + require_audio_duration=False, # Skip audio duration requirement for test + ) + + # Shards should be in output_dir/audio/ + audio_dir = output_dir / "audio" + assert audio_dir.exists() + assert audio_dir.is_dir() + + # Check shards are in the audio subdirectory + shard_files = sorted(audio_dir.glob("*.wsds")) + assert len(shard_files) == 2 # 3 files / 2 per shard = 2 shards + + # Check index was created at dataset root + index_file = output_dir / "index.sqlite3" + assert index_file.exists() + + def test_init_index_with_audio_named_output(self, tmp_path): + """When init_index=True and output_dir is already named 'audio', don't create nested audio/audio/.""" + input_dir = tmp_path / "in" + dataset_root = tmp_path / "dataset" + output_dir = dataset_root / "audio" + input_dir.mkdir() + + make_wav(input_dir / "test.wav") + + shard_from_audio_dir( + str(input_dir), + str(output_dir), + init_index=True, + require_audio_duration=False, + ) + + # Shards should be in output_dir (which is already named 'audio') + shard_files = sorted(output_dir.glob("*.wsds")) + assert len(shard_files) == 1 + + # Check we didn't create audio/audio/ + nested_audio = output_dir / "audio" + assert not nested_audio.exists() + + # Index should be at dataset_root (parent of audio/) + index_file = dataset_root / "index.sqlite3" + assert index_file.exists() + diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index cd4e4cd..78ca891 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -637,6 +637,15 @@ def shard_from_audio_dir( """ input_dir, output_dir = Path(input_dir), Path(output_dir) + + # When init_index=True, ensure shards are written to audio/ subdirectory + # so that list_all_shards() can find them (it only looks in subdirectories) + if init_index and output_dir.name != "audio": + dataset_root = output_dir + output_dir = output_dir / "audio" + else: + dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir + output_dir.mkdir(parents=True, exist_ok=True) exts = (".wav", ".flac", ".mp3", ".m4a", ".ogg", ".opus") @@ -692,8 +701,6 @@ def flush_batch(): print(f"[DONE] Wrote {shard_idx} WSDS shards -> {output_dir}") - dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir - if key_mapping: mapping_path = dataset_root / "key_mapping.json" with open(mapping_path, "w") as f: From d6eb097c868f47c1fe7ec3c78770b352773a4cba Mon Sep 17 00:00:00 2001 From: Theo Lebryk <43556997+tlebryk@users.noreply.github.com> Date: Tue, 10 Feb 2026 21:23:00 -0500 Subject: [PATCH 08/11] Delete .gitignore --- .gitignore | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index cc72b12..0000000 --- a/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -__pycache__/ -*.pyc -*.pyo -*.pyd -.pytest_cache/ -*.egg-info/ -dist/ -build/ From 04e4717775ea85892db9c688f93559c0156e9181 Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Tue, 10 Feb 2026 21:40:57 -0500 Subject: [PATCH 09/11] remove test that depnds on ipython --- tests/test_shard_from_audio.py | 58 ---------------------------------- wsds/ws_tools.py | 3 +- 2 files changed, 2 insertions(+), 59 deletions(-) diff --git a/tests/test_shard_from_audio.py b/tests/test_shard_from_audio.py index 21bcb39..3d7f5a9 100644 --- a/tests/test_shard_from_audio.py +++ b/tests/test_shard_from_audio.py @@ -203,62 +203,4 @@ def test_shard_naming(self, tmp_path): shard_names = sorted(p.name for p in output_dir.glob("*.wsds")) assert shard_names == ["audio-00000.wsds", "audio-00001.wsds"] - def test_init_index_creates_audio_subdir(self, tmp_path): - """When init_index=True, shards are written to audio/ subdirectory and index is created.""" - input_dir = tmp_path / "in" - output_dir = tmp_path / "dataset" - input_dir.mkdir() - - # Create some test audio files - for i in range(3): - make_wav(input_dir / f"file{i}.wav") - - shard_from_audio_dir( - str(input_dir), - str(output_dir), - max_files_per_shard=2, - init_index=True, - require_audio_duration=False, # Skip audio duration requirement for test - ) - - # Shards should be in output_dir/audio/ - audio_dir = output_dir / "audio" - assert audio_dir.exists() - assert audio_dir.is_dir() - - # Check shards are in the audio subdirectory - shard_files = sorted(audio_dir.glob("*.wsds")) - assert len(shard_files) == 2 # 3 files / 2 per shard = 2 shards - - # Check index was created at dataset root - index_file = output_dir / "index.sqlite3" - assert index_file.exists() - - def test_init_index_with_audio_named_output(self, tmp_path): - """When init_index=True and output_dir is already named 'audio', don't create nested audio/audio/.""" - input_dir = tmp_path / "in" - dataset_root = tmp_path / "dataset" - output_dir = dataset_root / "audio" - input_dir.mkdir() - - make_wav(input_dir / "test.wav") - - shard_from_audio_dir( - str(input_dir), - str(output_dir), - init_index=True, - require_audio_duration=False, - ) - - # Shards should be in output_dir (which is already named 'audio') - shard_files = sorted(output_dir.glob("*.wsds")) - assert len(shard_files) == 1 - - # Check we didn't create audio/audio/ - nested_audio = output_dir / "audio" - assert not nested_audio.exists() - - # Index should be at dataset_root (parent of audio/) - index_file = dataset_root / "index.sqlite3" - assert index_file.exists() diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index 78ca891..aac5821 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -632,9 +632,10 @@ def shard_from_audio_dir( key_prefix: Optional prefix to prepend to keys before hashing. Useful to avoid collisions when processing multiple directories with files that share the same names (e.g., "egyptian", "saudi"). - sort_files: If True, sorts files by name before processing. + sort_files: If True, sorts files by name before processing. This can be memory intensive for large datasets, but ensures deterministic shard assignment. """ + from tqdm import tqdm input_dir, output_dir = Path(input_dir), Path(output_dir) From e285afe419d2799bce86016b792526a28e81f01a Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Fri, 13 Mar 2026 11:08:37 -0400 Subject: [PATCH 10/11] Add convert_mka_to_audio utility with aresample=async=1 for duration-safe conversion MKA files from WebRTC recordings contain timestamp gaps that cause ffmpeg to silently drop audio, producing shorter output files. The aresample=async=1 filter fills gaps with silence to preserve the original duration. Also adds .mka to shard_from_audio_dir supported extensions. --- wsds/ws_tools.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index aac5821..dad160e 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -610,6 +610,62 @@ def write_batch(batch, out_path, compression: str | None = None): sink.write(row) +def convert_mka_to_audio( + input_dir: str, + output_dir: str, + output_format: str = "flac", + sort_files: bool = False, +) -> list[Path]: + """Convert .mka files to another audio format, preserving duration. + + MKA (Matroska audio) files from WebRTC recordings often contain timestamp + gaps. A naive ``ffmpeg -i in.mka out.flac`` silently drops the gaps, + producing a shorter file. This function uses ``-af aresample=async=1`` + to fill gaps with silence so that the output duration matches the source. + + Returns the list of output file paths. + """ + import subprocess + + from tqdm import tqdm + + input_dir, output_dir = Path(input_dir), Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + file_iter = input_dir.rglob("*.mka") + files = sorted(file_iter) if sort_files else list(file_iter) + print(f"[INFO] Found {len(files)} .mka files under {input_dir}") + + outputs: list[Path] = [] + failed: list[Path] = [] + for mka in tqdm(files, ncols=90, desc=f"Converting MKA → {output_format.upper()}"): + rel = mka.relative_to(input_dir).with_suffix(f".{output_format}") + out = output_dir / rel + out.parent.mkdir(parents=True, exist_ok=True) + + r = subprocess.run( + [ + "ffmpeg", "-nostdin", "-y", "-hide_banner", "-loglevel", "error", + "-i", str(mka), + "-af", "aresample=async=1", + str(out), + ], + capture_output=True, + ) + if r.returncode != 0: + tqdm.write(f"[WARN] FAILED: {mka}") + tqdm.write(r.stderr.decode()[-200:]) + failed.append(mka) + else: + outputs.append(out) + + print(f"[DONE] {len(outputs)} converted, {len(failed)} failed.") + if failed: + for f in failed: + print(f" FAILED: {f}") + return outputs + + @command def shard_from_audio_dir( input_dir: str, @@ -649,7 +705,7 @@ def shard_from_audio_dir( output_dir.mkdir(parents=True, exist_ok=True) - exts = (".wav", ".flac", ".mp3", ".m4a", ".ogg", ".opus") + exts = (".wav", ".flac", ".mp3", ".m4a", ".mka", ".ogg", ".opus") file_iter = (p for p in input_dir.rglob("*") if p.suffix.lower() in exts) if sort_files: files = sorted(file_iter) From 779ac08f1342916e32517c24794edc92da79ed85 Mon Sep 17 00:00:00 2001 From: Theo Lebryk Date: Mon, 16 Mar 2026 11:50:08 -0400 Subject: [PATCH 11/11] =?UTF-8?q?Remove=20convert=5Fmka=5Fto=5Faudio=20?= =?UTF-8?q?=E2=80=94=20MKA=20conversion=20with=20start=5Ftime=20handling?= =?UTF-8?q?=20belongs=20in=20consumer=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wsds/ws_tools.py | 56 ------------------------------------------------ 1 file changed, 56 deletions(-) diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index dad160e..0b64bd3 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -610,62 +610,6 @@ def write_batch(batch, out_path, compression: str | None = None): sink.write(row) -def convert_mka_to_audio( - input_dir: str, - output_dir: str, - output_format: str = "flac", - sort_files: bool = False, -) -> list[Path]: - """Convert .mka files to another audio format, preserving duration. - - MKA (Matroska audio) files from WebRTC recordings often contain timestamp - gaps. A naive ``ffmpeg -i in.mka out.flac`` silently drops the gaps, - producing a shorter file. This function uses ``-af aresample=async=1`` - to fill gaps with silence so that the output duration matches the source. - - Returns the list of output file paths. - """ - import subprocess - - from tqdm import tqdm - - input_dir, output_dir = Path(input_dir), Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - file_iter = input_dir.rglob("*.mka") - files = sorted(file_iter) if sort_files else list(file_iter) - print(f"[INFO] Found {len(files)} .mka files under {input_dir}") - - outputs: list[Path] = [] - failed: list[Path] = [] - for mka in tqdm(files, ncols=90, desc=f"Converting MKA → {output_format.upper()}"): - rel = mka.relative_to(input_dir).with_suffix(f".{output_format}") - out = output_dir / rel - out.parent.mkdir(parents=True, exist_ok=True) - - r = subprocess.run( - [ - "ffmpeg", "-nostdin", "-y", "-hide_banner", "-loglevel", "error", - "-i", str(mka), - "-af", "aresample=async=1", - str(out), - ], - capture_output=True, - ) - if r.returncode != 0: - tqdm.write(f"[WARN] FAILED: {mka}") - tqdm.write(r.stderr.decode()[-200:]) - failed.append(mka) - else: - outputs.append(out) - - print(f"[DONE] {len(outputs)} converted, {len(failed)} failed.") - if failed: - for f in failed: - print(f" FAILED: {f}") - return outputs - - @command def shard_from_audio_dir( input_dir: str,