diff --git a/pyproject.toml b/pyproject.toml index f064d25..95695bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ allow-direct-references = true files = ["requirements.txt"] [tool.hatch.build.targets.wheel] -packages = ["."] +packages = ["wsds"] [tool.ruff] line-length = 120 @@ -24,6 +24,12 @@ indent-width = 4 ignore = ["E203", "E501", "E731"] extend-select = ["I"] +[project.optional-dependencies] +test = ["pytest"] + +[tool.pytest.ini_options] +testpaths = ["tests"] + # --- build-data --- # [build-system] requires = ["hatchling", "hatch-requirements-txt"] diff --git a/requirements.txt b/requirements.txt index 4fe9e43..20365fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ polars>=1.36.1 pyarrow>=20 torch torchaudio +tqdm # torchcodec – optional, it causes serious performance regressions diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_shard_from_audio.py b/tests/test_shard_from_audio.py new file mode 100644 index 0000000..3d7f5a9 --- /dev/null +++ b/tests/test_shard_from_audio.py @@ -0,0 +1,206 @@ +import struct +from pathlib import Path + +import pyarrow as pa +import pyarrow.ipc + +from wsds.ws_tools import shard_from_audio_dir + + +def make_wav(path, num_samples=100, sample_rate=16000, num_channels=1): + """Write a minimal valid WAV file.""" + bits_per_sample = 16 + data_size = num_samples * num_channels * (bits_per_sample // 8) + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + 36 + data_size, + b"WAVE", + b"fmt ", + 16, + 1, # PCM + num_channels, + sample_rate, + sample_rate * num_channels * bits_per_sample // 8, + num_channels * bits_per_sample // 8, + bits_per_sample, + b"data", + data_size, + ) + pcm = b"\x00\x01" * num_samples * num_channels + path.write_bytes(header + pcm) + + +def _collect_shards(output_dir): + """Read all .wsds shards in output_dir, return list of (keys, audio_bytes, audio_types) per shard.""" + shards = [] + for shard_path in sorted(output_dir.glob("*.wsds")): + reader = pa.ipc.open_file(str(shard_path)) + table = reader.read_all() + keys = table.column("__key__").to_pylist() + audio = [v.as_py() for v in table.column("audio")] + audio_types = table.column("audio_type").to_pylist() + shards.append((keys, audio, audio_types)) + return shards + + +class TestShardFromAudioDir: + def test_basic_sharding(self, tmp_path): + """Files are split into correct number of shards and content matches.""" + input_dir = tmp_path / "audio_in" + output_dir = tmp_path / "audio_out" + input_dir.mkdir() + + stems = [f"clip_{i:03d}" for i in range(5)] + original_bytes = {} + for stem in stems: + p = input_dir / f"{stem}.wav" + make_wav(p, num_samples=50 + len(stem)) + original_bytes[stem] = p.read_bytes() + + shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2) + + shards = _collect_shards(output_dir) + # 5 files / 2 per shard = 3 shards + assert len(shards) == 3 + + all_keys = [] + all_audio = {} + all_types = [] + for keys, audio, audio_types in shards: + all_keys.extend(keys) + for k, a in zip(keys, audio): + all_audio[k] = a + all_types.extend(audio_types) + + assert sorted(all_keys) == sorted(stems) + for stem in stems: + assert all_audio[stem] == original_bytes[stem] + assert all(t == "wav" for t in all_types) + + def test_key_prefix(self, tmp_path): + """key_prefix is prepended to each key.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "hello.wav") + + shard_from_audio_dir(str(input_dir), str(output_dir), key_prefix="dataset1") + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["dataset1/hello"] + + def test_key_fn(self, tmp_path): + """key_fn transforms the key.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "original.wav") + + shard_from_audio_dir( + str(input_dir), str(output_dir), key_fn=lambda s: s.upper() + ) + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["ORIGINAL"] + + def test_key_fn_with_prefix(self, tmp_path): + """key_fn receives the prefixed stem.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "file.wav") + + shard_from_audio_dir( + str(input_dir), + str(output_dir), + key_prefix="pfx", + key_fn=lambda s: s.replace("/", "__"), + ) + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["pfx__file"] + + def test_oversized_files_skipped(self, tmp_path, monkeypatch): + """Files exceeding the Arrow byte limit are skipped.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + make_wav(input_dir / "small.wav", num_samples=10) + make_wav(input_dir / "big.wav", num_samples=100) + + small_size = (input_dir / "small.wav").stat().st_size + large_size = (input_dir / "big.wav").stat().st_size + + # Pick a fake limit between the two file sizes so big.wav gets skipped + fake_limit = (small_size + large_size) // 2 + + # Patch the read_bytes to attach a fake size, then patch len check via + # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. + # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function + # by replacing it with one that sets a lower limit. + import wsds.ws_tools as mod + + orig_code = mod.shard_from_audio_dir.__code__ + + # Replace the constant in the code object's co_consts + new_consts = tuple( + fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts + ) + new_code = orig_code.replace(co_consts=new_consts) + monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code) + + shard_from_audio_dir(str(input_dir), str(output_dir)) + + shards = _collect_shards(output_dir) + all_keys = [k for keys, _, _ in shards for k in keys] + assert "small" in all_keys + assert "big" not in all_keys + + def test_empty_input_dir(self, tmp_path): + """Empty input directory produces no shards.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + shard_from_audio_dir(str(input_dir), str(output_dir)) + + assert list(output_dir.glob("*.wsds")) == [] + + def test_subdirectory_files(self, tmp_path): + """Audio files in subdirectories use relative path as key.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + sub = input_dir / "speaker1" + sub.mkdir(parents=True) + + make_wav(sub / "utt.wav") + + shard_from_audio_dir(str(input_dir), str(output_dir)) + + shards = _collect_shards(output_dir) + keys = shards[0][0] + assert keys == ["speaker1/utt"] + + def test_shard_naming(self, tmp_path): + """Shard files are named audio-NNNNN.wsds.""" + input_dir = tmp_path / "in" + output_dir = tmp_path / "out" + input_dir.mkdir() + + for i in range(4): + make_wav(input_dir / f"f{i}.wav") + + shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2) + + shard_names = sorted(p.name for p in output_dir.glob("*.wsds")) + assert shard_names == ["audio-00000.wsds", "audio-00001.wsds"] + + diff --git a/wsds/ws_tools.py b/wsds/ws_tools.py index 7cb0ff3..0b64bd3 100644 --- a/wsds/ws_tools.py +++ b/wsds/ws_tools.py @@ -1,17 +1,17 @@ +from __future__ import annotations + import functools import json import os import sys -import tarfile -from collections import defaultdict +from collections.abc import Callable from pathlib import Path - import numpy as np import polars as pl import pyarrow as pa -import webdataset as wds from . import WSSample, WSSink +from .ws_audio import AudioReader commands = {} @@ -400,6 +400,7 @@ def init( source_dataset: Path | None = None, vad_column: str | None = None, num_workers: int = 32, + require_audio_duration: bool = True, ): """Initialize a new dataset, from scratch or from a segmentation of an existing one.""" import multiprocessing @@ -417,8 +418,13 @@ def init( source_dataset = new_dataset ds = WSDataset(source_dataset) - shard_extractor = functools.partial(extract_index_for_shard, source_dataset, vad_column=vad_column) - all_shards = ds.get_shard_list(ignore_index = True) + shard_extractor = functools.partial( + extract_index_for_shard, + source_dataset, + vad_column=vad_column, + require_audio_duration=require_audio_duration, + ) + all_shards = ds.get_shard_list() with AtomicFile(new_dataset / "index.sqlite3") as fname: with WSDSIndexWriter(fname) as index: @@ -510,15 +516,29 @@ def init_split( new_fields["audio"] = [("audio.wsds-computed", "audio")] index.append_metadata({"fields": new_fields}) -def extract_index_for_shard(dataset, shard, vad_column=None): +def extract_index_for_shard( + dataset: str | Path, + shard: tuple[str, str] | list[str] | str, + vad_column: str | None = None, + require_audio_duration: bool = True, +) -> dict: from . import WSDataset ds = WSDataset(dataset) index = [] i = 0 + if isinstance(shard, (tuple, list)) and len(shard) == 2: + dataset_path, shard_name = shard + else: + dataset_path, shard_name = "", shard - for s in ds.iter_shard(shard): - key = s["__key__"] + sample = WSSample(ds, (dataset_path, shard_name), 0) + + for s in ds.sequential_from(sample, 0): + try: + key = str(s["__key__"]) + except IndexError: + break if not vad_column: n = 1 @@ -530,7 +550,21 @@ def extract_index_for_shard(dataset, shard, vad_column=None): if vad.size > 0: speech_duration = float((vad[:, -1] - vad[:, -2]).sum()) # tend - tstart - audio_duration = s['load_duration'] or s['est_duration'] or -1 + if not require_audio_duration: + audio_duration = -1.0 + elif "inspected_duration" in s: + audio_duration = float(s["inspected_duration"]) + else: + try: + audio_reader = s.get_audio() + meta = audio_reader.metadata + audio_duration = _duration_seconds_from_metadata(meta, audio_reader) + if audio_duration is None: + raise ValueError("could not infer duration from audio metadata") + except Exception as e: + print("Audio loading error:", e) + print(" for sample:", s) + raise if ( n > 0 @@ -539,12 +573,145 @@ def extract_index_for_shard(dataset, shard, vad_column=None): i += n return { - "shard_name": shard[1], + "shard_name": shard_name, + "dataset_path": dataset_path, "index": index, "n_samples": i, } +def _duration_seconds_from_metadata(meta, audio_reader: AudioReader | None = None) -> float | None: + for attr in ("duration_seconds_from_header", "duration", "duration_seconds"): + val = getattr(meta, attr, None) + if val is not None: + return float(val) + num_frames = getattr(meta, "num_frames", None) or getattr(meta, "num_samples", None) + sample_rate = getattr(meta, "sample_rate", None) + if num_frames is not None and sample_rate: + return float(num_frames) / float(sample_rate) + if audio_reader is not None and sample_rate: + try: + raw_bytes = audio_reader.unwrap() + bits_per_sample = getattr(meta, "bits_per_sample", None) or 16 + num_channels = getattr(meta, "num_channels", None) or 1 + bytes_per_sample = (bits_per_sample // 8) * num_channels + data_bytes = len(raw_bytes) - 44 + if data_bytes > 0 and bytes_per_sample > 0: + num_samples = data_bytes // bytes_per_sample + return float(num_samples) / float(sample_rate) + except Exception: + pass + return None + + +def write_batch(batch, out_path, compression: str | None = None): + with WSSink(out_path, compression=compression) as sink: + for row in batch: + sink.write(row) + + +@command +def shard_from_audio_dir( + input_dir: str, + output_dir: str, + max_files_per_shard: int = 300, + init_index: bool = False, + require_audio_duration: bool = True, + key_fn: Callable[[str], str] | None = None, + write_key_mapping: bool = False, + key_prefix: str = "", + sort_files: bool = False, +): + """Write batched Feather (.wsds) shards with up to N audio files each. + + Args: + key_fn: Optional function to transform the file stem into a key. + E.g., a hash function for obfuscation. + write_key_mapping: If True and key_fn is provided, writes a JSON file + mapping transformed keys back to original stems. + key_prefix: Optional prefix to prepend to keys before hashing. Useful to + avoid collisions when processing multiple directories with + files that share the same names (e.g., "egyptian", "saudi"). + sort_files: If True, sorts files by name before processing. + This can be memory intensive for large datasets, but ensures deterministic shard assignment. + """ + from tqdm import tqdm + + input_dir, output_dir = Path(input_dir), Path(output_dir) + + # When init_index=True, ensure shards are written to audio/ subdirectory + # so that list_all_shards() can find them (it only looks in subdirectories) + if init_index and output_dir.name != "audio": + dataset_root = output_dir + output_dir = output_dir / "audio" + else: + dataset_root = output_dir.parent if output_dir.name == "audio" else output_dir + + output_dir.mkdir(parents=True, exist_ok=True) + + exts = (".wav", ".flac", ".mp3", ".m4a", ".mka", ".ogg", ".opus") + file_iter = (p for p in input_dir.rglob("*") if p.suffix.lower() in exts) + if sort_files: + files = sorted(file_iter) + print(f"[INFO] Found {len(files):,} audio files under {input_dir}") + else: + files = file_iter + print(f"[INFO] Processing audio files under {input_dir}") + MAX_ARROW_BYTES = 2_140_000_000 # ~2.1 GB Arrow cell limit + + shard_idx = 0 + batch = [] + key_mapping = {} + + def flush_batch(): + nonlocal shard_idx, batch + if not batch: + return + out_path = output_dir / f"audio-{shard_idx:05d}.wsds" + write_batch(batch, out_path, compression=None) + shard_idx += 1 + batch = [] + + for path in tqdm(files, ncols=90, desc="Writing WSDS shards"): + rel_path = path.relative_to(input_dir).with_suffix('') + stem = str(rel_path) + if key_prefix: + stem = f"{key_prefix}/{stem}" + key = key_fn(stem) if key_fn else stem + if key_fn and write_key_mapping: + key_mapping[key] = stem + ext = path.suffix.lower().lstrip(".") + try: + audio_bytes = path.read_bytes() + except Exception as e: + print(f"[WARN] Skipping {path}: {e}") + continue + + size = len(audio_bytes) + if size > MAX_ARROW_BYTES: + print(f"[SKIP] {path.name}: {size / 1e6:.1f} MB exceeds 2 GB Arrow limit") + continue + + batch.append({"__key__": key, "audio": audio_bytes, "audio_type": ext}) + + if len(batch) >= max_files_per_shard: + flush_batch() + + if batch: + flush_batch() + + print(f"[DONE] Wrote {shard_idx} WSDS shards -> {output_dir}") + + if key_mapping: + mapping_path = dataset_root / "key_mapping.json" + with open(mapping_path, "w") as f: + json.dump(key_mapping, f, indent=2) + print(f"[INFO] Wrote key mapping ({len(key_mapping):,} entries) -> {mapping_path}") + + if init_index: + init(dataset_root, require_audio_duration=require_audio_duration) + + @command def _remove_columns(*fnames, remove: str = ""): """