From 4eacf7112ba2aab88ee3418633bb397d097a548c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Mar 2026 21:15:58 +0100 Subject: [PATCH] Add entity-level HDFStore output format alongside h5py The stacked_dataset_builder now produces a Pandas HDFStore file (.hdfstore.h5) in addition to the existing h5py file. The HDFStore contains one table per entity (person, household, tax_unit, spm_unit, family, marital_unit) plus an embedded _variable_metadata manifest recording each variable's entity and uprating parameter path. The upload pipeline uploads HDFStore files to dedicated subdirectories (states_hdfstore/, districts_hdfstore/, cities_hdfstore/). A comparison test (test_format_comparison.py) validates that both formats contain identical data for all variables. Co-Authored-By: Claude Opus 4.6 --- .../publish_local_area.py | 32 ++ .../stacked_dataset_builder.py | 164 ++++++++++ .../tests/test_format_comparison.py | 285 ++++++++++++++++++ 3 files changed, 481 insertions(+) create mode 100644 policyengine_us_data/tests/test_format_comparison.py diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py index 4963f397..42bfd1b7 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py @@ -280,8 +280,18 @@ def build_and_upload_states( print(f"Uploading {state_code}.h5 to GCP...") upload_local_area_file(str(output_path), "states", skip_hf=True) + # Upload HDFStore file if it exists + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + if os.path.exists(hdfstore_path): + print(f"Uploading {state_code}.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "states_hdfstore", skip_hf=True + ) + # Queue for batched HuggingFace upload hf_queue.append((str(output_path), "states")) + if os.path.exists(hdfstore_path): + hf_queue.append((hdfstore_path, "states_hdfstore")) record_completed_state(state_code) print(f"Completed {state_code}") @@ -352,8 +362,18 @@ def build_and_upload_districts( print(f"Uploading {friendly_name}.h5 to GCP...") upload_local_area_file(str(output_path), "districts", skip_hf=True) + # Upload HDFStore file if it exists + hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5") + if os.path.exists(hdfstore_path): + print(f"Uploading {friendly_name}.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "districts_hdfstore", skip_hf=True + ) + # Queue for batched HuggingFace upload hf_queue.append((str(output_path), "districts")) + if os.path.exists(hdfstore_path): + hf_queue.append((hdfstore_path, "districts_hdfstore")) record_completed_district(friendly_name) print(f"Completed {friendly_name}") @@ -424,8 +444,20 @@ def build_and_upload_cities( str(output_path), "cities", skip_hf=True ) + # Upload HDFStore file if it exists + hdfstore_path = str(output_path).replace( + ".h5", ".hdfstore.h5" + ) + if os.path.exists(hdfstore_path): + print("Uploading NYC.hdfstore.h5 to GCP...") + upload_local_area_file( + hdfstore_path, "cities_hdfstore", skip_hf=True + ) + # Queue for batched HuggingFace upload hf_queue.append((str(output_path), "cities")) + if os.path.exists(hdfstore_path): + hf_queue.append((hdfstore_path, "cities_hdfstore")) record_completed_city("NYC") print("Completed NYC") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py index 010e151f..c4d449ce 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py @@ -59,6 +59,156 @@ def get_county_name(county_index: int) -> str: return County._member_names_[county_index] +def _split_into_entity_dfs(combined_df, system, vars_to_save, time_period): + """Split person-level DataFrame into entity-level DataFrames. + + The combined_df has columns named ``variable__period`` (e.g. + ``employment_income__2024``). This function strips the period suffix, + classifies each variable by entity, and returns one DataFrame per + entity with clean column names. + + For group entities the rows are deduplicated by entity ID so that each + entity appears exactly once. + """ + + suffix = f"__{time_period}" + + # Build a mapping from clean variable name -> column in combined_df + col_map = {} + for col in combined_df.columns: + if col.endswith(suffix): + clean = col[: -len(suffix)] + col_map[clean] = col + + # Entity classification buckets + ENTITIES = [ + "person", + "household", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + ] + entity_cols = {e: [] for e in ENTITIES} + + # Person-level entity membership ID columns (person_household_id, etc.) + person_ref_cols = [] + + for var in sorted(vars_to_save): + if var not in col_map: + continue + if var in system.variables: + entity_key = system.variables[var].entity.key + entity_cols[entity_key].append(var) + else: + # Geography/custom vars without system entry go to household + entity_cols["household"].append(var) + + # --- Person DataFrame --- + person_vars = ["person_id"] + entity_cols["person"] + # Add person-level entity membership columns + for entity in ENTITIES[1:]: # skip person + ref_col = f"person_{entity}_id" + if ref_col in col_map: + person_vars.append(ref_col) + person_ref_cols.append(ref_col) + + person_src_cols = [col_map[v] for v in person_vars if v in col_map] + person_df = combined_df[person_src_cols].copy() + person_df.columns = [ + c[: -len(suffix)] if c.endswith(suffix) else c + for c in person_df.columns + ] + + entity_dfs = {"person": person_df} + + # --- Group entity DataFrames: deduplicate by entity ID --- + for entity in ENTITIES[1:]: + id_col = f"{entity}_id" + person_ref = f"person_{entity}_id" + # Use person_ref column if available, else id_col + src_id = person_ref if person_ref in col_map else id_col + + if src_id not in col_map: + continue + + # Collect columns for this entity + cols_to_use = [src_id] + [ + v for v in entity_cols[entity] if v != id_col and v in col_map + ] + src_cols = [col_map[v] for v in cols_to_use] + df = combined_df[src_cols].copy() + # Strip period suffix + df.columns = [ + c[: -len(suffix)] if c.endswith(suffix) else c + for c in df.columns + ] + # Rename person_X_id -> X_id if needed + if src_id == person_ref and person_ref != id_col: + df = df.rename(columns={person_ref: id_col}) + # Deduplicate + df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True) + entity_dfs[entity] = df + + return entity_dfs + + +def _build_uprating_manifest(vars_to_save, system): + """Build manifest of variable metadata for embedding in HDFStore.""" + records = [] + for var in sorted(vars_to_save): + entity = ( + system.variables[var].entity.key + if var in system.variables + else "unknown" + ) + uprating = "" + if var in system.variables: + uprating = getattr(system.variables[var], "uprating", None) or "" + records.append( + {"variable": var, "entity": entity, "uprating": uprating} + ) + return pd.DataFrame(records) + + +def _save_hdfstore(entity_dfs, manifest_df, output_path, time_period): + """Save entity DataFrames and manifest to a Pandas HDFStore file.""" + import warnings + + hdfstore_path = output_path.replace(".h5", ".hdfstore.h5") + + print(f"\nSaving HDFStore to {hdfstore_path}...") + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + message=".*PyTables will pickle object types.*", + ) + with pd.HDFStore(hdfstore_path, mode="w") as store: + for entity_name, df in entity_dfs.items(): + # Convert object columns to string for HDFStore compatibility + for col in df.columns: + if df[col].dtype == object: + df[col] = df[col].astype(str) + store.put(entity_name, df, format="table") + + store.put("_variable_metadata", manifest_df, format="table") + store.put( + "_time_period", + pd.Series([time_period]), + format="table", + ) + + # Print summary + for entity_name, df in entity_dfs.items(): + print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols") + print(f" manifest: {len(manifest_df)} variables") + + print(f"HDFStore saved successfully!") + return hdfstore_path + + def create_sparse_cd_stacked_dataset( w, cds_to_calibrate, @@ -738,6 +888,20 @@ def create_sparse_cd_stacked_dataset( f" Average persons per household: {np.sum(person_weights) / np.sum(weights):.2f}" ) + # --- HDFStore output (entity-level format) --- + # Split the person-level combined_df into per-entity DataFrames and save + # alongside the h5py file. This format is consumed by the API v2 alpha + # and by policyengine-us's extend_single_year_dataset(). + entity_dfs = _split_into_entity_dfs( + combined_df, base_sim.tax_benefit_system, vars_to_save, time_period + ) + manifest_df = _build_uprating_manifest( + vars_to_save, base_sim.tax_benefit_system + ) + hdfstore_path = _save_hdfstore( + entity_dfs, manifest_df, output_path, time_period + ) + return output_path diff --git a/policyengine_us_data/tests/test_format_comparison.py b/policyengine_us_data/tests/test_format_comparison.py new file mode 100644 index 00000000..4741d41e --- /dev/null +++ b/policyengine_us_data/tests/test_format_comparison.py @@ -0,0 +1,285 @@ +""" +Compare h5py (variable-centric) and HDFStore (entity-level) output formats. + +Verifies that both formats produced by stacked_dataset_builder contain +identical data for all variables. + +Usage as pytest: + pytest test_format_comparison.py --h5py-path path/to/STATE.h5 \ + --hdfstore-path path/to/STATE.hdfstore.h5 + +Usage as standalone script: + python -m policyengine_us_data.tests.test_format_comparison \ + --h5py-path path/to/STATE.h5 \ + --hdfstore-path path/to/STATE.hdfstore.h5 +""" + +import argparse +import sys + +import h5py +import numpy as np +import pandas as pd +import pytest + + +def compare_formats(h5py_path: str, hdfstore_path: str) -> dict: + """Compare all variables between h5py and HDFStore formats. + + Returns a dict with keys: passed, failed, skipped, details. + """ + passed = [] + failed = [] + skipped = [] + + with h5py.File(h5py_path, "r") as f: + h5_vars = sorted(f.keys()) + # Get the year from the first variable's subkeys + first_var = h5_vars[0] + year = list(f[first_var].keys())[0] + + with pd.HDFStore(hdfstore_path, "r") as store: + # Load all entity DataFrames + store_keys = [k for k in store.keys() if not k.startswith("/_")] + entity_dfs = {k: store[k] for k in store_keys} + + # Load manifest + manifest = None + if "/_variable_metadata" in store.keys(): + manifest = store["/_variable_metadata"] + + for var in h5_vars: + h5_values = f[var][year][:] + + # Find which entity DataFrame contains this variable + found = False + for entity_key, df in entity_dfs.items(): + entity_name = entity_key.lstrip("/") + if var in df.columns: + hdf_values = df[var].values + + # For person-level variables, arrays should be + # same length and directly comparable (both are + # ordered by row index from combined_df). + # For group entities, the h5py array is at person + # level while HDFStore is deduplicated. We need + # to handle this difference. + if entity_name != "person" and len(hdf_values) != len( + h5_values + ): + # h5py stores at person level; HDFStore is + # deduplicated by entity ID. We can't do a + # direct comparison — verify unique values match. + h5_unique = np.unique(h5_values) + hdf_unique = np.unique(hdf_values) + if h5_values.dtype.kind in ("U", "S", "O"): + match = set(h5_unique) == set(hdf_unique) + else: + match = np.allclose( + np.sort(h5_unique.astype(float)), + np.sort(hdf_unique.astype(float)), + rtol=1e-5, + equal_nan=True, + ) + if match: + passed.append(var) + else: + failed.append( + ( + var, + f"unique values differ " + f"(h5py: {len(h5_unique)}, " + f"hdfstore: {len(hdf_unique)})", + ) + ) + else: + # Same length — direct comparison + if h5_values.dtype.kind in ("U", "S", "O"): + # String comparison + h5_str = np.array( + [ + ( + x.decode() + if isinstance(x, bytes) + else str(x) + ) + for x in h5_values + ] + ) + hdf_str = np.array( + [str(x) for x in hdf_values] + ) + if np.array_equal(h5_str, hdf_str): + passed.append(var) + else: + mismatches = np.sum(h5_str != hdf_str) + failed.append( + ( + var, + f"{mismatches} string mismatches", + ) + ) + else: + # Numeric comparison + h5_float = h5_values.astype(float) + hdf_float = hdf_values.astype(float) + if np.allclose( + h5_float, + hdf_float, + rtol=1e-5, + equal_nan=True, + ): + passed.append(var) + else: + diff = np.abs(h5_float - hdf_float) + max_diff = np.max(diff) + n_diff = np.sum( + ~np.isclose( + h5_float, + hdf_float, + rtol=1e-5, + equal_nan=True, + ) + ) + failed.append( + ( + var, + f"{n_diff} values differ, " + f"max diff={max_diff:.6f}", + ) + ) + found = True + break + + if not found: + skipped.append(var) + + return { + "passed": passed, + "failed": failed, + "skipped": skipped, + "total_h5py_vars": len(h5_vars), + } + + +def pytest_addoption(parser): + parser.addoption("--h5py-path", action="store", default=None) + parser.addoption("--hdfstore-path", action="store", default=None) + + +@pytest.fixture +def h5py_path(request): + path = request.config.getoption("--h5py-path") + if path is None: + pytest.skip("--h5py-path not provided") + return path + + +@pytest.fixture +def hdfstore_path(request): + path = request.config.getoption("--hdfstore-path") + if path is None: + pytest.skip("--hdfstore-path not provided") + return path + + +def test_formats_match(h5py_path, hdfstore_path): + """Verify h5py and HDFStore formats contain identical data.""" + result = compare_formats(h5py_path, hdfstore_path) + + print(f"\n{'='*60}") + print(f"Format Comparison Results") + print(f"{'='*60}") + print(f"Total h5py variables: {result['total_h5py_vars']}") + print(f"Passed: {len(result['passed'])}") + print(f"Failed: {len(result['failed'])}") + print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + + if result["failed"]: + print(f"\nFailed variables:") + for var, reason in result["failed"]: + print(f" {var}: {reason}") + + if result["skipped"]: + print(f"\nSkipped variables (not found in HDFStore):") + for var in result["skipped"]: + print(f" {var}") + + assert len(result["failed"]) == 0, ( + f"{len(result['failed'])} variables have mismatched values" + ) + assert len(result["skipped"]) == 0, ( + f"{len(result['skipped'])} variables missing from HDFStore" + ) + + +def test_manifest_present(hdfstore_path): + """Verify the HDFStore contains a variable metadata manifest.""" + with pd.HDFStore(hdfstore_path, "r") as store: + assert "/_variable_metadata" in store.keys(), ( + "Missing _variable_metadata table" + ) + manifest = store["/_variable_metadata"] + assert "variable" in manifest.columns + assert "entity" in manifest.columns + assert "uprating" in manifest.columns + assert len(manifest) > 0, "Manifest is empty" + print(f"\nManifest has {len(manifest)} variables") + print(f"Entities: {manifest['entity'].unique().tolist()}") + n_uprated = (manifest["uprating"] != "").sum() + print(f"Variables with uprating: {n_uprated}") + + +def test_all_entities_present(hdfstore_path): + """Verify the HDFStore contains all expected entity tables.""" + expected = {"person", "household", "tax_unit", "spm_unit", "family", "marital_unit"} + with pd.HDFStore(hdfstore_path, "r") as store: + actual = {k.lstrip("/") for k in store.keys() if not k.startswith("/_")} + missing = expected - actual + assert not missing, f"Missing entity tables: {missing}" + for entity in expected: + df = store[f"/{entity}"] + assert len(df) > 0, f"Entity {entity} has 0 rows" + assert f"{entity}_id" in df.columns, ( + f"Entity {entity} missing {entity}_id column" + ) + print(f" {entity}: {len(df):,} rows, {len(df.columns)} cols") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Compare h5py and HDFStore dataset formats" + ) + parser.add_argument( + "--h5py-path", required=True, help="Path to h5py format file" + ) + parser.add_argument( + "--hdfstore-path", required=True, help="Path to HDFStore format file" + ) + args = parser.parse_args() + + result = compare_formats(args.h5py_path, args.hdfstore_path) + + print(f"\n{'='*60}") + print(f"Format Comparison Results") + print(f"{'='*60}") + print(f"Total h5py variables: {result['total_h5py_vars']}") + print(f"Passed: {len(result['passed'])}") + print(f"Failed: {len(result['failed'])}") + print(f"Skipped (not in HDFStore): {len(result['skipped'])}") + + if result["failed"]: + print(f"\nFailed variables:") + for var, reason in result["failed"]: + print(f" {var}: {reason}") + + if result["skipped"]: + print(f"\nSkipped variables (not found in HDFStore):") + for var in result["skipped"]: + print(f" {var}") + + if result["failed"] or result["skipped"]: + sys.exit(1) + else: + print("\nAll variables match!") + sys.exit(0)