Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 158 additions & 13 deletions src/policyengine_api/api/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def get_traceparent() -> str | None:
router = APIRouter(prefix="/household", tags=["household"])


class AxisSpec(BaseModel):
"""Specification for a single axis in an axes group."""

name: str = Field(description="Variable name to vary, e.g. 'employment_income'")
min: float = Field(description="Minimum value of the range")
max: float = Field(description="Maximum value of the range")
count: int = Field(description="Number of evenly-spaced steps (e.g. 401)")
index: int = Field(
default=0, description="Which person (by index) to vary. Default 0."
)


class HouseholdCalculateRequest(BaseModel):
"""Request body for household calculation.

Expand Down Expand Up @@ -136,6 +148,10 @@ class HouseholdCalculateRequest(BaseModel):
dynamic_id: UUID | None = Field(
default=None, description="Optional behavioural response ID"
)
axes: list[list[AxisSpec]] | None = Field(
default=None,
description="Optional axes for earnings variation. List of axis groups; each group is a list of parallel axes.",
)


class HouseholdCalculateResponse(BaseModel):
Expand Down Expand Up @@ -251,6 +267,7 @@ def _run_local_household_uk(
year: int,
policy_data: dict | None,
session: Session,
axes: list[list[dict]] | None = None,
) -> None:
"""Run UK household calculation locally.

