From 828570f117230ccae8a1f2370967a0b535933903 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 11:02:33 +0100 Subject: [PATCH 01/13] load dataset only once --- sdgym/benchmark.py | 37 ++++++++++----------- tests/unit/test_benchmark.py | 63 ++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 30d8e608..b4665d7d 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -365,7 +365,7 @@ def _generate_job_args_list( paths = _setup_output_destination( output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) - job_tuples = [] + job_tuples_by_dataset = defaultdict(list) for dataset in datasets: for synthesizer in synthesizers: if paths: @@ -377,29 +377,30 @@ def _generate_job_args_list( final_name = synthesizer['name'] synthesizer['name'] = final_name - job_tuples.append((synthesizer, dataset)) + job_tuples_by_dataset[dataset].append(synthesizer) job_args_list = [] - for synthesizer, dataset in job_tuples: + for dataset, synthesizers in job_tuples_by_dataset.items(): data, metadata_dict = _load_dataset_with_client( modality, dataset, limit_dataset_size=limit_dataset_size, s3_client=s3_client ) - path = paths.get(dataset.name, {}).get(synthesizer['name'], None) - job_args_list.append( - JobArgs( - synthesizer=synthesizer, - data=data, - metadata=metadata_dict, - metrics=sdmetrics, - timeout=timeout, - compute_quality_score=compute_quality_score, - compute_diagnostic_score=compute_diagnostic_score, - compute_privacy_score=compute_privacy_score, - dataset_name=dataset.name, - modality=modality, - output_directions=path, + for synthesizer in synthesizers: + path = paths.get(dataset.name, {}).get(synthesizer['name'], None) + job_args_list.append( + JobArgs( + synthesizer=synthesizer, + data=data, + metadata=metadata_dict, + metrics=sdmetrics, + timeout=timeout, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + dataset_name=dataset.name, + modality=modality, + output_directions=path, + ) ) - ) return job_args_list diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 9c8065af..f020637b 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -1097,6 +1097,69 @@ def test__generate_job_args_list_local_root_additional_folder( ) +@patch('sdgym.benchmark.get_dataset_paths') +@patch('sdgym.benchmark._setup_output_destination') +@patch('sdgym.benchmark._load_dataset_with_client') +def test__generate_job_args_list_loads_each_dataset_once( + mock_load_dataset, + mock__setup_output_destination, + mock_get_dataset_paths, +): + """Test that each dataset is loaded once even when there are multiple synthesizers.""" + # Setup + dataset_a = Path('/dummy/single_table/datasetA') + dataset_b = Path('/dummy/single_table/datasetB') + mock_get_dataset_paths.return_value = [dataset_a, dataset_b] + mock__setup_output_destination.return_value = {} + data_a = Mock(name='data_a') + metadata_a = Mock(name='metadata_a') + data_b = Mock(name='data_b') + metadata_b = Mock(name='metadata_b') + mock_load_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] + synthesizers = [ + {'name': 'GaussianCopulaSynthesizer'}, + {'name': 'UniformSynthesizer'}, + ] + s3_client = Mock() + + # Run + job_args_list = _generate_job_args_list( + limit_dataset_size=True, + sdv_datasets=['datasetA', 'datasetB'], + additional_datasets_folder=None, + sdmetrics=None, + timeout=None, + output_destination=None, + compute_quality_score=False, + compute_diagnostic_score=False, + compute_privacy_score=False, + synthesizers=synthesizers, + s3_client=s3_client, + modality='single_table', + ) + + # Assert + mock_load_dataset.assert_has_calls([ + call('single_table', dataset_a, limit_dataset_size=True, s3_client=s3_client), + call('single_table', dataset_b, limit_dataset_size=True, s3_client=s3_client), + ]) + assert mock_load_dataset.call_count == 2 + assert len(job_args_list) == 4 + assert [job.dataset_name for job in job_args_list] == [ + 'datasetA', + 'datasetA', + 'datasetB', + 'datasetB', + ] + assert [job.data for job in job_args_list] == [data_a, data_a, data_b, data_b] + assert [job.metadata for job in job_args_list] == [ + metadata_a, + metadata_a, + metadata_b, + metadata_b, + ] + + @patch('sdgym.benchmark.get_dataset_paths') @patch('sdgym.benchmark._setup_output_destination') @patch('sdgym.benchmark._load_dataset_with_client') From cba80aaffcaddf65862d7488e1ce27be6aaea429 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 13:48:42 +0100 Subject: [PATCH 02/13] use downlad_demo from sdv --- .github/workflows/integration.yml | 3 + .github/workflows/minimum.yml | 3 + pyproject.toml | 3 +- sdgym/benchmark.py | 102 ++++++++++++---- sdgym/datasets.py | 117 ++++++++++++++++++ tests/integration/test_benchmark.py | 24 ++++ tests/unit/test_benchmark.py | 86 ++++++++++--- tests/unit/test_datasets.py | 179 ++++++++++++++++++++++++++++ tests/unit/test_s3.py | 8 +- 9 files changed, 484 insertions(+), 41 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e344e5bd..d475d6b6 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -28,6 +28,9 @@ jobs: python -m pip install --upgrade pip python -m pip install --no-cache-dir invoke .[test] - name: Run integration tests + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: invoke integration - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.14 name: Upload integration codecov report diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 7eaf8edc..8e143056 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -39,4 +39,7 @@ jobs: python -m pip install --no-cache-dir invoke .[test] - name: Test with minimum versions + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} run: invoke minimum diff --git a/pyproject.toml b/pyproject.toml index 46c6f0a9..f1c918cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,8 +66,7 @@ dependencies = [ "rdt>=1.20.0;python_version>='3.14'", "sdmetrics>=0.21.0;python_version<'3.14'", "sdmetrics>=0.26.0;python_version>='3.14'", - "sdv>=1.21.0;python_version<'3.14'", - "sdv>=1.33.0;python_version>='3.14'", + "sdv @ git+https://github.com/sdv-dev/SDV.git@main", ] [project.urls] diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index b4665d7d..7fa7be98 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -40,7 +40,14 @@ ) from sdmetrics.single_table import DCRBaselineProtection -from sdgym.datasets import _load_dataset_with_client, get_dataset_paths +from sdgym.datasets import ( + SDV_DATASETS_PRIVATE_BUCKET, + SDV_DATASETS_PUBLIC_BUCKET, + _get_dataset_bucket_mapping, + _load_dataset_with_client, + _load_sdv_demo_dataset, + get_dataset_paths, +) from sdgym.errors import BenchmarkError, SDGymError from sdgym.metrics import get_metrics from sdgym.progress import TqdmLogger @@ -123,6 +130,14 @@ class JobArgs(NamedTuple): output_directions: Optional[dict] +class ResolvedDataset(NamedTuple): + """Resolved dataset data and metadata for benchmark job creation.""" + + name: str + data: Any + metadata: Any + + def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality): """Import user-provided synthesizer and validate modality and uniqueness. @@ -323,6 +338,33 @@ def _setup_output_destination( return paths +def _resolve_dataset( + modality, + dataset, + limit_dataset_size, + source, + s3_client=None, + dataset_bucket_mapping=None, +): + if source == 'sdv_demo': + data, metadata = _load_sdv_demo_dataset( + modality=modality, + dataset_name=dataset, + dataset_bucket_mapping=dataset_bucket_mapping, + s3_client=s3_client, + limit_dataset_size=limit_dataset_size, + ) + return ResolvedDataset(dataset, data, metadata) + + data, metadata = _load_dataset_with_client( + modality, + dataset, + limit_dataset_size=limit_dataset_size, + s3_client=s3_client, + ) + return ResolvedDataset(dataset.name, data, metadata) + + def _generate_job_args_list( limit_dataset_size, sdv_datasets, @@ -337,15 +379,7 @@ def _generate_job_args_list( s3_client, modality, ): - sdv_datasets = ( - [] - if sdv_datasets is None - else get_dataset_paths( - modality=modality, - datasets=sdv_datasets, - s3_client=s3_client, - ) - ) + sdv_dataset_names = [] if sdv_datasets is None else sdv_datasets additional_datasets = ( [] if additional_datasets_folder is None @@ -359,13 +393,45 @@ def _generate_job_args_list( s3_client=s3_client, ) ) - datasets = sdv_datasets + additional_datasets + if not synthesizers: + return [] + + dataset_bucket_mapping = None + if sdv_dataset_names: + dataset_bucket_mapping = _get_dataset_bucket_mapping( + modality, + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client, + skip_inaccessible=True, + ) + + datasets = [ + _resolve_dataset( + modality=modality, + dataset=dataset, + limit_dataset_size=limit_dataset_size, + source='sdv_demo', + s3_client=s3_client, + dataset_bucket_mapping=dataset_bucket_mapping, + ) + for dataset in sdv_dataset_names + ] + datasets.extend( + _resolve_dataset( + modality=modality, + dataset=dataset, + limit_dataset_size=limit_dataset_size, + source='additional', + s3_client=s3_client, + ) + for dataset in additional_datasets + ) synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] paths = _setup_output_destination( output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) - job_tuples_by_dataset = defaultdict(list) + job_args_list = [] for dataset in datasets: for synthesizer in synthesizers: if paths: @@ -377,20 +443,12 @@ def _generate_job_args_list( final_name = synthesizer['name'] synthesizer['name'] = final_name - job_tuples_by_dataset[dataset].append(synthesizer) - - job_args_list = [] - for dataset, synthesizers in job_tuples_by_dataset.items(): - data, metadata_dict = _load_dataset_with_client( - modality, dataset, limit_dataset_size=limit_dataset_size, s3_client=s3_client - ) - for synthesizer in synthesizers: path = paths.get(dataset.name, {}).get(synthesizer['name'], None) job_args_list.append( JobArgs( synthesizer=synthesizer, - data=data, - metadata=metadata_dict, + data=dataset.data, + metadata=dataset.metadata, metrics=sdmetrics, timeout=timeout, compute_quality_score=compute_quality_score, diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 8b923b11..6be54e48 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -1,12 +1,23 @@ """SDGym module to handle datasets.""" +import io import logging import os from pathlib import Path import appdirs +import botocore import numpy as np import pandas as pd +from sdv.datasets.demo import ( + _find_data_zip_key, + _get_data_from_bucket, + _get_first_v1_metadata_bytes, + _get_metadata, + _list_objects, + _load_data_from_zip, + download_demo, +) from sdgym._dataset_utils import ( _get_dataset_subset, @@ -35,6 +46,13 @@ def _get_bucket_name(bucket): return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket +def _metadata_to_dict(metadata): + if isinstance(metadata, dict): + return metadata + + return metadata.to_dict() + + def _raise_dataset_not_found_error( s3_client, bucket_name, @@ -251,6 +269,105 @@ def _get_available_datasets( return pd.DataFrame(datasets_info) +def _get_dataset_bucket_mapping(modality, buckets, s3_client, skip_inaccessible=False): + """Map SDV demo dataset names to the bucket they should be loaded from.""" + dataset_buckets = {} + for bucket in buckets: + try: + available_datasets = _get_available_datasets( + modality, + bucket=bucket, + s3_client=s3_client, + ) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as error: + if skip_inaccessible: + LOGGER.info("Skipping inaccessible bucket '%s': %s", bucket, error) + continue + + raise ValueError( + f"Bucket '{bucket}' is not accessible with the provided credentials." + ) from error + + for dataset_name in available_datasets['dataset_name'].tolist(): + existing_bucket = dataset_buckets.get(dataset_name) + if existing_bucket and bucket != SDV_DATASETS_PRIVATE_BUCKET: + continue + + dataset_buckets[dataset_name] = bucket + + return dataset_buckets + + +def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=None): + """Load an SDV demo dataset from a private bucket with an SDGym S3 client.""" + bucket_name = _get_bucket_name(bucket) + s3_client = s3_client or get_s3_client() + dataset_prefix = f'{modality}/{dataset_name}/' + contents = _list_objects(dataset_prefix, bucket=bucket_name, client=s3_client) + data_key = _find_data_zip_key(contents, dataset_prefix, bucket_name) + data_bytes = io.BytesIO(_get_data_from_bucket(data_key, bucket=bucket_name, client=s3_client)) + metadata_bytes = _get_first_v1_metadata_bytes( + contents, dataset_prefix, bucket=bucket_name, client=s3_client + ) + data = _load_data_from_zip(data_bytes, bucket_name, dataset_name) + if modality != 'multi_table': + data = data.popitem()[1] + + metadata = _get_metadata(metadata_bytes, dataset_name) + return data, _metadata_to_dict(metadata) + + +def _load_sdv_demo_dataset( + modality, + dataset_name, + dataset_bucket_mapping=None, + s3_client=None, + limit_dataset_size=False, +): + """Load an SDV demo dataset from the resolved public or private bucket.""" + _validate_modality(modality) + buckets = [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET] + if dataset_bucket_mapping is None: + dataset_bucket_mapping = _get_dataset_bucket_mapping( + modality, + buckets, + s3_client or get_s3_client(), + skip_inaccessible=True, + ) + + bucket = dataset_bucket_mapping.get(dataset_name) + if bucket is None: + buckets_list = ', '.join(buckets) + raise ValueError( + f"Dataset '{dataset_name}' not found in SDV demo buckets for modality " + f"'{modality}'. Checked buckets: {buckets_list}." + ) + + bucket_name = _get_bucket_name(bucket) + try: + data, metadata = download_demo( + modality=modality, + dataset_name=dataset_name, + s3_bucket_name=bucket_name, + ) + metadata = _metadata_to_dict(metadata) + except ValueError: + if bucket != SDV_DATASETS_PRIVATE_BUCKET: + raise + + data, metadata = _load_private_sdv_demo_dataset( + modality, + dataset_name, + bucket, + s3_client=s3_client, + ) + + if limit_dataset_size: + data, metadata = _get_dataset_subset(data, metadata, modality=modality) + + return data, metadata + + def load_dataset( modality, dataset, diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index 0c7fe86a..cdf87685 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -929,6 +929,30 @@ def test_benchmark_multi_table_basic_synthesizers(): ] +@pytest.mark.skipif( + not os.getenv('AWS_ACCESS_KEY_ID') or not os.getenv('AWS_SECRET_ACCESS_KEY'), + reason='MovieLens benchmark requires AWS credentials for private dataset access.', +) +def test_benchmark_multi_table_private_dataset(): + """Test multi-table benchmark with private dataset `MovieLens`.""" + # Setup + datasets = ['MovieLens'] + synthesizers = ['HMASynthesizer'] + timeout = 10 + + # Run + result = benchmark_multi_table( + synthesizers=synthesizers, + sdv_datasets=datasets, + timeout=timeout, + ) + + # Assert + assert result['Dataset'].tolist() == ['MovieLens', 'MovieLens'] + assert result['Synthesizer'].tolist() == ['HMASynthesizer', 'MultiTableUniformSynthesizer'] + assert result['Quality_Score'].tolist() == [None, None] + + def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path): """Test saving in ``output_destination`` with multiple runs in multi-table mode. diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index f020637b..2e3ee068 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -202,10 +202,18 @@ def test__get_metainfo_increment_local(mock_logger, tmp_path): assert result_3 == 3 +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark.tqdm.tqdm') -def test_benchmark_single_table_progress_bar(tqdm_mock): +def test_benchmark_single_table_progress_bar( + tqdm_mock, mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping +): """Test that the benchmarking function updates the progress bar on one line.""" # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} + mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata scores_mock = MagicMock() scores_mock.__iter__.return_value = [ pd.DataFrame({ @@ -228,11 +236,22 @@ def test_benchmark_single_table_progress_bar(tqdm_mock): tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._score') @patch('sdgym.benchmark.multiprocessing') -def test_benchmark_single_table_with_timeout(mock_multiprocessing, mock__score): +def test_benchmark_single_table_with_timeout( + mock_multiprocessing, + mock__score, + mock_load_sdv_demo_dataset, + mock_get_dataset_bucket_mapping, +): """Test that benchmark runs with timeout.""" # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} + mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata mocked_process = mock_multiprocessing.Process.return_value manager = mock_multiprocessing.Manager.return_value manager_dict = {'timeout': True, 'Error': 'Synthesizer Timeout'} @@ -1097,25 +1116,26 @@ def test__generate_job_args_list_local_root_additional_folder( ) -@patch('sdgym.benchmark.get_dataset_paths') +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._setup_output_destination') -@patch('sdgym.benchmark._load_dataset_with_client') +@patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_loads_each_dataset_once( - mock_load_dataset, - mock__setup_output_destination, mock_get_dataset_paths, + mock__setup_output_destination, + mock_load_sdv_demo_dataset, + mock_get_dataset_bucket_mapping, ): """Test that each dataset is loaded once even when there are multiple synthesizers.""" # Setup - dataset_a = Path('/dummy/single_table/datasetA') - dataset_b = Path('/dummy/single_table/datasetB') - mock_get_dataset_paths.return_value = [dataset_a, dataset_b] + mock_get_dataset_paths.return_value = [] + mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket', 'datasetB': 'bucket'} mock__setup_output_destination.return_value = {} data_a = Mock(name='data_a') metadata_a = Mock(name='metadata_a') data_b = Mock(name='data_b') metadata_b = Mock(name='metadata_b') - mock_load_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] + mock_load_sdv_demo_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] synthesizers = [ {'name': 'GaussianCopulaSynthesizer'}, {'name': 'UniformSynthesizer'}, @@ -1139,11 +1159,24 @@ def test__generate_job_args_list_loads_each_dataset_once( ) # Assert - mock_load_dataset.assert_has_calls([ - call('single_table', dataset_a, limit_dataset_size=True, s3_client=s3_client), - call('single_table', dataset_b, limit_dataset_size=True, s3_client=s3_client), + mock_get_dataset_paths.assert_not_called() + mock_load_sdv_demo_dataset.assert_has_calls([ + call( + modality='single_table', + dataset_name='datasetA', + dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + s3_client=s3_client, + limit_dataset_size=True, + ), + call( + modality='single_table', + dataset_name='datasetB', + dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + s3_client=s3_client, + limit_dataset_size=True, + ), ]) - assert mock_load_dataset.call_count == 2 + assert mock_load_sdv_demo_dataset.call_count == 2 assert len(job_args_list) == 4 assert [job.dataset_name for job in job_args_list] == [ 'datasetA', @@ -1175,7 +1208,10 @@ def test__generate_job_args_list_s3_root_additional_folder( get_dataset_paths_mock.return_value = [dataset_path] s3_client = Mock() mock__setup_output_destination.return_value = {} - mock_load_dataset.return_value = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) + mock_load_dataset.return_value = ( + pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}), + {'tables': {}}, + ) # Run _generate_job_args_list( @@ -1211,9 +1247,27 @@ def test__generate_job_args_list_s3_root_additional_folder( ) -def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +def test_benchmark_single_table_no_warning_uniform_synthesizer( + mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping, recwarn +): """Test that no UserWarning is raised when running `UniformSynthesizer`.""" # Setup + data = pd.DataFrame({'column': [1, 2, 3]}) + metadata = { + 'tables': { + 'fake_hotel_guests': { + 'columns': { + 'column': { + 'sdtype': 'numerical', + } + } + } + } + } + mock_get_dataset_bucket_mapping.return_value = {'fake_hotel_guests': 'bucket'} + mock_load_sdv_demo_dataset.return_value = data, metadata expected_result = pd.DataFrame({ 'Synthesizer': {0: 'UniformSynthesizer'}, 'Dataset': {0: 'fake_hotel_guests'}, diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index f03c7d00..c770846b 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -1,16 +1,22 @@ from pathlib import Path from unittest.mock import Mock, call, patch +import botocore import numpy as np +import pandas as pd import pytest from sdgym.datasets import ( DATASETS_PATH, + SDV_DATASETS_PRIVATE_BUCKET, + SDV_DATASETS_PUBLIC_BUCKET, _download_dataset, _genereate_dataset_info, _get_bucket_name, + _get_dataset_bucket_mapping, _get_dataset_path_and_download, _load_dataset_with_client, + _load_sdv_demo_dataset, _path_contains_data_and_metadata, _validate_modality, get_data_and_metadata_from_path, @@ -362,6 +368,179 @@ def test_get_bucket_name_local_folder(): assert bucket_name == 'bucket-name' +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_prefers_private(get_available_mock): + """Test that datasets are mapped to private when duplicated across buckets.""" + # Setup + get_available_mock.side_effect = [ + pd.DataFrame({'dataset_name': ['public_only', 'duplicate']}), + pd.DataFrame({'dataset_name': ['private_only', 'duplicate']}), + ] + + # Run + result = _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + ) + + # Assert + assert result == { + 'public_only': SDV_DATASETS_PUBLIC_BUCKET, + 'private_only': SDV_DATASETS_PRIVATE_BUCKET, + 'duplicate': SDV_DATASETS_PRIVATE_BUCKET, + } + get_available_mock.assert_has_calls([ + call('single_table', bucket=SDV_DATASETS_PUBLIC_BUCKET, s3_client='s3_client'), + call('single_table', bucket=SDV_DATASETS_PRIVATE_BUCKET, s3_client='s3_client'), + ]) + + +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_skips_inaccessible_bucket(get_available_mock): + """Test inaccessible buckets can be skipped while building the mapping.""" + # Setup + error = botocore.exceptions.ClientError( + {'Error': {'Code': 'AccessDenied', 'Message': 'denied'}}, + 'ListObjectsV2', + ) + get_available_mock.side_effect = [ + pd.DataFrame({'dataset_name': ['public_only']}), + error, + ] + + # Run + result = _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + skip_inaccessible=True, + ) + + # Assert + assert result == {'public_only': SDV_DATASETS_PUBLIC_BUCKET} + + +@patch('sdgym.datasets._get_available_datasets') +def test__get_dataset_bucket_mapping_raises_inaccessible_bucket(get_available_mock): + """Test inaccessible buckets raise by default.""" + # Setup + get_available_mock.side_effect = botocore.exceptions.ClientError( + {'Error': {'Code': 'AccessDenied', 'Message': 'denied'}}, + 'ListObjectsV2', + ) + + # Run and Assert + with pytest.raises(ValueError, match="Bucket 's3://sdv-datasets-private' is not accessible"): + _get_dataset_bucket_mapping( + 'single_table', + [SDV_DATASETS_PRIVATE_BUCKET], + s3_client='s3_client', + ) + + +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_uses_download_demo(download_demo_mock): + """Test SDV demo datasets are loaded through SDV's download_demo.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = Mock() + metadata.to_dict.return_value = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.return_value = data, metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, data) + assert result_metadata == metadata.to_dict.return_value + download_demo_mock.assert_called_once_with( + modality='single_table', + dataset_name='demo', + s3_bucket_name='sdv-datasets-public', + ) + + +@patch('sdgym.datasets._load_private_sdv_demo_dataset') +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_falls_back_for_private_bucket( + download_demo_mock, load_private_mock +): + """Test SDV private-bucket errors fall back to SDGym private loading.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.side_effect = ValueError('Private buckets are only supported') + load_private_mock.return_value = data, metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + dataset_bucket_mapping={'demo': SDV_DATASETS_PRIVATE_BUCKET}, + s3_client='s3_client', + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, data) + assert result_metadata == metadata + load_private_mock.assert_called_once_with( + 'single_table', + 'demo', + SDV_DATASETS_PRIVATE_BUCKET, + s3_client='s3_client', + ) + + +@patch('sdgym.datasets._get_dataset_subset') +@patch('sdgym.datasets.download_demo') +def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_mock): + """Test SDV demo dataset loading applies the dataset size limit.""" + # Setup + data = pd.DataFrame({'column': [1, 2]}) + metadata = Mock() + metadata.to_dict.return_value = {'tables': {'demo': {'columns': {'column': {}}}}} + limited_data = pd.DataFrame({'column': [1]}) + limited_metadata = {'tables': {'demo': {'columns': {'column': {}}}}} + download_demo_mock.return_value = data, metadata + subset_mock.return_value = limited_data, limited_metadata + + # Run + result = _load_sdv_demo_dataset( + modality='single_table', + dataset_name='demo', + dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + limit_dataset_size=True, + ) + + # Assert + result_data, result_metadata = result + pd.testing.assert_frame_equal(result_data, limited_data) + assert result_metadata == limited_metadata + subset_mock.assert_called_once_with( + data, + metadata.to_dict.return_value, + modality='single_table', + ) + + +def test__load_sdv_demo_dataset_raises_when_dataset_not_found(): + """Test a clear error is raised when a demo dataset is absent from all buckets.""" + # Run and Assert + with pytest.raises(ValueError, match="Dataset 'missing' not found in SDV demo buckets"): + _load_sdv_demo_dataset( + modality='single_table', + dataset_name='missing', + dataset_bucket_mapping={}, + ) + + @patch('sdgym.datasets._get_dataset_path_and_download') @patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True) @patch('sdgym.datasets.Path') diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index 00bb7d8d..3dd0a20f 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -320,11 +320,14 @@ def test__get_s3_client_with_credentials(mock_boto_client): mock_s3_client.head_bucket.assert_called_once_with(Bucket='my-bucket') -def test__get_s3_client_errors(): +@patch('sdgym.s3.boto3.client') +def test__get_s3_client_errors(mock_boto_client): """Test `_get_s3_client` raises error for invalid input.""" # Setup output_destination = 's3:/' expected_error = re.escape(f'Invalid S3 URL: {output_destination}') + mock_s3_client = mock_boto_client.return_value + mock_s3_client.head_bucket.side_effect = NoCredentialsError() # Run and Assert with pytest.raises(ValueError, match=expected_error): @@ -333,6 +336,9 @@ def test__get_s3_client_errors(): with pytest.raises(NoCredentialsError, match='Unable to locate credentials'): _get_s3_client('s3://bucket_name/') + mock_boto_client.assert_called_once_with('s3') + mock_s3_client.head_bucket.assert_called_once_with(Bucket='bucket_name') + def test__read_data_from_bucket_key_reads_body(): """Test that the function reads data from S3 object body.""" From 3803c1947689f0c2d727a88b3b21893f14745cd2 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 14:11:37 +0100 Subject: [PATCH 03/13] update _resolve_dataset --- pyproject.toml | 3 +- sdgym/_benchmark/benchmark.py | 3 +- sdgym/benchmark.py | 100 ++++++++++++------------ tests/unit/test_benchmark.py | 138 +++++++++++++++++++++++++++++++++- 4 files changed, 187 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f1c918cd..83b325d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,8 +64,7 @@ dependencies = [ 'XlsxWriter>=1.2.8', "rdt>=1.18.2;python_version<'3.14'", "rdt>=1.20.0;python_version>='3.14'", - "sdmetrics>=0.21.0;python_version<'3.14'", - "sdmetrics>=0.26.0;python_version>='3.14'", + "sdmetrics>=0.28.0", "sdv @ git+https://github.com/sdv-dev/SDV.git@main", ] diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 19fb8de5..9ad30906 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -13,6 +13,7 @@ DEFAULT_SINGLE_TABLE_DATASETS, DEFAULT_SINGLE_TABLE_SYNTHESIZERS, S3_REGION, + SDGYM_BRANCH_INSTALL_COMMAND, _ensure_uniform_included, _generate_job_args_list, _get_empty_dataframe, @@ -206,7 +207,7 @@ def _get_user_data_script( log "======== Install Dependencies ==========" pip install --upgrade pip {sdv_install} - pip install "sdgym[all]" + {SDGYM_BRANCH_INSTALL_COMMAND} {gpu_block} diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 7fa7be98..7191eee3 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -73,6 +73,10 @@ TIMEOUT = 345600 LOGGER = logging.getLogger(__name__) +SDGYM_BRANCH_INSTALL_COMMAND = ( + 'pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git' + '@issue-604-2-private-bucket"' +) DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [ 'GaussianCopulaSynthesizer', 'CTGANSynthesizer', @@ -340,44 +344,10 @@ def _setup_output_destination( def _resolve_dataset( modality, - dataset, - limit_dataset_size, - source, - s3_client=None, - dataset_bucket_mapping=None, -): - if source == 'sdv_demo': - data, metadata = _load_sdv_demo_dataset( - modality=modality, - dataset_name=dataset, - dataset_bucket_mapping=dataset_bucket_mapping, - s3_client=s3_client, - limit_dataset_size=limit_dataset_size, - ) - return ResolvedDataset(dataset, data, metadata) - - data, metadata = _load_dataset_with_client( - modality, - dataset, - limit_dataset_size=limit_dataset_size, - s3_client=s3_client, - ) - return ResolvedDataset(dataset.name, data, metadata) - - -def _generate_job_args_list( - limit_dataset_size, sdv_datasets, additional_datasets_folder, - sdmetrics, - timeout, - output_destination, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, - synthesizers, - s3_client, - modality, + limit_dataset_size, + s3_client=None, ): sdv_dataset_names = [] if sdv_datasets is None else sdv_datasets additional_datasets = ( @@ -393,8 +363,6 @@ def _generate_job_args_list( s3_client=s3_client, ) ) - if not synthesizers: - return [] dataset_bucket_mapping = None if sdv_dataset_names: @@ -405,26 +373,52 @@ def _generate_job_args_list( skip_inaccessible=True, ) - datasets = [ - _resolve_dataset( + datasets = [] + for dataset_name in sdv_dataset_names: + data, metadata = _load_sdv_demo_dataset( modality=modality, - dataset=dataset, - limit_dataset_size=limit_dataset_size, - source='sdv_demo', - s3_client=s3_client, + dataset_name=dataset_name, dataset_bucket_mapping=dataset_bucket_mapping, + s3_client=s3_client, + limit_dataset_size=limit_dataset_size, ) - for dataset in sdv_dataset_names - ] - datasets.extend( - _resolve_dataset( - modality=modality, - dataset=dataset, + datasets.append(ResolvedDataset(dataset_name, data, metadata)) + + for dataset in additional_datasets: + data, metadata = _load_dataset_with_client( + modality, + dataset, limit_dataset_size=limit_dataset_size, - source='additional', s3_client=s3_client, ) - for dataset in additional_datasets + datasets.append(ResolvedDataset(dataset.name, data, metadata)) + + return datasets + + +def _generate_job_args_list( + limit_dataset_size, + sdv_datasets, + additional_datasets_folder, + sdmetrics, + timeout, + output_destination, + compute_quality_score, + compute_diagnostic_score, + compute_privacy_score, + synthesizers, + s3_client, + modality, +): + if not synthesizers: + return [] + + datasets = _resolve_dataset( + modality=modality, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + limit_dataset_size=limit_dataset_size, + s3_client=s3_client, ) synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] @@ -1427,7 +1421,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - pip install sdgym[all] + {SDGYM_BRANCH_INSTALL_COMMAND} pip install s3fs echo "======== Write Script ===========" diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 2e3ee068..b8d3a40b 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -21,6 +21,7 @@ _generate_job_args_list, _get_metainfo_increment, _import_and_validate_synthesizers, + _resolve_dataset, _setup_output_destination, _setup_output_destination_aws, _store_job_args_in_s3, @@ -233,6 +234,19 @@ def test_benchmark_single_table_progress_bar( ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='student_placements', + dataset_bucket_mapping={'student_placements': 'bucket'}, + s3_client=None, + limit_dataset_size=False, + ) tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) @@ -265,6 +279,19 @@ def test_benchmark_single_table_with_timeout( ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='student_placements', + dataset_bucket_mapping={'student_placements': 'bucket'}, + s3_client=None, + limit_dataset_size=False, + ) mocked_process.start.assert_called_once_with() mocked_process.join.assert_called_once_with(1) mocked_process.terminate.assert_called_once_with() @@ -1078,10 +1105,78 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) +@patch('sdgym.benchmark._load_dataset_with_client') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.get_dataset_paths') +def test__resolve_dataset_loads_sdv_and_additional_datasets( + mock_get_dataset_paths, + mock_get_dataset_bucket_mapping, + mock_load_sdv_demo_dataset, + mock_load_dataset, + tmp_path, +): + """Test the `_resolve_dataset` method.""" + # Setup + additional_folder = tmp_path / 'additional' + additional_dataset_path = additional_folder / 'single_table' / 'custom_dataset' + sdv_data = Mock(name='sdv_data') + sdv_metadata = Mock(name='sdv_metadata') + additional_data = Mock(name='additional_data') + additional_metadata = Mock(name='additional_metadata') + s3_client = Mock() + mock_get_dataset_paths.return_value = [additional_dataset_path] + mock_get_dataset_bucket_mapping.return_value = {'sdv_dataset': 'bucket'} + mock_load_sdv_demo_dataset.return_value = sdv_data, sdv_metadata + mock_load_dataset.return_value = additional_data, additional_metadata + + # Run + result = _resolve_dataset( + modality='single_table', + sdv_datasets=['sdv_dataset'], + additional_datasets_folder=str(additional_folder), + limit_dataset_size=True, + s3_client=s3_client, + ) + + # Assert + mock_get_dataset_paths.assert_called_once_with( + modality='single_table', + bucket=str(additional_folder / 'single_table'), + s3_client=s3_client, + ) + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + s3_client, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='sdv_dataset', + dataset_bucket_mapping={'sdv_dataset': 'bucket'}, + s3_client=s3_client, + limit_dataset_size=True, + ) + mock_load_dataset.assert_called_once_with( + 'single_table', + additional_dataset_path, + limit_dataset_size=True, + s3_client=s3_client, + ) + assert [dataset.name for dataset in result] == ['sdv_dataset', 'custom_dataset'] + assert [dataset.data for dataset in result] == [sdv_data, additional_data] + assert [dataset.metadata for dataset in result] == [sdv_metadata, additional_metadata] + + @pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark._setup_output_destination') +@patch('sdgym.benchmark._load_dataset_with_client') @patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_local_root_additional_folder( get_dataset_paths_mock, + mock_load_dataset, + mock__setup_output_destination, tmp_path, modality, ): @@ -1091,6 +1186,8 @@ def test__generate_job_args_list_local_root_additional_folder( local_root.mkdir() dataset_path = tmp_path / 'my_root' / modality / 'datasetA' get_dataset_paths_mock.return_value = [dataset_path] + mock_load_dataset.return_value = Mock(), Mock() + mock__setup_output_destination.return_value = {} # Run _generate_job_args_list( @@ -1103,7 +1200,7 @@ def test__generate_job_args_list_local_root_additional_folder( compute_quality_score=False, compute_diagnostic_score=False, compute_privacy_score=False, - synthesizers=[], + synthesizers=[{'name': 'UniformSynthesizer'}], s3_client=None, modality=modality, ) @@ -1114,6 +1211,19 @@ def test__generate_job_args_list_local_root_additional_folder( bucket=str(local_root / modality), s3_client=None, ) + mock_load_dataset.assert_called_once_with( + modality, + dataset_path, + limit_dataset_size=False, + s3_client=None, + ) + mock__setup_output_destination.assert_called_once_with( + None, + ['UniformSynthesizer'], + ['datasetA'], + modality=modality, + s3_client=None, + ) @patch('sdgym.benchmark._get_dataset_bucket_mapping') @@ -1160,6 +1270,19 @@ def test__generate_job_args_list_loads_each_dataset_once( # Assert mock_get_dataset_paths.assert_not_called() + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + s3_client, + skip_inaccessible=True, + ) + mock__setup_output_destination.assert_called_once_with( + None, + ['GaussianCopulaSynthesizer', 'UniformSynthesizer'], + ['datasetA', 'datasetB'], + modality='single_table', + s3_client=s3_client, + ) mock_load_sdv_demo_dataset.assert_has_calls([ call( modality='single_table', @@ -1289,6 +1412,19 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer( ) # Assert + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + None, + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name='fake_hotel_guests', + dataset_bucket_mapping={'fake_hotel_guests': 'bucket'}, + s3_client=None, + limit_dataset_size=False, + ) warnings_text = ' '.join(str(w.message) for w in recwarn) assert 'is incompatible with transformer' not in warnings_text pd.testing.assert_frame_equal(result[expected_result.columns], expected_result) From d040382e737e8f0e77ed1420b7a8454b451beac3 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 15:40:18 +0100 Subject: [PATCH 04/13] fix lint --- sdgym/_benchmark/benchmark.py | 3 +-- sdgym/benchmark.py | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 9ad30906..19fb8de5 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -13,7 +13,6 @@ DEFAULT_SINGLE_TABLE_DATASETS, DEFAULT_SINGLE_TABLE_SYNTHESIZERS, S3_REGION, - SDGYM_BRANCH_INSTALL_COMMAND, _ensure_uniform_included, _generate_job_args_list, _get_empty_dataframe, @@ -207,7 +206,7 @@ def _get_user_data_script( log "======== Install Dependencies ==========" pip install --upgrade pip {sdv_install} - {SDGYM_BRANCH_INSTALL_COMMAND} + pip install "sdgym[all]" {gpu_block} diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 7191eee3..8b19987d 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -73,10 +73,6 @@ TIMEOUT = 345600 LOGGER = logging.getLogger(__name__) -SDGYM_BRANCH_INSTALL_COMMAND = ( - 'pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git' - '@issue-604-2-private-bucket"' -) DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [ 'GaussianCopulaSynthesizer', 'CTGANSynthesizer', @@ -1421,7 +1417,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - {SDGYM_BRANCH_INSTALL_COMMAND} + pip install sdgym[all] pip install s3fs echo "======== Write Script ===========" From 2118c9a3379eca16d430e2542c44e23972395485 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 25 May 2026 20:49:40 +0100 Subject: [PATCH 05/13] cleaning --- sdgym/datasets.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 6be54e48..6ef494dc 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -46,13 +46,6 @@ def _get_bucket_name(bucket): return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket -def _metadata_to_dict(metadata): - if isinstance(metadata, dict): - return metadata - - return metadata.to_dict() - - def _raise_dataset_not_found_error( s3_client, bucket_name, @@ -314,7 +307,7 @@ def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=Non data = data.popitem()[1] metadata = _get_metadata(metadata_bytes, dataset_name) - return data, _metadata_to_dict(metadata) + return data, metadata.to_dict() def _load_sdv_demo_dataset( @@ -350,7 +343,7 @@ def _load_sdv_demo_dataset( dataset_name=dataset_name, s3_bucket_name=bucket_name, ) - metadata = _metadata_to_dict(metadata) + metadata = metadata.to_dict() except ValueError: if bucket != SDV_DATASETS_PRIVATE_BUCKET: raise From b3481aa6081563fe08bdb863b0ffb54b1b1d4330 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Tue, 26 May 2026 09:57:00 +0100 Subject: [PATCH 06/13] add validation --- sdgym/benchmark.py | 9 +++- sdgym/datasets.py | 19 +------- tests/unit/_benchmark_launcher/test_utils.py | 26 ++++++++++- tests/unit/test_benchmark.py | 48 +++++++++++++++++--- tests/unit/test_datasets.py | 17 ++----- 5 files changed, 78 insertions(+), 41 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 8b19987d..8638b9b3 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -368,13 +368,20 @@ def _resolve_dataset( s3_client, skip_inaccessible=True, ) + missing_names = [name for name in sdv_dataset_names if name not in dataset_bucket_mapping] + if missing_names: + missing_to_print = "', '".join(missing_names) + raise ValueError( + f'The following SDV demo datasets were not found in the expected buckets: ' + f"'{missing_to_print}'. Please check that the dataset names are correct." + ) datasets = [] for dataset_name in sdv_dataset_names: data, metadata = _load_sdv_demo_dataset( modality=modality, dataset_name=dataset_name, - dataset_bucket_mapping=dataset_bucket_mapping, + bucket=dataset_bucket_mapping.get(dataset_name), s3_client=s3_client, limit_dataset_size=limit_dataset_size, ) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 6ef494dc..4d9f32b9 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -313,29 +313,12 @@ def _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client=Non def _load_sdv_demo_dataset( modality, dataset_name, - dataset_bucket_mapping=None, + bucket, s3_client=None, limit_dataset_size=False, ): """Load an SDV demo dataset from the resolved public or private bucket.""" _validate_modality(modality) - buckets = [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET] - if dataset_bucket_mapping is None: - dataset_bucket_mapping = _get_dataset_bucket_mapping( - modality, - buckets, - s3_client or get_s3_client(), - skip_inaccessible=True, - ) - - bucket = dataset_bucket_mapping.get(dataset_name) - if bucket is None: - buckets_list = ', '.join(buckets) - raise ValueError( - f"Dataset '{dataset_name}' not found in SDV demo buckets for modality " - f"'{modality}'. Checked buckets: {buckets_list}." - ) - bucket_name = _get_bucket_name(bucket) try: data, metadata = download_demo( diff --git a/tests/unit/_benchmark_launcher/test_utils.py b/tests/unit/_benchmark_launcher/test_utils.py index 10f26f4d..d2626196 100644 --- a/tests/unit/_benchmark_launcher/test_utils.py +++ b/tests/unit/_benchmark_launcher/test_utils.py @@ -516,10 +516,33 @@ def test_resolve_credentials_with_filepath_deep_merges_file_over_env( assert credentials == expected -def test_resolve_credentials_file_mode(tmp_path): +@patch('sdgym._benchmark_launcher.utils._get_env_credentials') +def test_resolve_credentials_file_mode(mock_get_env_credentials, tmp_path): """Test `resolve_credentials` returns credentials from a file merged over env defaults.""" # Setup credential_file = tmp_path / 'credentials.json' + mock_get_env_credentials.return_value = { + 'aws': { + 'AWS_ACCESS_KEY_ID': None, + 'AWS_SECRET_ACCESS_KEY': None, + }, + 'gcp': { + 'type': None, + 'project_id': None, + 'private_key_id': None, + 'private_key': None, + 'client_email': None, + 'client_id': None, + 'auth_uri': None, + 'token_uri': None, + 'auth_provider_x509_cert_url': None, + 'client_x509_cert_url': None, + }, + 'sdv_enterprise': { + 'SDV_ENTERPRISE_USERNAME': None, + 'SDV_ENTERPRISE_LICENSE_KEY': None, + }, + } file_credentials = { 'aws': { 'AWS_ACCESS_KEY_ID': 'FILE_AKIA', @@ -559,6 +582,7 @@ def test_resolve_credentials_file_mode(tmp_path): credentials = resolve_credentials(str(credential_file)) # Assert + mock_get_env_credentials.assert_called_once_with() assert credentials == expected_credentials diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index b8d3a40b..5f5fc166 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -243,7 +243,7 @@ def test_benchmark_single_table_progress_bar( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='student_placements', - dataset_bucket_mapping={'student_placements': 'bucket'}, + bucket='bucket', s3_client=None, limit_dataset_size=False, ) @@ -288,7 +288,7 @@ def test_benchmark_single_table_with_timeout( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='student_placements', - dataset_bucket_mapping={'student_placements': 'bucket'}, + bucket='bucket', s3_client=None, limit_dataset_size=False, ) @@ -1154,7 +1154,7 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='sdv_dataset', - dataset_bucket_mapping={'sdv_dataset': 'bucket'}, + bucket='bucket', s3_client=s3_client, limit_dataset_size=True, ) @@ -1169,6 +1169,40 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( assert [dataset.metadata for dataset in result] == [sdv_metadata, additional_metadata] +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_sdv_demo_dataset') +def test__resolve_dataset_raises_when_sdv_dataset_is_missing_from_buckets( + mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping +): + """Test `_resolve_dataset` raises when an SDV dataset is not found in any bucket.""" + # Setup + mock_get_dataset_bucket_mapping.return_value = {'available_dataset': 'bucket'} + + # Run and Assert + with pytest.raises( + ValueError, + match=( + 'The following SDV demo datasets were not found in the expected buckets: ' + "'missing_dataset'. Please check that the dataset names are correct." + ), + ): + _resolve_dataset( + modality='single_table', + sdv_datasets=['available_dataset', 'missing_dataset'], + additional_datasets_folder=None, + limit_dataset_size=False, + s3_client='s3_client', + ) + + mock_get_dataset_bucket_mapping.assert_called_once_with( + 'single_table', + ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], + 's3_client', + skip_inaccessible=True, + ) + mock_load_sdv_demo_dataset.assert_not_called() + + @pytest.mark.parametrize('modality', ['single_table', 'multi_table']) @patch('sdgym.benchmark._setup_output_destination') @patch('sdgym.benchmark._load_dataset_with_client') @@ -1239,7 +1273,7 @@ def test__generate_job_args_list_loads_each_dataset_once( """Test that each dataset is loaded once even when there are multiple synthesizers.""" # Setup mock_get_dataset_paths.return_value = [] - mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket', 'datasetB': 'bucket'} + mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket-a', 'datasetB': 'bucket-b'} mock__setup_output_destination.return_value = {} data_a = Mock(name='data_a') metadata_a = Mock(name='metadata_a') @@ -1287,14 +1321,14 @@ def test__generate_job_args_list_loads_each_dataset_once( call( modality='single_table', dataset_name='datasetA', - dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + bucket='bucket-a', s3_client=s3_client, limit_dataset_size=True, ), call( modality='single_table', dataset_name='datasetB', - dataset_bucket_mapping={'datasetA': 'bucket', 'datasetB': 'bucket'}, + bucket='bucket-b', s3_client=s3_client, limit_dataset_size=True, ), @@ -1421,7 +1455,7 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer( mock_load_sdv_demo_dataset.assert_called_once_with( modality='single_table', dataset_name='fake_hotel_guests', - dataset_bucket_mapping={'fake_hotel_guests': 'bucket'}, + bucket='bucket', s3_client=None, limit_dataset_size=False, ) diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index c770846b..e38efbb7 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -452,7 +452,7 @@ def test__load_sdv_demo_dataset_uses_download_demo(download_demo_mock): result = _load_sdv_demo_dataset( modality='single_table', dataset_name='demo', - dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + bucket=SDV_DATASETS_PUBLIC_BUCKET, ) # Assert @@ -482,7 +482,7 @@ def test__load_sdv_demo_dataset_falls_back_for_private_bucket( result = _load_sdv_demo_dataset( modality='single_table', dataset_name='demo', - dataset_bucket_mapping={'demo': SDV_DATASETS_PRIVATE_BUCKET}, + bucket=SDV_DATASETS_PRIVATE_BUCKET, s3_client='s3_client', ) @@ -515,7 +515,7 @@ def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_m result = _load_sdv_demo_dataset( modality='single_table', dataset_name='demo', - dataset_bucket_mapping={'demo': SDV_DATASETS_PUBLIC_BUCKET}, + bucket=SDV_DATASETS_PUBLIC_BUCKET, limit_dataset_size=True, ) @@ -530,17 +530,6 @@ def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_m ) -def test__load_sdv_demo_dataset_raises_when_dataset_not_found(): - """Test a clear error is raised when a demo dataset is absent from all buckets.""" - # Run and Assert - with pytest.raises(ValueError, match="Dataset 'missing' not found in SDV demo buckets"): - _load_sdv_demo_dataset( - modality='single_table', - dataset_name='missing', - dataset_bucket_mapping={}, - ) - - @patch('sdgym.datasets._get_dataset_path_and_download') @patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True) @patch('sdgym.datasets.Path') From a62376b851bbbfe8699c847878b80247f755ab08 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Wed, 27 May 2026 15:42:36 +0100 Subject: [PATCH 07/13] move dataset loading to execution --- pyproject.toml | 6 +- sdgym/benchmark.py | 173 +++++++++------ sdgym/dataset_explorer.py | 5 +- sdgym/datasets.py | 8 +- sdgym/result_explorer/result_explorer.py | 2 +- .../result_explorer/test_result_explorer.py | 4 +- tests/unit/test_benchmark.py | 206 ++++++++++-------- tests/unit/test_dataset_explorer.py | 6 +- tests/unit/test_datasets.py | 38 +--- 9 files changed, 251 insertions(+), 197 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 83b325d7..46c6f0a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,8 +64,10 @@ dependencies = [ 'XlsxWriter>=1.2.8', "rdt>=1.18.2;python_version<'3.14'", "rdt>=1.20.0;python_version>='3.14'", - "sdmetrics>=0.28.0", - "sdv @ git+https://github.com/sdv-dev/SDV.git@main", + "sdmetrics>=0.21.0;python_version<'3.14'", + "sdmetrics>=0.26.0;python_version>='3.14'", + "sdv>=1.21.0;python_version<'3.14'", + "sdv>=1.33.0;python_version>='3.14'", ] [project.urls] diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 8638b9b3..021b5109 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -114,12 +114,21 @@ SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS +class DatasetInfo(NamedTuple): + """Information needed to load a dataset when a benchmark job executes.""" + + name: str + source: str + dataset_path: Any + bucket: Optional[str] + limit_dataset_size: bool + + class JobArgs(NamedTuple): """Arguments needed to run a single synthesizer + dataset benchmark job.""" synthesizer: dict - data: Any - metadata: Any + dataset_info: Any metrics: Any timeout: Optional[int] compute_quality_score: bool @@ -130,14 +139,6 @@ class JobArgs(NamedTuple): output_directions: Optional[dict] -class ResolvedDataset(NamedTuple): - """Resolved dataset data and metadata for benchmark job creation.""" - - name: str - data: Any - metadata: Any - - def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality): """Import user-provided synthesizer and validate modality and uniqueness. @@ -378,23 +379,28 @@ def _resolve_dataset( datasets = [] for dataset_name in sdv_dataset_names: - data, metadata = _load_sdv_demo_dataset( - modality=modality, - dataset_name=dataset_name, - bucket=dataset_bucket_mapping.get(dataset_name), - s3_client=s3_client, - limit_dataset_size=limit_dataset_size, + datasets.append( + DatasetInfo( + name=dataset_name, + source='sdv_demo', + dataset_path=None, + bucket=dataset_bucket_mapping.get(dataset_name), + limit_dataset_size=limit_dataset_size, + ) ) - datasets.append(ResolvedDataset(dataset_name, data, metadata)) for dataset in additional_datasets: - data, metadata = _load_dataset_with_client( - modality, - dataset, - limit_dataset_size=limit_dataset_size, - s3_client=s3_client, + bucket = additional_datasets_folder if is_s3_path(additional_datasets_folder) else None + dataset_to_load = dataset.name if bucket else dataset + datasets.append( + DatasetInfo( + name=dataset.name, + source='additional', + dataset_path=dataset_to_load, + bucket=bucket, + limit_dataset_size=limit_dataset_size, + ) ) - datasets.append(ResolvedDataset(dataset.name, data, metadata)) return datasets @@ -416,7 +422,7 @@ def _generate_job_args_list( if not synthesizers: return [] - datasets = _resolve_dataset( + dataset_infos = _resolve_dataset( modality=modality, sdv_datasets=sdv_datasets, additional_datasets_folder=additional_datasets_folder, @@ -424,34 +430,37 @@ def _generate_job_args_list( s3_client=s3_client, ) synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] - dataset_names = [dataset.name for dataset in datasets] + dataset_names = [dataset_info.name for dataset_info in dataset_infos] paths = _setup_output_destination( output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) job_args_list = [] - for dataset in datasets: + for dataset_info in dataset_infos: for synthesizer in synthesizers: if paths: final_name = next( - (name for name in paths[dataset.name] if name.startswith(synthesizer['name'])), + ( + name + for name in paths[dataset_info.name] + if name.startswith(synthesizer['name']) + ), synthesizer['name'], ) else: final_name = synthesizer['name'] synthesizer['name'] = final_name - path = paths.get(dataset.name, {}).get(synthesizer['name'], None) + path = paths.get(dataset_info.name, {}).get(synthesizer['name'], None) job_args_list.append( JobArgs( synthesizer=synthesizer, - data=dataset.data, - metadata=dataset.metadata, + dataset_info=dataset_info, metrics=sdmetrics, timeout=timeout, compute_quality_score=compute_quality_score, compute_diagnostic_score=compute_diagnostic_score, compute_privacy_score=compute_privacy_score, - dataset_name=dataset.name, + dataset_name=dataset_info.name, modality=modality, output_directions=path, ) @@ -878,13 +887,43 @@ def _format_output( return scores +def _get_s3_client_from_result_writer(result_writer): + if isinstance(result_writer, S3ResultsWriter): + return result_writer.s3_client + + return None + + +def _load_from_dataset_info(dataset_info, modality, result_writer=None): + """Load data and metadata based on the provided dataset information.""" + s3_client = _get_s3_client_from_result_writer(result_writer) + if dataset_info.source == 'sdv_demo': + return _load_sdv_demo_dataset( + modality=modality, + dataset_name=dataset_info.name, + bucket=dataset_info.bucket, + s3_client=s3_client, + limit_dataset_size=dataset_info.limit_dataset_size, + ) + + if dataset_info.source == 'additional': + return _load_dataset_with_client( + modality=modality, + dataset_name=dataset_info.name, + datasets_path=dataset_info.dataset_path, + bucket=dataset_info.bucket, + limit_dataset_size=dataset_info.limit_dataset_size, + s3_client=s3_client, + ) + + raise ValueError(f"Unknown dataset source: '{dataset_info.source}'.") + + def _run_job(job_args, result_writer=None): # Reset random seed np.random.seed() synthesizer = job_args.synthesizer - data = job_args.data - metadata = job_args.metadata metrics = job_args.metrics timeout = job_args.timeout compute_quality_score = job_args.compute_quality_score @@ -893,6 +932,7 @@ def _run_job(job_args, result_writer=None): dataset_name = job_args.dataset_name modality = job_args.modality synthesizer_path = job_args.output_directions + dataset_info = job_args.dataset_info name = synthesizer['name'] LOGGER.info( @@ -904,37 +944,44 @@ def _run_job(job_args, result_writer=None): ) output = {} try: - if timeout: - output = _score_with_timeout( - timeout=timeout, - synthesizer=synthesizer, - data=data, - metadata=metadata, - metrics=metrics, - compute_quality_score=compute_quality_score, - compute_diagnostic_score=compute_diagnostic_score, - compute_privacy_score=compute_privacy_score, - modality=modality, - dataset_name=dataset_name, - synthesizer_path=synthesizer_path, - result_writer=result_writer, - ) - else: - output = _score( - synthesizer=synthesizer, - data=data, - metadata=metadata, - metrics=metrics, - compute_quality_score=compute_quality_score, - compute_diagnostic_score=compute_diagnostic_score, - compute_privacy_score=compute_privacy_score, - modality=modality, - dataset_name=dataset_name, - synthesizer_path=synthesizer_path, - result_writer=result_writer, - ) + data, metadata = _load_from_dataset_info( + dataset_info, modality, result_writer=result_writer + ) + try: + if timeout: + output = _score_with_timeout( + timeout=timeout, + synthesizer=synthesizer, + data=data, + metadata=metadata, + metrics=metrics, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + modality=modality, + dataset_name=dataset_name, + synthesizer_path=synthesizer_path, + result_writer=result_writer, + ) + else: + output = _score( + synthesizer=synthesizer, + data=data, + metadata=metadata, + metrics=metrics, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + modality=modality, + dataset_name=dataset_name, + synthesizer_path=synthesizer_path, + result_writer=result_writer, + ) + except Exception as error: + output['Error'] = error + except Exception as error: - output['exception'] = error + output['Error'] = error scores = _format_output( output, diff --git a/sdgym/dataset_explorer.py b/sdgym/dataset_explorer.py index b2620e32..fb253bed 100644 --- a/sdgym/dataset_explorer.py +++ b/sdgym/dataset_explorer.py @@ -245,7 +245,10 @@ def _load_and_summarize_datasets(self, modality, datasets=None): dataset_size_mb = dataset_row['size_MB'] dataset_num_table = dataset_row['num_tables'] data, metadata_dict = _load_dataset_with_client( - modality, dataset=dataset_name, bucket=self._bucket_name, s3_client=self.s3_client + modality=modality, + dataset_name=dataset_name, + bucket=self._bucket_name, + s3_client=self.s3_client, ) metadata_stats = DatasetExplorer.get_metadata_summary(metadata_dict) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 4d9f32b9..7e15c290 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -386,7 +386,7 @@ def load_dataset( ) return _load_dataset_with_client( modality=modality, - dataset=dataset, + dataset_name=dataset, datasets_path=datasets_path, bucket=bucket, s3_client=s3_client, @@ -396,7 +396,7 @@ def load_dataset( def _load_dataset_with_client( modality, - dataset, + dataset_name, datasets_path=None, bucket=None, s3_client=None, @@ -405,7 +405,7 @@ def _load_dataset_with_client( """Get the data and metadata of a dataset using a given s3 client.""" _validate_modality(modality) dataset_path = _get_dataset_path_and_download( - modality, dataset, datasets_path, bucket, s3_client=s3_client + modality, dataset_name, datasets_path, bucket, s3_client=s3_client ) data, metadata_dict = get_data_and_metadata_from_path(dataset_path, modality) @@ -451,7 +451,7 @@ def get_dataset_paths( if datasets is None: if not is_remote and Path(bucket).exists(): datasets = [] - folder_items = list(Path(bucket).iterdir()) + folder_items = sorted(Path(bucket).iterdir()) for dataset in folder_items: if _path_contains_data_and_metadata(dataset) and dataset not in datasets: datasets.append(dataset) diff --git a/sdgym/result_explorer/result_explorer.py b/sdgym/result_explorer/result_explorer.py index 000297a3..d0e6dd16 100644 --- a/sdgym/result_explorer/result_explorer.py +++ b/sdgym/result_explorer/result_explorer.py @@ -167,7 +167,7 @@ def load_real_data(self, dataset_name): """ data, _ = _load_dataset_with_client( modality=self.modality, - dataset=dataset_name, + dataset_name=dataset_name, s3_client=self.s3_client, ) return data diff --git a/tests/unit/result_explorer/test_result_explorer.py b/tests/unit/result_explorer/test_result_explorer.py index 3656b02d..cb0ab0fd 100644 --- a/tests/unit/result_explorer/test_result_explorer.py +++ b/tests/unit/result_explorer/test_result_explorer.py @@ -297,7 +297,7 @@ def test_load_real_data(self, mock_load_dataset, tmp_path): # Assert mock_load_dataset.assert_called_once_with( - modality='single_table', dataset='adult', s3_client=None + modality='single_table', dataset_name='adult', s3_client=None ) pd.testing.assert_frame_equal(real_data, expected_data) @@ -317,7 +317,7 @@ def test_load_real_data_multi_table(self, mock_load_dataset, tmp_path): # Assert mock_load_dataset.assert_called_once_with( - modality='multi_table', dataset='synthea', s3_client=None + modality='multi_table', dataset_name='synthea', s3_client=None ) assert real_data == expected_data finally: diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 5f5fc166..a1ae9775 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -12,6 +12,7 @@ import yaml from sdgym.benchmark import ( + DatasetInfo, JobArgs, _add_adjusted_scores, _check_write_permissions, @@ -20,7 +21,9 @@ _format_output, _generate_job_args_list, _get_metainfo_increment, + _get_s3_client_from_result_writer, _import_and_validate_synthesizers, + _load_from_dataset_info, _resolve_dataset, _setup_output_destination, _setup_output_destination_aws, @@ -34,7 +37,7 @@ benchmark_single_table, benchmark_single_table_aws, ) -from sdgym.result_writer import LocalResultsWriter +from sdgym.result_writer import LocalResultsWriter, S3ResultsWriter from sdgym.s3 import S3_REGION from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer @@ -204,17 +207,11 @@ def test__get_metainfo_increment_local(mock_logger, tmp_path): @patch('sdgym.benchmark._get_dataset_bucket_mapping') -@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark.tqdm.tqdm') -def test_benchmark_single_table_progress_bar( - tqdm_mock, mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping -): +def test_benchmark_single_table_progress_bar(tqdm_mock, mock_get_dataset_bucket_mapping): """Test that the benchmarking function updates the progress bar on one line.""" # Setup - data = pd.DataFrame({'column': [1, 2]}) - metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} - mock_load_sdv_demo_dataset.return_value = data, metadata scores_mock = MagicMock() scores_mock.__iter__.return_value = [ pd.DataFrame({ @@ -240,13 +237,6 @@ def test_benchmark_single_table_progress_bar( None, skip_inaccessible=True, ) - mock_load_sdv_demo_dataset.assert_called_once_with( - modality='single_table', - dataset_name='student_placements', - bucket='bucket', - s3_client=None, - limit_dataset_size=False, - ) tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) @@ -440,6 +430,70 @@ def test__format_output(): pd.testing.assert_frame_equal(scores, expected_scores) +def test__get_s3_client_from_result_writer(): + """Test the `get_s3_client_from_result_writer` method.""" + # Setup + mock_s3_client = Mock() + result_writer = S3ResultsWriter(mock_s3_client) + other_result_writer = LocalResultsWriter() + + # Run + result = _get_s3_client_from_result_writer(result_writer) + other_result = _get_s3_client_from_result_writer(other_result_writer) + + # Assert + assert result == mock_s3_client + assert other_result is None + + +@patch('sdgym.benchmark._load_sdv_demo_dataset') +@patch('sdgym.benchmark._load_dataset_with_client') +@patch('sdgym.benchmark._get_s3_client_from_result_writer') +@pytest.mark.parametrize('source', ['sdv_demo', 'additional']) +def test__load_from_dataset_info( + mock_get_s3_client, mock_load_dataset_with_client, mock_load_sdv_demo_dataset, source +): + """Test the `_load_from_dataset_info` function.""" + # Setup + mock_dataset_info = Mock() + mock_dataset_info.source = source + result_writer = Mock() + expected_result = Mock() + mock_get_s3_client.return_value = 's3_client' + mock_load_sdv_demo_dataset.return_value = expected_result + mock_load_dataset_with_client.return_value = expected_result + dataset_info = DatasetInfo( + name='dataset', + source=source, + dataset_path='dataset_path', + bucket='bucket_name', + limit_dataset_size=True, + ) + + # Run + result = _load_from_dataset_info(dataset_info, 'single_table', result_writer) + + # Assert + assert result == expected_result + if source == 'sdv_demo': + mock_load_sdv_demo_dataset.assert_called_once_with( + modality='single_table', + dataset_name=dataset_info.name, + bucket=dataset_info.bucket, + s3_client='s3_client', + limit_dataset_size=dataset_info.limit_dataset_size, + ) + else: + mock_load_dataset_with_client.assert_called_once_with( + modality='single_table', + dataset_name=dataset_info.name, + bucket=dataset_info.bucket, + datasets_path=dataset_info.dataset_path, + s3_client='s3_client', + limit_dataset_size=dataset_info.limit_dataset_size, + ) + + def test__validate_output_destination(tmp_path): """Test the `_validate_output_destination` function.""" # Setup @@ -584,8 +638,7 @@ def test__write_metainfo_file(mock_datetime, mock_open, mock_safe_load, tmp_path jobs = [ JobArgs( synthesizer={'name': 'GaussianCopulaSynthesizer'}, - data=None, - metadata=None, + dataset_info=DatasetInfo('adult', 'sdv_demo', 'adult', 'bucket', False), metrics=None, timeout=None, compute_quality_score=False, @@ -597,8 +650,7 @@ def test__write_metainfo_file(mock_datetime, mock_open, mock_safe_load, tmp_path ), JobArgs( synthesizer={'name': 'CTGANSynthesizer'}, - data=None, - metadata=None, + dataset_info=DatasetInfo('census', 'sdv_demo', 'census', 'bucket', False), metrics=None, timeout=None, compute_quality_score=False, @@ -1105,30 +1157,20 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) -@patch('sdgym.benchmark._load_dataset_with_client') -@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._get_dataset_bucket_mapping') @patch('sdgym.benchmark.get_dataset_paths') def test__resolve_dataset_loads_sdv_and_additional_datasets( mock_get_dataset_paths, mock_get_dataset_bucket_mapping, - mock_load_sdv_demo_dataset, - mock_load_dataset, tmp_path, ): """Test the `_resolve_dataset` method.""" # Setup additional_folder = tmp_path / 'additional' additional_dataset_path = additional_folder / 'single_table' / 'custom_dataset' - sdv_data = Mock(name='sdv_data') - sdv_metadata = Mock(name='sdv_metadata') - additional_data = Mock(name='additional_data') - additional_metadata = Mock(name='additional_metadata') s3_client = Mock() mock_get_dataset_paths.return_value = [additional_dataset_path] mock_get_dataset_bucket_mapping.return_value = {'sdv_dataset': 'bucket'} - mock_load_sdv_demo_dataset.return_value = sdv_data, sdv_metadata - mock_load_dataset.return_value = additional_data, additional_metadata # Run result = _resolve_dataset( @@ -1151,22 +1193,11 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( s3_client, skip_inaccessible=True, ) - mock_load_sdv_demo_dataset.assert_called_once_with( - modality='single_table', - dataset_name='sdv_dataset', - bucket='bucket', - s3_client=s3_client, - limit_dataset_size=True, - ) - mock_load_dataset.assert_called_once_with( - 'single_table', - additional_dataset_path, - limit_dataset_size=True, - s3_client=s3_client, - ) assert [dataset.name for dataset in result] == ['sdv_dataset', 'custom_dataset'] - assert [dataset.data for dataset in result] == [sdv_data, additional_data] - assert [dataset.metadata for dataset in result] == [sdv_metadata, additional_metadata] + assert result == [ + DatasetInfo('sdv_dataset', 'sdv_demo', None, 'bucket', True), + DatasetInfo('custom_dataset', 'additional', additional_dataset_path, None, True), + ] @patch('sdgym.benchmark._get_dataset_bucket_mapping') @@ -1205,11 +1236,9 @@ def test__resolve_dataset_raises_when_sdv_dataset_is_missing_from_buckets( @pytest.mark.parametrize('modality', ['single_table', 'multi_table']) @patch('sdgym.benchmark._setup_output_destination') -@patch('sdgym.benchmark._load_dataset_with_client') @patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_local_root_additional_folder( get_dataset_paths_mock, - mock_load_dataset, mock__setup_output_destination, tmp_path, modality, @@ -1220,7 +1249,6 @@ def test__generate_job_args_list_local_root_additional_folder( local_root.mkdir() dataset_path = tmp_path / 'my_root' / modality / 'datasetA' get_dataset_paths_mock.return_value = [dataset_path] - mock_load_dataset.return_value = Mock(), Mock() mock__setup_output_destination.return_value = {} # Run @@ -1245,12 +1273,6 @@ def test__generate_job_args_list_local_root_additional_folder( bucket=str(local_root / modality), s3_client=None, ) - mock_load_dataset.assert_called_once_with( - modality, - dataset_path, - limit_dataset_size=False, - s3_client=None, - ) mock__setup_output_destination.assert_called_once_with( None, ['UniformSynthesizer'], @@ -1261,25 +1283,18 @@ def test__generate_job_args_list_local_root_additional_folder( @patch('sdgym.benchmark._get_dataset_bucket_mapping') -@patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._setup_output_destination') @patch('sdgym.benchmark.get_dataset_paths') -def test__generate_job_args_list_loads_each_dataset_once( +def test__generate_job_args_list_stores_dataset_infos( mock_get_dataset_paths, mock__setup_output_destination, - mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping, ): - """Test that each dataset is loaded once even when there are multiple synthesizers.""" + """Test that each job stores dataset infos.""" # Setup mock_get_dataset_paths.return_value = [] mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket-a', 'datasetB': 'bucket-b'} mock__setup_output_destination.return_value = {} - data_a = Mock(name='data_a') - metadata_a = Mock(name='metadata_a') - data_b = Mock(name='data_b') - metadata_b = Mock(name='metadata_b') - mock_load_sdv_demo_dataset.side_effect = [(data_a, metadata_a), (data_b, metadata_b)] synthesizers = [ {'name': 'GaussianCopulaSynthesizer'}, {'name': 'UniformSynthesizer'}, @@ -1317,23 +1332,6 @@ def test__generate_job_args_list_loads_each_dataset_once( modality='single_table', s3_client=s3_client, ) - mock_load_sdv_demo_dataset.assert_has_calls([ - call( - modality='single_table', - dataset_name='datasetA', - bucket='bucket-a', - s3_client=s3_client, - limit_dataset_size=True, - ), - call( - modality='single_table', - dataset_name='datasetB', - bucket='bucket-b', - s3_client=s3_client, - limit_dataset_size=True, - ), - ]) - assert mock_load_sdv_demo_dataset.call_count == 2 assert len(job_args_list) == 4 assert [job.dataset_name for job in job_args_list] == [ 'datasetA', @@ -1341,20 +1339,17 @@ def test__generate_job_args_list_loads_each_dataset_once( 'datasetB', 'datasetB', ] - assert [job.data for job in job_args_list] == [data_a, data_a, data_b, data_b] - assert [job.metadata for job in job_args_list] == [ - metadata_a, - metadata_a, - metadata_b, - metadata_b, + assert [job.dataset_info for job in job_args_list] == [ + DatasetInfo('datasetA', 'sdv_demo', None, 'bucket-a', True), + DatasetInfo('datasetA', 'sdv_demo', None, 'bucket-a', True), + DatasetInfo('datasetB', 'sdv_demo', None, 'bucket-b', True), + DatasetInfo('datasetB', 'sdv_demo', None, 'bucket-b', True), ] @patch('sdgym.benchmark.get_dataset_paths') @patch('sdgym.benchmark._setup_output_destination') -@patch('sdgym.benchmark._load_dataset_with_client') def test__generate_job_args_list_s3_root_additional_folder( - mock_load_dataset, mock__setup_output_destination, get_dataset_paths_mock, ): @@ -1365,13 +1360,9 @@ def test__generate_job_args_list_s3_root_additional_folder( get_dataset_paths_mock.return_value = [dataset_path] s3_client = Mock() mock__setup_output_destination.return_value = {} - mock_load_dataset.return_value = ( - pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}), - {'tables': {}}, - ) # Run - _generate_job_args_list( + job_args_list = _generate_job_args_list( limit_dataset_size=False, sdv_datasets=None, additional_datasets_folder=s3_root, @@ -1399,8 +1390,8 @@ def test__generate_job_args_list_s3_root_additional_folder( modality='single_table', s3_client=s3_client, ) - mock_load_dataset.assert_called_once_with( - 'single_table', dataset_path, limit_dataset_size=False, s3_client=s3_client + assert job_args_list[0].dataset_info == DatasetInfo( + 'datasetA', 'additional', 'datasetA', s3_root, False ) @@ -1714,3 +1705,30 @@ def test_benchmark_multi_table_aws_no_jobs( compute_privacy_score=None, sdmetrics=None, ) + + +@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark._load_from_dataset_info') +def test_benchmark_single_table_error_loading_data( + mock_load_from_dataset_info, mock_get_dataset_bucket_mapping +): + """Test that `benchmark_single_table` handles errors when loading data.""" + # Setup + error = ValueError('Failed to load dataset') + mock_get_dataset_bucket_mapping.return_value = {'census': 'bucket'} + mock_load_from_dataset_info.side_effect = error + expected_result = pd.DataFrame({ + 'Synthesizer': ['UniformSynthesizer'], + 'Dataset': ['census'], + 'Error': [error], + }) + + # Run + result = benchmark_single_table( + synthesizers=['UniformSynthesizer'], + custom_synthesizers=None, + sdv_datasets=['census'], + ) + + # Assert + pd.testing.assert_frame_equal(result[expected_result.columns], expected_result) diff --git a/tests/unit/test_dataset_explorer.py b/tests/unit/test_dataset_explorer.py index 5ebe138e..e8bc40ba 100644 --- a/tests/unit/test_dataset_explorer.py +++ b/tests/unit/test_dataset_explorer.py @@ -269,8 +269,8 @@ def test__load_and_summarize_datasets(self, mock_load_dataset, mock_get_datasets s3_client=None, ) mock_load_dataset.assert_called_once_with( - 'single_table', - dataset='test', + modality='single_table', + dataset_name='test', bucket='sdv-datasets-public', s3_client=None, ) @@ -299,7 +299,7 @@ def test__load_and_summarize_datasets_with_datasets(self, mock_load_dataset, moc # Assert assert mock_load_dataset.call_count == 2 - loaded_names = [call.kwargs['dataset'] for call in mock_load_dataset.call_args_list] + loaded_names = [call.kwargs['dataset_name'] for call in mock_load_dataset.call_args_list] assert set(loaded_names) == {'ds1', 'ds3'} assert [result['Dataset'] for result in results] == ['ds1', 'ds3'] assert len(results) == 2 diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index e38efbb7..914155a7 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -7,7 +7,6 @@ import pytest from sdgym.datasets import ( - DATASETS_PATH, SDV_DATASETS_PRIVATE_BUCKET, SDV_DATASETS_PUBLIC_BUCKET, _download_dataset, @@ -530,38 +529,23 @@ def test__load_sdv_demo_dataset_limits_dataset_size(download_demo_mock, subset_m ) -@patch('sdgym.datasets._get_dataset_path_and_download') -@patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True) -@patch('sdgym.datasets.Path') -def test_get_dataset_paths_local_bucket(path_mock, contains_mock, download_mock): +def test_get_dataset_paths_local_bucket(tmp_path): """Test datasets are discovered locally when bucket path exists.""" - # Setup - def path_side_effect(arg=None): - """Return the mocked bucket path if matching bucket name, else datasets folder.""" - if arg == bucket: - return bucket_path_mock - return Path('datasets_folder') - - path_mock.side_effect = path_side_effect - modality = 'single_table' - bucket = 'local_bucket' - - bucket_path_mock = Mock() - bucket_path_mock.exists.return_value = True - dataset1 = Path('dataset_1') - dataset2 = Path('dataset_2') - bucket_path_mock.iterdir.return_value = [dataset1, dataset2] + bucket = tmp_path / 'local_bucket' + dataset1 = bucket / 'dataset_1' + dataset2 = bucket / 'dataset_2' + for dataset in (dataset1, dataset2): + dataset.mkdir(parents=True) + (dataset / 'metadata.json').touch() + (dataset / 'data.zip').touch() # Run - get_dataset_paths(modality, None, None, bucket) + result = get_dataset_paths(modality, None, None, str(bucket)) # Assert - download_mock.assert_has_calls([ - call(modality, dataset1, DATASETS_PATH / 'single_table', bucket=bucket, s3_client=None), - call(modality, dataset2, DATASETS_PATH / 'single_table', bucket=bucket, s3_client=None), - ]) + assert result == [dataset1, dataset2] @patch('sdgym.datasets.get_s3_client') @@ -587,7 +571,7 @@ def test_load_dataset_mock(mock_load_dataset_with_client, mock_get_s3_client): ) mock_load_dataset_with_client.assert_called_once_with( modality=modality, - dataset=dataset_name, + dataset_name=dataset_name, datasets_path=None, bucket=None, s3_client='s3_client', From f0eedea146c44330f325608c05d8c5beeb2a924e Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Wed, 27 May 2026 15:59:34 +0100 Subject: [PATCH 08/13] fix tests --- pyproject.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 46c6f0a9..83b325d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,10 +64,8 @@ dependencies = [ 'XlsxWriter>=1.2.8', "rdt>=1.18.2;python_version<'3.14'", "rdt>=1.20.0;python_version>='3.14'", - "sdmetrics>=0.21.0;python_version<'3.14'", - "sdmetrics>=0.26.0;python_version>='3.14'", - "sdv>=1.21.0;python_version<'3.14'", - "sdv>=1.33.0;python_version>='3.14'", + "sdmetrics>=0.28.0", + "sdv @ git+https://github.com/sdv-dev/SDV.git@main", ] [project.urls] From c349d704ae83eae493066bcc112fe1e86a6f6cd3 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 28 May 2026 09:01:01 +0100 Subject: [PATCH 09/13] rename _get_dataset_bucket_mapping -> dataset_to_bucket --- sdgym/benchmark.py | 10 +++---- sdgym/datasets.py | 2 +- tests/unit/test_benchmark.py | 54 ++++++++++++++++++------------------ tests/unit/test_datasets.py | 14 +++++----- 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 021b5109..163c3ccf 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -43,9 +43,9 @@ from sdgym.datasets import ( SDV_DATASETS_PRIVATE_BUCKET, SDV_DATASETS_PUBLIC_BUCKET, - _get_dataset_bucket_mapping, _load_dataset_with_client, _load_sdv_demo_dataset, + dataset_to_bucket, get_dataset_paths, ) from sdgym.errors import BenchmarkError, SDGymError @@ -361,15 +361,15 @@ def _resolve_dataset( ) ) - dataset_bucket_mapping = None + dataset_name_to_bucket = None if sdv_dataset_names: - dataset_bucket_mapping = _get_dataset_bucket_mapping( + dataset_name_to_bucket = dataset_to_bucket( modality, [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], s3_client, skip_inaccessible=True, ) - missing_names = [name for name in sdv_dataset_names if name not in dataset_bucket_mapping] + missing_names = [name for name in sdv_dataset_names if name not in dataset_name_to_bucket] if missing_names: missing_to_print = "', '".join(missing_names) raise ValueError( @@ -384,7 +384,7 @@ def _resolve_dataset( name=dataset_name, source='sdv_demo', dataset_path=None, - bucket=dataset_bucket_mapping.get(dataset_name), + bucket=dataset_name_to_bucket.get(dataset_name), limit_dataset_size=limit_dataset_size, ) ) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 7e15c290..e03bff51 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -262,7 +262,7 @@ def _get_available_datasets( return pd.DataFrame(datasets_info) -def _get_dataset_bucket_mapping(modality, buckets, s3_client, skip_inaccessible=False): +def dataset_to_bucket(modality, buckets, s3_client, skip_inaccessible=False): """Map SDV demo dataset names to the bucket they should be loaded from.""" dataset_buckets = {} for bucket in buckets: diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index a1ae9775..7e11536b 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -206,12 +206,12 @@ def test__get_metainfo_increment_local(mock_logger, tmp_path): assert result_3 == 3 -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark.tqdm.tqdm') -def test_benchmark_single_table_progress_bar(tqdm_mock, mock_get_dataset_bucket_mapping): +def test_benchmark_single_table_progress_bar(tqdm_mock, mockdataset_to_bucket): """Test that the benchmarking function updates the progress bar on one line.""" # Setup - mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mockdataset_to_bucket.return_value = {'student_placements': 'bucket'} scores_mock = MagicMock() scores_mock.__iter__.return_value = [ pd.DataFrame({ @@ -231,7 +231,7 @@ def test_benchmark_single_table_progress_bar(tqdm_mock, mock_get_dataset_bucket_ ) # Assert - mock_get_dataset_bucket_mapping.assert_called_once_with( + mockdataset_to_bucket.assert_called_once_with( 'single_table', ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], None, @@ -240,7 +240,7 @@ def test_benchmark_single_table_progress_bar(tqdm_mock, mock_get_dataset_bucket_ tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True) -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark._load_sdv_demo_dataset') @patch('sdgym.benchmark._score') @patch('sdgym.benchmark.multiprocessing') @@ -248,13 +248,13 @@ def test_benchmark_single_table_with_timeout( mock_multiprocessing, mock__score, mock_load_sdv_demo_dataset, - mock_get_dataset_bucket_mapping, + mockdataset_to_bucket, ): """Test that benchmark runs with timeout.""" # Setup data = pd.DataFrame({'column': [1, 2]}) metadata = {'tables': {'student_placements': {'columns': {'column': {'sdtype': 'numerical'}}}}} - mock_get_dataset_bucket_mapping.return_value = {'student_placements': 'bucket'} + mockdataset_to_bucket.return_value = {'student_placements': 'bucket'} mock_load_sdv_demo_dataset.return_value = data, metadata mocked_process = mock_multiprocessing.Process.return_value manager = mock_multiprocessing.Manager.return_value @@ -269,7 +269,7 @@ def test_benchmark_single_table_with_timeout( ) # Assert - mock_get_dataset_bucket_mapping.assert_called_once_with( + mockdataset_to_bucket.assert_called_once_with( 'single_table', ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], None, @@ -1157,11 +1157,11 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark.get_dataset_paths') def test__resolve_dataset_loads_sdv_and_additional_datasets( mock_get_dataset_paths, - mock_get_dataset_bucket_mapping, + mockdataset_to_bucket, tmp_path, ): """Test the `_resolve_dataset` method.""" @@ -1170,7 +1170,7 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( additional_dataset_path = additional_folder / 'single_table' / 'custom_dataset' s3_client = Mock() mock_get_dataset_paths.return_value = [additional_dataset_path] - mock_get_dataset_bucket_mapping.return_value = {'sdv_dataset': 'bucket'} + mockdataset_to_bucket.return_value = {'sdv_dataset': 'bucket'} # Run result = _resolve_dataset( @@ -1187,7 +1187,7 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( bucket=str(additional_folder / 'single_table'), s3_client=s3_client, ) - mock_get_dataset_bucket_mapping.assert_called_once_with( + mockdataset_to_bucket.assert_called_once_with( 'single_table', ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], s3_client, @@ -1200,14 +1200,14 @@ def test__resolve_dataset_loads_sdv_and_additional_datasets( ] -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark._load_sdv_demo_dataset') def test__resolve_dataset_raises_when_sdv_dataset_is_missing_from_buckets( - mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping + mock_load_sdv_demo_dataset, mockdataset_to_bucket ): """Test `_resolve_dataset` raises when an SDV dataset is not found in any bucket.""" # Setup - mock_get_dataset_bucket_mapping.return_value = {'available_dataset': 'bucket'} + mockdataset_to_bucket.return_value = {'available_dataset': 'bucket'} # Run and Assert with pytest.raises( @@ -1225,7 +1225,7 @@ def test__resolve_dataset_raises_when_sdv_dataset_is_missing_from_buckets( s3_client='s3_client', ) - mock_get_dataset_bucket_mapping.assert_called_once_with( + mockdataset_to_bucket.assert_called_once_with( 'single_table', ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], 's3_client', @@ -1282,18 +1282,18 @@ def test__generate_job_args_list_local_root_additional_folder( ) -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark._setup_output_destination') @patch('sdgym.benchmark.get_dataset_paths') def test__generate_job_args_list_stores_dataset_infos( mock_get_dataset_paths, mock__setup_output_destination, - mock_get_dataset_bucket_mapping, + mockdataset_to_bucket, ): """Test that each job stores dataset infos.""" # Setup mock_get_dataset_paths.return_value = [] - mock_get_dataset_bucket_mapping.return_value = {'datasetA': 'bucket-a', 'datasetB': 'bucket-b'} + mockdataset_to_bucket.return_value = {'datasetA': 'bucket-a', 'datasetB': 'bucket-b'} mock__setup_output_destination.return_value = {} synthesizers = [ {'name': 'GaussianCopulaSynthesizer'}, @@ -1319,7 +1319,7 @@ def test__generate_job_args_list_stores_dataset_infos( # Assert mock_get_dataset_paths.assert_not_called() - mock_get_dataset_bucket_mapping.assert_called_once_with( + mockdataset_to_bucket.assert_called_once_with( 'single_table', ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], s3_client, @@ -1395,10 +1395,10 @@ def test__generate_job_args_list_s3_root_additional_folder( ) -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark._load_sdv_demo_dataset') def test_benchmark_single_table_no_warning_uniform_synthesizer( - mock_load_sdv_demo_dataset, mock_get_dataset_bucket_mapping, recwarn + mock_load_sdv_demo_dataset, mockdataset_to_bucket, recwarn ): """Test that no UserWarning is raised when running `UniformSynthesizer`.""" # Setup @@ -1414,7 +1414,7 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer( } } } - mock_get_dataset_bucket_mapping.return_value = {'fake_hotel_guests': 'bucket'} + mockdataset_to_bucket.return_value = {'fake_hotel_guests': 'bucket'} mock_load_sdv_demo_dataset.return_value = data, metadata expected_result = pd.DataFrame({ 'Synthesizer': {0: 'UniformSynthesizer'}, @@ -1437,7 +1437,7 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer( ) # Assert - mock_get_dataset_bucket_mapping.assert_called_once_with( + mockdataset_to_bucket.assert_called_once_with( 'single_table', ['s3://sdv-datasets-public', 's3://sdv-datasets-private'], None, @@ -1707,15 +1707,15 @@ def test_benchmark_multi_table_aws_no_jobs( ) -@patch('sdgym.benchmark._get_dataset_bucket_mapping') +@patch('sdgym.benchmark.dataset_to_bucket') @patch('sdgym.benchmark._load_from_dataset_info') def test_benchmark_single_table_error_loading_data( - mock_load_from_dataset_info, mock_get_dataset_bucket_mapping + mock_load_from_dataset_info, mockdataset_to_bucket ): """Test that `benchmark_single_table` handles errors when loading data.""" # Setup error = ValueError('Failed to load dataset') - mock_get_dataset_bucket_mapping.return_value = {'census': 'bucket'} + mockdataset_to_bucket.return_value = {'census': 'bucket'} mock_load_from_dataset_info.side_effect = error expected_result = pd.DataFrame({ 'Synthesizer': ['UniformSynthesizer'], diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index 914155a7..8bd105a4 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -12,12 +12,12 @@ _download_dataset, _genereate_dataset_info, _get_bucket_name, - _get_dataset_bucket_mapping, _get_dataset_path_and_download, _load_dataset_with_client, _load_sdv_demo_dataset, _path_contains_data_and_metadata, _validate_modality, + dataset_to_bucket, get_data_and_metadata_from_path, get_dataset_paths, load_dataset, @@ -368,7 +368,7 @@ def test_get_bucket_name_local_folder(): @patch('sdgym.datasets._get_available_datasets') -def test__get_dataset_bucket_mapping_prefers_private(get_available_mock): +def test_dataset_to_bucket_prefers_private(get_available_mock): """Test that datasets are mapped to private when duplicated across buckets.""" # Setup get_available_mock.side_effect = [ @@ -377,7 +377,7 @@ def test__get_dataset_bucket_mapping_prefers_private(get_available_mock): ] # Run - result = _get_dataset_bucket_mapping( + result = dataset_to_bucket( 'single_table', [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], s3_client='s3_client', @@ -396,7 +396,7 @@ def test__get_dataset_bucket_mapping_prefers_private(get_available_mock): @patch('sdgym.datasets._get_available_datasets') -def test__get_dataset_bucket_mapping_skips_inaccessible_bucket(get_available_mock): +def test_dataset_to_bucket_skips_inaccessible_bucket(get_available_mock): """Test inaccessible buckets can be skipped while building the mapping.""" # Setup error = botocore.exceptions.ClientError( @@ -409,7 +409,7 @@ def test__get_dataset_bucket_mapping_skips_inaccessible_bucket(get_available_moc ] # Run - result = _get_dataset_bucket_mapping( + result = dataset_to_bucket( 'single_table', [SDV_DATASETS_PUBLIC_BUCKET, SDV_DATASETS_PRIVATE_BUCKET], s3_client='s3_client', @@ -421,7 +421,7 @@ def test__get_dataset_bucket_mapping_skips_inaccessible_bucket(get_available_moc @patch('sdgym.datasets._get_available_datasets') -def test__get_dataset_bucket_mapping_raises_inaccessible_bucket(get_available_mock): +def test_dataset_to_bucket_raises_inaccessible_bucket(get_available_mock): """Test inaccessible buckets raise by default.""" # Setup get_available_mock.side_effect = botocore.exceptions.ClientError( @@ -431,7 +431,7 @@ def test__get_dataset_bucket_mapping_raises_inaccessible_bucket(get_available_mo # Run and Assert with pytest.raises(ValueError, match="Bucket 's3://sdv-datasets-private' is not accessible"): - _get_dataset_bucket_mapping( + dataset_to_bucket( 'single_table', [SDV_DATASETS_PRIVATE_BUCKET], s3_client='s3_client', From 2884e39ff9bb0fa4109623ab88d58746bb70d599 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 28 May 2026 10:11:01 +0100 Subject: [PATCH 10/13] add tests --- sdgym/datasets.py | 7 ++-- tests/unit/test_datasets.py | 69 ++++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/sdgym/datasets.py b/sdgym/datasets.py index e03bff51..ff5d1b39 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -283,10 +283,9 @@ def dataset_to_bucket(modality, buckets, s3_client, skip_inaccessible=False): for dataset_name in available_datasets['dataset_name'].tolist(): existing_bucket = dataset_buckets.get(dataset_name) - if existing_bucket and bucket != SDV_DATASETS_PRIVATE_BUCKET: - continue - - dataset_buckets[dataset_name] = bucket + # If a dataset is available in multiple buckets, prefer the private one. + if existing_bucket is None or bucket == SDV_DATASETS_PRIVATE_BUCKET: + dataset_buckets[dataset_name] = bucket return dataset_buckets diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index 8bd105a4..1eb33fb5 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import botocore import numpy as np @@ -14,6 +14,7 @@ _get_bucket_name, _get_dataset_path_and_download, _load_dataset_with_client, + _load_private_sdv_demo_dataset, _load_sdv_demo_dataset, _path_contains_data_and_metadata, _validate_modality, @@ -438,6 +439,72 @@ def test_dataset_to_bucket_raises_inaccessible_bucket(get_available_mock): ) +@patch('sdgym.datasets.get_s3_client') +@patch('sdgym.datasets._get_metadata') +@patch('sdgym.datasets._load_data_from_zip') +@patch('sdgym.datasets._get_first_v1_metadata_bytes') +@patch('sdgym.datasets._get_data_from_bucket') +@patch('sdgym.datasets._find_data_zip_key') +@patch('sdgym.datasets._list_objects') +def test__load_private_sdv_demo_dataset( + list_objects_mock, + find_data_zip_key_mock, + get_data_from_bucket_mock, + get_first_v1_metadata_bytes_mock, + load_data_from_zip_mock, + get_metadata_mock, + get_s3_client_mock, +): + """Test the `_load_private_sdv_demo_dataset` method.""" + # Setup + modality = 'single_table' + dataset_name = 'demo' + bucket = SDV_DATASETS_PRIVATE_BUCKET + bucket_name = 'sdv-datasets-private' + dataset_prefix = f'{modality}/{dataset_name}/' + data_key = f'{dataset_prefix}data.zip' + contents = [ + {'Key': f'{dataset_prefix}metadata.json'}, + {'Key': data_key}, + ] + raw_data = b'fake zipped data' + metadata_bytes = b'{"meta": "data"}' + table_data = pd.DataFrame({'column': [1, 2, 3]}) + metadata_mock = Mock() + + s3_client_mock = Mock() + list_objects_mock.return_value = contents + find_data_zip_key_mock.return_value = data_key + get_data_from_bucket_mock.return_value = raw_data + get_first_v1_metadata_bytes_mock.return_value = metadata_bytes + load_data_from_zip_mock.return_value = {'table_name': table_data} + metadata_mock.to_dict.return_value = {'meta': 'data'} + get_metadata_mock.return_value = metadata_mock + + # Run + data, metadata = _load_private_sdv_demo_dataset(modality, dataset_name, bucket, s3_client_mock) + + # Assert + get_s3_client_mock.assert_not_called() + list_objects_mock.assert_called_once_with( + dataset_prefix, bucket=bucket_name, client=s3_client_mock + ) + find_data_zip_key_mock.assert_called_once_with(contents, dataset_prefix, bucket_name) + get_data_from_bucket_mock.assert_called_once_with( + data_key, bucket=bucket_name, client=s3_client_mock + ) + get_first_v1_metadata_bytes_mock.assert_called_once_with( + contents, dataset_prefix, bucket=bucket_name, client=s3_client_mock + ) + load_data_from_zip_mock.assert_called_once_with(ANY, bucket_name, dataset_name) + data_bytes = load_data_from_zip_mock.call_args.args[0] + assert data_bytes.getvalue() == raw_data + get_metadata_mock.assert_called_once_with(metadata_bytes, dataset_name) + metadata_mock.to_dict.assert_called_once_with() + pd.testing.assert_frame_equal(data, table_data) + assert metadata == {'meta': 'data'} + + @patch('sdgym.datasets.download_demo') def test__load_sdv_demo_dataset_uses_download_demo(download_demo_mock): """Test SDV demo datasets are loaded through SDV's download_demo.""" From 5555441e23077a4a8472b842bc18b524e2add3fa Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 28 May 2026 10:27:14 +0100 Subject: [PATCH 11/13] stop compressing job_arg_list --- sdgym/_benchmark/benchmark.py | 2 +- sdgym/_benchmark_launcher/utils.py | 4 +--- sdgym/benchmark.py | 15 ++++---------- .../test_benchmark_launcher.py | 20 +++++++++---------- tests/unit/_benchmark_launcher/test_utils.py | 6 +++--- tests/unit/test_benchmark.py | 12 ++++------- 6 files changed, 23 insertions(+), 36 deletions(-) diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 19fb8de5..3e1ea3ea 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -206,7 +206,7 @@ def _get_user_data_script( log "======== Install Dependencies ==========" pip install --upgrade pip {sdv_install} - pip install "sdgym[all]" + pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-604-2-private-bucket" {gpu_block} diff --git a/sdgym/_benchmark_launcher/utils.py b/sdgym/_benchmark_launcher/utils.py index abfd9189..6212cf07 100644 --- a/sdgym/_benchmark_launcher/utils.py +++ b/sdgym/_benchmark_launcher/utils.py @@ -285,9 +285,7 @@ def _build_instance_artifact_filepaths( return ( _build_s3_uri(output_destination, f'{artifact_key_prefix}/{metainfo_name}.yaml'), _build_s3_uri(output_destination, f'{artifact_key_prefix}/{results_name}.csv'), - _build_s3_uri( - output_destination, f'{modality_prefix}/job_args_list_{metainfo_name}.pkl.gz' - ), + _build_s3_uri(output_destination, f'{modality_prefix}/job_args_list_{metainfo_name}.pkl'), ) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 163c3ccf..aedf20a1 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -1,7 +1,6 @@ """Main SDGym benchmarking module.""" import functools -import gzip import logging import math import multiprocessing @@ -1402,12 +1401,11 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client): filename = os.path.basename(job_args_list[0].output_directions['metainfo']) modality = job_args_list[0].modality metainfo = os.path.splitext(filename)[0] - job_args_key = f'{modality}/job_args_list_{metainfo}.pkl.gz' + job_args_key = f'{modality}/job_args_list_{metainfo}.pkl' job_args_key = f'{path}{job_args_key}' if path else job_args_key serialized_data = cloudpickle.dumps(job_args_list) - compressed = gzip.compress(serialized_data, compresslevel=1) - s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=compressed) + s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data) return bucket_name, job_args_key @@ -1418,7 +1416,6 @@ def _get_s3_script_content( return f""" import boto3 import cloudpickle -import gzip from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file from sdgym.result_writer import S3ResultsWriter @@ -1429,11 +1426,7 @@ def _get_s3_script_content( region_name='{region_name}' ) response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}') -blob = response['Body'].read() -if blob[:2] == b'\\x1f\\x8b': - blob = gzip.decompress(blob) - -job_args_list = cloudpickle.loads(blob) +job_args_list = cloudpickle.loads(response['Body'].read()) modality = job_args_list[0].modality result_writer = S3ResultsWriter(s3_client=s3_client) _write_metainfo_file({synthesizers}, job_args_list, modality, result_writer=result_writer) @@ -1471,7 +1464,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - pip install sdgym[all] + pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-604-2-private-bucket" pip install s3fs echo "======== Write Script ===========" diff --git a/tests/unit/_benchmark_launcher/test_benchmark_launcher.py b/tests/unit/_benchmark_launcher/test_benchmark_launcher.py index 8d11ce33..bb527d04 100644 --- a/tests/unit/_benchmark_launcher/test_benchmark_launcher.py +++ b/tests/unit/_benchmark_launcher/test_benchmark_launcher.py @@ -196,7 +196,7 @@ def test_build_instance_artifacts( mock_build_instance_artifact_filepaths.return_value = ( 's3://bucket/prefix/metainfo(1).yaml', 's3://bucket/prefix/results(1).csv', - 's3://bucket/prefix_job/job_args_list_metainfo(1).pkl.gz', + 's3://bucket/prefix_job/job_args_list_metainfo(1).pkl', ) # Run @@ -251,7 +251,7 @@ def test_build_instance_artifacts( 'output_destination': 's3://bucket/path', 'metainfo_filepath': 's3://bucket/prefix/metainfo(1).yaml', 'result_filepath': 's3://bucket/prefix/results(1).csv', - 'job_arg_filepath': 's3://bucket/prefix_job/job_args_list_metainfo(1).pkl.gz', + 'job_arg_filepath': 's3://bucket/prefix_job/job_args_list_metainfo(1).pkl', 'jobs': [ { 'dataset': 'dataset1', @@ -417,12 +417,12 @@ def test_launch_internal_calls_method_for_each_job( ( 's3://bucket/artifact-prefix/metainfo.yaml', 's3://bucket/artifact-prefix/results.csv', - 's3://bucket/modality_prefix/job_args_list_metainfo.pkl.gz', + 's3://bucket/modality_prefix/job_args_list_metainfo.pkl', ), ( 's3://bucket/artifact-prefix/metainfo(1).yaml', 's3://bucket/artifact-prefix/results(1).csv', - 's3://bucket/modality_prefix/job_args_list_metainfo(1).pkl.gz', + 's3://bucket/modality_prefix/job_args_list_metainfo(1).pkl', ), ] mock_resolve_compute.return_value = { @@ -470,7 +470,7 @@ def test_launch_internal_calls_method_for_each_job( 'output_destination': output_destination, 'metainfo_filepath': 's3://bucket/artifact-prefix/metainfo.yaml', 'result_filepath': 's3://bucket/artifact-prefix/results.csv', - 'job_arg_filepath': 's3://bucket/modality_prefix/job_args_list_metainfo.pkl.gz', + 'job_arg_filepath': 's3://bucket/modality_prefix/job_args_list_metainfo.pkl', 'jobs': [ { 'dataset': 'd1', @@ -497,7 +497,7 @@ def test_launch_internal_calls_method_for_each_job( 'output_destination': output_destination, 'metainfo_filepath': 's3://bucket/artifact-prefix/metainfo(1).yaml', 'result_filepath': 's3://bucket/artifact-prefix/results(1).csv', - 'job_arg_filepath': 's3://bucket/modality_prefix/job_args_list_metainfo(1).pkl.gz', + 'job_arg_filepath': 's3://bucket/modality_prefix/job_args_list_metainfo(1).pkl', 'jobs': [ { 'dataset': 'd2', @@ -1333,13 +1333,13 @@ def test_finalize(self): 'instance-1': { 'jobs': [{'dataset': 'adult', 'synthesizer': 'CTGAN'}], 'result_filepath': 's3://bucket/path/prefix/results.csv', - 'job_arg_filepath': 's3://bucket/path/single_table/job_args_list_metainfo.pkl.gz', + 'job_arg_filepath': 's3://bucket/path/single_table/job_args_list_metainfo.pkl', }, 'instance-2': { 'jobs': [], 'result_filepath': 's3://bucket/other-path/prefix/results(1).csv', 'job_arg_filepath': ( - 's3://bucket/other-path/single_table/job_args_list_metainfo(1).pkl.gz' + 's3://bucket/other-path/single_table/job_args_list_metainfo(1).pkl' ), }, } @@ -1360,8 +1360,8 @@ def test_finalize(self): call(result=result_df, filepath='s3://bucket/other-path/prefix/results(1).csv'), ] assert launcher._storage_manager.delete.call_args_list == [ - call('s3://bucket/path/single_table/job_args_list_metainfo.pkl.gz'), - call('s3://bucket/other-path/single_table/job_args_list_metainfo(1).pkl.gz'), + call('s3://bucket/path/single_table/job_args_list_metainfo.pkl'), + call('s3://bucket/other-path/single_table/job_args_list_metainfo(1).pkl'), ] assert launcher._update_instance_metainfo.call_args_list == [ call('instance-1'), diff --git a/tests/unit/_benchmark_launcher/test_utils.py b/tests/unit/_benchmark_launcher/test_utils.py index d2626196..29305e74 100644 --- a/tests/unit/_benchmark_launcher/test_utils.py +++ b/tests/unit/_benchmark_launcher/test_utils.py @@ -829,7 +829,7 @@ def test__build_instance_artifact_filepaths(mock_build_s3_uri): mock_build_s3_uri.side_effect = [ 's3://bucket/prefix/metainfo.yaml', 's3://bucket/prefix/results.csv', - 's3://bucket/modality/job_args_list_metainfo.pkl.gz', + 's3://bucket/modality/job_args_list_metainfo.pkl', ] # Run @@ -845,12 +845,12 @@ def test__build_instance_artifact_filepaths(mock_build_s3_uri): assert mock_build_s3_uri.call_args_list == [ call('s3://bucket/root', 'prefix/metainfo.yaml'), call('s3://bucket/root', 'prefix/results.csv'), - call('s3://bucket/root', 'modality/job_args_list_metainfo.pkl.gz'), + call('s3://bucket/root', 'modality/job_args_list_metainfo.pkl'), ] assert result == ( 's3://bucket/prefix/metainfo.yaml', 's3://bucket/prefix/results.csv', - 's3://bucket/modality/job_args_list_metainfo.pkl.gz', + 's3://bucket/modality/job_args_list_metainfo.pkl', ) diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 7e11536b..63f2ca57 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -848,9 +848,8 @@ def test_validate_aws_inputs_permission_error(mock_check_write_permissions, mock _validate_aws_inputs(valid_url, None, None) -@patch('sdgym.benchmark.gzip.compress') @patch('sdgym.benchmark.cloudpickle.dumps') -def test_store_job_args_in_s3_stores_compressed_job_args(mock_dumps, mock_compress): +def test_store_job_args_in_s3_stores_job_args(mock_dumps): """Test the `_store_job_args_in_s3` method.""" # Setup output_destination = 's3://my-bucket/some/path/' @@ -863,9 +862,7 @@ def test_store_job_args_in_s3_stores_compressed_job_args(mock_dumps, mock_compre job_args_list = [job_args] serialized = b'serialized-bytes' - compressed = b'compressed-bytes' mock_dumps.return_value = serialized - mock_compress.return_value = compressed # Run bucket_name, job_args_key = _store_job_args_in_s3( @@ -874,15 +871,14 @@ def test_store_job_args_in_s3_stores_compressed_job_args(mock_dumps, mock_compre # Assert mock_dumps.assert_called_once_with(job_args_list) - mock_compress.assert_called_once_with(serialized, compresslevel=1) s3_client_mock.put_object.assert_called_once_with( Bucket='my-bucket', - Key='some/path/single_table/job_args_list_metainfo.pkl.gz', - Body=compressed, + Key='some/path/single_table/job_args_list_metainfo.pkl', + Body=serialized, ) assert bucket_name == 'my-bucket' - assert job_args_key == 'some/path/single_table/job_args_list_metainfo.pkl.gz' + assert job_args_key == 'some/path/single_table/job_args_list_metainfo.pkl' @patch('sdgym.benchmark._import_and_validate_synthesizers') From e3c2e90bc8de4faabe76bc8df7cfa5b910c14b47 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 28 May 2026 10:47:45 +0100 Subject: [PATCH 12/13] undo pip install command --- sdgym/_benchmark/benchmark.py | 2 +- sdgym/benchmark.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdgym/_benchmark/benchmark.py b/sdgym/_benchmark/benchmark.py index 3e1ea3ea..19fb8de5 100644 --- a/sdgym/_benchmark/benchmark.py +++ b/sdgym/_benchmark/benchmark.py @@ -206,7 +206,7 @@ def _get_user_data_script( log "======== Install Dependencies ==========" pip install --upgrade pip {sdv_install} - pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-604-2-private-bucket" + pip install "sdgym[all]" {gpu_block} diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index aedf20a1..ad32a4ce 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -1464,7 +1464,7 @@ def _get_user_data_script(access_key, secret_key, region_name, script_content): echo "======== Install Dependencies in venv ============" pip install --upgrade pip - pip install "sdgym[all] @ git+https://github.com/sdv-dev/SDGym.git@issue-604-2-private-bucket" + pip install sdgym[all] pip install s3fs echo "======== Write Script ===========" From b97fa0d89ed2c7db69789c5d315941fa4d7ead62 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 28 May 2026 15:07:38 +0100 Subject: [PATCH 13/13] remove dataset_name from JobArgs --- sdgym/benchmark.py | 10 ++++------ tests/unit/test_benchmark.py | 6 ++---- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index ad32a4ce..49c4e569 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -127,13 +127,12 @@ class JobArgs(NamedTuple): """Arguments needed to run a single synthesizer + dataset benchmark job.""" synthesizer: dict - dataset_info: Any + dataset_info: DatasetInfo metrics: Any timeout: Optional[int] compute_quality_score: bool compute_diagnostic_score: bool compute_privacy_score: bool - dataset_name: str modality: str output_directions: Optional[dict] @@ -459,7 +458,6 @@ def _generate_job_args_list( compute_quality_score=compute_quality_score, compute_diagnostic_score=compute_diagnostic_score, compute_privacy_score=compute_privacy_score, - dataset_name=dataset_info.name, modality=modality, output_directions=path, ) @@ -928,10 +926,10 @@ def _run_job(job_args, result_writer=None): compute_quality_score = job_args.compute_quality_score compute_diagnostic_score = job_args.compute_diagnostic_score compute_privacy_score = job_args.compute_privacy_score - dataset_name = job_args.dataset_name modality = job_args.modality synthesizer_path = job_args.output_directions dataset_info = job_args.dataset_info + dataset_name = dataset_info.name name = synthesizer['name'] LOGGER.info( @@ -1086,7 +1084,7 @@ def _validate_output_destination(output_destination, aws_keys=None): def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=None): - jobs = [[job.dataset_name, job.synthesizer['name']] for job in job_args_list] + jobs = [[job.dataset_info.name, job.synthesizer['name']] for job in job_args_list] if not job_args_list or not job_args_list[0].output_directions: return @@ -1756,7 +1754,7 @@ def benchmark_multi_table( ) if output_destination and job_args_list: - metainfo_filename = job_args_list[0][-1]['metainfo'] + metainfo_filename = job_args_list[0].output_directions['metainfo'] _update_metainfo_file(metainfo_filename, result_writer) return scores diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 63f2ca57..e1c363fd 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -644,7 +644,6 @@ def test__write_metainfo_file(mock_datetime, mock_open, mock_safe_load, tmp_path compute_quality_score=False, compute_diagnostic_score=False, compute_privacy_score=False, - dataset_name='adult', modality='single_table', output_directions=file_name, ), @@ -656,7 +655,6 @@ def test__write_metainfo_file(mock_datetime, mock_open, mock_safe_load, tmp_path compute_quality_score=False, compute_diagnostic_score=False, compute_privacy_score=False, - dataset_name='census', modality='single_table', output_directions=None, ), @@ -1329,7 +1327,7 @@ def test__generate_job_args_list_stores_dataset_infos( s3_client=s3_client, ) assert len(job_args_list) == 4 - assert [job.dataset_name for job in job_args_list] == [ + assert [job.dataset_info.name for job in job_args_list] == [ 'datasetA', 'datasetA', 'datasetB', @@ -1471,7 +1469,7 @@ def test_benchmark_multi_table_with_jobs( # Setup fake_scores = pd.DataFrame({'a': [1]}) mock__run_jobs.return_value = fake_scores - job_args = ('arg1', 'arg2', {'metainfo': 'meta.yaml'}) + job_args = Mock(output_directions={'metainfo': 'meta.yaml'}) mock__generate_job_args_list.return_value = [job_args] expected_valid_synthesizers = ['HMASynthesizer', 'MultiTableUniformSynthesizer', 'CustomSynth']