From 9b1da24f6e5552b2a04bc0d86a6e6bde4f68f8ef Mon Sep 17 00:00:00 2001 From: SakshiKekre Date: Mon, 16 Mar 2026 18:39:34 +0530 Subject: [PATCH] feat: Add household axes support for MTR and earnings variation charts Add axes parameter to /household/calculate endpoint enabling the frontend to sweep a person's input variable (e.g. employment_income) across a range and receive 401-point arrays for all output variables. This powers MTR charts and earnings variation charts, removing the frontend's dependency on the v1 API's calculate-full endpoint. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/policyengine_api/api/household.py | 171 +++++++++- src/policyengine_api/modal_app.py | 171 +++++++--- src/policyengine_api/utils/axes.py | 209 ++++++++++++ tests/test_axes.py | 462 ++++++++++++++++++++++++++ 4 files changed, 961 insertions(+), 52 deletions(-) create mode 100644 src/policyengine_api/utils/axes.py create mode 100644 tests/test_axes.py diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 3abdc10..3da3249 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -47,6 +47,18 @@ def get_traceparent() -> str | None: router = APIRouter(prefix="/household", tags=["household"]) +class AxisSpec(BaseModel): + """Specification for a single axis in an axes group.""" + + name: str = Field(description="Variable name to vary, e.g. 'employment_income'") + min: float = Field(description="Minimum value of the range") + max: float = Field(description="Maximum value of the range") + count: int = Field(description="Number of evenly-spaced steps (e.g. 401)") + index: int = Field( + default=0, description="Which person (by index) to vary. Default 0." + ) + + class HouseholdCalculateRequest(BaseModel): """Request body for household calculation. @@ -136,6 +148,10 @@ class HouseholdCalculateRequest(BaseModel): dynamic_id: UUID | None = Field( default=None, description="Optional behavioural response ID" ) + axes: list[list[AxisSpec]] | None = Field( + default=None, + description="Optional axes for earnings variation. List of axis groups; each group is a list of parallel axes.", + ) class HouseholdCalculateResponse(BaseModel): @@ -251,6 +267,7 @@ def _run_local_household_uk( year: int, policy_data: dict | None, session: Session, + axes: list[list[dict]] | None = None, ) -> None: """Run UK household calculation locally. @@ -259,7 +276,9 @@ def _run_local_household_uk( from datetime import datetime, timezone try: - result = _calculate_household_uk(people, benunit, household, year, policy_data) + result = _calculate_household_uk( + people, benunit, household, year, policy_data, axes=axes + ) # Update job with result job = session.get(HouseholdJob, job_id) @@ -290,6 +309,7 @@ def _calculate_household_uk( household: list[dict], year: int, policy_data: dict | None, + axes: list[list[dict]] | None = None, ) -> dict: """Calculate UK household(s) and return result dict. @@ -353,6 +373,30 @@ def _calculate_household_uk( household_data[key] = [0.0] * n_households household_data[key][i] = value + # Save original counts for axes reshape + n_original_people = n_people + n_original_benunits = n_benunits + n_original_households = n_households + axis_count = 0 + + # Expand data for axes if provided + if axes is not None: + from policyengine_api.utils.axes import expand_dataframes_for_axes + + entity_datas = {"benunit": benunit_data, "household": household_data} + person_entity_id_keys = { + "benunit": "person_benunit_id", + "household": "person_household_id", + } + person_data, expanded_entities, axis_count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + benunit_data = expanded_entities["benunit"] + household_data = expanded_entities["household"] + n_people = len(person_data["person_id"]) + n_benunits = len(benunit_data["benunit_id"]) + n_households = len(household_data["household_id"]) + # Create MicroDataFrames person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") @@ -442,12 +486,25 @@ def safe_convert(value): household_dict[var] = safe_convert(output_data.household[var].iloc[i]) household_outputs.append(household_dict) - return { + result = { "person": person_outputs, "benunit": benunit_outputs, "household": household_outputs, } + # Reshape output for axes + if axes is not None: + from policyengine_api.utils.axes import reshape_axes_output + + n_original = { + "person": n_original_people, + "benunit": n_original_benunits, + "household": n_original_households, + } + result = reshape_axes_output(result, n_original, axis_count) + + return result + def _run_local_household_us( job_id: str, @@ -460,6 +517,7 @@ def _run_local_household_us( year: int, policy_data: dict | None, session: Session, + axes: list[list[dict]] | None = None, ) -> None: """Run US household calculation locally. @@ -477,6 +535,7 @@ def _run_local_household_us( household, year, policy_data, + axes=axes, ) # Update job with result @@ -511,6 +570,7 @@ def _calculate_household_us( household: list[dict], year: int, policy_data: dict | None, + axes: list[list[dict]] | None = None, ) -> dict: """Calculate US household(s) and return result dict. @@ -608,6 +668,48 @@ def _calculate_household_us( tax_unit_data[key] = [0.0] * n_tax_units tax_unit_data[key][i] = value + # Save original counts for axes reshape + n_original_people = n_people + n_original_households = n_households + n_original_marital_units = n_marital_units + n_original_families = n_families + n_original_spm_units = n_spm_units + n_original_tax_units = n_tax_units + axis_count = 0 + + # Expand data for axes if provided + if axes is not None: + from policyengine_api.utils.axes import expand_dataframes_for_axes + + entity_datas = { + "household": household_data, + "marital_unit": marital_unit_data, + "family": family_data, + "spm_unit": spm_unit_data, + "tax_unit": tax_unit_data, + } + person_entity_id_keys = { + "household": "person_household_id", + "marital_unit": "person_marital_unit_id", + "family": "person_family_id", + "spm_unit": "person_spm_unit_id", + "tax_unit": "person_tax_unit_id", + } + person_data, expanded_entities, axis_count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + household_data = expanded_entities["household"] + marital_unit_data = expanded_entities["marital_unit"] + family_data = expanded_entities["family"] + spm_unit_data = expanded_entities["spm_unit"] + tax_unit_data = expanded_entities["tax_unit"] + n_people = len(person_data["person_id"]) + n_households = len(household_data["household_id"]) + n_marital_units = len(marital_unit_data["marital_unit_id"]) + n_families = len(family_data["family_id"]) + n_spm_units = len(spm_unit_data["spm_unit_id"]) + n_tax_units = len(tax_unit_data["tax_unit_id"]) + # Create MicroDataFrames person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") household_df = MicroDataFrame( @@ -695,7 +797,7 @@ def extract_entity_outputs( outputs.append(row_dict) return outputs - return { + result = { "person": extract_entity_outputs("person", output_data.person, n_people), "marital_unit": extract_entity_outputs( "marital_unit", output_data.marital_unit, len(output_data.marital_unit) @@ -714,6 +816,22 @@ def extract_entity_outputs( ), } + # Reshape output for axes + if axes is not None: + from policyengine_api.utils.axes import reshape_axes_output + + n_original = { + "person": n_original_people, + "household": n_original_households, + "marital_unit": n_original_marital_units, + "family": n_original_families, + "spm_unit": n_original_spm_units, + "tax_unit": n_original_tax_units, + } + result = reshape_axes_output(result, n_original, axis_count) + + return result + def _trigger_modal_household( job_id: str, @@ -725,6 +843,11 @@ def _trigger_modal_household( """Trigger household simulation - Modal or local based on settings.""" from policyengine_api.config import settings + # Serialize axes to dicts for passing to Modal/local functions + axes_dicts: list[list[dict]] | None = None + if request.axes is not None: + axes_dicts = [[axis.model_dump() for axis in group] for group in request.axes] + if not settings.agent_use_modal and session is not None: # Run locally if request.country_id == "uk": @@ -736,6 +859,7 @@ def _trigger_modal_household( year=request.year or 2026, policy_data=policy_data, session=session, + axes=axes_dicts, ) else: _run_local_household_us( @@ -749,6 +873,7 @@ def _trigger_modal_household( year=request.year or 2024, policy_data=policy_data, session=session, + axes=axes_dicts, ) else: # Use Modal @@ -771,6 +896,7 @@ def _trigger_modal_household( policy_data=policy_data, dynamic_data=dynamic_data, traceparent=traceparent, + axes=axes_dicts, ) else: fn = modal.Function.from_name( @@ -790,6 +916,7 @@ def _trigger_modal_household( policy_data=policy_data, dynamic_data=dynamic_data, traceparent=traceparent, + axes=axes_dicts, ) @@ -871,23 +998,41 @@ def calculate_household( has_policy=request.policy_id is not None, has_dynamic=request.dynamic_id is not None, ): + # Validate axes if provided + if request.axes is not None: + from policyengine_api.utils.axes import validate_axes + + axes_dicts = [ + [axis.model_dump() for axis in group] for group in request.axes + ] + try: + validate_axes(axes_dicts, n_people=len(request.people)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + # Get policy and dynamic data for Modal policy_data = _get_policy_data(request.policy_id, session) dynamic_data = _get_dynamic_data(request.dynamic_id, session) # Create job record + request_data = { + "people": request.people, + "benunit": request.benunit, + "marital_unit": request.marital_unit, + "family": request.family, + "spm_unit": request.spm_unit, + "tax_unit": request.tax_unit, + "household": request.household, + "year": request.year, + } + if request.axes is not None: + request_data["axes"] = [ + [axis.model_dump() for axis in group] for group in request.axes + ] + job = HouseholdJob( country_id=request.country_id, - request_data={ - "people": request.people, - "benunit": request.benunit, - "marital_unit": request.marital_unit, - "family": request.family, - "spm_unit": request.spm_unit, - "tax_unit": request.tax_unit, - "household": request.household, - "year": request.year, - }, + request_data=request_data, policy_id=request.policy_id, dynamic_id=request.dynamic_id, status=HouseholdJobStatus.PENDING, diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index d0f23f1..af433aa 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -221,6 +221,7 @@ def simulate_household_uk( policy_data: dict | None, dynamic_data: dict | None, traceparent: str | None = None, + axes: list | None = None, ) -> None: """Calculate UK household(s) and write result to database. @@ -294,6 +295,30 @@ def simulate_household_uk( household_data[key] = [0.0] * n_households household_data[key][i] = value + # Save original counts for axes reshape + n_original_people = n_people + n_original_benunits = n_benunits + n_original_households = n_households + axis_count = 0 + + # Expand data for axes if provided + if axes is not None: + from policyengine_api.utils.axes import expand_dataframes_for_axes + + entity_datas = {"benunit": benunit_data, "household": household_data} + person_entity_id_keys = { + "benunit": "person_benunit_id", + "household": "person_household_id", + } + person_data, expanded_entities, axis_count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + benunit_data = expanded_entities["benunit"] + household_data = expanded_entities["household"] + n_people = len(person_data["person_id"]) + n_benunits = len(benunit_data["benunit_id"]) + n_households = len(household_data["household_id"]) + # Create MicroDataFrames person_df = MicroDataFrame( pd.DataFrame(person_data), weights="person_weight" @@ -396,6 +421,23 @@ def safe_convert(value): ) household_outputs.append(household_dict) + result_data = { + "person": person_outputs, + "benunit": benunit_outputs, + "household": household_outputs, + } + + # Reshape output for axes + if axes is not None: + from policyengine_api.utils.axes import reshape_axes_output + + n_original = { + "person": n_original_people, + "benunit": n_original_benunits, + "household": n_original_households, + } + result_data = reshape_axes_output(result_data, n_original, axis_count) + # Write result to database with Session(engine) as session: from sqlmodel import text @@ -410,13 +452,7 @@ def safe_convert(value): """), params={ "job_id": job_id, - "result": json.dumps( - { - "person": person_outputs, - "benunit": benunit_outputs, - "household": household_outputs, - } - ), + "result": json.dumps(result_data), "completed_at": datetime.now(timezone.utc), }, ) @@ -466,6 +502,7 @@ def simulate_household_us( policy_data: dict | None, dynamic_data: dict | None, traceparent: str | None = None, + axes: list | None = None, ) -> None: """Calculate US household(s) and write result to database. @@ -574,6 +611,48 @@ def simulate_household_us( tax_unit_data[key] = [0.0] * n_tax_units tax_unit_data[key][i] = value + # Save original counts for axes reshape + n_original_people = n_people + n_original_households = n_households + n_original_marital_units = n_marital_units + n_original_families = n_families + n_original_spm_units = n_spm_units + n_original_tax_units = n_tax_units + axis_count = 0 + + # Expand data for axes if provided + if axes is not None: + from policyengine_api.utils.axes import expand_dataframes_for_axes + + entity_datas = { + "household": household_data, + "marital_unit": marital_unit_data, + "family": family_data, + "spm_unit": spm_unit_data, + "tax_unit": tax_unit_data, + } + person_entity_id_keys = { + "household": "person_household_id", + "marital_unit": "person_marital_unit_id", + "family": "person_family_id", + "spm_unit": "person_spm_unit_id", + "tax_unit": "person_tax_unit_id", + } + person_data, expanded_entities, axis_count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + household_data = expanded_entities["household"] + marital_unit_data = expanded_entities["marital_unit"] + family_data = expanded_entities["family"] + spm_unit_data = expanded_entities["spm_unit"] + tax_unit_data = expanded_entities["tax_unit"] + n_people = len(person_data["person_id"]) + n_households = len(household_data["household_id"]) + n_marital_units = len(marital_unit_data["marital_unit_id"]) + n_families = len(family_data["family_id"]) + n_spm_units = len(spm_unit_data["spm_unit_id"]) + n_tax_units = len(tax_unit_data["tax_unit_id"]) + # Create MicroDataFrames person_df = MicroDataFrame( pd.DataFrame(person_data), weights="person_weight" @@ -674,6 +753,51 @@ def extract_entity_outputs( outputs.append(row_dict) return outputs + result_data = { + "person": extract_entity_outputs( + "person", output_data.person, n_people + ), + "marital_unit": extract_entity_outputs( + "marital_unit", + output_data.marital_unit, + len(output_data.marital_unit), + ), + "family": extract_entity_outputs( + "family", + output_data.family, + len(output_data.family), + ), + "spm_unit": extract_entity_outputs( + "spm_unit", + output_data.spm_unit, + len(output_data.spm_unit), + ), + "tax_unit": extract_entity_outputs( + "tax_unit", + output_data.tax_unit, + len(output_data.tax_unit), + ), + "household": extract_entity_outputs( + "household", + output_data.household, + len(output_data.household), + ), + } + + # Reshape output for axes + if axes is not None: + from policyengine_api.utils.axes import reshape_axes_output + + n_original = { + "person": n_original_people, + "household": n_original_households, + "marital_unit": n_original_marital_units, + "family": n_original_families, + "spm_unit": n_original_spm_units, + "tax_unit": n_original_tax_units, + } + result_data = reshape_axes_output(result_data, n_original, axis_count) + # Write result to database with Session(engine) as session: from sqlmodel import text @@ -688,38 +812,7 @@ def extract_entity_outputs( """), params={ "job_id": job_id, - "result": json.dumps( - { - "person": extract_entity_outputs( - "person", output_data.person, n_people - ), - "marital_unit": extract_entity_outputs( - "marital_unit", - output_data.marital_unit, - len(output_data.marital_unit), - ), - "family": extract_entity_outputs( - "family", - output_data.family, - len(output_data.family), - ), - "spm_unit": extract_entity_outputs( - "spm_unit", - output_data.spm_unit, - len(output_data.spm_unit), - ), - "tax_unit": extract_entity_outputs( - "tax_unit", - output_data.tax_unit, - len(output_data.tax_unit), - ), - "household": extract_entity_outputs( - "household", - output_data.household, - len(output_data.household), - ), - } - ), + "result": json.dumps(result_data), "completed_at": datetime.now(timezone.utc), }, ) diff --git a/src/policyengine_api/utils/axes.py b/src/policyengine_api/utils/axes.py new file mode 100644 index 0000000..4602631 --- /dev/null +++ b/src/policyengine_api/utils/axes.py @@ -0,0 +1,209 @@ +"""Utility functions for household axes (earnings variation) support. + +Axes allow varying a person's variable (e.g. employment_income) across a +linspace range, replicating all entities so the simulation covers the full +sweep in a single run. +""" + +from __future__ import annotations + +import numpy as np + + +def validate_axes(axes: list[list[dict]], n_people: int) -> None: + """Validate axes specification. + + Args: + axes: List of axis groups. Each group is a list of axis dicts with + keys: name, min, max, count, index. + n_people: Number of people in the household. + + Raises: + ValueError: If axes spec is invalid. + """ + if len(axes) == 0: + raise ValueError("axes must contain exactly 1 axis group, got 0") + if len(axes) > 1: + raise ValueError(f"axes must contain exactly 1 axis group, got {len(axes)}") + + group = axes[0] + if len(group) == 0: + raise ValueError("Axis group must contain at least one axis") + + counts = set() + for axis in group: + name = axis.get("name", "") + if not name or not isinstance(name, str): + raise ValueError("Each axis must have a non-empty 'name' string") + + min_val = axis.get("min") + max_val = axis.get("max") + if not isinstance(min_val, (int, float)): + raise ValueError(f"Axis '{name}': 'min' must be numeric") + if not isinstance(max_val, (int, float)): + raise ValueError(f"Axis '{name}': 'max' must be numeric") + + count = axis.get("count") + if not isinstance(count, int) or count < 2 or count > 1000: + raise ValueError( + f"Axis '{name}': 'count' must be an integer between 2 and 1000" + ) + + index = axis.get("index", 0) + if not isinstance(index, int) or index < 0 or index >= n_people: + raise ValueError(f"Axis '{name}': 'index' must be in [0, {n_people})") + + counts.add(count) + + if len(counts) > 1: + raise ValueError( + f"All parallel axes in a group must have the same count, got {counts}" + ) + + +def expand_dataframes_for_axes( + axes: list[list[dict]], + person_data: dict[str, list], + entity_datas: dict[str, dict[str, list]], + person_entity_id_keys: dict[str, str], +) -> tuple[dict[str, list], dict[str, dict[str, list]], int]: + """Expand person and entity data for axes simulation. + + Args: + axes: Validated axes spec (exactly 1 group). + person_data: Dict of column_name -> list of values for persons. + entity_datas: Dict of entity_name -> {column_name -> list of values}. + e.g. {"benunit": {"benunit_id": [0], ...}, "household": {...}} + person_entity_id_keys: Mapping from entity_name to the FK column in + person_data. e.g. {"benunit": "person_benunit_id", "household": "person_household_id"} + + Returns: + (expanded_person_data, expanded_entity_datas, axis_count) + """ + group = axes[0] + axis_count = group[0]["count"] + n_people = len(person_data["person_id"]) + + # Replicate person rows: each person repeated axis_count times + expanded_person = {} + for col, values in person_data.items(): + expanded = [] + for val in values: + expanded.extend([val] * axis_count) + expanded_person[col] = expanded + + # Update person IDs to 0..n_people*axis_count-1 + expanded_person["person_id"] = list(range(n_people * axis_count)) + + # Set all weights to 1.0 + if "person_weight" in expanded_person: + expanded_person["person_weight"] = [1.0] * (n_people * axis_count) + + # Replicate entity rows and update IDs + expanded_entities = {} + for entity_name, entity_data in entity_datas.items(): + n_entities = len(next(iter(entity_data.values()))) + expanded_entity = {} + for col, values in entity_data.items(): + expanded = [] + for val in values: + expanded.extend([val] * axis_count) + expanded_entity[col] = expanded + + # Update entity IDs to 0..n_entities*axis_count-1 + id_col = f"{entity_name}_id" + if id_col in expanded_entity: + expanded_entity[id_col] = list(range(n_entities * axis_count)) + + # Set entity weights to 1.0 + weight_col = f"{entity_name}_weight" + if weight_col in expanded_entity: + expanded_entity[weight_col] = [1.0] * (n_entities * axis_count) + + expanded_entities[entity_name] = expanded_entity + + # Update person-to-entity FK mappings + # Original person p pointing to entity e -> copy at position p*axis_count+i + # should point to entity e*axis_count+i + for entity_name, fk_col in person_entity_id_keys.items(): + if fk_col in expanded_person: + original_fks = person_data[fk_col] + new_fks = [] + for p_idx in range(n_people): + orig_entity_id = original_fks[p_idx] + for i in range(axis_count): + new_fks.append(int(orig_entity_id) * axis_count + i) + expanded_person[fk_col] = new_fks + + # Apply linspace values for each axis in the group + for axis in group: + var_name = axis["name"] + min_val = axis["min"] + max_val = axis["max"] + count = axis["count"] + index = axis.get("index", 0) + + linspace_values = np.linspace(min_val, max_val, count).tolist() + + # Create column if it doesn't exist + if var_name not in expanded_person: + expanded_person[var_name] = [0.0] * (n_people * axis_count) + + # Set varied variable on target person's copies + # Target person's copies are at positions index*axis_count .. index*axis_count+count-1 + start = index * axis_count + for i in range(count): + expanded_person[var_name][start + i] = linspace_values[i] + + return expanded_person, expanded_entities, axis_count + + +def reshape_axes_output( + result: dict[str, list], + n_original: dict[str, int], + axis_count: int, +) -> dict[str, list]: + """Reshape flat simulation output back into axes format. + + Groups axis_count consecutive rows per original entity into a single dict + with array values. + + Args: + result: Standard result dict like {"person": [{var: scalar}, ...]} + with n_original * axis_count rows per entity. + n_original: Dict of entity_name -> original count before expansion. + e.g. {"person": 1, "benunit": 1, "household": 1} + axis_count: Number of axis steps. + + Returns: + Dict with same structure but array values per variable. + e.g. {"person": [{"employment_income": [0, 500, ...]}]} + """ + reshaped = {} + for entity_name, rows in result.items(): + if not isinstance(rows, list): + reshaped[entity_name] = rows + continue + + orig_count = n_original.get(entity_name) + if orig_count is None or len(rows) != orig_count * axis_count: + # Unknown entity or mismatched row count - pass through unchanged + reshaped[entity_name] = rows + continue + + grouped = [] + for orig_idx in range(orig_count): + start = orig_idx * axis_count + end = start + axis_count + chunk = rows[start:end] + + # Merge chunk into single dict with arrays + merged = {} + if chunk: + for var in chunk[0]: + merged[var] = [row[var] for row in chunk] + grouped.append(merged) + + reshaped[entity_name] = grouped + + return reshaped diff --git a/tests/test_axes.py b/tests/test_axes.py new file mode 100644 index 0000000..c8a5322 --- /dev/null +++ b/tests/test_axes.py @@ -0,0 +1,462 @@ +"""Tests for household axes utility functions.""" + +import pytest + +from policyengine_api.utils.axes import ( + expand_dataframes_for_axes, + reshape_axes_output, + validate_axes, +) + + +class TestValidateAxes: + """Tests for validate_axes().""" + + def test_valid_single_axis(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 5, + "index": 0, + } + ] + ] + validate_axes(axes, n_people=1) # Should not raise + + def test_valid_parallel_axes(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 5, + "index": 0, + }, + { + "name": "self_employment_income", + "min": 0, + "max": 50000, + "count": 5, + "index": 0, + }, + ] + ] + validate_axes(axes, n_people=1) # Should not raise + + def test_empty_axes(self): + with pytest.raises(ValueError, match="exactly 1 axis group, got 0"): + validate_axes([], n_people=1) + + def test_multiple_groups(self): + group = [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 5, + "index": 0, + } + ] + with pytest.raises(ValueError, match="exactly 1 axis group, got 2"): + validate_axes([group, group], n_people=1) + + def test_count_too_low(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 1, + "index": 0, + } + ] + ] + with pytest.raises(ValueError, match="between 2 and 1000"): + validate_axes(axes, n_people=1) + + def test_count_too_high(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 1001, + "index": 0, + } + ] + ] + with pytest.raises(ValueError, match="between 2 and 1000"): + validate_axes(axes, n_people=1) + + def test_index_out_of_bounds(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 5, + "index": 2, + } + ] + ] + with pytest.raises(ValueError, match="index"): + validate_axes(axes, n_people=2) + + def test_index_negative(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 5, + "index": -1, + } + ] + ] + with pytest.raises(ValueError, match="index"): + validate_axes(axes, n_people=1) + + def test_mismatched_counts(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100000, + "count": 5, + "index": 0, + }, + { + "name": "self_employment_income", + "min": 0, + "max": 50000, + "count": 10, + "index": 0, + }, + ] + ] + with pytest.raises(ValueError, match="same count"): + validate_axes(axes, n_people=1) + + def test_empty_name(self): + axes = [[{"name": "", "min": 0, "max": 100000, "count": 5, "index": 0}]] + with pytest.raises(ValueError, match="non-empty 'name'"): + validate_axes(axes, n_people=1) + + def test_empty_group(self): + with pytest.raises(ValueError, match="at least one axis"): + validate_axes([[]], n_people=1) + + +class TestExpandDataframes: + """Tests for expand_dataframes_for_axes().""" + + def test_single_person_single_axis(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 1000, + "count": 5, + "index": 0, + } + ] + ] + person_data = { + "person_id": [0], + "person_household_id": [0], + "person_weight": [1.0], + "employment_income": [50000.0], + } + entity_datas = { + "household": { + "household_id": [0], + "household_weight": [1.0], + } + } + person_entity_id_keys = {"household": "person_household_id"} + + exp_person, exp_entities, count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + + assert count == 5 + assert len(exp_person["person_id"]) == 5 + assert exp_person["person_id"] == [0, 1, 2, 3, 4] + # employment_income should be linspace(0, 1000, 5) + assert exp_person["employment_income"] == [0.0, 250.0, 500.0, 750.0, 1000.0] + assert exp_person["person_weight"] == [1.0] * 5 + + # Household should be replicated + assert len(exp_entities["household"]["household_id"]) == 5 + assert exp_entities["household"]["household_id"] == [0, 1, 2, 3, 4] + + def test_two_person_vary_first(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 400, + "count": 3, + "index": 0, + } + ] + ] + person_data = { + "person_id": [0, 1], + "person_household_id": [0, 0], + "person_weight": [1.0, 1.0], + "employment_income": [50000.0, 30000.0], + "age": [40.0, 30.0], + } + entity_datas = { + "household": { + "household_id": [0], + "household_weight": [1.0], + } + } + person_entity_id_keys = {"household": "person_household_id"} + + exp_person, exp_entities, count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + + assert count == 3 + # 2 people * 3 steps = 6 person rows + assert len(exp_person["person_id"]) == 6 + assert exp_person["person_id"] == [0, 1, 2, 3, 4, 5] + + # Person 0 copies: indices 0,1,2 -> employment_income = linspace(0,400,3) + assert exp_person["employment_income"][0] == 0.0 + assert exp_person["employment_income"][1] == 200.0 + assert exp_person["employment_income"][2] == 400.0 + + # Person 1 copies: indices 3,4,5 -> employment_income stays at 30000 + assert exp_person["employment_income"][3] == 30000.0 + assert exp_person["employment_income"][4] == 30000.0 + assert exp_person["employment_income"][5] == 30000.0 + + # Age should be replicated + assert exp_person["age"][0:3] == [40.0, 40.0, 40.0] + assert exp_person["age"][3:6] == [30.0, 30.0, 30.0] + + def test_entity_replication(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 100, + "count": 3, + "index": 0, + } + ] + ] + person_data = { + "person_id": [0], + "person_benunit_id": [0], + "person_household_id": [0], + "person_weight": [1.0], + } + entity_datas = { + "benunit": { + "benunit_id": [0], + "benunit_weight": [1.0], + }, + "household": { + "household_id": [0], + "household_weight": [1.0], + "region": ["LONDON"], + }, + } + person_entity_id_keys = { + "benunit": "person_benunit_id", + "household": "person_household_id", + } + + exp_person, exp_entities, count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + + assert count == 3 + + # Benunit replicated + assert len(exp_entities["benunit"]["benunit_id"]) == 3 + assert exp_entities["benunit"]["benunit_id"] == [0, 1, 2] + + # Household replicated with region + assert len(exp_entities["household"]["household_id"]) == 3 + assert exp_entities["household"]["region"] == ["LONDON", "LONDON", "LONDON"] + + # FK mapping updated + assert exp_person["person_benunit_id"] == [0, 1, 2] + assert exp_person["person_household_id"] == [0, 1, 2] + + def test_parallel_axes(self): + axes = [ + [ + { + "name": "employment_income", + "min": 0, + "max": 1000, + "count": 3, + "index": 0, + }, + { + "name": "self_employment_income", + "min": 100, + "max": 500, + "count": 3, + "index": 0, + }, + ] + ] + person_data = { + "person_id": [0], + "person_household_id": [0], + "person_weight": [1.0], + } + entity_datas = { + "household": { + "household_id": [0], + "household_weight": [1.0], + } + } + person_entity_id_keys = {"household": "person_household_id"} + + exp_person, exp_entities, count = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + + assert count == 3 + assert exp_person["employment_income"] == [0.0, 500.0, 1000.0] + assert exp_person["self_employment_income"] == [100.0, 300.0, 500.0] + + def test_variable_not_in_person_data_created(self): + axes = [[{"name": "new_variable", "min": 0, "max": 10, "count": 3, "index": 0}]] + person_data = { + "person_id": [0], + "person_household_id": [0], + "person_weight": [1.0], + } + entity_datas = { + "household": { + "household_id": [0], + "household_weight": [1.0], + } + } + person_entity_id_keys = {"household": "person_household_id"} + + exp_person, _, _ = expand_dataframes_for_axes( + axes, person_data, entity_datas, person_entity_id_keys + ) + + assert "new_variable" in exp_person + assert exp_person["new_variable"] == [0.0, 5.0, 10.0] + + +class TestReshapeAxesOutput: + """Tests for reshape_axes_output().""" + + def test_single_person(self): + # 1 person * 3 steps = 3 rows + result = { + "person": [ + {"employment_income": 0.0, "tax": 0.0}, + {"employment_income": 500.0, "tax": 100.0}, + {"employment_income": 1000.0, "tax": 200.0}, + ], + "household": [ + {"income": 0.0}, + {"income": 500.0}, + {"income": 1000.0}, + ], + } + n_original = {"person": 1, "household": 1} + + reshaped = reshape_axes_output(result, n_original, axis_count=3) + + assert len(reshaped["person"]) == 1 + assert reshaped["person"][0]["employment_income"] == [0.0, 500.0, 1000.0] + assert reshaped["person"][0]["tax"] == [0.0, 100.0, 200.0] + + assert len(reshaped["household"]) == 1 + assert reshaped["household"][0]["income"] == [0.0, 500.0, 1000.0] + + def test_two_person(self): + # 2 people * 2 steps = 4 rows + result = { + "person": [ + {"income": 0.0}, + {"income": 100.0}, + {"income": 50.0}, + {"income": 50.0}, + ], + } + n_original = {"person": 2} + + reshaped = reshape_axes_output(result, n_original, axis_count=2) + + assert len(reshaped["person"]) == 2 + assert reshaped["person"][0]["income"] == [0.0, 100.0] + assert reshaped["person"][1]["income"] == [50.0, 50.0] + + def test_string_values(self): + result = { + "person": [ + {"status": "employed"}, + {"status": "unemployed"}, + ], + } + n_original = {"person": 1} + + reshaped = reshape_axes_output(result, n_original, axis_count=2) + + assert reshaped["person"][0]["status"] == ["employed", "unemployed"] + + def test_unknown_entity_passthrough(self): + result = { + "person": [{"income": 0.0}, {"income": 100.0}], + "unknown_entity": [{"x": 1}], + } + n_original = {"person": 1} + + reshaped = reshape_axes_output(result, n_original, axis_count=2) + + # person is reshaped + assert len(reshaped["person"]) == 1 + # unknown_entity passes through unchanged + assert reshaped["unknown_entity"] == [{"x": 1}] + + def test_mismatched_rows_passthrough(self): + result = { + "person": [{"income": 0.0}, {"income": 100.0}, {"income": 200.0}], + } + # n_original * axis_count = 1 * 2 = 2, but we have 3 rows + n_original = {"person": 1} + + reshaped = reshape_axes_output(result, n_original, axis_count=2) + + # Mismatched, so pass through unchanged + assert len(reshaped["person"]) == 3 + + def test_non_list_value_passthrough(self): + result = { + "person": [{"income": 0.0}, {"income": 100.0}], + "metadata": "some_string", + } + n_original = {"person": 1} + + reshaped = reshape_axes_output(result, n_original, axis_count=2) + + assert reshaped["metadata"] == "some_string"