Expand All @@ -259,7 +276,9 @@ def _run_local_household_uk(
from datetime import datetime, timezone

try:
result = _calculate_household_uk(people, benunit, household, year, policy_data)
result = _calculate_household_uk(
people, benunit, household, year, policy_data, axes=axes
)

# Update job with result
job = session.get(HouseholdJob, job_id)
Expand Down Expand Up @@ -290,6 +309,7 @@ def _calculate_household_uk(
household: list[dict],
year: int,
policy_data: dict | None,
axes: list[list[dict]] | None = None,
) -> dict:
"""Calculate UK household(s) and return result dict.

Expand Down Expand Up @@ -353,6 +373,30 @@ def _calculate_household_uk(
household_data[key] = [0.0] * n_households
household_data[key][i] = value

# Save original counts for axes reshape
n_original_people = n_people
n_original_benunits = n_benunits
n_original_households = n_households
axis_count = 0

# Expand data for axes if provided
if axes is not None:
from policyengine_api.utils.axes import expand_dataframes_for_axes

entity_datas = {"benunit": benunit_data, "household": household_data}
person_entity_id_keys = {
"benunit": "person_benunit_id",
"household": "person_household_id",
}
person_data, expanded_entities, axis_count = expand_dataframes_for_axes(
axes, person_data, entity_datas, person_entity_id_keys
)
benunit_data = expanded_entities["benunit"]
household_data = expanded_entities["household"]
n_people = len(person_data["person_id"])
n_benunits = len(benunit_data["benunit_id"])
n_households = len(household_data["household_id"])

# Create MicroDataFrames
person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight")
benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight")
Expand Down Expand Up @@ -442,12 +486,25 @@ def safe_convert(value):
household_dict[var] = safe_convert(output_data.household[var].iloc[i])
household_outputs.append(household_dict)

return {
result = {
"person": person_outputs,
"benunit": benunit_outputs,
"household": household_outputs,
}

# Reshape output for axes
if axes is not None:
from policyengine_api.utils.axes import reshape_axes_output

n_original = {
"person": n_original_people,
"benunit": n_original_benunits,
"household": n_original_households,
}
result = reshape_axes_output(result, n_original, axis_count)

return result


def _run_local_household_us(
job_id: str,
Expand All @@ -460,6 +517,7 @@ def _run_local_household_us(
year: int,
policy_data: dict | None,
session: Session,
axes: list[list[dict]] | None = None,
) -> None:
"""Run US household calculation locally.

Expand All @@ -477,6 +535,7 @@ def _run_local_household_us(
household,
year,
policy_data,
axes=axes,
)

# Update job with result
Expand Down Expand Up @@ -511,6 +570,7 @@ def _calculate_household_us(
household: list[dict],
year: int,
policy_data: dict | None,
axes: list[list[dict]] | None = None,
) -> dict:
"""Calculate US household(s) and return result dict.

Expand Down Expand Up @@ -608,6 +668,48 @@ def _calculate_household_us(
tax_unit_data[key] = [0.0] * n_tax_units
tax_unit_data[key][i] = value

# Save original counts for axes reshape
n_original_people = n_people
n_original_households = n_households
n_original_marital_units = n_marital_units
n_original_families = n_families
n_original_spm_units = n_spm_units
n_original_tax_units = n_tax_units
axis_count = 0

# Expand data for axes if provided
if axes is not None:
from policyengine_api.utils.axes import expand_dataframes_for_axes

entity_datas = {
"household": household_data,
"marital_unit": marital_unit_data,
"family": family_data,
"spm_unit": spm_unit_data,
"tax_unit": tax_unit_data,
}
person_entity_id_keys = {
"household": "person_household_id",
"marital_unit": "person_marital_unit_id",
"family": "person_family_id",
"spm_unit": "person_spm_unit_id",
"tax_unit": "person_tax_unit_id",
}
person_data, expanded_entities, axis_count = expand_dataframes_for_axes(
axes, person_data, entity_datas, person_entity_id_keys
)
household_data = expanded_entities["household"]
marital_unit_data = expanded_entities["marital_unit"]
family_data = expanded_entities["family"]
spm_unit_data = expanded_entities["spm_unit"]
tax_unit_data = expanded_entities["tax_unit"]
n_people = len(person_data["person_id"])
n_households = len(household_data["household_id"])
n_marital_units = len(marital_unit_data["marital_unit_id"])
n_families = len(family_data["family_id"])
n_spm_units = len(spm_unit_data["spm_unit_id"])
n_tax_units = len(tax_unit_data["tax_unit_id"])

# Create MicroDataFrames
person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight")
household_df = MicroDataFrame(
Expand Down Expand Up @@ -695,7 +797,7 @@ def extract_entity_outputs(
outputs.append(row_dict)
return outputs

return {
result = {
"person": extract_entity_outputs("person", output_data.person, n_people),
"marital_unit": extract_entity_outputs(
"marital_unit", output_data.marital_unit, len(output_data.marital_unit)
Expand All @@ -714,6 +816,22 @@ def extract_entity_outputs(
),
}

# Reshape output for axes
if axes is not None:
from policyengine_api.utils.axes import reshape_axes_output

n_original = {
"person": n_original_people,
"household": n_original_households,
"marital_unit": n_original_marital_units,
"family": n_original_families,
"spm_unit": n_original_spm_units,
"tax_unit": n_original_tax_units,
}
result = reshape_axes_output(result, n_original, axis_count)

return result


def _trigger_modal_household(
job_id: str,
Expand All @@ -725,6 +843,11 @@ def _trigger_modal_household(
"""Trigger household simulation - Modal or local based on settings."""
from policyengine_api.config import settings

# Serialize axes to dicts for passing to Modal/local functions
axes_dicts: list[list[dict]] | None = None
if request.axes is not None:
axes_dicts = [[axis.model_dump() for axis in group] for group in request.axes]

if not settings.agent_use_modal and session is not None:
# Run locally
if request.country_id == "uk":
Expand All @@ -736,6 +859,7 @@ def _trigger_modal_household(
year=request.year or 2026,
policy_data=policy_data,
session=session,
axes=axes_dicts,
)
else:
_run_local_household_us(
Expand All @@ -749,6 +873,7 @@ def _trigger_modal_household(
year=request.year or 2024,
policy_data=policy_data,
session=session,
axes=axes_dicts,
)
else:
# Use Modal
Expand All @@ -771,6 +896,7 @@ def _trigger_modal_household(
policy_data=policy_data,
dynamic_data=dynamic_data,
traceparent=traceparent,
axes=axes_dicts,
)
else:
fn = modal.Function.from_name(
Expand All @@ -790,6 +916,7 @@ def _trigger_modal_household(
policy_data=policy_data,
dynamic_data=dynamic_data,
traceparent=traceparent,
axes=axes_dicts,
)


Expand Down Expand Up @@ -871,23 +998,41 @@ def calculate_household(
has_policy=request.policy_id is not None,
has_dynamic=request.dynamic_id is not None,
):
# Validate axes if provided
if request.axes is not None:
from policyengine_api.utils.axes import validate_axes

axes_dicts = [
[axis.model_dump() for axis in group] for group in request.axes
]
try:
validate_axes(axes_dicts, n_people=len(request.people))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

# Get policy and dynamic data for Modal
policy_data = _get_policy_data(request.policy_id, session)
dynamic_data = _get_dynamic_data(request.dynamic_id, session)

# Create job record
request_data = {
"people": request.people,
"benunit": request.benunit,
"marital_unit": request.marital_unit,
"family": request.family,
"spm_unit": request.spm_unit,
"tax_unit": request.tax_unit,
"household": request.household,
"year": request.year,
}
if request.axes is not None:
request_data["axes"] = [
[axis.model_dump() for axis in group] for group in request.axes
]

job = HouseholdJob(
country_id=request.country_id,
request_data={
"people": request.people,
"benunit": request.benunit,
"marital_unit": request.marital_unit,
"family": request.family,
"spm_unit": request.spm_unit,
"tax_unit": request.tax_unit,
"household": request.household,
"year": request.year,
},
request_data=request_data,
policy_id=request.policy_id,
dynamic_id=request.dynamic_id,
status=HouseholdJobStatus.PENDING,
Expand Down
Loading
Loading