Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,18 @@ def build_and_upload_states(
print(f"Uploading {state_code}.h5 to GCP...")
upload_local_area_file(str(output_path), "states", skip_hf=True)

# Upload HDFStore file if it exists
hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5")
if os.path.exists(hdfstore_path):
print(f"Uploading {state_code}.hdfstore.h5 to GCP...")
upload_local_area_file(
hdfstore_path, "states_hdfstore", skip_hf=True
)

# Queue for batched HuggingFace upload
hf_queue.append((str(output_path), "states"))
if os.path.exists(hdfstore_path):
hf_queue.append((hdfstore_path, "states_hdfstore"))

record_completed_state(state_code)
print(f"Completed {state_code}")
Expand Down Expand Up @@ -352,8 +362,18 @@ def build_and_upload_districts(
print(f"Uploading {friendly_name}.h5 to GCP...")
upload_local_area_file(str(output_path), "districts", skip_hf=True)

# Upload HDFStore file if it exists
hdfstore_path = str(output_path).replace(".h5", ".hdfstore.h5")
if os.path.exists(hdfstore_path):
print(f"Uploading {friendly_name}.hdfstore.h5 to GCP...")
upload_local_area_file(
hdfstore_path, "districts_hdfstore", skip_hf=True
)

# Queue for batched HuggingFace upload
hf_queue.append((str(output_path), "districts"))
if os.path.exists(hdfstore_path):
hf_queue.append((hdfstore_path, "districts_hdfstore"))

record_completed_district(friendly_name)
print(f"Completed {friendly_name}")
Expand Down Expand Up @@ -424,8 +444,20 @@ def build_and_upload_cities(
str(output_path), "cities", skip_hf=True
)

# Upload HDFStore file if it exists
hdfstore_path = str(output_path).replace(
".h5", ".hdfstore.h5"
)
if os.path.exists(hdfstore_path):
print("Uploading NYC.hdfstore.h5 to GCP...")
upload_local_area_file(
hdfstore_path, "cities_hdfstore", skip_hf=True
)

# Queue for batched HuggingFace upload
hf_queue.append((str(output_path), "cities"))
if os.path.exists(hdfstore_path):
hf_queue.append((hdfstore_path, "cities_hdfstore"))

record_completed_city("NYC")
print("Completed NYC")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,156 @@ def get_county_name(county_index: int) -> str:
return County._member_names_[county_index]


def _split_into_entity_dfs(combined_df, system, vars_to_save, time_period):
"""Split person-level DataFrame into entity-level DataFrames.

The combined_df has columns named ``variable__period`` (e.g.
``employment_income__2024``). This function strips the period suffix,
classifies each variable by entity, and returns one DataFrame per
entity with clean column names.

For group entities the rows are deduplicated by entity ID so that each
entity appears exactly once.
"""

suffix = f"__{time_period}"

# Build a mapping from clean variable name -> column in combined_df
col_map = {}
for col in combined_df.columns:
if col.endswith(suffix):
clean = col[: -len(suffix)]
col_map[clean] = col

# Entity classification buckets
ENTITIES = [
"person",
"household",
"tax_unit",
"spm_unit",
"family",
"marital_unit",
]
entity_cols = {e: [] for e in ENTITIES}

# Person-level entity membership ID columns (person_household_id, etc.)
person_ref_cols = []

for var in sorted(vars_to_save):
if var not in col_map:
continue
if var in system.variables:
entity_key = system.variables[var].entity.key
entity_cols[entity_key].append(var)
else:
# Geography/custom vars without system entry go to household
entity_cols["household"].append(var)

# --- Person DataFrame ---
person_vars = ["person_id"] + entity_cols["person"]
# Add person-level entity membership columns
for entity in ENTITIES[1:]: # skip person
ref_col = f"person_{entity}_id"
if ref_col in col_map:
person_vars.append(ref_col)
person_ref_cols.append(ref_col)

person_src_cols = [col_map[v] for v in person_vars if v in col_map]
person_df = combined_df[person_src_cols].copy()
person_df.columns = [
c[: -len(suffix)] if c.endswith(suffix) else c
for c in person_df.columns
]

entity_dfs = {"person": person_df}

# --- Group entity DataFrames: deduplicate by entity ID ---
for entity in ENTITIES[1:]:
id_col = f"{entity}_id"
person_ref = f"person_{entity}_id"
# Use person_ref column if available, else id_col
src_id = person_ref if person_ref in col_map else id_col

if src_id not in col_map:
continue

# Collect columns for this entity
cols_to_use = [src_id] + [
v for v in entity_cols[entity] if v != id_col and v in col_map
]
src_cols = [col_map[v] for v in cols_to_use]
df = combined_df[src_cols].copy()
# Strip period suffix
df.columns = [
c[: -len(suffix)] if c.endswith(suffix) else c
for c in df.columns
]
# Rename person_X_id -> X_id if needed
if src_id == person_ref and person_ref != id_col:
df = df.rename(columns={person_ref: id_col})
# Deduplicate
df = df.drop_duplicates(subset=[id_col]).reset_index(drop=True)
entity_dfs[entity] = df

return entity_dfs


def _build_uprating_manifest(vars_to_save, system):
"""Build manifest of variable metadata for embedding in HDFStore."""
records = []
for var in sorted(vars_to_save):
entity = (
system.variables[var].entity.key
if var in system.variables
else "unknown"
)
uprating = ""
if var in system.variables:
uprating = getattr(system.variables[var], "uprating", None) or ""
records.append(
{"variable": var, "entity": entity, "uprating": uprating}
)
return pd.DataFrame(records)


def _save_hdfstore(entity_dfs, manifest_df, output_path, time_period):
"""Save entity DataFrames and manifest to a Pandas HDFStore file."""
import warnings

hdfstore_path = output_path.replace(".h5", ".hdfstore.h5")

print(f"\nSaving HDFStore to {hdfstore_path}...")

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=pd.errors.PerformanceWarning,
message=".*PyTables will pickle object types.*",
)
with pd.HDFStore(hdfstore_path, mode="w") as store:
for entity_name, df in entity_dfs.items():
# Convert object columns to string for HDFStore compatibility
for col in df.columns:
if df[col].dtype == object:
df[col] = df[col].astype(str)
store.put(entity_name, df, format="table")

store.put("_variable_metadata", manifest_df, format="table")
store.put(
"_time_period",
pd.Series([time_period]),
format="table",
)

# Print summary
for entity_name, df in entity_dfs.items():
print(f" {entity_name}: {len(df):,} rows, {len(df.columns)} cols")
print(f" manifest: {len(manifest_df)} variables")

print(f"HDFStore saved successfully!")
return hdfstore_path


def create_sparse_cd_stacked_dataset(
w,
cds_to_calibrate,
Expand Down Expand Up @@ -738,6 +888,20 @@ def create_sparse_cd_stacked_dataset(
f" Average persons per household: {np.sum(person_weights) / np.sum(weights):.2f}"
)

# --- HDFStore output (entity-level format) ---
# Split the person-level combined_df into per-entity DataFrames and save
# alongside the h5py file. This format is consumed by the API v2 alpha
# and by policyengine-us's extend_single_year_dataset().
entity_dfs = _split_into_entity_dfs(
combined_df, base_sim.tax_benefit_system, vars_to_save, time_period
)
manifest_df = _build_uprating_manifest(
vars_to_save, base_sim.tax_benefit_system
)
hdfstore_path = _save_hdfstore(
entity_dfs, manifest_df, output_path, time_period
)

return output_path


Expand Down
Loading