diff --git a/.github/workflows/pipeline.yaml b/.github/workflows/pipeline.yaml new file mode 100644 index 00000000..cec5a748 --- /dev/null +++ b/.github/workflows/pipeline.yaml @@ -0,0 +1,58 @@ +name: Run Pipeline + +on: + push: + branches: [main] + workflow_dispatch: + inputs: + gpu: + description: "GPU type for regional calibration" + default: "T4" + type: string + epochs: + description: "Epochs for regional calibration" + default: "1000" + type: string + national_epochs: + description: "Epochs for national calibration" + default: "4000" + type: string + num_workers: + description: "Number of parallel H5 workers" + default: "8" + type: string + skip_national: + description: "Skip national calibration/H5" + default: false + type: boolean + +jobs: + pipeline: + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install Modal + run: pip install modal + + - name: Launch pipeline on Modal + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + run: | + ARGS="--action run --branch main" + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + ARGS="$ARGS --gpu ${{ inputs.gpu }}" + ARGS="$ARGS --epochs ${{ inputs.epochs }}" + ARGS="$ARGS --national-epochs ${{ inputs.national_epochs }}" + ARGS="$ARGS --num-workers ${{ inputs.num_workers }}" + if [ "${{ inputs.skip_national }}" = "true" ]; then + ARGS="$ARGS --skip-national" + fi + fi + modal run --detach modal_app/pipeline.py::main $ARGS diff --git a/.github/workflows/pr_code_changes.yaml b/.github/workflows/pr_code_changes.yaml index cf335694..83f0866b 100644 --- a/.github/workflows/pr_code_changes.yaml +++ b/.github/workflows/pr_code_changes.yaml @@ -84,7 +84,7 @@ jobs: needs: [check-fork, Lint] uses: ./.github/workflows/reusable_test.yaml with: - full_suite: true + full_suite: false upload_data: false deploy_docs: false secrets: inherit \ No newline at end of file diff --git a/Makefile b/Makefile index 602afe3d..09d85db2 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,12 @@ -.PHONY: all format test install download upload docker documentation data validate-data calibrate calibrate-build publish-local-area upload-calibration upload-dataset upload-database push-to-modal build-matrices calibrate-modal calibrate-modal-national calibrate-both stage-h5s stage-national-h5 stage-all-h5s pipeline validate-staging validate-staging-full upload-validation check-staging check-sanity clean build paper clean-paper presentations database database-refresh promote-database promote-dataset promote build-h5s validate-local +.PHONY: all format test install download upload docker documentation data validate-data calibrate calibrate-build publish-local-area upload-calibration upload-dataset upload-database push-to-modal build-data-modal build-matrices calibrate-modal calibrate-modal-national calibrate-both stage-h5s stage-national-h5 stage-all-h5s pipeline validate-staging validate-staging-full upload-validation check-staging check-sanity clean build paper clean-paper presentations database database-refresh promote-database promote-dataset promote build-h5s validate-local -GPU ?= A100-80GB -EPOCHS ?= 200 +GPU ?= T4 +EPOCHS ?= 1000 NATIONAL_GPU ?= T4 -NATIONAL_EPOCHS ?= 200 +NATIONAL_EPOCHS ?= 4000 BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD) NUM_WORKERS ?= 8 +N_CLONES ?= 430 VERSION ?= HF_CLONE_DIR ?= $(HOME)/huggingface/policyengine-us-data @@ -87,9 +88,11 @@ promote-database: @echo "Copied DB and raw_inputs to HF clone. Now cd to HF repo, commit, and push." promote-dataset: - cp policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \ - $(HF_CLONE_DIR)/calibration/source_imputed_stratified_extended_cps.h5 - @echo "Copied dataset to HF clone. Now cd to HF repo, commit, and push." + python -c "from policyengine_us_data.utils.huggingface import upload; \ + upload('policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5', \ + 'policyengine/policyengine-us-data', \ + 'calibration/source_imputed_stratified_extended_cps.h5')" + @echo "Dataset promoted to HF." data: download python policyengine_us_data/utils/uprating.py @@ -155,59 +158,55 @@ upload-database: @echo "Database uploaded to HF." push-to-modal: - modal volume put local-area-staging \ + modal volume put pipeline-artifacts \ policyengine_us_data/storage/calibration/calibration_weights.npy \ - calibration_inputs/calibration/calibration_weights.npy --force - modal volume put local-area-staging \ - policyengine_us_data/storage/calibration/stacked_blocks.npy \ - calibration_inputs/calibration/stacked_blocks.npy --force - modal volume put local-area-staging \ - policyengine_us_data/storage/calibration/stacked_takeup.npz \ - calibration_inputs/calibration/stacked_takeup.npz --force - modal volume put local-area-staging \ + artifacts/calibration_weights.npy --force + modal volume put pipeline-artifacts \ policyengine_us_data/storage/calibration/policy_data.db \ - calibration_inputs/calibration/policy_data.db --force - modal volume put local-area-staging \ - policyengine_us_data/storage/calibration/geo_labels.json \ - calibration_inputs/calibration/geo_labels.json --force - modal volume put local-area-staging \ + artifacts/policy_data.db --force + modal volume put pipeline-artifacts \ policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5 \ - calibration_inputs/calibration/source_imputed_stratified_extended_cps.h5 --force - @echo "All calibration inputs pushed to Modal volume." + artifacts/source_imputed_stratified_extended_cps.h5 --force + @echo "All pipeline artifacts pushed to Modal volume." build-matrices: - modal run modal_app/remote_calibration_runner.py::build_package \ - --branch $(BRANCH) + modal run --detach modal_app/remote_calibration_runner.py::build_package \ + --branch $(BRANCH) --county-level --n-clones $(N_CLONES) calibrate-modal: - modal run modal_app/remote_calibration_runner.py::main \ + modal run --detach modal_app/remote_calibration_runner.py::main \ --branch $(BRANCH) --gpu $(GPU) --epochs $(EPOCHS) \ + --beta 0.65 --lambda-l0 1e-7 --lambda-l2 1e-8 --log-freq 500 \ + --target-config policyengine_us_data/calibration/target_config.yaml \ --push-results calibrate-modal-national: - modal run modal_app/remote_calibration_runner.py::main \ + modal run --detach modal_app/remote_calibration_runner.py::main \ --branch $(BRANCH) --gpu $(NATIONAL_GPU) \ --epochs $(NATIONAL_EPOCHS) \ + --beta 0.65 --lambda-l0 1e-4 --lambda-l2 1e-12 --log-freq 500 \ + --target-config policyengine_us_data/calibration/target_config.yaml \ --push-results --national calibrate-both: $(MAKE) calibrate-modal & $(MAKE) calibrate-modal-national & wait stage-h5s: - modal run modal_app/local_area.py::main \ - --branch $(BRANCH) --num-workers $(NUM_WORKERS) \ - $(if $(SKIP_DOWNLOAD),--skip-download) + modal run --detach modal_app/local_area.py::main \ + --branch $(BRANCH) --num-workers $(NUM_WORKERS) --n-clones $(N_CLONES) stage-national-h5: - modal run modal_app/local_area.py::main_national \ - --branch $(BRANCH) + modal run --detach modal_app/local_area.py::main_national \ + --branch $(BRANCH) --n-clones $(N_CLONES) stage-all-h5s: $(MAKE) stage-h5s & $(MAKE) stage-national-h5 & wait promote: + @echo "This will run the full Modal promote pipeline (local_area.py::main_promote)." + @read -p "Are you sure? [y/N] " confirm && [ "$$confirm" = "y" ] || (echo "Aborted."; exit 1) $(eval VERSION := $(or $(VERSION),$(shell python -c "import tomllib; print(tomllib.load(open('pyproject.toml','rb'))['project']['version'])"))) - modal run modal_app/local_area.py::main_promote \ + modal run --detach modal_app/local_area.py::main_promote \ --branch $(BRANCH) --version $(VERSION) validate-staging: @@ -231,12 +230,15 @@ check-sanity: python -m policyengine_us_data.calibration.validate_staging \ --sanity-only --area-type states --areas NC -pipeline: data upload-dataset build-matrices calibrate-both stage-all-h5s - @echo "" - @echo "========================================" - @echo "Pipeline complete. H5s are in HF staging." - @echo "Run 'Promote Local Area H5 Files' workflow in GitHub to publish." - @echo "========================================" +build-data-modal: + modal run --detach modal_app/data_build.py::main --branch $(BRANCH) --upload --skip-tests + +pipeline: + modal run --detach modal_app.pipeline::main \ + --action run --branch $(BRANCH) --gpu $(GPU) \ + --epochs $(EPOCHS) --national-gpu $(NATIONAL_GPU) \ + --national-epochs $(NATIONAL_EPOCHS) \ + --num-workers $(NUM_WORKERS) --n-clones $(N_CLONES) clean: rm -f policyengine_us_data/storage/*.h5 diff --git a/modal_app/README.md b/modal_app/README.md index 876f3610..730142f7 100644 --- a/modal_app/README.md +++ b/modal_app/README.md @@ -78,7 +78,6 @@ Every run produces these local files (whichever the calibration script emits): - **unified_diagnostics.csv** — Final per-target diagnostics - **calibration_log.csv** — Per-target metrics across epochs (requires `--log-freq`) - **unified_run_config.json** — Run configuration and summary stats -- **stacked_blocks.npy** — Census block assignments for stacked records ## Artifact Upload to HuggingFace @@ -88,7 +87,6 @@ atomic commit after writing them locally: | Local file | HF path | |------------|---------| | `calibration_weights.npy` | `calibration/calibration_weights.npy` | -| `stacked_blocks.npy` | `calibration/stacked_blocks.npy` | | `calibration_log.csv` | `calibration/logs/calibration_log.csv` | | `unified_diagnostics.csv` | `calibration/logs/unified_diagnostics.csv` | | `unified_run_config.json` | `calibration/logs/unified_run_config.json` | @@ -205,7 +203,6 @@ Artifacts uploaded to HF by `--push-results`: | Local file | HF path | |------------|---------| | `calibration_weights.npy` | `calibration/calibration_weights.npy` | -| `stacked_blocks.npy` | `calibration/stacked_blocks.npy` | | `calibration_log.csv` | `calibration/logs/calibration_log.csv` | | `unified_diagnostics.csv` | `calibration/logs/unified_diagnostics.csv` | | `unified_run_config.json` | `calibration/logs/unified_run_config.json` | diff --git a/modal_app/data_build.py b/modal_app/data_build.py index 20314e4d..8997ef57 100644 --- a/modal_app/data_build.py +++ b/modal_app/data_build.py @@ -20,14 +20,47 @@ create_if_missing=True, ) +# Shared pipeline volume for inter-step artifact transport +pipeline_volume = modal.Volume.from_name( + "pipeline-artifacts", + create_if_missing=True, +) +PIPELINE_MOUNT = "/pipeline" + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_IGNORE = [ + ".git", + "__pycache__", + "*.egg-info", + ".pytest_cache", + "*.h5", + "*.npy", + "*.pkl", + "*.db", + "node_modules", + "venv", + ".venv", + "docs/_build", + "paper", + "presentations", +] image = ( - modal.Image.debian_slim(python_version="3.13").apt_install("git").pip_install("uv") + modal.Image.debian_slim(python_version="3.13") + .apt_install("git") + .pip_install("uv>=0.8") + .add_local_dir( + str(_REPO_ROOT), + remote_path="/root/policyengine-us-data", + copy=True, + ignore=_IGNORE, + ) + .run_commands( + "cd /root/policyengine-us-data && UV_HTTP_TIMEOUT=300 uv sync --frozen" + ) ) -REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" VOLUME_MOUNT = "/checkpoints" _volume_lock = threading.Lock() -_DEFAULT_UV_HTTP_TIMEOUT = "1800" # Script to output file mapping for checkpointing # Values can be a single file path (str) or a list of file paths @@ -88,17 +121,25 @@ def setup_gcp_credentials(): return None -def _run_uv_sync(*args: str) -> None: - """Run uv sync with a higher default network timeout for large wheels.""" - env = os.environ.copy() - env.setdefault("UV_HTTP_TIMEOUT", _DEFAULT_UV_HTTP_TIMEOUT) - subprocess.run(["uv", "sync", *args], check=True, env=env) - - @functools.cache def get_current_commit() -> str: - """Get the current git commit SHA (cached per process).""" - return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + """Get the current git commit SHA (cached per process). + + Falls back to a hash of pyproject.toml version when .git + is not available (pre-baked Modal images exclude .git). + """ + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], text=True, stderr=subprocess.DEVNULL + ).strip() + except (subprocess.CalledProcessError, FileNotFoundError): + import hashlib + + version_file = Path("/root/policyengine-us-data/pyproject.toml") + if version_file.exists(): + content = version_file.read_bytes() + return hashlib.sha256(content).hexdigest()[:12] + return "unknown" def get_checkpoint_path(branch: str, output_file: str) -> Path: @@ -278,7 +319,10 @@ def run_tests_with_checkpoints( @app.function( image=image, secrets=[hf_secret, gcp_secret], - volumes={VOLUME_MOUNT: checkpoint_volume}, + volumes={ + VOLUME_MOUNT: checkpoint_volume, + PIPELINE_MOUNT: pipeline_volume, + }, memory=32768, cpu=8.0, timeout=14400, @@ -288,6 +332,8 @@ def build_datasets( branch: str = "main", sequential: bool = False, clear_checkpoints: bool = False, + skip_tests: bool = False, + skip_enhanced_cps: bool = False, ): """Build all datasets with preemption-resilient checkpointing. @@ -296,6 +342,9 @@ def build_datasets( branch: Git branch to build from. sequential: Use sequential (non-parallel) execution. clear_checkpoints: Clear existing checkpoints before starting. + skip_tests: Skip running the test suite (useful for calibration runs). + skip_enhanced_cps: Skip enhanced_cps.py and small_enhanced_cps.py + (useful for calibration runs that only need source_imputed H5). """ setup_gcp_credentials() @@ -309,9 +358,7 @@ def build_datasets( checkpoint_volume.commit() print(f"Cleared checkpoints for branch: {branch}") - os.chdir("/root") - subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True) - os.chdir("policyengine-us-data") + os.chdir("/root/policyengine-us-data") # Clean stale checkpoints from other commits branch_dir = Path(VOLUME_MOUNT) / branch @@ -323,9 +370,6 @@ def build_datasets( print(f"Removed stale checkpoint dir: {entry.name[:12]}") checkpoint_volume.commit() - # Use uv sync to install exact versions from uv.lock. - _run_uv_sync("--locked") - env = os.environ.copy() # Download prerequisites @@ -333,9 +377,22 @@ def build_datasets( "policyengine_us_data/storage/download_private_prerequisites.py", env=env, ) + # Checkpoint policy_data.db immediately after download so it survives + # test failures and can be restored on retries. + save_checkpoint( + branch, + "policyengine_us_data/storage/calibration/policy_data.db", + checkpoint_volume, + ) if sequential: for script, output in SCRIPT_OUTPUTS.items(): + if skip_enhanced_cps and script in ( + "policyengine_us_data/datasets/cps/enhanced_cps.py", + "policyengine_us_data/datasets/cps/small_enhanced_cps.py", + ): + print(f"Skipping {script} (--skip-enhanced-cps)") + continue run_script_with_checkpoint( script, output, @@ -417,55 +474,126 @@ def build_datasets( # GROUP 3: After extended_cps - run in parallel # enhanced_cps and stratified_cps both depend on extended_cps print("=== Phase 4: Building enhanced and stratified CPS (parallel) ===") + phase4_futures = [] with ThreadPoolExecutor(max_workers=2) as executor: - futures = [ + if not skip_enhanced_cps: + phase4_futures.append( + executor.submit( + run_script_with_checkpoint, + "policyengine_us_data/datasets/cps/enhanced_cps.py", + SCRIPT_OUTPUTS[ + "policyengine_us_data/datasets/cps/enhanced_cps.py" + ], + branch, + checkpoint_volume, + env=env, + ) + ) + else: + print("Skipping enhanced_cps.py (--skip-enhanced-cps)") + phase4_futures.append( executor.submit( run_script_with_checkpoint, - "policyengine_us_data/datasets/cps/enhanced_cps.py", - SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/enhanced_cps.py"], + "policyengine_us_data/calibration/create_stratified_cps.py", + SCRIPT_OUTPUTS[ + "policyengine_us_data/calibration/create_stratified_cps.py" + ], branch, checkpoint_volume, env=env, - ), + ) + ) + for future in as_completed(phase4_futures): + future.result() + + # GROUP 4: After Phase 4 - run in parallel + # create_source_imputed_cps needs stratified_cps + # small_enhanced_cps needs enhanced_cps + print( + "=== Phase 5: Building source imputed CPS " + "and small enhanced CPS (parallel) ===" + ) + phase5_futures = [] + with ThreadPoolExecutor(max_workers=2) as executor: + phase5_futures.append( executor.submit( run_script_with_checkpoint, - "policyengine_us_data/calibration/create_stratified_cps.py", + "policyengine_us_data/calibration/create_source_imputed_cps.py", SCRIPT_OUTPUTS[ - "policyengine_us_data/calibration/create_stratified_cps.py" + "policyengine_us_data/calibration/create_source_imputed_cps.py" ], branch, checkpoint_volume, env=env, - ), - ] - for future in as_completed(futures): + ) + ) + if not skip_enhanced_cps: + phase5_futures.append( + executor.submit( + run_script_with_checkpoint, + "policyengine_us_data/datasets/cps/small_enhanced_cps.py", + SCRIPT_OUTPUTS[ + "policyengine_us_data/datasets/cps/small_enhanced_cps.py" + ], + branch, + checkpoint_volume, + env=env, + ) + ) + else: + print("Skipping small_enhanced_cps.py (--skip-enhanced-cps)") + for future in as_completed(phase5_futures): future.result() - # SEQUENTIAL: Small enhanced CPS (needs enhanced_cps) - print("=== Phase 5: Building small enhanced CPS ===") - run_script_with_checkpoint( - "policyengine_us_data/datasets/cps/small_enhanced_cps.py", - SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/small_enhanced_cps.py"], - branch, - checkpoint_volume, - env=env, + # Copy pipeline artifacts to shared volume before tests so that a test + # failure does not block downstream calibration steps. + # Files selected: + # - source_imputed H5: main dataset for calibration and local area builds + # - policy_data.db: calibration target database + # - calibration_weights.npy: pre-existing weights for re-runs (if present) + print("Copying pipeline artifacts to shared volume...") + artifacts_dir = Path(PIPELINE_MOUNT) / "artifacts" + artifacts_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2( + "policyengine_us_data/storage/source_imputed_stratified_extended_cps_2024.h5", + artifacts_dir / "source_imputed_stratified_extended_cps.h5", + ) + shutil.copy2( + "policyengine_us_data/storage/calibration/policy_data.db", + artifacts_dir / "policy_data.db", + ) + cal_weights = Path("policyengine_us_data/storage/calibration_weights.npy") + if cal_weights.exists(): + shutil.copy2( + cal_weights, + artifacts_dir / "calibration_weights.npy", ) + print("Copied existing calibration_weights.npy to pipeline volume") + pipeline_volume.commit() + print("Pipeline artifacts committed to shared volume") # Run tests with checkpointing - print("=== Running tests with checkpointing ===") - run_tests_with_checkpoints(branch, checkpoint_volume, env) + if skip_tests: + print("Skipping tests (--skip-tests)") + else: + print("=== Running tests with checkpointing ===") + run_tests_with_checkpoints(branch, checkpoint_volume, env) - # Upload if requested + # Upload if requested (HF publication only) if upload: + upload_args = [] + if skip_enhanced_cps: + upload_args.append("--no-require-enhanced-cps") run_script( "policyengine_us_data/storage/upload_completed_datasets.py", + args=upload_args, env=env, ) # Clean up checkpoints after successful completion cleanup_checkpoints(branch, checkpoint_volume) - return "Data build and tests completed successfully" + return "Data build completed successfully" @app.local_entrypoint() @@ -474,11 +602,15 @@ def main( branch: str = "main", sequential: bool = False, clear_checkpoints: bool = False, + skip_tests: bool = False, + skip_enhanced_cps: bool = False, ): result = build_datasets.remote( upload=upload, branch=branch, sequential=sequential, clear_checkpoints=clear_checkpoints, + skip_tests=skip_tests, + skip_enhanced_cps=skip_enhanced_cps, ) print(result) diff --git a/modal_app/images.py b/modal_app/images.py new file mode 100644 index 00000000..5a1bac20 --- /dev/null +++ b/modal_app/images.py @@ -0,0 +1,51 @@ +"""Shared pre-baked Modal images for policyengine-us-data. + +Bakes source code and dependencies into image layers at build time. +Modal caches layers by content hash of copied files -- if code +changes, the image rebuilds; if not, the cached layer is reused. +""" + +import modal +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parent.parent + +_ignore = [ + ".git", + "__pycache__", + "*.egg-info", + ".pytest_cache", + "*.h5", + "*.npy", + "*.pkl", + "*.db", + "node_modules", + "venv", + ".venv", + "docs/_build", + "paper", + "presentations", +] + + +def _base_image(extras: list[str] | None = None): + extra_flags = " ".join(f"--extra {e}" for e in (extras or [])) + return ( + modal.Image.debian_slim(python_version="3.13") + .apt_install("git") + .pip_install("uv>=0.8") + .add_local_dir( + str(REPO_ROOT), + remote_path="/root/policyengine-us-data", + copy=True, + ignore=_ignore, + ) + .run_commands( + f"cd /root/policyengine-us-data && " + f"UV_HTTP_TIMEOUT=300 uv sync --frozen {extra_flags}" + ) + ) + + +cpu_image = _base_image() +gpu_image = _base_image(extras=["l0"]) diff --git a/modal_app/local_area.py b/modal_app/local_area.py index 7755615f..ea3355a1 100644 --- a/modal_app/local_area.py +++ b/modal_app/local_area.py @@ -28,15 +28,44 @@ create_if_missing=True, ) +pipeline_volume = modal.Volume.from_name( + "pipeline-artifacts", + create_if_missing=True, +) + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_IGNORE = [ + ".git", + "__pycache__", + "*.egg-info", + ".pytest_cache", + "*.h5", + "*.npy", + "*.pkl", + "*.db", + "node_modules", + "venv", + ".venv", + "docs/_build", + "paper", + "presentations", +] image = ( modal.Image.debian_slim(python_version="3.13") .apt_install("git") - .pip_install("uv", "tomli") + .pip_install("uv>=0.8") + .add_local_dir( + str(_REPO_ROOT), + remote_path="/root/policyengine-us-data", + copy=True, + ignore=_IGNORE, + ) + .run_commands( + "cd /root/policyengine-us-data && UV_HTTP_TIMEOUT=300 uv sync --frozen" + ) ) -REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" VOLUME_MOUNT = "/staging" -_DEFAULT_UV_HTTP_TIMEOUT = "1800" def setup_gcp_credentials(): @@ -51,47 +80,29 @@ def setup_gcp_credentials(): return None -def _run_uv_sync(*args: str) -> None: - """Run uv sync with a higher default network timeout for large wheels.""" - env = os.environ.copy() - env.setdefault("UV_HTTP_TIMEOUT", _DEFAULT_UV_HTTP_TIMEOUT) - subprocess.run(["uv", "sync", *args], check=True, env=env) - - def setup_repo(branch: str): - """Clone the repo at the requested branch and install deps. + """Change to the pre-baked repo directory. - Always clones fresh from GitHub so every container runs the - latest code — no stale image cache issues. + The branch parameter is kept for API compatibility but is + no longer used for cloning -- code is baked into the image. """ - repo_dir = Path("/root/policyengine-us-data") - - if repo_dir.exists(): - import shutil - - shutil.rmtree(repo_dir) - - os.chdir("/root") - subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True) - os.chdir("policyengine-us-data") - sha = subprocess.run( - ["git", "rev-parse", "HEAD"], - capture_output=True, - text=True, - ).stdout.strip() - print(f"Checked out {branch} at {sha[:8]}") - _run_uv_sync("--locked") + os.chdir("/root/policyengine-us-data") def validate_artifacts( config_path: Path, artifact_dir: Path, + filename_remap: Dict[str, str] = None, ) -> None: """Verify artifact checksums against unified_run_config.json. Args: config_path: Path to unified_run_config.json. artifact_dir: Directory containing the artifact files. + filename_remap: Optional mapping from config filenames to + actual filenames on disk (e.g. national weights are + stored as national_calibration_weights.npy but the + config records calibration_weights.npy). Raises: RuntimeError: If any artifact is missing or has a @@ -115,11 +126,13 @@ def validate_artifacts( print("WARNING: No artifacts section in run config, skipping validation") return + remap = filename_remap or {} for filename, expected_hash in artifacts.items(): - filepath = artifact_dir / filename + actual_filename = remap.get(filename, filename) + filepath = artifact_dir / actual_filename if not filepath.exists(): raise RuntimeError( - f"Artifact validation failed: {filename} not found in {artifact_dir}" + f"Artifact validation failed: {actual_filename} not found in {artifact_dir}" ) h = hashlib.sha256() with open(filepath, "rb") as fh: @@ -213,8 +226,16 @@ def run_phase( version: str, calibration_inputs: Dict[str, str], version_dir: Path, -) -> set: - """Run a single build phase, spawning workers and collecting results.""" + validate: bool = True, +) -> tuple: + """Run a single build phase, spawning workers and collecting results. + + Returns: + A tuple of (volume_completed, phase_errors, validation_rows) + where phase_errors is a list of error dicts from workers + and crashes, and validation_rows is a list of per-target + validation result dicts. + """ work_chunks = partition_work(states, districts, cities, num_workers, completed) total_remaining = sum(len(c) for c in work_chunks) @@ -223,7 +244,7 @@ def run_phase( if total_remaining == 0: print(f"All {phase_name} items already built!") - return completed + return completed, [], [] handles = [] for i, chunk in enumerate(work_chunks): @@ -233,12 +254,14 @@ def run_phase( version=version, work_items=chunk, calibration_inputs=calibration_inputs, + validate=validate, ) handles.append(handle) print(f"Waiting for {phase_name} workers to complete...") all_results = [] all_errors = [] + all_validation_rows = [] for i, handle in enumerate(handles): try: @@ -250,6 +273,11 @@ def run_phase( ) if result["errors"]: all_errors.extend(result["errors"]) + # Collect validation rows + v_rows = result.get("validation_rows", []) + if v_rows: + all_validation_rows.extend(v_rows) + print(f" Worker {i}: {len(v_rows)} validation rows") except Exception as e: all_errors.append({"worker": i, "error": str(e)}) print(f" Worker {i}: CRASHED - {e}") @@ -276,13 +304,16 @@ def run_phase( if len(all_errors) > 5: print(f" ... and {len(all_errors) - 5} more") - return volume_completed + return volume_completed, all_errors, all_validation_rows @app.function( image=image, secrets=[hf_secret, gcp_secret], - volumes={VOLUME_MOUNT: staging_volume}, + volumes={ + VOLUME_MOUNT: staging_volume, + "/pipeline": pipeline_volume, + }, memory=16384, cpu=4.0, timeout=14400, @@ -292,6 +323,7 @@ def build_areas_worker( version: str, work_items: List[Dict], calibration_inputs: Dict[str, str], + validate: bool = True, ) -> Dict: """ Worker function that builds a subset of H5 files. @@ -321,17 +353,26 @@ def build_areas_worker( "--output-dir", str(output_dir), ] - if "geography" not in calibration_inputs: - raise RuntimeError( - "geography.npz path missing from calibration_inputs. " - "Re-run calibration to generate this artifact." - ) + if "n_clones" in calibration_inputs: + worker_cmd.extend(["--n-clones", str(calibration_inputs["n_clones"])]) + if "seed" in calibration_inputs: + worker_cmd.extend(["--seed", str(calibration_inputs["seed"])]) + repo_root = Path("/root/policyengine-us-data") + cal_dir = repo_root / "policyengine_us_data" / "calibration" worker_cmd.extend( [ - "--geography-path", - calibration_inputs["geography"], + "--target-config", + str(cal_dir / "target_config.yaml"), ] ) + worker_cmd.extend( + [ + "--validation-config", + str(cal_dir / "target_config_full.yaml"), + ] + ) + if not validate: + worker_cmd.append("--no-validate") result = subprocess.run( worker_cmd, capture_output=True, @@ -575,7 +616,10 @@ def promote_publish(branch: str = "main", version: str = "") -> str: @app.function( image=image, secrets=[hf_secret, gcp_secret], - volumes={VOLUME_MOUNT: staging_volume}, + volumes={ + VOLUME_MOUNT: staging_volume, + "/pipeline": pipeline_volume, + }, memory=8192, timeout=86400, ) @@ -583,8 +627,9 @@ def coordinate_publish( branch: str = "main", num_workers: int = 8, skip_upload: bool = False, - skip_download: bool = False, -) -> str: + n_clones: int = 430, + validate: bool = True, +) -> Dict: """Coordinate the full publishing workflow.""" setup_gcp_credentials() setup_repo(branch) @@ -602,79 +647,34 @@ def coordinate_publish( shutil.rmtree(version_dir) version_dir.mkdir(parents=True, exist_ok=True) - calibration_dir = staging_dir / "calibration_inputs" - - # hf_hub_download preserves directory structure, so files are in calibration/ subdir - weights_path = calibration_dir / "calibration" / "calibration_weights.npy" - db_path = calibration_dir / "calibration" / "policy_data.db" - - if skip_download: - print("Verifying pre-pushed calibration inputs...") - staging_volume.reload() - dataset_path = ( - calibration_dir - / "calibration" - / "source_imputed_stratified_extended_cps.h5" - ) - required = { - "weights": weights_path, - "dataset": dataset_path, - "database": db_path, - "geography": (calibration_dir / "calibration" / "geography.npz"), - "run_config": (calibration_dir / "calibration" / "unified_run_config.json"), - } - for label, p in required.items(): - if not p.exists(): - raise RuntimeError(f"Missing required calibration input ({label}): {p}") - print("All required calibration inputs found on volume.") - else: - if calibration_dir.exists(): - shutil.rmtree(calibration_dir) - calibration_dir.mkdir(parents=True, exist_ok=True) - - print("Downloading calibration inputs from HuggingFace...") - result = subprocess.run( - [ - "uv", - "run", - "python", - "-c", - f""" -from policyengine_us_data.utils.huggingface import download_calibration_inputs -download_calibration_inputs("{calibration_dir}") -print("Done") -""", - ], - text=True, - env=os.environ.copy(), - ) - if result.returncode != 0: - raise RuntimeError(f"Download failed: {result.stderr}") - staging_volume.commit() - print("Calibration inputs downloaded") - - dataset_path = ( - calibration_dir / "calibration" / "source_imputed_stratified_extended_cps.h5" - ) + pipeline_volume.reload() + artifacts = Path("/pipeline/artifacts") + weights_path = artifacts / "calibration_weights.npy" + db_path = artifacts / "policy_data.db" + dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5" + config_json_path = artifacts / "unified_run_config.json" + + required = { + "weights": weights_path, + "dataset": dataset_path, + "database": db_path, + } + for label, p in required.items(): + if not p.exists(): + raise RuntimeError( + f"Missing {label} on pipeline volume: {p}. " + f"Run upstream pipeline steps first." + ) + print("All required pipeline artifacts found on volume.") - geo_npz_path = calibration_dir / "calibration" / "geography.npz" - config_json_path = calibration_dir / "calibration" / "unified_run_config.json" calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), "database": str(db_path), + "n_clones": n_clones, + "seed": 42, } - if not geo_npz_path.exists(): - raise RuntimeError( - f"geography.npz not found at {geo_npz_path}. " - f"Re-run calibration to generate this artifact." - ) - calibration_inputs["geography"] = str(geo_npz_path) - print(f"Geography artifact found: {geo_npz_path}") - validate_artifacts( - config_json_path, - calibration_dir / "calibration", - ) + validate_artifacts(config_json_path, artifacts) result = subprocess.run( [ "uv", @@ -721,9 +721,13 @@ def coordinate_publish( version=version, calibration_inputs=calibration_inputs, version_dir=version_dir, + validate=validate, ) - completed = run_phase( + accumulated_errors = [] + accumulated_validation_rows = [] + + completed, phase_errors, v_rows = run_phase( "States", states=states, districts=[], @@ -731,8 +735,10 @@ def coordinate_publish( completed=completed, **phase_args, ) + accumulated_errors.extend(phase_errors) + accumulated_validation_rows.extend(v_rows) - completed = run_phase( + completed, phase_errors, v_rows = run_phase( "Districts", states=[], districts=districts, @@ -740,8 +746,10 @@ def coordinate_publish( completed=completed, **phase_args, ) + accumulated_errors.extend(phase_errors) + accumulated_validation_rows.extend(v_rows) - completed = run_phase( + completed, phase_errors, v_rows = run_phase( "Cities", states=[], districts=[], @@ -749,6 +757,18 @@ def coordinate_publish( completed=completed, **phase_args, ) + accumulated_errors.extend(phase_errors) + accumulated_validation_rows.extend(v_rows) + + # Fail if any workers crashed (not just missing files) + if accumulated_errors: + crash_errors = [e for e in accumulated_errors if "worker" in e] + if crash_errors: + raise RuntimeError( + f"Build failed: {len(crash_errors)} worker " + f"crash(es) detected across all phases. " + f"Errors: {crash_errors[:3]}" + ) expected_total = len(states) + len(districts) + len(cities) if len(completed) < expected_total: @@ -761,7 +781,10 @@ def coordinate_publish( if skip_upload: print("\nSkipping upload (--skip-upload flag set)") - return f"Build complete for version {version}. Upload skipped." + return { + "message": (f"Build complete for version {version}. Upload skipped."), + "validation_rows": accumulated_validation_rows, + } print("\nValidating staging...") manifest = validate_staging.remote(branch=branch, version=version) @@ -793,7 +816,10 @@ def coordinate_publish( ) print("=" * 60) - return result + return { + "message": result, + "validation_rows": accumulated_validation_rows, + } @app.local_entrypoint() @@ -801,28 +827,36 @@ def main( branch: str = "main", num_workers: int = 8, skip_upload: bool = False, - skip_download: bool = False, + n_clones: int = 430, ): """Local entrypoint for Modal CLI.""" result = coordinate_publish.remote( branch=branch, num_workers=num_workers, skip_upload=skip_upload, - skip_download=skip_download, + n_clones=n_clones, ) - print(result) + if isinstance(result, dict): + print(result.get("message", result)) + else: + print(result) @app.function( image=image, secrets=[hf_secret, gcp_secret], - volumes={VOLUME_MOUNT: staging_volume}, + volumes={ + VOLUME_MOUNT: staging_volume, + "/pipeline": pipeline_volume, + }, memory=16384, timeout=14400, ) def coordinate_national_publish( branch: str = "main", -) -> str: + n_clones: int = 430, + validate: bool = True, +) -> Dict: """Build and upload a national US.h5 from national weights.""" setup_gcp_credentials() setup_repo(branch) @@ -830,63 +864,41 @@ def coordinate_national_publish( version = get_version() print(f"Building national H5 for version {version} from branch {branch}") - import shutil - staging_dir = Path(VOLUME_MOUNT) - calibration_dir = staging_dir / "national_calibration_inputs" - if calibration_dir.exists(): - shutil.rmtree(calibration_dir) - calibration_dir.mkdir(parents=True, exist_ok=True) - - print("Downloading national calibration inputs from HF...") - result = subprocess.run( - [ - "uv", - "run", - "python", - "-c", - f""" -from policyengine_us_data.utils.huggingface import ( - download_calibration_inputs, -) -download_calibration_inputs("{calibration_dir}", prefix="national_") -print("Done") -""", - ], - text=True, - env=os.environ.copy(), - ) - if result.returncode != 0: - raise RuntimeError(f"Download failed: {result.stderr}") - staging_volume.commit() - print("National calibration inputs downloaded") - weights_path = calibration_dir / "calibration" / "national_calibration_weights.npy" - db_path = calibration_dir / "calibration" / "policy_data.db" - dataset_path = ( - calibration_dir / "calibration" / "source_imputed_stratified_extended_cps.h5" - ) + pipeline_volume.reload() + artifacts = Path("/pipeline/artifacts") + weights_path = artifacts / "national_calibration_weights.npy" + db_path = artifacts / "policy_data.db" + dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5" + config_json_path = artifacts / "national_unified_run_config.json" + + required = { + "weights": weights_path, + "dataset": dataset_path, + "database": db_path, + } + for label, p in required.items(): + if not p.exists(): + raise RuntimeError( + f"Missing {label} on pipeline volume: {p}. " + f"Run upstream pipeline steps first." + ) + print("All required national pipeline artifacts found.") - geo_npz_path = calibration_dir / "calibration" / "national_geography.npz" - config_json_path = ( - calibration_dir / "calibration" / "national_unified_run_config.json" - ) calibration_inputs = { "weights": str(weights_path), "dataset": str(dataset_path), "database": str(db_path), + "n_clones": n_clones, + "seed": 42, } - if not geo_npz_path.exists(): - raise RuntimeError( - f"national_geography.npz not found at " - f"{geo_npz_path}. Re-run national calibration " - f"to generate this artifact." - ) - calibration_inputs["geography"] = str(geo_npz_path) - print(f"National geography artifact found: {geo_npz_path}") validate_artifacts( config_json_path, - calibration_dir / "calibration", + artifacts, + filename_remap={ + "calibration_weights.npy": "national_calibration_weights.npy", + }, ) version_dir = staging_dir / version version_dir.mkdir(parents=True, exist_ok=True) @@ -898,6 +910,7 @@ def coordinate_national_publish( version=version, work_items=work_items, calibration_inputs=calibration_inputs, + validate=validate, ) print( @@ -914,6 +927,45 @@ def coordinate_national_publish( if not national_h5.exists(): raise RuntimeError(f"Expected {national_h5} not found after build") + # Compute SHA256 checksum before upload for integrity verification + import hashlib + + h = hashlib.sha256() + with open(national_h5, "rb") as fh: + for chunk in iter(lambda: fh.read(1 << 20), b""): + h.update(chunk) + national_checksum = f"sha256:{h.hexdigest()}" + national_size = national_h5.stat().st_size + print(f"National H5 checksum: {national_checksum} ({national_size:,} bytes)") + + # ── National validation ── + national_validation_output = "" + if validate: + print("Running national H5 validation...") + val_result = subprocess.run( + [ + "uv", + "run", + "python", + "-m", + "policyengine_us_data.calibration.validate_national_h5", + "--h5-path", + str(national_h5), + ], + capture_output=True, + text=True, + env=os.environ.copy(), + ) + national_validation_output = val_result.stdout + print(val_result.stdout) + if val_result.stderr: + print(val_result.stderr) + if val_result.returncode != 0: + print( + "WARNING: National validation returned " + f"non-zero exit code: {val_result.returncode}" + ) + print(f"Uploading {national_h5} to HF staging...") result = subprocess.run( [ @@ -938,18 +990,33 @@ def coordinate_national_publish( if result.returncode != 0: raise RuntimeError(f"Staging upload failed: {result.stderr}") - print("National H5 staged. Run promote workflow to publish.") - return ( - f"National US.h5 built and staged for version {version}. " - f"Run main_national_promote to publish." + # Verify the file still exists on the volume after upload + staging_volume.reload() + if not national_h5.exists(): + raise RuntimeError("National H5 disappeared from staging volume after upload") + print( + f"Post-upload verification passed: {national_h5} " + f"(checksum: {national_checksum})" ) + print("National H5 staged. Run promote workflow to publish.") + return { + "message": ( + f"National US.h5 built and staged for version " + f"{version}. Run main_national_promote to publish." + ), + "national_validation": national_validation_output, + } + @app.local_entrypoint() -def main_national(branch: str = "main"): +def main_national(branch: str = "main", n_clones: int = 430): """Build and stage national US.h5.""" - result = coordinate_national_publish.remote(branch=branch) - print(result) + result = coordinate_national_publish.remote(branch=branch, n_clones=n_clones) + if isinstance(result, dict): + print(result.get("message", result)) + else: + print(result) @app.function( diff --git a/modal_app/pipeline.py b/modal_app/pipeline.py new file mode 100644 index 00000000..83cac802 --- /dev/null +++ b/modal_app/pipeline.py @@ -0,0 +1,1262 @@ +""" +End-to-end versioned pipeline orchestrator for Modal. + +Chains all dataset-building steps (build datasets, build calibration +package, fit weights, build H5s, stage, promote) into a single +coordinated run with diagnostics, resume support, and atomic +promotion. + +**Stability assumption**: This pipeline is designed for production +use when the target branch is stable and not expected to change +during the run. All steps clone from branch tip independently; +artifacts flow through the shared pipeline volume. The run's +metadata records the SHA at orchestrator start for auditability. +If the branch changes mid-run, intermediate artifacts may come +from different commits. For development branches that are actively +changing, run individual steps manually instead. + +Usage: + # Full pipeline run + modal run --detach modal_app/pipeline.py::main \\ + --action run --branch main --gpu T4 --epochs 200 + + # Check status + modal run modal_app/pipeline.py::main --action status + + # Resume a failed run + modal run --detach modal_app/pipeline.py::main \\ + --action run --resume-run-id + + # Promote a completed run + modal run modal_app/pipeline.py::main \\ + --action promote --run-id +""" + +import json +import os +import subprocess +import time +import traceback +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from io import BytesIO +from pathlib import Path +from typing import Optional + +import modal + +# ── Modal resources ────────────────────────────────────────────── + +app = modal.App("policyengine-us-data-pipeline") + +hf_secret = modal.Secret.from_name("huggingface-token") +gcp_secret = modal.Secret.from_name("gcp-credentials") + +pipeline_volume = modal.Volume.from_name("pipeline-artifacts", create_if_missing=True) +staging_volume = modal.Volume.from_name("local-area-staging", create_if_missing=True) + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_IGNORE = [ + ".git", + "__pycache__", + "*.egg-info", + ".pytest_cache", + "*.h5", + "*.npy", + "*.pkl", + "*.db", + "node_modules", + "venv", + ".venv", + "docs/_build", + "paper", + "presentations", +] +image = ( + modal.Image.debian_slim(python_version="3.13") + .apt_install("git") + .pip_install("uv>=0.8") + .add_local_dir( + str(_REPO_ROOT), + remote_path="/root/policyengine-us-data", + copy=True, + ignore=_IGNORE, + ) + .run_commands( + "cd /root/policyengine-us-data && UV_HTTP_TIMEOUT=300 uv sync --frozen" + ) +) + +REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" +PIPELINE_MOUNT = "/pipeline" +STAGING_MOUNT = "/staging" +ARTIFACTS_DIR = f"{PIPELINE_MOUNT}/artifacts" +RUNS_DIR = f"{PIPELINE_MOUNT}/runs" + + +# ── Run metadata ───────────────────────────────────────────────── + + +@dataclass +class RunMetadata: + """Metadata for a pipeline run. + + Tracks run identity, progress, and diagnostics for + auditability and resume support. + """ + + run_id: str + branch: str + sha: str + version: str + start_time: str + status: str # running | completed | failed | promoted + step_timings: dict = field(default_factory=dict) + error: Optional[str] = None + + def to_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "RunMetadata": + return cls(**data) + + +def generate_run_id(version: str, sha: str) -> str: + """Generate a unique run ID. + + Format: {version}_{sha[:8]}_{YYYYMMDD_HHMMSS} + """ + ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + return f"{version}_{sha[:8]}_{ts}" + + +def write_run_meta( + meta: RunMetadata, + vol: modal.Volume, +) -> None: + """Write run metadata to the pipeline volume.""" + run_dir = Path(RUNS_DIR) / meta.run_id + run_dir.mkdir(parents=True, exist_ok=True) + meta_path = run_dir / "meta.json" + with open(meta_path, "w") as f: + json.dump(meta.to_dict(), f, indent=2) + vol.commit() + + +def read_run_meta( + run_id: str, + vol: modal.Volume, +) -> RunMetadata: + """Read run metadata from the pipeline volume.""" + vol.reload() + meta_path = Path(RUNS_DIR) / run_id / "meta.json" + if not meta_path.exists(): + raise FileNotFoundError(f"No metadata found for run {run_id} at {meta_path}") + with open(meta_path) as f: + return RunMetadata.from_dict(json.load(f)) + + +def get_pinned_sha(branch: str) -> str: + """Get the current tip SHA for a branch from GitHub.""" + result = subprocess.run( + [ + "git", + "ls-remote", + REPO_URL, + f"refs/heads/{branch}", + ], + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to get SHA for branch {branch}: {result.stderr}") + line = result.stdout.strip() + if not line: + raise RuntimeError(f"Branch {branch} not found in remote") + return line.split()[0] + + +def get_version_from_branch(branch: str) -> str: + """Get the package version from the pre-baked pyproject.toml. + + The branch parameter is kept for API compatibility but is + no longer used -- version comes from the baked source. + """ + import tomllib + + pyproject_path = "/root/policyengine-us-data/pyproject.toml" + with open(pyproject_path, "rb") as f: + pyproject = tomllib.load(f) + return pyproject["project"]["version"] + + +def archive_diagnostics( + run_id: str, + result_bytes: dict, + vol: modal.Volume, + prefix: str = "", +) -> None: + """Archive calibration diagnostics to the run directory.""" + diag_dir = Path(RUNS_DIR) / run_id / "diagnostics" + diag_dir.mkdir(parents=True, exist_ok=True) + + file_map = { + "log": f"{prefix}unified_diagnostics.csv", + "cal_log": f"{prefix}calibration_log.csv", + "config": f"{prefix}unified_run_config.json", + } + + for key, filename in file_map.items(): + data = result_bytes.get(key) + if data: + path = diag_dir / filename + with open(path, "wb") as f: + f.write(data) + print(f" Archived {filename} ({len(data):,} bytes)") + + vol.commit() + + +def _step_completed(meta: RunMetadata, step: str) -> bool: + """Check if a step is marked completed in metadata.""" + timing = meta.step_timings.get(step, {}) + return timing.get("status") == "completed" + + +def _record_step( + meta: RunMetadata, + step: str, + start: float, + vol: modal.Volume, + status: str = "completed", +) -> None: + """Record step timing and status in metadata.""" + meta.step_timings[step] = { + "start": datetime.fromtimestamp(start, tz=timezone.utc).isoformat(), + "end": datetime.now(timezone.utc).isoformat(), + "duration_s": round(time.time() - start, 1), + "status": status, + } + write_run_meta(meta, vol) + + +# ── Include other Modal apps ───────────────────────────────────── +# app.include() merges functions from other apps into this one, +# ensuring Modal mounts their files and registers their functions +# (with their GPU/memory/volume configs) in the ephemeral run. +# +# Inside Modal containers the auto-mounted package root may not be +# on sys.path when the module first loads; ensure it is importable. +import sys + +_parent = str(Path(__file__).resolve().parent.parent) +if _parent not in sys.path: + sys.path.insert(0, _parent) +# The image bakes the repo at /root/policyengine-us-data, but Modal +# auto-mounts the entrypoint elsewhere, so _parent may not contain +# modal_app/. Ensure the baked repo root is always importable. +_baked = "/root/policyengine-us-data" +if _baked not in sys.path: + sys.path.insert(0, _baked) + +from modal_app.data_build import app as _data_build_app +from modal_app.data_build import build_datasets + +app.include(_data_build_app) + +from modal_app.remote_calibration_runner import app as _calibration_app +from modal_app.remote_calibration_runner import ( + build_package_remote, + PACKAGE_GPU_FUNCTIONS, +) + +app.include(_calibration_app) + +from modal_app.local_area import app as _local_area_app +from modal_app.local_area import ( + coordinate_publish, + coordinate_national_publish, + promote_publish, + promote_national_publish, +) + +app.include(_local_area_app) + + +# ── Stage base datasets ───────────────────────────────────────── + + +def _setup_repo() -> None: + """Change to the pre-baked repo directory.""" + os.chdir("/root/policyengine-us-data") + + +def stage_base_datasets( + run_id: str, + version: str, + branch: str, +) -> None: + """Upload source_imputed + policy_data.db from pipeline + volume to HF staging/. + + Clones the repo and shells out to upload_to_staging_hf() + via subprocess, consistent with other Modal apps. + + Args: + run_id: The current run ID (for logging). + version: Package version string for the commit. + branch: Git branch for repo clone. + """ + artifacts = Path(ARTIFACTS_DIR) + + source_imputed = artifacts / "source_imputed_stratified_extended_cps.h5" + policy_db = artifacts / "policy_data.db" + + files_with_paths = [] + if source_imputed.exists(): + files_with_paths.append( + ( + str(source_imputed), + "calibration/source_imputed_stratified_extended_cps.h5", + ) + ) + print(f" source_imputed: {source_imputed.stat().st_size:,} bytes") + else: + print(" WARNING: source_imputed not found, skipping") + + if policy_db.exists(): + files_with_paths.append((str(policy_db), "calibration/policy_data.db")) + print(f" policy_data.db: {policy_db.stat().st_size:,} bytes") + else: + print(" WARNING: policy_data.db not found, skipping") + + if not files_with_paths: + print(" No base datasets to stage") + return + + _setup_repo() + + # Build the upload script as a Python snippet + import json as _json + + pairs_json = _json.dumps(files_with_paths) + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +import json +from policyengine_us_data.utils.data_upload import ( + upload_to_staging_hf, +) + +pairs = json.loads('''{pairs_json}''') +files_with_paths = [(p, r) for p, r in pairs] +count = upload_to_staging_hf(files_with_paths, "{version}") +print(f"Staged {{count}} base dataset(s) to HF") +""", + ], + cwd="/root/policyengine-us-data", + text=True, + capture_output=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(f"Base dataset staging failed: {result.stderr}") + print(f" {result.stdout.strip()}") + + +def upload_run_diagnostics( + run_id: str, + branch: str, +) -> None: + """Upload run diagnostics to HF for archival. + + Shells out via subprocess for consistency with other + Modal apps and to avoid package dependencies in the + orchestrator image. + + Args: + run_id: The current run ID. + branch: Git branch for repo clone. + """ + diag_dir = Path(RUNS_DIR) / run_id / "diagnostics" + if not diag_dir.exists(): + print(" No diagnostics to upload") + return + + files = list(diag_dir.glob("*")) + if not files: + print(" No diagnostic files found") + return + + print(f" Found {len(files)} diagnostic file(s) to upload") + + # Build file list as JSON for the subprocess + import json as _json + + file_entries = [ + (str(f), f"calibration/runs/{run_id}/diagnostics/{f.name}") for f in files + ] + entries_json = _json.dumps(file_entries) + + _setup_repo() + + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +import json, os +from huggingface_hub import HfApi + +entries = json.loads('''{entries_json}''') +api = HfApi() +token = os.environ.get("HUGGING_FACE_TOKEN") +for local_path, repo_path in entries: + api.upload_file( + path_or_fileobj=local_path, + path_in_repo=repo_path, + repo_id="policyengine/policyengine-us-data", + repo_type="model", + token=token, + ) + print(f"Uploaded {{repo_path}}") +""", + ], + cwd="/root/policyengine-us-data", + capture_output=True, + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(f"Diagnostics upload failed: {result.stderr}") + print(f" {result.stdout.strip()}") + + +def _write_validation_diagnostics( + run_id: str, + regional_result, + national_result, + meta: RunMetadata, + vol: modal.Volume, +) -> None: + """Aggregate validation rows into a diagnostics CSV. + + Extracts validation_rows from coordinate_publish and + national_validation from coordinate_national_publish, + writes them to runs/{run_id}/diagnostics/validation_results.csv, + and records a summary in meta.json. + """ + import csv + + validation_rows = [] + + # Extract regional validation rows + if isinstance(regional_result, dict): + v_rows = regional_result.get("validation_rows", []) + if v_rows: + validation_rows.extend(v_rows) + print(f" Collected {len(v_rows)} regional validation rows") + + # Extract national validation output + national_output = "" + if isinstance(national_result, dict): + national_output = national_result.get("national_validation", "") + if national_output: + print(" National validation output captured") + + if not validation_rows and not national_output: + print(" No validation data to write") + return + + diag_dir = Path(RUNS_DIR) / run_id / "diagnostics" + diag_dir.mkdir(parents=True, exist_ok=True) + + # Write regional validation CSV + if validation_rows: + csv_columns = [ + "area_type", + "area_id", + "district", + "variable", + "target_name", + "period", + "target_value", + "sim_value", + "error", + "rel_error", + "abs_error", + "rel_abs_error", + "sanity_check", + "sanity_reason", + "in_training", + ] + csv_path = diag_dir / "validation_results.csv" + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=csv_columns) + writer.writeheader() + for row in validation_rows: + writer.writerow({k: row.get(k, "") for k in csv_columns}) + print(f" Wrote {len(validation_rows)} rows to {csv_path}") + + # Compute summary + n_sanity_fail = sum( + 1 for r in validation_rows if r.get("sanity_check") == "FAIL" + ) + rae_vals = [ + r["rel_abs_error"] + for r in validation_rows + if isinstance(r.get("rel_abs_error"), (int, float)) + and r["rel_abs_error"] != float("inf") + ] + mean_rae = sum(rae_vals) / len(rae_vals) if rae_vals else 0.0 + + # Per-area summaries for worst areas + area_stats = {} + for r in validation_rows: + key = f"{r.get('area_type', '')}:{r.get('area_id', '')}" + if key not in area_stats: + area_stats[key] = {"rae_vals": [], "fails": 0} + if r.get("sanity_check") == "FAIL": + area_stats[key]["fails"] += 1 + rae = r.get("rel_abs_error") + if isinstance(rae, (int, float)) and rae != float("inf"): + area_stats[key]["rae_vals"].append(rae) + + worst_areas = sorted( + area_stats.items(), + key=lambda x: ( + sum(x[1]["rae_vals"]) / len(x[1]["rae_vals"]) if x[1]["rae_vals"] else 0 + ), + reverse=True, + )[:5] + + validation_summary = { + "total_targets": len(validation_rows), + "sanity_failures": n_sanity_fail, + "mean_rel_abs_error": round(mean_rae, 4), + "worst_areas": [ + { + "area": k, + "mean_rae": round( + ( + sum(v["rae_vals"]) / len(v["rae_vals"]) + if v["rae_vals"] + else 0 + ), + 4, + ), + "sanity_fails": v["fails"], + } + for k, v in worst_areas + ], + } + + print( + f" Validation summary: " + f"{len(validation_rows)} targets, " + f"{n_sanity_fail} sanity failures, " + f"mean RAE={mean_rae:.4f}" + ) + + # Record in meta.json + meta.step_timings["validation"] = validation_summary + write_run_meta(meta, vol) + + # Write national validation output + if national_output: + nat_path = diag_dir / "national_validation.txt" + with open(nat_path, "w") as f: + f.write(national_output) + print(f" Wrote national validation to {nat_path}") + + vol.commit() + + +# ── Orchestrator ───────────────────────────────────────────────── + + +@app.function( + image=image, + cpu=2, + memory=4096, + timeout=86400, # 24 hours (Modal max) + volumes={ + PIPELINE_MOUNT: pipeline_volume, + STAGING_MOUNT: staging_volume, + }, + secrets=[hf_secret, gcp_secret], +) +def run_pipeline( + branch: str = "main", + gpu: str = "T4", + epochs: int = 1000, + national_gpu: str = "T4", + national_epochs: int = 4000, + num_workers: int = 8, + n_clones: int = 430, + skip_national: bool = False, + resume_run_id: str = None, +) -> str: + """Run the full pipeline end-to-end. + + Args: + branch: Git branch to build from. + gpu: GPU type for regional calibration. + epochs: Training epochs for regional calibration. + national_gpu: GPU type for national calibration. + national_epochs: Training epochs for national. + num_workers: Number of parallel H5 workers. + n_clones: Number of clones for H5 building. + skip_national: Skip national calibration/H5. + resume_run_id: Resume a previously failed run. + + Returns: + The run ID for use with promote. + """ + # ── Setup GCP credentials ── + creds_json = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON") + if creds_json: + creds_path = "/tmp/gcp-credentials.json" + with open(creds_path, "w") as f: + f.write(creds_json) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_path + + # ── Initialize or resume run ── + sha = get_pinned_sha(branch) + version = get_version_from_branch(branch) + + if resume_run_id: + print(f"Resuming run {resume_run_id}...") + meta = read_run_meta(resume_run_id, pipeline_volume) + if meta.sha != sha: + raise RuntimeError( + f"Branch {branch} has moved since run " + f"started.\n" + f" Run SHA: {meta.sha[:12]}\n" + f" Current SHA: {sha[:12]}\n" + f"Start a fresh run instead." + ) + meta.status = "running" + run_id = resume_run_id + else: + run_id = generate_run_id(version, sha) + meta = RunMetadata( + run_id=run_id, + branch=branch, + sha=sha, + version=version, + start_time=datetime.now(timezone.utc).isoformat(), + status="running", + ) + + # Create run directory + run_dir = Path(RUNS_DIR) / run_id + run_dir.mkdir(parents=True, exist_ok=True) + (run_dir / "diagnostics").mkdir(exist_ok=True) + + # Create artifacts directory + Path(ARTIFACTS_DIR).mkdir(parents=True, exist_ok=True) + + write_run_meta(meta, pipeline_volume) + + print("=" * 60) + print("PIPELINE RUN") + print("=" * 60) + print(f" Run ID: {run_id}") + print(f" Branch: {branch}") + print(f" SHA: {sha[:12]}") + print(f" Version: {version}") + print(f" GPU: {gpu} (regional)") + if not skip_national: + print(f" GPU: {national_gpu} (national)") + print(f" Epochs: {epochs}") + print(f" Workers: {num_workers}") + print(f" Clones: {n_clones}") + if resume_run_id: + completed = [ + s for s, t in meta.step_timings.items() if t.get("status") == "completed" + ] + print(f" Resume: skipping {completed}") + print("=" * 60) + + try: + # ── Step 1: Build datasets ── + if not _step_completed(meta, "build_datasets"): + print("\n[Step 1/5] Building datasets...") + step_start = time.time() + + build_datasets.remote( + upload=True, + branch=branch, + sequential=False, + skip_tests=True, + skip_enhanced_cps=False, + ) + + # The build_datasets step produces files in its + # own volume. Key outputs (source_imputed, + # policy_data.db) are staged to HF in step 4. + # TODO(#617): When pipeline_artifacts.py lands, + # call mirror_to_pipeline() here for audit trail. + _record_step( + meta, + "build_datasets", + step_start, + pipeline_volume, + ) + print( + f" Completed in {meta.step_timings['build_datasets']['duration_s']}s" + ) + else: + print("\n[Step 1/5] Build datasets (skipped - completed)") + + # ── Step 2: Build calibration package ── + if not _step_completed(meta, "build_package"): + print("\n[Step 2/5] Building calibration package...") + step_start = time.time() + + pkg_path = build_package_remote.remote( + branch=branch, + workers=num_workers, + n_clones=n_clones, + ) + print(f" Package at: {pkg_path}") + + _record_step( + meta, + "build_package", + step_start, + pipeline_volume, + ) + print(f" Completed in {meta.step_timings['build_package']['duration_s']}s") + else: + print("\n[Step 2/5] Build package (skipped - completed)") + + # ── Step 3: Fit weights (parallel) ── + if not _step_completed(meta, "fit_weights"): + print("\n[Step 3/5] Fitting calibration weights...") + step_start = time.time() + + vol_path = "/pipeline/artifacts/calibration_package.pkl" + target_cfg = "policyengine_us_data/calibration/target_config.yaml" + + # Spawn regional fit + regional_func = PACKAGE_GPU_FUNCTIONS[gpu] + print(f" Spawning regional fit ({gpu}, {epochs} epochs)...") + regional_handle = regional_func.spawn( + branch=branch, + epochs=epochs, + volume_package_path=vol_path, + target_config=target_cfg, + beta=0.65, + lambda_l0=1e-7, + lambda_l2=1e-8, + log_freq=500, + ) + + # Spawn national fit (if enabled) + national_handle = None + if not skip_national: + national_func = PACKAGE_GPU_FUNCTIONS[national_gpu] + print( + f" Spawning national fit " + f"({national_gpu}, " + f"{national_epochs} epochs)..." + ) + national_handle = national_func.spawn( + branch=branch, + epochs=national_epochs, + volume_package_path=vol_path, + target_config=target_cfg, + beta=0.65, + lambda_l0=1e-4, + lambda_l2=1e-12, + log_freq=500, + ) + + # Collect regional results + print(" Waiting for regional fit...") + regional_result = regional_handle.get() + print(" Regional fit complete. Writing to volume...") + + # Write regional results to pipeline volume + with pipeline_volume.batch_upload(force=True) as batch: + batch.put_file( + BytesIO(regional_result["weights"]), + "artifacts/calibration_weights.npy", + ) + if regional_result.get("config"): + batch.put_file( + BytesIO(regional_result["config"]), + "artifacts/unified_run_config.json", + ) + + archive_diagnostics( + run_id, + regional_result, + pipeline_volume, + prefix="", + ) + + # Collect national results + if national_handle is not None: + print(" Waiting for national fit...") + national_result = national_handle.get() + print(" National fit complete. Writing to volume...") + + with pipeline_volume.batch_upload(force=True) as batch: + batch.put_file( + BytesIO(national_result["weights"]), + "artifacts/national_calibration_weights.npy", + ) + if national_result.get("config"): + batch.put_file( + BytesIO(national_result["config"]), + "artifacts/national_unified_run_config.json", + ) + + archive_diagnostics( + run_id, + national_result, + pipeline_volume, + prefix="national_", + ) + + _record_step( + meta, + "fit_weights", + step_start, + pipeline_volume, + ) + print(f" Completed in {meta.step_timings['fit_weights']['duration_s']}s") + else: + print("\n[Step 3/5] Fit weights (skipped - completed)") + + # ── Step 4: Build H5s + stage + diagnostics (parallel) ── + # Per plan: all four tasks run in parallel: + # 4a. coordinate_publish (regional H5s) + # 4b. coordinate_national_publish (national H5) + # 4c. stage_base_datasets (datasets → HF staging) + # 4d. upload_run_diagnostics (diagnostics → HF) + if not _step_completed(meta, "publish_and_stage"): + print( + "\n[Step 4/5] Building H5s, staging datasets, " + "uploading diagnostics (parallel)..." + ) + step_start = time.time() + + # Spawn H5 builds (run on separate Modal containers) + print(f" Spawning regional H5 build ({num_workers} workers)...") + regional_h5_handle = coordinate_publish.spawn( + branch=branch, + num_workers=num_workers, + skip_upload=False, + n_clones=n_clones, + validate=True, + ) + + national_h5_handle = None + if not skip_national: + print(" Spawning national H5 build...") + national_h5_handle = coordinate_national_publish.spawn( + branch=branch, + n_clones=n_clones, + validate=True, + ) + + # While H5 builds run, stage base datasets + # and upload diagnostics in this container + pipeline_volume.reload() + + print(" Staging base datasets to HF...") + stage_base_datasets(run_id, version, branch) + + print(" Uploading run diagnostics...") + upload_run_diagnostics(run_id, branch) + + # Now wait for H5 builds to finish + print(" Waiting for regional H5 build...") + regional_h5_result = regional_h5_handle.get() + regional_msg = ( + regional_h5_result.get("message", regional_h5_result) + if isinstance(regional_h5_result, dict) + else regional_h5_result + ) + print(f" Regional H5: {regional_msg}") + + national_h5_result = None + if national_h5_handle is not None: + print(" Waiting for national H5 build...") + national_h5_result = national_h5_handle.get() + national_msg = ( + national_h5_result.get("message", national_h5_result) + if isinstance(national_h5_result, dict) + else national_h5_result + ) + print(f" National H5: {national_msg}") + + # ── Aggregate validation results ── + _write_validation_diagnostics( + run_id=run_id, + regional_result=regional_h5_result, + national_result=national_h5_result, + meta=meta, + vol=pipeline_volume, + ) + + _record_step( + meta, + "publish_and_stage", + step_start, + pipeline_volume, + ) + print( + f" Completed in " + f"{meta.step_timings['publish_and_stage']['duration_s']}s" + ) + else: + print("\n[Step 4/5] Publish + stage (skipped - completed)") + + # ── Step 5: Finalize ── + print("\n[Step 5/5] Finalizing run...") + meta.status = "completed" + write_run_meta(meta, pipeline_volume) + + print("\n" + "=" * 60) + print("PIPELINE COMPLETE") + print("=" * 60) + print(f" Run ID: {run_id}") + print(f" Status: {meta.status}") + _print_step_timings(meta) + print( + f"\nTo promote, run:\n" + f" modal run modal_app/pipeline.py" + f"::main --action promote " + f"--run-id {run_id}" + ) + print("=" * 60) + + return run_id + + except Exception as e: + meta.status = "failed" + meta.error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" + write_run_meta(meta, pipeline_volume) + print(f"\nPIPELINE FAILED: {e}") + print(f"Resume with: --resume-run-id {run_id}") + raise + + +def _print_step_timings(meta: RunMetadata) -> None: + """Print formatted step timings.""" + total = 0.0 + for step, timing in meta.step_timings.items(): + dur = timing.get("duration_s", 0) + total += dur + status = timing.get("status", "unknown") + print(f" {step}: {dur}s ({status})") + hours = total / 3600 + print(f" TOTAL: {total:.0f}s ({hours:.1f}h)") + + +# ── Promote ────────────────────────────────────────────────────── + + +@app.function( + image=image, + cpu=2, + memory=4096, + timeout=7200, + volumes={ + PIPELINE_MOUNT: pipeline_volume, + STAGING_MOUNT: staging_volume, + }, + secrets=[hf_secret, gcp_secret], +) +def promote_run( + run_id: str, + version: str = None, +) -> str: + """Promote a completed pipeline run to production. + + 1. Verify run status is "completed" + 2. Promote H5s (regional + national) via existing + promote functions + 3. Register version in version_manifest.json + 4. Update run status to "promoted" + + Args: + run_id: The run ID to promote. + version: Override version (default: from run + metadata). + + Returns: + Summary message. + """ + # Setup GCP + creds_json = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_JSON") + if creds_json: + creds_path = "/tmp/gcp-credentials.json" + with open(creds_path, "w") as f: + f.write(creds_json) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = creds_path + + meta = read_run_meta(run_id, pipeline_volume) + + if meta.status not in ("completed", "promoted"): + raise RuntimeError( + f"Run {run_id} has status " + f"'{meta.status}'. Only completed runs " + f"can be promoted." + ) + + if meta.status == "promoted": + print(f"WARNING: Run {run_id} was already promoted. Re-promoting...") + + version = version or meta.version + + print("=" * 60) + print("PROMOTING PIPELINE RUN") + print("=" * 60) + print(f" Run ID: {run_id}") + print(f" Version: {version}") + print(f" Branch: {meta.branch}") + print(f" SHA: {meta.sha[:12]}") + print("=" * 60) + + # Clone repo for subprocess calls + _setup_repo() + + # Promote base datasets from staging → production + print("\nPromoting base datasets (staging → production)...") + try: + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +from policyengine_us_data.utils.data_upload import ( + promote_staging_to_production_hf, +) + +base_files = [ + "calibration/source_imputed_stratified_extended_cps.h5", + "calibration/policy_data.db", +] +count = promote_staging_to_production_hf(base_files, "{version}") +print(f"Promoted {{count}} base dataset(s)") +""", + ], + cwd="/root/policyengine-us-data", + capture_output=True, + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(result.stderr) + print(f" {result.stdout.strip()}") + except Exception as e: + print(f" WARNING: Base dataset promotion: {e}") + + # Promote H5s via existing functions + print("\nPromoting regional H5s...") + try: + regional_result = promote_publish.remote( + branch=meta.branch, + version=version, + ) + print(f" {regional_result}") + except Exception as e: + print(f" WARNING: Regional promote: {e}") + + print("\nPromoting national H5...") + try: + national_result = promote_national_publish.remote( + branch=meta.branch, + ) + print(f" {national_result}") + except Exception as e: + print(f" WARNING: National promote: {e}") + + # Register version in manifest + print("\nRegistering version in manifest...") + try: + result = subprocess.run( + [ + "uv", + "run", + "python", + "-c", + f""" +from policyengine_us_data.utils.version_manifest import ( + build_manifest, + upload_manifest, +) + +blob_names = [ + "calibration/source_imputed_stratified_extended_cps.h5", + "calibration/policy_data.db", + "calibration/calibration_weights.npy", +] +manifest = build_manifest( + version="{version}", + blob_names=blob_names, +) +manifest.pipeline_run_id = "{run_id}" +manifest.diagnostics_path = "calibration/runs/{run_id}/diagnostics/" +upload_manifest(manifest) +print("Registered version {version} in version_manifest.json") +""", + ], + cwd="/root/policyengine-us-data", + capture_output=True, + text=True, + env=os.environ.copy(), + ) + if result.returncode != 0: + raise RuntimeError(result.stderr) + print(f" {result.stdout.strip()}") + except Exception as e: + print(f" WARNING: Version registration failed: {e}") + print(" This can be done manually later via version_manifest.py") + + # Update run status + meta.status = "promoted" + write_run_meta(meta, pipeline_volume) + + print("\n" + "=" * 60) + print("PROMOTION COMPLETE") + print("=" * 60) + print(f" Version {version} is now live.") + print("=" * 60) + + return f"Promoted run {run_id} as version {version}" + + +# ── Status ─────────────────────────────────────────────────────── + + +@app.function( + image=image, + timeout=60, + volumes={PIPELINE_MOUNT: pipeline_volume}, +) +def pipeline_status( + run_id: str = None, +) -> str: + """Get pipeline status. + + If run_id is provided, show that run's details. + Otherwise, list all runs. + """ + pipeline_volume.reload() + runs_dir = Path(RUNS_DIR) + + if not runs_dir.exists(): + return "No pipeline runs found." + + if run_id: + meta = read_run_meta(run_id, pipeline_volume) + lines = [ + f"Run: {meta.run_id}", + f" Branch: {meta.branch}", + f" SHA: {meta.sha[:12]}", + f" Version: {meta.version}", + f" Status: {meta.status}", + f" Started: {meta.start_time}", + ] + if meta.error: + lines.append(f" Error: {meta.error[:200]}") + if meta.step_timings: + lines.append(" Steps:") + for step, timing in meta.step_timings.items(): + dur = timing.get("duration_s", "?") + status = timing.get("status", "unknown") + lines.append(f" {step}: {dur}s ({status})") + return "\n".join(lines) + + # List all runs + runs = [] + for entry in sorted(runs_dir.iterdir()): + meta_path = entry / "meta.json" + if meta_path.exists(): + with open(meta_path) as f: + data = json.load(f) + runs.append( + f" {data['run_id']}: " + f"{data['status']} " + f"(branch={data['branch']}, " + f"v={data['version']})" + ) + + if not runs: + return "No pipeline runs found." + + return "Pipeline runs:\n" + "\n".join(runs) + + +# ── Local entrypoint ───────────────────────────────────────────── + + +@app.local_entrypoint() +def main( + action: str = "run", + branch: str = "main", + run_id: str = None, + resume_run_id: str = None, + gpu: str = "T4", + epochs: int = 1000, + national_gpu: str = "T4", + national_epochs: int = 4000, + num_workers: int = 8, + n_clones: int = 430, + skip_national: bool = False, + version: str = None, +): + """Pipeline entrypoint. + + Actions: + run - Run the full pipeline + status - Show pipeline status + promote - Promote a completed run + """ + if action == "run": + result = run_pipeline.remote( + branch=branch, + gpu=gpu, + epochs=epochs, + national_gpu=national_gpu, + national_epochs=national_epochs, + num_workers=num_workers, + n_clones=n_clones, + skip_national=skip_national, + resume_run_id=resume_run_id, + ) + print(f"\nPipeline run complete: {result}") + + elif action == "status": + result = pipeline_status.remote( + run_id=run_id, + ) + print(result) + + elif action == "promote": + if not run_id: + raise ValueError("--run-id is required for promote") + result = promote_run.remote( + run_id=run_id, + version=version, + ) + print(result) + + else: + raise ValueError(f"Unknown action: {action}. Use: run, status, promote") diff --git a/modal_app/remote_calibration_runner.py b/modal_app/remote_calibration_runner.py index 2a8e5277..55196947 100644 --- a/modal_app/remote_calibration_runner.py +++ b/modal_app/remote_calibration_runner.py @@ -5,15 +5,44 @@ app = modal.App("policyengine-us-data-fit-weights") hf_secret = modal.Secret.from_name("huggingface-token") -calibration_vol = modal.Volume.from_name("calibration-data", create_if_missing=True) - +pipeline_vol = modal.Volume.from_name("pipeline-artifacts", create_if_missing=True) + +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_IGNORE = [ + ".git", + "__pycache__", + "*.egg-info", + ".pytest_cache", + "*.h5", + "*.npy", + "*.pkl", + "*.db", + "node_modules", + "venv", + ".venv", + "docs/_build", + "paper", + "presentations", +] image = ( - modal.Image.debian_slim(python_version="3.11").apt_install("git").pip_install("uv") + modal.Image.debian_slim(python_version="3.13") + .apt_install("git") + .pip_install("uv>=0.8") + .add_local_dir( + str(_REPO_ROOT), + remote_path="/root/policyengine-us-data", + copy=True, + ignore=_IGNORE, + ) + .run_commands( + "cd /root/policyengine-us-data && " + "UV_HTTP_TIMEOUT=300 uv sync --frozen --extra l0" + ) ) -REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git" -VOLUME_MOUNT = "/calibration-data" -_DEFAULT_UV_HTTP_TIMEOUT = "1800" +PIPELINE_MOUNT = "/pipeline" def _run_streaming(cmd, env=None, label=""): @@ -41,19 +70,9 @@ def _run_streaming(cmd, env=None, label=""): return proc.returncode, lines -def _run_uv_sync(*args: str) -> None: - """Run uv sync with a higher default network timeout for large wheels.""" - env = os.environ.copy() - env.setdefault("UV_HTTP_TIMEOUT", _DEFAULT_UV_HTTP_TIMEOUT) - subprocess.run(["uv", "sync", *args], check=True, env=env) - - -def _clone_and_install(branch: str): - """Clone the repo and install dependencies.""" - os.chdir("/root") - subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True) - os.chdir("policyengine-us-data") - _run_uv_sync("--extra", "l0") +def _setup_repo(): + """Change to the pre-baked repo directory.""" + os.chdir("/root/policyengine-us-data") def _append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq=None): @@ -76,9 +95,6 @@ def _collect_outputs(cal_lines): log_path = None cal_log_path = None config_path = None - blocks_path = None - geo_labels_path = None - geography_path = None for line in cal_lines: if "OUTPUT_PATH:" in line: output_path = line.split("OUTPUT_PATH:")[1].strip() @@ -86,12 +102,6 @@ def _collect_outputs(cal_lines): config_path = line.split("CONFIG_PATH:")[1].strip() elif "CAL_LOG_PATH:" in line: cal_log_path = line.split("CAL_LOG_PATH:")[1].strip() - elif "GEO_LABELS_PATH:" in line: - geo_labels_path = line.split("GEO_LABELS_PATH:")[1].strip() - elif "GEOGRAPHY_PATH:" in line: - geography_path = line.split("GEOGRAPHY_PATH:")[1].strip() - elif "BLOCKS_PATH:" in line: - blocks_path = line.split("BLOCKS_PATH:")[1].strip() elif "LOG_PATH:" in line: log_path = line.split("LOG_PATH:")[1].strip() @@ -113,29 +123,11 @@ def _collect_outputs(cal_lines): with open(config_path, "rb") as f: config_bytes = f.read() - blocks_bytes = None - if blocks_path and os.path.exists(blocks_path): - with open(blocks_path, "rb") as f: - blocks_bytes = f.read() - - geo_labels_bytes = None - if geo_labels_path and os.path.exists(geo_labels_path): - with open(geo_labels_path, "rb") as f: - geo_labels_bytes = f.read() - - geography_bytes = None - if geography_path and os.path.exists(geography_path): - with open(geography_path, "rb") as f: - geography_bytes = f.read() - return { "weights": weights_bytes, "log": log_bytes, "cal_log": cal_log_bytes, "config": config_bytes, - "blocks": blocks_bytes, - "geo_labels": geo_labels_bytes, - "geography": geography_bytes, } @@ -177,40 +169,6 @@ def _trigger_repository_dispatch(event_type: str = "calibration-updated"): return True -def _upload_source_imputed(lines): - """Parse SOURCE_IMPUTED_PATH from output and upload to HF.""" - source_path = None - for line in lines: - if "SOURCE_IMPUTED_PATH:" in line: - raw = line.split("SOURCE_IMPUTED_PATH:")[1].strip() - source_path = raw.split("]")[-1].strip() if "]" in raw else raw - if not source_path or not os.path.exists(source_path): - return - print(f"Uploading source-imputed dataset: {source_path}", flush=True) - rc, _ = _run_streaming( - [ - "uv", - "run", - "python", - "-c", - "from policyengine_us_data.utils.huggingface import upload; " - f"upload('{source_path}', " - "'policyengine/policyengine-us-data', " - "'calibration/" - "source_imputed_stratified_extended_cps.h5')", - ], - env=os.environ.copy(), - label="upload-source-imputed", - ) - if rc != 0: - print( - "WARNING: Failed to upload source-imputed dataset", - flush=True, - ) - else: - print("Source-imputed dataset uploaded to HF", flush=True) - - def _fit_weights_impl( branch: str, epochs: int, @@ -223,34 +181,18 @@ def _fit_weights_impl( skip_county: bool = True, workers: int = 8, ) -> dict: - """Full pipeline: download data, build matrix, fit weights.""" - _clone_and_install(branch) - - print("Downloading calibration inputs from HuggingFace...", flush=True) - dl_rc, dl_lines = _run_streaming( - [ - "uv", - "run", - "python", - "-c", - "from policyengine_us_data.utils.huggingface import " - "download_calibration_inputs; " - "paths = download_calibration_inputs('/root/calibration_data'); " - "print(f\"DB: {paths['database']}\"); " - "print(f\"DATASET: {paths['dataset']}\")", - ], - env=os.environ.copy(), - label="download", - ) - if dl_rc != 0: - raise RuntimeError(f"Download failed with code {dl_rc}") - - db_path = dataset_path = None - for line in dl_lines: - if "DB:" in line: - db_path = line.split("DB:")[1].strip() - elif "DATASET:" in line: - dataset_path = line.split("DATASET:")[1].strip() + """Full pipeline: read data from pipeline volume, build matrix, fit.""" + _setup_repo() + + pipeline_vol.reload() + artifacts = f"{PIPELINE_MOUNT}/artifacts" + db_path = f"{artifacts}/policy_data.db" + dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5" + for label, p in [("database", db_path), ("dataset", dataset_path)]: + if not os.path.exists(p): + raise RuntimeError( + f"Missing {label} on pipeline volume: {p}. Run data_build first." + ) script_path = "policyengine_us_data/calibration/unified_calibration.py" cmd = [ @@ -283,8 +225,6 @@ def _fit_weights_impl( if cal_rc != 0: raise RuntimeError(f"Script failed with code {cal_rc}") - _upload_source_imputed(cal_lines) - return _collect_outputs(cal_lines) @@ -303,7 +243,7 @@ def _fit_from_package_impl( if not volume_package_path: raise ValueError("volume_package_path is required") - _clone_and_install(branch) + _setup_repo() pkg_path = "/root/calibration_package.pkl" import shutil @@ -370,9 +310,14 @@ def _print_provenance_from_meta(meta: dict, current_branch: str = None) -> None: ) -def _write_package_sidecar(pkg_path: str) -> None: - """Extract metadata from a pickle package and write a JSON sidecar.""" +def _write_package_sidecar(pkg_path: str) -> bool: + """Extract metadata from a pickle package and write a JSON sidecar. + + Returns: + True if sidecar was written successfully, False otherwise. + """ import json + import logging import pickle sidecar_path = pkg_path.replace(".pkl", "_meta.json") @@ -387,11 +332,14 @@ def _write_package_sidecar(pkg_path: str) -> None: f"Sidecar metadata written to {sidecar_path}", flush=True, ) + return True except Exception as e: - print( - f"WARNING: Failed to write sidecar: {e}", - flush=True, + logging.warning( + "Failed to write package sidecar for %s: %s", + pkg_path, + e, ) + return False def _build_package_impl( @@ -399,41 +347,22 @@ def _build_package_impl( target_config: str = None, skip_county: bool = True, workers: int = 8, + n_clones: int = 430, ) -> str: - """Download data, build X matrix, save package to volume.""" - _clone_and_install(branch) - - print( - "Downloading calibration inputs from HuggingFace...", - flush=True, - ) - dl_rc, dl_lines = _run_streaming( - [ - "uv", - "run", - "python", - "-c", - "from policyengine_us_data.utils.huggingface import " - "download_calibration_inputs; " - "paths = download_calibration_inputs(" - "'/root/calibration_data'); " - "print(f\"DB: {paths['database']}\"); " - "print(f\"DATASET: {paths['dataset']}\")", - ], - env=os.environ.copy(), - label="download", - ) - if dl_rc != 0: - raise RuntimeError(f"Download failed with code {dl_rc}") - - db_path = dataset_path = None - for line in dl_lines: - if "DB:" in line: - db_path = line.split("DB:")[1].strip() - elif "DATASET:" in line: - dataset_path = line.split("DATASET:")[1].strip() + """Read data from pipeline volume, build X matrix, save package.""" + _setup_repo() + + pipeline_vol.reload() + artifacts = f"{PIPELINE_MOUNT}/artifacts" + db_path = f"{artifacts}/policy_data.db" + dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5" + for label, p in [("database", db_path), ("dataset", dataset_path)]: + if not os.path.exists(p): + raise RuntimeError( + f"Missing {label} on pipeline volume: {p}. Run data_build first." + ) - pkg_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + pkg_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" script_path = "policyengine_us_data/calibration/unified_calibration.py" cmd = [ "uv", @@ -458,6 +387,7 @@ def _build_package_impl( cmd.append("--county-level") if workers > 1: cmd.extend(["--workers", str(workers)]) + cmd.extend(["--n-clones", str(n_clones)]) build_rc, build_lines = _run_streaming( cmd, @@ -467,16 +397,20 @@ def _build_package_impl( if build_rc != 0: raise RuntimeError(f"Package build failed with code {build_rc}") - _upload_source_imputed(build_lines) - - _write_package_sidecar(pkg_path) + sidecar_ok = _write_package_sidecar(pkg_path) + if not sidecar_ok: + print( + "WARNING: Package sidecar (provenance metadata) " + "was not written. The package itself is still valid.", + flush=True, + ) size = os.path.getsize(pkg_path) print( f"Package saved to volume at {pkg_path} ({size:,} bytes)", flush=True, ) - calibration_vol.commit() + pipeline_vol.commit() return pkg_path @@ -486,26 +420,28 @@ def _build_package_impl( memory=65536, cpu=8.0, timeout=50400, - volumes={VOLUME_MOUNT: calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def build_package_remote( branch: str = "main", target_config: str = None, skip_county: bool = True, workers: int = 8, + n_clones: int = 430, ) -> str: return _build_package_impl( branch, target_config=target_config, skip_county=skip_county, workers=workers, + n_clones=n_clones, ) @app.function( image=image, timeout=30, - volumes={VOLUME_MOUNT: calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def check_volume_package() -> dict: """Check if a calibration package exists on the volume. @@ -516,8 +452,8 @@ def check_volume_package() -> dict: import datetime import json - pkg_path = f"{VOLUME_MOUNT}/calibration_package.pkl" - sidecar_path = f"{VOLUME_MOUNT}/calibration_package_meta.json" + pkg_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" + sidecar_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package_meta.json" if not os.path.exists(pkg_path): return {"exists": False} @@ -558,6 +494,7 @@ def check_volume_package() -> dict: cpu=8.0, gpu="T4", timeout=14400, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_weights_t4( branch: str = "main", @@ -592,6 +529,7 @@ def fit_weights_t4( cpu=8.0, gpu="A10", timeout=14400, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_weights_a10( branch: str = "main", @@ -626,6 +564,7 @@ def fit_weights_a10( cpu=8.0, gpu="A100-40GB", timeout=14400, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_weights_a100_40( branch: str = "main", @@ -660,6 +599,7 @@ def fit_weights_a100_40( cpu=8.0, gpu="A100-80GB", timeout=14400, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_weights_a100_80( branch: str = "main", @@ -694,6 +634,7 @@ def fit_weights_a100_80( cpu=8.0, gpu="H100", timeout=14400, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_weights_h100( branch: str = "main", @@ -739,7 +680,7 @@ def fit_weights_h100( cpu=8.0, gpu="T4", timeout=14400, - volumes={"/calibration-data": calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_from_package_t4( branch: str = "main", @@ -771,7 +712,7 @@ def fit_from_package_t4( cpu=8.0, gpu="A10", timeout=14400, - volumes={"/calibration-data": calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_from_package_a10( branch: str = "main", @@ -803,7 +744,7 @@ def fit_from_package_a10( cpu=8.0, gpu="A100-40GB", timeout=14400, - volumes={"/calibration-data": calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_from_package_a100_40( branch: str = "main", @@ -835,7 +776,7 @@ def fit_from_package_a100_40( cpu=8.0, gpu="A100-80GB", timeout=14400, - volumes={"/calibration-data": calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_from_package_a100_80( branch: str = "main", @@ -867,7 +808,7 @@ def fit_from_package_a100_80( cpu=8.0, gpu="H100", timeout=14400, - volumes={"/calibration-data": calibration_vol}, + volumes={PIPELINE_MOUNT: pipeline_vol}, ) def fit_from_package_h100( branch: str = "main", @@ -936,7 +877,7 @@ def main( ) if package_path: - vol_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" print(f"Reading package from {package_path}...", flush=True) import json as _json import pickle as _pkl @@ -944,25 +885,24 @@ def main( with open(package_path, "rb") as f: package_bytes = f.read() size = len(package_bytes) - # Extract metadata for sidecar pkg_meta = _pkl.loads(package_bytes).get("metadata", {}) sidecar_bytes = _json.dumps(pkg_meta, indent=2).encode() print( f"Uploading package ({size:,} bytes) to Modal volume...", flush=True, ) - with calibration_vol.batch_upload(force=True) as batch: + with pipeline_vol.batch_upload(force=True) as batch: from io import BytesIO - batch.put( + batch.put_file( BytesIO(package_bytes), - "calibration_package.pkl", + "artifacts/calibration_package.pkl", ) - batch.put( + batch.put_file( BytesIO(sidecar_bytes), - "calibration_package_meta.json", + "artifacts/calibration_package_meta.json", ) - calibration_vol.commit() + pipeline_vol.commit() del package_bytes print("Upload complete.", flush=True) _print_provenance_from_meta(pkg_meta, branch) @@ -984,7 +924,7 @@ def main( flush=True, ) print( - "Mode: full pipeline (download, build matrix, fit)", + "Mode: full pipeline (read from volume, build matrix, fit)", flush=True, ) print( @@ -1009,7 +949,7 @@ def main( workers=workers, ) else: - vol_path = f"{VOLUME_MOUNT}/calibration_package.pkl" + vol_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl" vol_info = check_volume_package.remote() if not vol_info["exists"]: raise SystemExit( @@ -1040,10 +980,6 @@ def main( f" - calibration/{prefix}calibration_weights.npy", flush=True, ) - print( - f" - calibration/{prefix}stacked_blocks.npy", - flush=True, - ) print( f" - calibration/logs/{prefix}* (diagnostics, " "config, calibration log)", @@ -1087,23 +1023,22 @@ def main( f.write(result["config"]) print(f"Run config saved to: {config_output}") - blocks_output = f"{prefix}stacked_blocks.npy" - if result.get("blocks"): - with open(blocks_output, "wb") as f: - f.write(result["blocks"]) - print(f"Stacked blocks saved to: {blocks_output}") - - geo_labels_output = f"{prefix}geo_labels.json" - if result.get("geo_labels"): - with open(geo_labels_output, "wb") as f: - f.write(result["geo_labels"]) - print(f"Geo labels saved to: {geo_labels_output}") + # Push weights to pipeline volume for downstream steps + from io import BytesIO - geography_output = f"{prefix}geography.npz" - if result.get("geography"): - with open(geography_output, "wb") as f: - f.write(result["geography"]) - print(f"Geography saved to: {geography_output}") + print("Pushing weights to pipeline volume...", flush=True) + with pipeline_vol.batch_upload(force=True) as batch: + batch.put_file( + BytesIO(result["weights"]), + f"artifacts/{prefix}calibration_weights.npy", + ) + if result.get("config"): + batch.put_file( + BytesIO(result["config"]), + f"artifacts/{prefix}unified_run_config.json", + ) + pipeline_vol.commit() + print("Weights committed to pipeline volume", flush=True) if push_results: from policyengine_us_data.utils.huggingface import ( @@ -1112,9 +1047,6 @@ def main( upload_calibration_artifacts( weights_path=output, - blocks_path=(blocks_output if result.get("blocks") else None), - geo_labels_path=(geo_labels_output if result.get("geo_labels") else None), - geography_path=(geography_output if result.get("geography") else None), log_dir=".", prefix=prefix, ) @@ -1129,6 +1061,7 @@ def build_package( target_config: str = None, county_level: bool = False, workers: int = 8, + n_clones: int = 430, ): """Build the calibration package (X matrix) on CPU and save to Modal volume. Then run main() to fit.""" @@ -1155,6 +1088,7 @@ def build_package( target_config=target_config, skip_county=not county_level, workers=workers, + n_clones=n_clones, ) print( f"Package built and saved to Modal volume at {vol_path}", diff --git a/modal_app/resilience.py b/modal_app/resilience.py new file mode 100644 index 00000000..59991ae3 --- /dev/null +++ b/modal_app/resilience.py @@ -0,0 +1,44 @@ +"""Subprocess retry wrapper for network-dependent operations.""" + +import subprocess +import time +from typing import Optional + + +def run_with_retry( + cmd: list[str], + max_retries: int = 3, + backoff: float = 5.0, + env: Optional[dict] = None, + label: str = "", +) -> subprocess.CompletedProcess: + """Run a subprocess command with retries on failure. + + Args: + cmd: Command and arguments. + max_retries: Maximum number of retry attempts. + backoff: Base delay between retries (doubled each attempt). + env: Environment variables. + label: Label for log messages. + + Returns: + CompletedProcess on success. + + Raises: + subprocess.CalledProcessError: If all retries exhausted. + """ + tag = f"[{label}] " if label else "" + for attempt in range(max_retries + 1): + result = subprocess.run(cmd, env=env) + if result.returncode == 0: + return result + if attempt < max_retries: + delay = backoff * (2**attempt) + print( + f"{tag}Attempt {attempt + 1} failed " + f"(rc={result.returncode}), " + f"retrying in {delay:.0f}s..." + ) + time.sleep(delay) + else: + raise subprocess.CalledProcessError(result.returncode, cmd) diff --git a/modal_app/worker_script.py b/modal_app/worker_script.py index f36b59a0..0c039d2d 100644 --- a/modal_app/worker_script.py +++ b/modal_app/worker_script.py @@ -13,6 +13,141 @@ from pathlib import Path +def _validate_in_subprocess( + h5_path, + area_type, + area_id, + display_id, + area_targets, + area_training, + constraints_map, + db_path, + period, +): + """Run validation for one area inside a subprocess. + + All Microsimulation memory is reclaimed when the + subprocess exits. + """ + import logging + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + ) + from policyengine_us import Microsimulation + from sqlalchemy import create_engine as _ce + from policyengine_us_data.calibration.validate_staging import ( + validate_area, + _build_variable_entity_map, + ) + + engine = _ce(f"sqlite:///{db_path}") + sim = Microsimulation(dataset=h5_path) + variable_entity_map = _build_variable_entity_map(sim) + + results = validate_area( + sim=sim, + targets_df=area_targets, + engine=engine, + area_type=area_type, + area_id=area_id, + display_id=display_id, + period=period, + training_mask=area_training, + variable_entity_map=variable_entity_map, + constraints_map=constraints_map, + ) + return results + + +def _validate_h5_subprocess( + h5_path, + item_type, + item_id, + state_fips, + candidate, + cd_subset, + validation_targets, + training_mask_full, + constraints_map, + db_path, + period, +): + """Spawn a subprocess to validate one H5 file. + + Uses multiprocessing spawn to isolate memory. + """ + import multiprocessing as _mp + + # Determine geo_level and geographic_id for filtering targets + if item_type == "state": + geo_level = "state" + geographic_id = str(state_fips) + area_type = "states" + display_id = item_id + elif item_type == "district": + geo_level = "district" + geographic_id = str(candidate) + area_type = "districts" + display_id = item_id + elif item_type == "city": + # NYC: aggregate targets for NYC CDs + geo_level = "district" + area_type = "cities" + display_id = item_id + elif item_type == "national": + geo_level = "national" + geographic_id = "US" + area_type = "national" + display_id = "US" + else: + return [] + + # Filter targets to matching area + if item_type == "city": + # Match targets for any NYC CD + nyc_cd_set = set(str(cd) for cd in cd_subset) + mask = (validation_targets["geo_level"] == geo_level) & validation_targets[ + "geographic_id" + ].astype(str).isin(nyc_cd_set) + elif item_type == "national": + mask = validation_targets["geo_level"] == geo_level + else: + mask = (validation_targets["geo_level"] == geo_level) & ( + validation_targets["geographic_id"].astype(str) == geographic_id + ) + + area_targets = validation_targets[mask].reset_index(drop=True) + area_training = training_mask_full[mask.values] + + if len(area_targets) == 0: + return [] + + # Filter constraints_map to relevant strata + area_strata = area_targets["stratum_id"].unique().tolist() + area_constraints = {int(s): constraints_map.get(int(s), []) for s in area_strata} + + ctx = _mp.get_context("spawn") + with ctx.Pool(1) as pool: + results = pool.apply( + _validate_in_subprocess, + ( + h5_path, + area_type, + item_id, + display_id, + area_targets, + area_training, + area_constraints, + db_path, + period, + ), + ) + + return results + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--work-items", required=True, help="JSON work items") @@ -21,9 +156,38 @@ def main(): parser.add_argument("--db-path", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument( - "--geography-path", - required=True, - help="Path to geography.npz from calibration", + "--n-clones", + type=int, + default=430, + help="Number of clones used in calibration", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed used in calibration", + ) + parser.add_argument( + "--no-validate", + action="store_true", + default=False, + help="Skip per-item validation after each H5 build", + ) + parser.add_argument( + "--period", + type=int, + default=2024, + help="Tax year for validation targets", + ) + parser.add_argument( + "--target-config", + default=None, + help="Path to training target_config.yaml", + ) + parser.add_argument( + "--validation-config", + default=None, + help="Path to target_config_full.yaml for validation", ) args = parser.parse_args() @@ -52,37 +216,104 @@ def main(): STATE_CODES, ) from policyengine_us_data.calibration.clone_and_assign import ( - load_geography, + assign_random_geography, ) + from policyengine_us import Microsimulation weights = np.load(weights_path) - # Load geography from .npz (required) - if not args.geography_path or not Path(args.geography_path).exists(): - raise RuntimeError( - f"--geography-path is required and must exist. " - f"Got: {args.geography_path}. " - f"Re-run calibration to generate geography.npz." - ) - geography = load_geography(args.geography_path) + sim = Microsimulation(dataset=str(dataset_path)) + n_records = sim.calculate("household_id", map_to="household").shape[0] + del sim + + geography = assign_random_geography( + n_records=n_records, + n_clones=args.n_clones, + seed=args.seed, + ) cds_to_calibrate = sorted(set(geography.cd_geoid.astype(str))) geo_labels = cds_to_calibrate print( - f"Loaded geography from {args.geography_path}: " + f"Generated geography: " f"{geography.n_clones} clones x " f"{geography.n_records} records", file=sys.stderr, ) + # ── Validation setup (once per worker) ── + validation_targets = None + training_mask_full = None + constraints_map = None + if not args.no_validate: + from sqlalchemy import create_engine + from policyengine_us_data.calibration.validate_staging import ( + _query_all_active_targets, + _batch_stratum_constraints, + CSV_COLUMNS, + ) + from policyengine_us_data.calibration.unified_calibration import ( + load_target_config, + _match_rules, + ) + + engine = create_engine(f"sqlite:///{db_path}") + validation_targets = _query_all_active_targets(engine, args.period) + print( + f"Loaded {len(validation_targets)} validation targets", + file=sys.stderr, + ) + + # Apply exclude/include from validation config + if args.validation_config: + val_cfg = load_target_config(args.validation_config) + exc_rules = val_cfg.get("exclude", []) + if exc_rules: + exc_mask = _match_rules(validation_targets, exc_rules) + validation_targets = validation_targets[~exc_mask].reset_index( + drop=True + ) + inc_rules = val_cfg.get("include", []) + if inc_rules: + inc_mask = _match_rules(validation_targets, inc_rules) + validation_targets = validation_targets[inc_mask].reset_index(drop=True) + + # Compute training mask from training config + if args.target_config: + tr_cfg = load_target_config(args.target_config) + tr_inc = tr_cfg.get("include", []) + if tr_inc: + training_mask_full = np.asarray( + _match_rules(validation_targets, tr_inc), + dtype=bool, + ) + else: + training_mask_full = np.ones(len(validation_targets), dtype=bool) + else: + training_mask_full = np.ones(len(validation_targets), dtype=bool) + + # Batch-load constraints + stratum_ids = validation_targets["stratum_id"].unique().tolist() + constraints_map = _batch_stratum_constraints(engine, stratum_ids) + print( + f"Validation ready: {len(validation_targets)} targets, " + f"{len(stratum_ids)} strata", + file=sys.stderr, + ) + results = { "completed": [], "failed": [], "errors": [], + "validation_rows": [], + "validation_summary": {}, } for item in work_items: item_type = item["type"] item_id = item["id"] + state_fips = None + candidate = None + cd_subset = None try: if item_type == "state": @@ -179,9 +410,24 @@ def main(): elif item_type == "national": national_dir = output_dir / "national" national_dir.mkdir(parents=True, exist_ok=True) + n_clones_from_weights = weights.shape[0] // n_records + if n_clones_from_weights != geography.n_clones: + print( + f"National weights have {n_clones_from_weights} clones " + f"but geography has {geography.n_clones}; " + f"regenerating geography", + file=sys.stderr, + ) + national_geo = assign_random_geography( + n_records=n_records, + n_clones=n_clones_from_weights, + seed=args.seed, + ) + else: + national_geo = geography path = build_h5( weights=weights, - geography=geography, + geography=national_geo, dataset_path=dataset_path, output_path=national_dir / "US.h5", ) @@ -195,6 +441,59 @@ def main(): file=sys.stderr, ) + # ── Per-item validation ── + if not args.no_validate and validation_targets is not None: + try: + v_rows = _validate_h5_subprocess( + h5_path=str(path), + item_type=item_type, + item_id=item_id, + state_fips=( + state_fips + if item_type in ("state", "district") + else None + ), + candidate=(candidate if item_type == "district" else None), + cd_subset=(cd_subset if item_type == "city" else None), + validation_targets=validation_targets, + training_mask_full=training_mask_full, + constraints_map=constraints_map, + db_path=str(db_path), + period=args.period, + ) + results["validation_rows"].extend(v_rows) + key = f"{item_type}:{item_id}" + n_fail = sum( + 1 for r in v_rows if r.get("sanity_check") == "FAIL" + ) + rae_vals = [ + r["rel_abs_error"] + for r in v_rows + if isinstance( + r.get("rel_abs_error"), + (int, float), + ) + and r["rel_abs_error"] != float("inf") + ] + mean_rae = sum(rae_vals) / len(rae_vals) if rae_vals else 0.0 + results["validation_summary"][key] = { + "n_targets": len(v_rows), + "n_sanity_fail": n_fail, + "mean_rel_abs_error": round(mean_rae, 4), + } + print( + f" Validated {key}: " + f"{len(v_rows)} targets, " + f"{n_fail} sanity fails, " + f"mean RAE={mean_rae:.4f}", + file=sys.stderr, + ) + except Exception as ve: + print( + f" Validation failed for {item_type}:{item_id}: {ve}", + file=sys.stderr, + ) + except Exception as e: results["failed"].append(f"{item_type}:{item_id}") results["errors"].append( diff --git a/policyengine_us_data/calibration/publish_local_area.py b/policyengine_us_data/calibration/publish_local_area.py index 72594631..0c4fcf11 100644 --- a/policyengine_us_data/calibration/publish_local_area.py +++ b/policyengine_us_data/calibration/publish_local_area.py @@ -8,6 +8,10 @@ python publish_local_area.py [--skip-download] [--states-only] [--upload] """ +import hashlib +import json +import shutil + import numpy as np from pathlib import Path from typing import List @@ -66,6 +70,49 @@ ] +META_FILE = WORK_DIR / "checkpoint_meta.json" + + +def compute_input_fingerprint( + weights_path: Path, dataset_path: Path, n_clones: int, seed: int +) -> str: + h = hashlib.sha256() + for p in [weights_path, dataset_path]: + with open(p, "rb") as f: + while chunk := f.read(8192): + h.update(chunk) + h.update(f"{n_clones}:{seed}".encode()) + return h.hexdigest()[:16] + + +def validate_or_clear_checkpoints(fingerprint: str): + if META_FILE.exists(): + stored = json.loads(META_FILE.read_text()) + if stored.get("fingerprint") == fingerprint: + print(f"Inputs unchanged ({fingerprint}), resuming...") + return + print( + f"Inputs changed " + f"({stored.get('fingerprint')} -> {fingerprint}), " + f"clearing..." + ) + else: + print(f"No checkpoint metadata, starting fresh ({fingerprint})") + for cp in [ + CHECKPOINT_FILE, + CHECKPOINT_FILE_DISTRICTS, + CHECKPOINT_FILE_CITIES, + ]: + if cp.exists(): + cp.unlink() + for subdir in ["states", "districts", "cities"]: + d = WORK_DIR / subdir + if d.exists(): + shutil.rmtree(d) + META_FILE.parent.mkdir(parents=True, exist_ok=True) + META_FILE.write_text(json.dumps({"fingerprint": fingerprint})) + + def load_completed_states() -> set: if CHECKPOINT_FILE.exists(): content = CHECKPOINT_FILE.read_text().strip() @@ -311,22 +358,6 @@ def build_h5( unique_geo = derive_geography_from_blocks(unique_blocks) clone_geo = {k: v[block_inv] for k, v in unique_geo.items()} - # === Calculate weights for all entity levels === - person_weights = np.repeat(clone_weights, persons_per_clone) - per_person_wt = clone_weights / np.maximum(persons_per_clone, 1) - - entity_weights = {} - for ek in SUB_ENTITIES: - n_ents = len(entity_clone_idx[ek]) - ent_person_counts = np.zeros(n_ents, dtype=np.int32) - np.add.at( - ent_person_counts, - new_person_entity_ids[ek], - 1, - ) - clone_ids_e = np.repeat(np.arange(n_clones), entities_per_clone[ek]) - entity_weights[ek] = per_person_wt[clone_ids_e] * ent_person_counts - # === Determine variables to save === vars_to_save = set(sim.input_variables) vars_to_save.add("county") @@ -413,16 +444,12 @@ def build_h5( } # === Override weights === + # Only write household_weight; sub-entity weights (tax_unit_weight, + # spm_unit_weight, person_weight, etc.) are formula variables in + # policyengine-us that derive from household_weight at runtime. data["household_weight"] = { time_period: clone_weights.astype(np.float32), } - data["person_weight"] = { - time_period: person_weights.astype(np.float32), - } - for ek in SUB_ENTITIES: - data[f"{ek}_weight"] = { - time_period: entity_weights[ek].astype(np.float32), - } # === Override geography === data["state_fips"] = { @@ -451,6 +478,13 @@ def build_h5( time_period: clone_geo[gv].astype("S"), } + # === Set zip_code for LA County clones (ACA rating area fix) === + la_mask = clone_geo["county_fips"].astype(str) == "06037" + if la_mask.any(): + zip_codes = np.full(len(la_mask), "UNKNOWN") + zip_codes[la_mask] = "90001" + data["zip_code"] = {time_period: zip_codes.astype("S")} + # === Gap 4: Congressional district GEOID === clone_cd_geoids = np.array([int(cd) for cd in active_clone_cds], dtype=np.int32) data["congressional_district_geoid"] = { @@ -531,6 +565,7 @@ def build_h5( hh_blocks=active_blocks, hh_state_fips=hh_state_fips, hh_ids=original_hh_ids, + hh_clone_indices=active_geo.astype(np.int64), entity_hh_indices=entity_hh_indices, entity_counts=entity_counts, time_period=time_period, @@ -861,6 +896,15 @@ def main(): print(f"Using dataset: {inputs['dataset']}") + print("Computing input fingerprint...") + fingerprint = compute_input_fingerprint( + inputs["weights"], + inputs["dataset"], + args.n_clones, + args.seed, + ) + validate_or_clear_checkpoints(fingerprint) + sim = Microsimulation(dataset=str(inputs["dataset"])) n_hh = sim.calculate("household_id", map_to="household").shape[0] del sim diff --git a/policyengine_us_data/calibration/stacked_dataset_builder.py b/policyengine_us_data/calibration/stacked_dataset_builder.py deleted file mode 100644 index 0089f0d1..00000000 --- a/policyengine_us_data/calibration/stacked_dataset_builder.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -CLI for creating CD-stacked datasets from calibration artifacts. - -Thin wrapper around build_h5/build_states/build_districts/build_cities -in publish_local_area.py. Loads a GeographyAssignment from geography.npz -and delegates all H5 building logic. -""" - -import os -import numpy as np -from pathlib import Path - -from policyengine_us_data.calibration.clone_and_assign import ( - load_geography, -) - -if __name__ == "__main__": - import argparse - - from policyengine_us import Microsimulation - from policyengine_us_data.calibration.publish_local_area import ( - build_h5, - build_states, - build_districts, - build_cities, - ) - from policyengine_us_data.utils.takeup import SIMPLE_TAKEUP_VARS - - parser = argparse.ArgumentParser( - description="Create CD-stacked datasets from calibration artifacts" - ) - parser.add_argument( - "--weights-path", - required=True, - help="Path to w_cd.npy file", - ) - parser.add_argument( - "--dataset-path", - required=True, - help="Path to stratified dataset .h5 file", - ) - parser.add_argument( - "--db-path", - required=True, - help="Path to policy_data.db", - ) - parser.add_argument( - "--geography-path", - required=True, - help="Path to geography.npz from calibration", - ) - parser.add_argument( - "--output-dir", - default="./temp", - help="Output directory for files", - ) - parser.add_argument( - "--mode", - choices=[ - "national", - "states", - "cds", - "single-cd", - "single-state", - "nyc", - ], - default="national", - help="Output mode", - ) - parser.add_argument( - "--cd", - type=str, - help="Single CD GEOID (--mode single-cd)", - ) - parser.add_argument( - "--state", - type=str, - help="State code e.g. RI, CA (--mode single-state)", - ) - - args = parser.parse_args() - weights_path = Path(args.weights_path) - dataset_path = Path(args.dataset_path) - db_path = Path(args.db_path).resolve() - output_dir = Path(args.output_dir) - mode = args.mode - - os.makedirs(output_dir, exist_ok=True) - - # === Load and validate === - w = np.load(str(weights_path)) - db_uri = f"sqlite:///{db_path}" - - # === Load geography (required) === - if not args.geography_path or not Path(args.geography_path).exists(): - raise ValueError( - f"--geography-path is required and must exist. " - f"Got: {args.geography_path}. " - f"Re-run calibration to generate geography.npz." - ) - geography = load_geography(args.geography_path) - print( - f"Loaded geography from {args.geography_path}: " - f"{geography.n_clones} clones x " - f"{geography.n_records} records" - ) - - print(f"Geography: {geography.n_clones} clones x {geography.n_records} records") - - takeup_filter = [spec["variable"] for spec in SIMPLE_TAKEUP_VARS] - - # === Dispatch === - if mode == "national": - output_path = output_dir / "US.h5" - print(f"\nCreating national dataset: {output_path}") - build_h5( - weights=w, - geography=geography, - dataset_path=dataset_path, - output_path=output_path, - takeup_filter=takeup_filter, - ) - - elif mode == "states": - build_states( - weights_path=weights_path, - dataset_path=dataset_path, - geography=geography, - output_dir=output_dir, - completed_states=set(), - takeup_filter=takeup_filter, - ) - - elif mode == "single-state": - if not args.state: - raise ValueError("--state required with --mode single-state") - build_states( - weights_path=weights_path, - dataset_path=dataset_path, - geography=geography, - output_dir=output_dir, - completed_states=set(), - takeup_filter=takeup_filter, - state_filter=args.state.upper(), - ) - - elif mode == "cds": - build_districts( - weights_path=weights_path, - dataset_path=dataset_path, - geography=geography, - output_dir=output_dir, - completed_districts=set(), - takeup_filter=takeup_filter, - ) - - elif mode == "single-cd": - if not args.cd: - raise ValueError("--cd required with --mode single-cd") - calibrated_cds = sorted(set(cd_geoid)) - if args.cd not in calibrated_cds: - raise ValueError(f"CD {args.cd} not in calibrated CDs") - output_path = output_dir / f"{args.cd}.h5" - print(f"\nCreating single CD dataset: {output_path}") - build_h5( - weights=w, - geography=geography, - dataset_path=dataset_path, - output_path=output_path, - cd_subset=[args.cd], - takeup_filter=takeup_filter, - ) - - elif mode == "nyc": - build_cities( - weights_path=weights_path, - dataset_path=dataset_path, - geography=geography, - output_dir=output_dir, - completed_cities=set(), - takeup_filter=takeup_filter, - ) - - print("\nDone!") diff --git a/policyengine_us_data/calibration/target_config.yaml b/policyengine_us_data/calibration/target_config.yaml index 175aefa7..477ae672 100644 --- a/policyengine_us_data/calibration/target_config.yaml +++ b/policyengine_us_data/calibration/target_config.yaml @@ -19,23 +19,17 @@ include: geo_level: district - variable: taxable_pension_income geo_level: district - # DISABLED: refundable_ctc formula doesn't gate on tax_unit_is_filer; - # non-filer values inflate totals beyond IRS SOI targets. - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: refundable_ctc - # geo_level: district + - variable: refundable_ctc + geo_level: district - variable: unemployment_compensation geo_level: district # === DISTRICT — ACA PTC === - # DISABLED: aca_ptc formula doesn't gate on tax_unit_is_filer; - # non-filer values inflate totals beyond IRS SOI targets. - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: aca_ptc - # geo_level: district - # - variable: tax_unit_count - # geo_level: district - # domain_variable: aca_ptc + - variable: aca_ptc + geo_level: district + - variable: tax_unit_count + geo_level: district + domain_variable: aca_ptc # === STATE === - variable: person_count @@ -54,11 +48,8 @@ include: geo_level: national - variable: child_support_received geo_level: national - # DISABLED: eitc formula doesn't gate on tax_unit_is_filer; - # non-filer values inflate totals beyond IRS SOI targets. - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: eitc - # geo_level: national + - variable: eitc + geo_level: national - variable: health_insurance_premiums_without_medicare_part_b geo_level: national - variable: medicaid @@ -97,19 +88,15 @@ include: geo_level: national # === NATIONAL — IRS SOI domain-constrained dollar targets === - # DISABLED: aca_ptc formula doesn't gate on tax_unit_is_filer - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: aca_ptc - # geo_level: national - # domain_variable: aca_ptc + - variable: aca_ptc + geo_level: national + domain_variable: aca_ptc - variable: dividend_income geo_level: national domain_variable: dividend_income - # DISABLED: eitc formula doesn't gate on tax_unit_is_filer - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: eitc - # geo_level: national - # domain_variable: eitc_child_count + - variable: eitc + geo_level: national + domain_variable: eitc_child_count - variable: income_tax_positive geo_level: national - variable: income_tax_before_credits @@ -124,11 +111,9 @@ include: - variable: qualified_dividend_income geo_level: national domain_variable: qualified_dividend_income - # DISABLED: refundable_ctc formula doesn't gate on tax_unit_is_filer - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: refundable_ctc - # geo_level: national - # domain_variable: refundable_ctc + - variable: refundable_ctc + geo_level: national + domain_variable: refundable_ctc - variable: rental_income geo_level: national domain_variable: rental_income @@ -161,19 +146,15 @@ include: domain_variable: unemployment_compensation # === NATIONAL — IRS SOI filer count targets === - # DISABLED: aca_ptc inflated by non-filers - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: tax_unit_count - # geo_level: national - # domain_variable: aca_ptc + - variable: tax_unit_count + geo_level: national + domain_variable: aca_ptc - variable: tax_unit_count geo_level: national domain_variable: dividend_income - # DISABLED: eitc inflated by non-filers - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: tax_unit_count - # geo_level: national - # domain_variable: eitc_child_count + - variable: tax_unit_count + geo_level: national + domain_variable: eitc_child_count - variable: tax_unit_count geo_level: national domain_variable: income_tax @@ -195,11 +176,9 @@ include: - variable: tax_unit_count geo_level: national domain_variable: real_estate_taxes - # DISABLED: refundable_ctc inflated by non-filers - # See https://github.com/PolicyEngine/policyengine-us/issues/7748 - # - variable: tax_unit_count - # geo_level: national - # domain_variable: refundable_ctc + - variable: tax_unit_count + geo_level: national + domain_variable: refundable_ctc - variable: tax_unit_count geo_level: national domain_variable: rental_income diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py index 66bc1f9b..f81d92bc 100644 --- a/policyengine_us_data/calibration/unified_calibration.py +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -1274,29 +1274,7 @@ def main(argv=None): logger.info("Weights saved to %s", output_path) print(f"OUTPUT_PATH:{output_path}") - # Save full geography for local-area pipeline - from policyengine_us_data.calibration.clone_and_assign import ( - GeographyAssignment, - save_geography, - ) - - geography = GeographyAssignment( - block_geoid=geography_info["block_geoid"], - cd_geoid=geography_info["cd_geoid"], - county_fips=np.array([b[:5] for b in geography_info["block_geoid"]]), - state_fips=np.array( - [int(b[:2]) for b in geography_info["block_geoid"]], - dtype=np.int32, - ), - n_records=geography_info["base_n_records"], - n_clones=len(sorted(set(geography_info["cd_geoid"].astype(str)))), - ) - geo_path = output_dir / "geography.npz" - save_geography(geography, geo_path) - logger.info("Geography saved to %s", geo_path) - print(f"GEOGRAPHY_PATH:{geo_path}") - - # Also save legacy artifacts for backward compatibility + # Save legacy block artifact for backward compatibility blocks_path = output_dir / "stacked_blocks.npy" np.save(str(blocks_path), geography_info["block_geoid"]) logger.info("Blocks saved to %s", blocks_path) @@ -1348,7 +1326,6 @@ def _sha256(filepath): "elapsed_seconds": round(t_end - t_start, 1), "artifacts": { "calibration_weights.npy": _sha256(output_path), - "geography.npz": _sha256(geo_path), }, } run_config.update(get_git_provenance()) diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py index 04d785ff..7fa80322 100644 --- a/policyengine_us_data/calibration/unified_matrix_builder.py +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -152,7 +152,38 @@ def _compute_single_state( exc, ) - return (state, {"hh": hh, "person": person, "entity": entity_vals}) + entity_wf_false = {} + if rerandomize_takeup: + has_tu_target = any( + info["entity"] == "tax_unit" for info in affected_targets.values() + ) + if has_tu_target: + n_tu = len(state_sim.calculate("tax_unit_id", map_to="tax_unit").values) + state_sim.set_input( + "would_file_taxes_voluntarily", + time_period, + np.zeros(n_tu, dtype=bool), + ) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + for tvar, info in affected_targets.items(): + if info["entity"] != "tax_unit": + continue + entity_wf_false[tvar] = state_sim.calculate( + tvar, + time_period, + map_to="tax_unit", + ).values.astype(np.float32) + + return ( + state, + { + "hh": hh, + "person": person, + "entity": entity_vals, + "entity_wf_false": entity_wf_false, + }, + ) def _compute_single_state_group_counties( @@ -222,11 +253,17 @@ def _compute_single_state_group_counties( time_period, np.full(n_hh, county_idx, dtype=np.int32), ) + if county_fips == "06037": + state_sim.set_input( + "zip_code", + time_period, + np.full(n_hh, "90001"), + ) if rerandomize_takeup: for vname, (ent, orig) in original_takeup.items(): state_sim.set_input(vname, time_period, orig) for var in get_calculated_variables(state_sim): - if var != "county": + if var not in ("county", "zip_code"): state_sim.delete_arrays(var) hh = {} @@ -257,7 +294,7 @@ def _compute_single_state_group_counties( np.ones(n_ent, dtype=bool), ) for var in get_calculated_variables(state_sim): - if var != "county": + if var not in ("county", "zip_code"): state_sim.delete_arrays(var) entity_vals = {} @@ -278,7 +315,15 @@ def _compute_single_state_group_counties( exc, ) - results.append((county_fips, {"hh": hh, "entity": entity_vals})) + results.append( + ( + county_fips, + { + "hh": hh, + "entity": entity_vals, + }, + ) + ) return results @@ -552,11 +597,39 @@ def _process_single_clone( # Takeup re-randomisation if do_takeup and affected_target_info: from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS, compute_block_takeup_for_entities, ) clone_blocks = geo_blocks[col_start:col_end] + # Phase 1: compute non-target draws (would_file) FIRST + wf_draws = {} + for spec in SIMPLE_TAKEUP_VARS: + if spec.get("target") is not None: + continue + var_name = spec["variable"] + entity = spec["entity"] + rate_key = spec["rate_key"] + if rate_key not in precomputed_rates: + continue + ent_hh = entity_hh_idx_map[entity] + ent_blocks = clone_blocks[ent_hh] + ent_hh_ids = household_ids[ent_hh] + ent_ci = np.full(len(ent_hh), clone_idx, dtype=np.int64) + draws = compute_block_takeup_for_entities( + var_name, + precomputed_rates[rate_key], + ent_blocks, + ent_hh_ids, + ent_ci, + ) + wf_draws[entity] = draws + if var_name in person_vars: + pidx = entity_to_person_idx[entity] + person_vars[var_name] = draws[pidx].astype(np.float32) + + # Phase 2: target loop with would_file blending for tvar, info in affected_target_info.items(): if tvar.endswith("_count"): continue @@ -586,14 +659,44 @@ def _process_single_clone( if tvar in sv: ent_eligible[m] = sv[tvar][m] + # Blend: for tax_unit targets, select between + # all-takeup-true and would_file=false values + if entity_level == "tax_unit" and "tax_unit" in wf_draws: + ent_wf_false = np.zeros(n_ent, dtype=np.float32) + if tvar in county_dep_targets and county_values: + ent_counties = clone_counties[ent_hh] + for cfips in np.unique(ent_counties): + m = ent_counties == cfips + cv = county_values.get(cfips, {}).get("entity_wf_false", {}) + if tvar in cv: + ent_wf_false[m] = cv[tvar][m] + else: + st = int(cfips[:2]) + sv = state_values[st].get("entity_wf_false", {}) + if tvar in sv: + ent_wf_false[m] = sv[tvar][m] + else: + for st in np.unique(ent_states): + m = ent_states == st + sv = state_values[int(st)].get("entity_wf_false", {}) + if tvar in sv: + ent_wf_false[m] = sv[tvar][m] + ent_eligible = np.where( + wf_draws["tax_unit"], + ent_eligible, + ent_wf_false, + ) + ent_blocks = clone_blocks[ent_hh] ent_hh_ids = household_ids[ent_hh] + ent_ci = np.full(n_ent, clone_idx, dtype=np.int64) ent_takeup = compute_block_takeup_for_entities( takeup_var, precomputed_rates[info["rate_key"]], ent_blocks, ent_hh_ids, + ent_ci, ) ent_values = (ent_eligible * ent_takeup).astype(np.float32) @@ -950,10 +1053,43 @@ def _build_state_values( exc, ) + entity_wf_false = {} + if rerandomize_takeup: + has_tu_target = any( + info["entity"] == "tax_unit" + for info in affected_targets.values() + ) + if has_tu_target: + n_tu = len( + state_sim.calculate( + "tax_unit_id", + map_to="tax_unit", + ).values + ) + state_sim.set_input( + "would_file_taxes_voluntarily", + self.time_period, + np.zeros(n_tu, dtype=bool), + ) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + for ( + tvar, + info, + ) in affected_targets.items(): + if info["entity"] != "tax_unit": + continue + entity_wf_false[tvar] = state_sim.calculate( + tvar, + self.time_period, + map_to="tax_unit", + ).values.astype(np.float32) + state_values[state] = { "hh": hh, "person": person, "entity": entity_vals, + "entity_wf_false": entity_wf_false, } if (i + 1) % 10 == 0 or i == 0: logger.info( @@ -1104,122 +1240,20 @@ def _build_county_values( f.cancel() raise RuntimeError(f"State group {sf} failed: {exc}") from exc else: - from policyengine_us import Microsimulation - from policyengine_us_data.utils.takeup import ( - SIMPLE_TAKEUP_VARS, - ) - county_count = 0 - for state_fips, counties in sorted(state_to_counties.items()): - state_sim = Microsimulation(dataset=self.dataset_path) - - state_sim.set_input( - "state_fips", + for sf, counties in sorted(state_to_counties.items()): + results = _compute_single_state_group_counties( + self.dataset_path, self.time_period, - np.full(n_hh, state_fips, dtype=np.int32), + sf, + counties, + n_hh, + county_dep_targets_list, + rerandomize_takeup, + affected_targets, ) - - original_takeup = {} - if rerandomize_takeup: - for spec in SIMPLE_TAKEUP_VARS: - entity = spec["entity"] - original_takeup[spec["variable"]] = ( - entity, - state_sim.calculate( - spec["variable"], - self.time_period, - map_to=entity, - ).values.copy(), - ) - - for county_fips in counties: - county_idx = get_county_enum_index_from_fips(county_fips) - state_sim.set_input( - "county", - self.time_period, - np.full( - n_hh, - county_idx, - dtype=np.int32, - ), - ) - if rerandomize_takeup: - for vname, ( - ent, - orig, - ) in original_takeup.items(): - state_sim.set_input( - vname, - self.time_period, - orig, - ) - for var in get_calculated_variables(state_sim): - if var != "county": - state_sim.delete_arrays(var) - - hh = {} - for var in county_dep_targets: - if var.endswith("_count"): - continue - try: - hh[var] = state_sim.calculate( - var, - self.time_period, - map_to="household", - ).values.astype(np.float32) - except Exception as exc: - logger.warning( - "Cannot calculate '%s' for county %s: %s", - var, - county_fips, - exc, - ) - - if rerandomize_takeup: - for spec in SIMPLE_TAKEUP_VARS: - entity = spec["entity"] - n_ent = len( - state_sim.calculate( - f"{entity}_id", - map_to=entity, - ).values - ) - state_sim.set_input( - spec["variable"], - self.time_period, - np.ones(n_ent, dtype=bool), - ) - for var in get_calculated_variables(state_sim): - if var != "county": - state_sim.delete_arrays(var) - - entity_vals = {} - if rerandomize_takeup: - for ( - tvar, - info, - ) in affected_targets.items(): - entity_level = info["entity"] - try: - entity_vals[tvar] = state_sim.calculate( - tvar, - self.time_period, - map_to=entity_level, - ).values.astype(np.float32) - except Exception as exc: - logger.warning( - "Cannot calculate " - "entity-level '%s' " - "for county %s: %s", - tvar, - county_fips, - exc, - ) - - county_values[county_fips] = { - "hh": hh, - "entity": entity_vals, - } + for county_fips, vals in results: + county_values[county_fips] = vals county_count += 1 if county_count % 500 == 0 or county_count == 1: logger.info( @@ -1928,10 +1962,14 @@ def build_matrix( len(affected_target_info), ) - # Pre-compute takeup rates (constant across clones) + # Pre-compute takeup rates for ALL takeup vars + from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS as _ALL_TAKEUP, + ) + precomputed_rates = {} - for tvar, info in affected_target_info.items(): - rk = info["rate_key"] + for spec in _ALL_TAKEUP: + rk = spec["rate_key"] if rk not in precomputed_rates: precomputed_rates[rk] = load_take_up_rate(rk, self.time_period) @@ -2083,6 +2121,38 @@ def build_matrix( # for affected target variables if rerandomize_takeup and affected_target_info: clone_blocks = geography.block_geoid[col_start:col_end] + + from policyengine_us_data.utils.takeup import ( + SIMPLE_TAKEUP_VARS as _SEQ_TAKEUP, + ) + + # Phase 1: non-target draws (would_file) FIRST + wf_draws = {} + for spec in _SEQ_TAKEUP: + if spec.get("target") is not None: + continue + var_name = spec["variable"] + entity = spec["entity"] + rate_key = spec["rate_key"] + if rate_key not in precomputed_rates: + continue + ent_hh = entity_hh_idx_map[entity] + ent_blocks = clone_blocks[ent_hh] + ent_hh_ids = household_ids[ent_hh] + ent_ci = np.full(len(ent_hh), clone_idx, dtype=np.int64) + draws = compute_block_takeup_for_entities( + var_name, + precomputed_rates[rate_key], + ent_blocks, + ent_hh_ids, + ent_ci, + ) + wf_draws[entity] = draws + if var_name in person_vars: + pidx = entity_to_person_idx[entity] + person_vars[var_name] = draws[pidx].astype(np.float32) + + # Phase 2: target loop with would_file blending for ( tvar, info, @@ -2116,14 +2186,47 @@ def build_matrix( if tvar in sv: ent_eligible[m] = sv[tvar][m] + # Blend for tax_unit targets + if entity_level == "tax_unit" and "tax_unit" in wf_draws: + ent_wf_false = np.zeros(n_ent, dtype=np.float32) + if tvar in county_dep_targets and county_values: + ent_counties = clone_counties[ent_hh] + for cfips in np.unique(ent_counties): + m = ent_counties == cfips + cv = county_values.get(cfips, {}).get( + "entity_wf_false", {} + ) + if tvar in cv: + ent_wf_false[m] = cv[tvar][m] + else: + st = int(cfips[:2]) + sv = state_values[st].get("entity_wf_false", {}) + if tvar in sv: + ent_wf_false[m] = sv[tvar][m] + else: + for st in np.unique(ent_states): + m = ent_states == st + sv = state_values[int(st)].get( + "entity_wf_false", {} + ) + if tvar in sv: + ent_wf_false[m] = sv[tvar][m] + ent_eligible = np.where( + wf_draws["tax_unit"], + ent_eligible, + ent_wf_false, + ) + ent_blocks = clone_blocks[ent_hh] ent_hh_ids = household_ids[ent_hh] + ent_ci = np.full(n_ent, clone_idx, dtype=np.int64) ent_takeup = compute_block_takeup_for_entities( takeup_var, precomputed_rates[info["rate_key"]], ent_blocks, ent_hh_ids, + ent_ci, ) ent_values = (ent_eligible * ent_takeup).astype(np.float32) diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 8f6f051c..a7d782cb 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -50,6 +50,10 @@ def fetch_congressional_districts(year): df = df.drop(columns=["n_districts"]) df.loc[df["district_number"] == 0, "district_number"] = 1 + df.loc[ + (df["state_fips"] == 11) & (df["district_number"] == 98), + "district_number", + ] = 1 df["congressional_district_geoid"] = df["state_fips"] * 100 + df["district_number"] df = df[ diff --git a/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py b/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py index f2b634e0..6afaa2a6 100644 --- a/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py +++ b/policyengine_us_data/storage/calibration_targets/make_block_cd_distributions.py @@ -77,9 +77,17 @@ def build_block_cd_distributions(): df["state_fips"] = df["GEOID"].str[:2] # Create CD geoid in our format: state_fips * 100 + district - # Examples: AL-1 = 101, NY-10 = 3610, DC = 1198 + # Examples: AL-1 = 101, NY-10 = 3610, DC = 1101 df["cd_geoid"] = df["state_fips"].astype(int) * 100 + df["CD119"].astype(int) + # Normalize at-large districts: Census uses 00 (and 98 for DC) → convert to 01 + district_num = df["cd_geoid"] % 100 + state_fips_int = df["state_fips"].astype(int) + at_large_mask = (district_num == 0) | ( + (state_fips_int == 11) & (district_num == 98) + ) + df.loc[at_large_mask, "cd_geoid"] = state_fips_int[at_large_mask] * 100 + 1 + # Step 4: Calculate P(block|CD) print("\nCalculating block probabilities...") cd_totals = df.groupby("cd_geoid")["POP20"].transform("sum") diff --git a/policyengine_us_data/storage/download_private_prerequisites.py b/policyengine_us_data/storage/download_private_prerequisites.py index 94586a81..4d8a977d 100644 --- a/policyengine_us_data/storage/download_private_prerequisites.py +++ b/policyengine_us_data/storage/download_private_prerequisites.py @@ -1,3 +1,5 @@ +import os + from policyengine_us_data.utils.huggingface import download from pathlib import Path @@ -27,9 +29,15 @@ local_folder=FOLDER, version=None, ) -download( - repo="policyengine/policyengine-us-data", - repo_filename="calibration/policy_data.db", - local_folder=FOLDER, - version=None, -) +if os.environ.get("SKIP_POLICY_DB_DOWNLOAD"): + print( + "SKIP_POLICY_DB_DOWNLOAD set — skipping " + "policy_data.db download from HuggingFace" + ) +else: + download( + repo="policyengine/policyengine-us-data", + repo_filename="calibration/policy_data.db", + local_folder=FOLDER, + version=None, + ) diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index 7af0da04..5a15739c 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -163,22 +163,33 @@ def _check_group_has_data(f, name): print(f" Household weight sum: {hh_weight:,.0f}") -def upload_datasets(): - dataset_files = [ - EnhancedCPS_2024.file_path, +def upload_datasets(require_enhanced_cps: bool = True): + required_files = [ CPS_2024.file_path, - STORAGE_FOLDER / "small_enhanced_cps_2024.h5", STORAGE_FOLDER / "calibration" / "policy_data.db", ] + enhanced_files = [ + EnhancedCPS_2024.file_path, + STORAGE_FOLDER / "small_enhanced_cps_2024.h5", + ] + if require_enhanced_cps: + required_files.extend(enhanced_files) - # Filter to only existing files existing_files = [] - for file_path in dataset_files: + for file_path in required_files: if file_path.exists(): existing_files.append(file_path) print(f"✓ Found: {file_path}") else: - raise FileNotFoundError(f"File not found: {file_path}") + raise FileNotFoundError(f"Required file not found: {file_path}") + + if not require_enhanced_cps: + for file_path in enhanced_files: + if file_path.exists(): + existing_files.append(file_path) + print(f"✓ Found (optional): {file_path}") + else: + print(f"⚠ Skipping (not built): {file_path}") if not existing_files: raise ValueError("No dataset files found to upload!") @@ -211,4 +222,13 @@ def validate_all_datasets(): if __name__ == "__main__": - upload_datasets() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--no-require-enhanced-cps", + action="store_true", + help="Treat enhanced_cps and small_enhanced_cps as optional.", + ) + args = parser.parse_args() + upload_datasets(require_enhanced_cps=not args.no_require_enhanced_cps) diff --git a/policyengine_us_data/tests/conftest.py b/policyengine_us_data/tests/conftest.py new file mode 100644 index 00000000..0af57ca1 --- /dev/null +++ b/policyengine_us_data/tests/conftest.py @@ -0,0 +1,86 @@ +"""Shared fixtures and helpers for version manifest tests.""" + +import json +from unittest.mock import MagicMock + +import pytest + +from policyengine_us_data.utils.version_manifest import ( + HFVersionInfo, + GCSVersionInfo, + VersionManifest, + VersionRegistry, +) + +# -- Fixtures ------------------------------------------------------ + + +@pytest.fixture +def sample_generations() -> dict[str, int]: + return { + "enhanced_cps_2024.h5": 1710203948123456, + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + } + + +@pytest.fixture +def sample_hf_info() -> HFVersionInfo: + return HFVersionInfo( + repo="policyengine/policyengine-us-data", + commit="abc123def456", + ) + + +@pytest.fixture +def sample_manifest( + sample_generations: dict[str, int], + sample_hf_info: HFVersionInfo, +) -> VersionManifest: + return VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + + +@pytest.fixture +def sample_registry( + sample_manifest: VersionManifest, +) -> VersionRegistry: + """A registry with one version entry.""" + return VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + + +@pytest.fixture +def mock_bucket() -> MagicMock: + bucket = MagicMock() + bucket.name = "policyengine-us-data" + return bucket + + +# -- Helpers ------------------------------------------------------- + + +def make_mock_blob(generation: int) -> MagicMock: + blob = MagicMock() + blob.generation = generation + return blob + + +def setup_bucket_with_registry( + bucket: MagicMock, + registry: VersionRegistry, +) -> None: + """Configure a mock bucket to serve a registry.""" + registry_json = json.dumps(registry.to_dict()) + blob = MagicMock() + blob.download_as_text.return_value = registry_json + bucket.blob.return_value = blob diff --git a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py index 28a3c906..1283dabe 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py @@ -74,44 +74,61 @@ def test_rate_comparison_produces_booleans(self): class TestBlockSaltedDraws: """Verify compute_block_takeup_for_entities produces - reproducible, block-dependent draws.""" + reproducible, clone-dependent draws.""" - def test_same_block_same_results(self): - blocks = np.array(["370010001001001"] * 500) - d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) - d2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + def test_same_inputs_same_results(self): + n = 500 + blocks = np.array(["370010001001001"] * n) + hh_ids = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci + ) + d2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci + ) np.testing.assert_array_equal(d1, d2) - def test_different_blocks_different_results(self): + def test_different_clone_idx_different_results(self): n = 500 + blocks = np.array(["370010001001001"] * n) + hh_ids = np.arange(n, dtype=np.int64) + ci0 = np.zeros(n, dtype=np.int64) + ci1 = np.ones(n, dtype=np.int64) d1 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", - 0.8, - np.array(["370010001001001"] * n), + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci0 ) d2 = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", - 0.8, - np.array(["480010002002002"] * n), + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci1 ) assert not np.array_equal(d1, d2) def test_different_vars_different_results(self): - blocks = np.array(["370010001001001"] * 500) - d1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) - d2 = compute_block_takeup_for_entities("takes_up_aca_if_eligible", 0.8, blocks) + n = 500 + blocks = np.array(["370010001001001"] * n) + hh_ids = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci + ) + d2 = compute_block_takeup_for_entities( + "takes_up_aca_if_eligible", 0.8, blocks, hh_ids, ci + ) assert not np.array_equal(d1, d2) - def test_hh_salt_differs_from_block_only(self): - blocks = np.array(["370010001001001"] * 500) - hh_ids = np.array([1] * 500) - d_block = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks + def test_different_hh_ids_different_results(self): + n = 500 + blocks = np.array(["370010001001001"] * n) + ci = np.zeros(n, dtype=np.int64) + hh_a = np.arange(n, dtype=np.int64) + hh_b = np.arange(n, dtype=np.int64) + 1000 + d1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_a, ci ) - d_hh = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks, hh_ids + d2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_b, ci ) - assert not np.array_equal(d_block, d_hh) + assert not np.array_equal(d1, d2) class TestApplyBlockTakeupToArrays: @@ -126,6 +143,7 @@ def _make_arrays(self, n_hh, persons_per_hh, tu_per_hh, spm_per_hh): hh_blocks = np.array(["370010001001001"] * n_hh) hh_state_fips = np.array([37] * n_hh, dtype=np.int32) hh_ids = np.arange(n_hh, dtype=np.int64) + hh_clone_indices = np.zeros(n_hh, dtype=np.int64) entity_hh_indices = { "person": np.repeat(np.arange(n_hh), persons_per_hh), "tax_unit": np.repeat(np.arange(n_hh), tu_per_hh), @@ -140,6 +158,7 @@ def _make_arrays(self, n_hh, persons_per_hh, tu_per_hh, spm_per_hh): hh_blocks, hh_state_fips, hh_ids, + hh_clone_indices, entity_hh_indices, entity_counts, ) @@ -336,38 +355,61 @@ def test_county_fips_length(self): class TestBlockTakeupSeeding: """Verify compute_block_takeup_for_entities is - reproducible and block-dependent.""" + reproducible and clone-dependent.""" def test_reproducible(self): + n = 100 blocks = np.array(["010010001001001"] * 50 + ["020010001001001"] * 50) - r1 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) - r2 = compute_block_takeup_for_entities("takes_up_snap_if_eligible", 0.8, blocks) + hh_ids = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) + r1 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci + ) + r2 = compute_block_takeup_for_entities( + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci + ) np.testing.assert_array_equal(r1, r2) - def test_different_blocks_different_draws(self): + def test_different_blocks_different_rates(self): + """With state-dependent rates, different blocks yield + different takeup because rate thresholds differ.""" n = 500 - blocks_a = np.array(["010010001001001"] * n) - blocks_b = np.array(["020010001001001"] * n) + hh_ids = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) + rate_dict = {"AL": 0.9, "AK": 0.3} r_a = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks_a + "takes_up_snap_if_eligible", + rate_dict, + np.array(["010010001001001"] * n), + hh_ids, + ci, ) r_b = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks_b + "takes_up_snap_if_eligible", + rate_dict, + np.array(["020010001001001"] * n), + hh_ids, + ci, ) assert not np.array_equal(r_a, r_b) def test_returns_booleans(self): - blocks = np.array(["370010001001001"] * 100) + n = 100 + blocks = np.array(["370010001001001"] * n) + hh_ids = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) result = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.8, blocks + "takes_up_snap_if_eligible", 0.8, blocks, hh_ids, ci ) assert result.dtype == bool def test_rate_respected(self): n = 10000 blocks = np.array(["370010001001001"] * n) + hh_ids = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) result = compute_block_takeup_for_entities( - "takes_up_snap_if_eligible", 0.75, blocks + "takes_up_snap_if_eligible", 0.75, blocks, hh_ids, ci ) frac = result.mean() assert 0.70 < frac < 0.80 @@ -481,6 +523,7 @@ def test_matrix_and_stacked_identical_draws(self): """Both paths must produce identical boolean arrays.""" var = "takes_up_snap_if_eligible" rate = 0.75 + clone_idx = 5 # 2 blocks, 3 households, variable entity counts per HH # HH0 has 2 entities in block A @@ -497,20 +540,23 @@ def test_matrix_and_stacked_identical_draws(self): ] ) hh_ids = np.array([100, 100, 200, 200, 200, 300]) + ci = np.full(len(blocks), clone_idx, dtype=np.int64) - # Path 1: compute_block_takeup_for_entities (stacked) - stacked = compute_block_takeup_for_entities(var, rate, blocks, hh_ids) + # Path 1: compute_block_takeup_for_entities + stacked = compute_block_takeup_for_entities(var, rate, blocks, hh_ids, ci) - # Path 2: reproduce matrix builder inline logic + # Path 2: reproduce inline logic with hh_id:clone_idx salt n = len(blocks) inline_takeup = np.zeros(n, dtype=bool) - for blk in np.unique(blocks): - bm = blocks == blk - for hh_id in np.unique(hh_ids[bm]): - hh_mask = bm & (hh_ids == hh_id) - rng = seeded_rng(var, salt=f"{blk}:{int(hh_id)}") - draws = rng.random(int(hh_mask.sum())) - inline_takeup[hh_mask] = draws < rate + for hh_id in np.unique(hh_ids): + hh_mask = hh_ids == hh_id + rng = seeded_rng(var, salt=f"{int(hh_id)}:{clone_idx}") + draws = rng.random(int(hh_mask.sum())) + # Rate from block's state FIPS + blk = blocks[hh_mask][0] + sf = int(str(blk)[:2]) + r = _resolve_rate(rate, sf) + inline_takeup[hh_mask] = draws < r np.testing.assert_array_equal(stacked, inline_takeup) @@ -542,18 +588,22 @@ def test_state_specific_rate_resolved_from_block(self): n = 5000 blocks_nc = np.array(["370010001001001"] * n) - result_nc = compute_block_takeup_for_entities(var, rate_dict, blocks_nc) - # NC rate=0.9, expect ~90% + hh_ids_nc = np.arange(n, dtype=np.int64) + ci = np.zeros(n, dtype=np.int64) + result_nc = compute_block_takeup_for_entities( + var, rate_dict, blocks_nc, hh_ids_nc, ci + ) frac_nc = result_nc.mean() assert 0.85 < frac_nc < 0.95, f"NC frac={frac_nc}" blocks_tx = np.array(["480010002002002"] * n) - result_tx = compute_block_takeup_for_entities(var, rate_dict, blocks_tx) - # TX rate=0.6, expect ~60% + hh_ids_tx = np.arange(n, dtype=np.int64) + result_tx = compute_block_takeup_for_entities( + var, rate_dict, blocks_tx, hh_ids_tx, ci + ) frac_tx = result_tx.mean() assert 0.55 < frac_tx < 0.65, f"TX frac={frac_tx}" - # Verify _resolve_rate actually gives different rates assert _resolve_rate(rate_dict, 37) == 0.9 assert _resolve_rate(rate_dict, 48) == 0.6 diff --git a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py index dbc76fb1..492719d9 100644 --- a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py +++ b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py @@ -685,12 +685,11 @@ def test_returns_empty_when_no_targets(self): @patch( "policyengine_us_data.calibration" - ".unified_matrix_builder.get_county_enum_index_from_fips", + ".block_assignment.get_county_enum_index_from_fips", return_value=1, ) @patch( - "policyengine_us_data.calibration" - ".unified_matrix_builder.get_calculated_variables", + "policyengine_us_data.calibration.calibration_utils.get_calculated_variables", return_value=["var_a"], ) @patch("policyengine_us.Microsimulation") @@ -718,12 +717,11 @@ def test_return_structure(self, mock_msim_cls, mock_gcv, mock_county_idx): @patch( "policyengine_us_data.calibration" - ".unified_matrix_builder.get_county_enum_index_from_fips", + ".block_assignment.get_county_enum_index_from_fips", return_value=1, ) @patch( - "policyengine_us_data.calibration" - ".unified_matrix_builder.get_calculated_variables", + "policyengine_us_data.calibration.calibration_utils.get_calculated_variables", return_value=["var_a"], ) @patch("policyengine_us.Microsimulation") @@ -749,12 +747,11 @@ def test_sim_reuse_within_state(self, mock_msim_cls, mock_gcv, mock_county_idx): @patch( "policyengine_us_data.calibration" - ".unified_matrix_builder.get_county_enum_index_from_fips", + ".block_assignment.get_county_enum_index_from_fips", return_value=1, ) @patch( - "policyengine_us_data.calibration" - ".unified_matrix_builder.get_calculated_variables", + "policyengine_us_data.calibration.calibration_utils.get_calculated_variables", return_value=[], ) @patch("policyengine_us.Microsimulation") @@ -778,12 +775,11 @@ def test_fresh_sim_across_states(self, mock_msim_cls, mock_gcv, mock_county_idx) @patch( "policyengine_us_data.calibration" - ".unified_matrix_builder.get_county_enum_index_from_fips", + ".block_assignment.get_county_enum_index_from_fips", return_value=1, ) @patch( - "policyengine_us_data.calibration" - ".unified_matrix_builder.get_calculated_variables", + "policyengine_us_data.calibration.calibration_utils.get_calculated_variables", return_value=["var_a", "county"], ) @patch("policyengine_us.Microsimulation") @@ -940,12 +936,11 @@ def _make_geo(self, county_fips_list, n_records=4): ) @patch( "policyengine_us_data.calibration" - ".unified_matrix_builder.get_county_enum_index_from_fips", + ".block_assignment.get_county_enum_index_from_fips", return_value=1, ) @patch( - "policyengine_us_data.calibration" - ".unified_matrix_builder.get_calculated_variables", + "policyengine_us_data.calibration.calibration_utils.get_calculated_variables", return_value=[], ) @patch("policyengine_us.Microsimulation") @@ -984,12 +979,11 @@ def test_workers_gt1_creates_pool( @patch( "policyengine_us_data.calibration" - ".unified_matrix_builder.get_county_enum_index_from_fips", + ".block_assignment.get_county_enum_index_from_fips", return_value=1, ) @patch( - "policyengine_us_data.calibration" - ".unified_matrix_builder.get_calculated_variables", + "policyengine_us_data.calibration.calibration_utils.get_calculated_variables", return_value=[], ) @patch("policyengine_us.Microsimulation") diff --git a/policyengine_us_data/tests/test_datasets/conftest.py b/policyengine_us_data/tests/test_datasets/conftest.py new file mode 100644 index 00000000..776d30d9 --- /dev/null +++ b/policyengine_us_data/tests/test_datasets/conftest.py @@ -0,0 +1,26 @@ +"""Skip dataset tests that need full data build artifacts. + +In basic CI (full_suite=false), H5 files are not built locally +and Microsimulation requires ~16GB RAM. These tests run inside +Modal containers (32GB) during full_suite=true builds. +""" + +import pytest +from policyengine_us_data.storage import STORAGE_FOLDER + +NEEDS_ECPS = not (STORAGE_FOLDER / "enhanced_cps_2024.h5").exists() +NEEDS_CPS = not (STORAGE_FOLDER / "cps_2024.h5").exists() + +collect_ignore_glob = [] +if NEEDS_ECPS: + collect_ignore_glob.extend( + [ + "test_enhanced_cps.py", + "test_dataset_sanity.py", + "test_small_enhanced_cps.py", + "test_sparse_enhanced_cps.py", + "test_sipp_assets.py", + ] + ) +if NEEDS_CPS: + collect_ignore_glob.append("test_cps.py") diff --git a/policyengine_us_data/tests/test_pipeline.py b/policyengine_us_data/tests/test_pipeline.py new file mode 100644 index 00000000..8894dc33 --- /dev/null +++ b/policyengine_us_data/tests/test_pipeline.py @@ -0,0 +1,263 @@ +"""Tests for pipeline orchestrator metadata and helpers.""" + +import json +import time +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +modal = pytest.importorskip("modal") + +from modal_app.pipeline import ( + RunMetadata, + _step_completed, + _record_step, + generate_run_id, + write_run_meta, + read_run_meta, +) + + +# -- RunMetadata tests ------------------------------------------ + + +class TestRunMetadata: + def test_to_dict(self): + meta = RunMetadata( + run_id="1.72.3_abc12345_20260319_120000", + branch="main", + sha="abc12345deadbeef", + version="1.72.3", + start_time="2026-03-19T12:00:00Z", + status="running", + ) + d = meta.to_dict() + + assert d["run_id"] == ("1.72.3_abc12345_20260319_120000") + assert d["branch"] == "main" + assert d["sha"] == "abc12345deadbeef" + assert d["version"] == "1.72.3" + assert d["status"] == "running" + assert d["step_timings"] == {} + assert d["error"] is None + + def test_from_dict(self): + data = { + "run_id": "1.72.3_abc12345_20260319_120000", + "branch": "main", + "sha": "abc12345deadbeef", + "version": "1.72.3", + "start_time": "2026-03-19T12:00:00Z", + "status": "completed", + "step_timings": { + "build_datasets": { + "status": "completed", + "duration_s": 100.0, + } + }, + "error": None, + } + meta = RunMetadata.from_dict(data) + + assert meta.run_id == ("1.72.3_abc12345_20260319_120000") + assert meta.status == "completed" + assert meta.step_timings["build_datasets"]["status"] == "completed" + + def test_roundtrip(self): + meta = RunMetadata( + run_id="1.72.3_abc12345_20260319_120000", + branch="main", + sha="abc12345deadbeef", + version="1.72.3", + start_time="2026-03-19T12:00:00Z", + status="failed", + error="RuntimeError: test", + ) + roundtripped = RunMetadata.from_dict(meta.to_dict()) + + assert roundtripped.run_id == meta.run_id + assert roundtripped.status == meta.status + assert roundtripped.error == meta.error + + def test_step_timings_default_empty(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + ) + assert meta.step_timings == {} + + +# -- generate_run_id tests ------------------------------------- + + +class TestGenerateRunId: + def test_format(self): + run_id = generate_run_id("1.72.3", "abc12345deadbeef") + + parts = run_id.split("_") + assert parts[0] == "1.72.3" + assert parts[1] == "abc12345" + assert len(parts) == 4 # version_sha_date_time + + def test_sha_truncated_to_8(self): + run_id = generate_run_id("1.0.0", "abcdef1234567890") + sha_part = run_id.split("_")[1] + assert sha_part == "abcdef12" + assert len(sha_part) == 8 + + def test_unique_ids(self): + id1 = generate_run_id("1.0.0", "abc123") + time.sleep(0.01) + id2 = generate_run_id("1.0.0", "abc123") + # Timestamps should differ (or at least + # the function doesn't reuse) + assert isinstance(id1, str) + assert isinstance(id2, str) + + +# -- _step_completed tests ------------------------------------ + + +class TestStepCompleted: + def test_completed_step(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + step_timings={ + "build_datasets": { + "status": "completed", + "duration_s": 50.0, + } + }, + ) + assert _step_completed(meta, "build_datasets") + + def test_incomplete_step(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + step_timings={ + "build_datasets": { + "status": "failed", + "duration_s": 10.0, + } + }, + ) + assert not _step_completed(meta, "build_datasets") + + def test_missing_step(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + ) + assert not _step_completed(meta, "build_datasets") + + +# -- _record_step tests ---------------------------------------- + + +class TestRecordStep: + def test_records_timing(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + ) + mock_vol = MagicMock() + start = time.time() - 5.0 + + with patch("modal_app.pipeline.write_run_meta"): + _record_step(meta, "build_datasets", start, mock_vol) + + timing = meta.step_timings["build_datasets"] + assert timing["status"] == "completed" + assert timing["duration_s"] >= 5.0 + assert "start" in timing + assert "end" in timing + + def test_records_custom_status(self): + meta = RunMetadata( + run_id="test", + branch="main", + sha="abc", + version="1.0.0", + start_time="now", + status="running", + ) + mock_vol = MagicMock() + + with patch("modal_app.pipeline.write_run_meta"): + _record_step( + meta, + "build_datasets", + time.time(), + mock_vol, + status="failed", + ) + + assert meta.step_timings["build_datasets"]["status"] == "failed" + + +# -- write/read_run_meta tests -------------------------------- + + +class TestRunMetaIO: + def test_write_and_read(self, tmp_path): + meta = RunMetadata( + run_id="test_run", + branch="main", + sha="abc123", + version="1.0.0", + start_time="2026-03-19T12:00:00Z", + status="running", + ) + mock_vol = MagicMock() + + runs_dir = tmp_path / "runs" + + with patch( + "modal_app.pipeline.RUNS_DIR", + str(runs_dir), + ): + write_run_meta(meta, mock_vol) + mock_vol.commit.assert_called_once() + + # Verify file was written + meta_path = runs_dir / "test_run" / "meta.json" + assert meta_path.exists() + + with open(meta_path) as f: + data = json.load(f) + assert data["run_id"] == "test_run" + assert data["status"] == "running" + + def test_read_nonexistent_raises(self): + mock_vol = MagicMock() + + with patch( + "modal_app.pipeline.RUNS_DIR", + "/nonexistent", + ): + with pytest.raises(FileNotFoundError): + read_run_meta("fake_run", mock_vol) diff --git a/policyengine_us_data/tests/test_version_manifest.py b/policyengine_us_data/tests/test_version_manifest.py new file mode 100644 index 00000000..573841e6 --- /dev/null +++ b/policyengine_us_data/tests/test_version_manifest.py @@ -0,0 +1,850 @@ +"""Tests for version manifest registry system.""" + +import json +from unittest.mock import MagicMock, patch, call + +import pytest +from google.api_core.exceptions import NotFound + +from policyengine_us_data.utils.version_manifest import ( + GCSVersionInfo, + VersionManifest, + VersionRegistry, + build_manifest, + upload_manifest, + get_current_version, + get_manifest, + list_versions, + download_versioned_file, + rollback, + get_data_manifest, + get_data_version, +) +from policyengine_us_data.tests.conftest import ( + make_mock_blob, + setup_bucket_with_registry, +) + +_MOD = "policyengine_us_data.utils.version_manifest" + + +# -- VersionManifest serialization tests --------------------------- + + +class TestVersionManifestSerialization: + def test_to_dict(self, sample_manifest): + result = sample_manifest.to_dict() + + assert result["version"] == "1.72.3" + assert result["created_at"] == "2026-03-10T14:30:00Z" + assert result["hf"]["repo"] == ("policyengine/policyengine-us-data") + assert result["hf"]["commit"] == "abc123def456" + assert result["gcs"]["bucket"] == ("policyengine-us-data") + assert result["gcs"]["generations"]["enhanced_cps_2024.h5"] == 1710203948123456 + + def test_from_dict(self, sample_manifest): + data = { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": { + "repo": ("policyengine/policyengine-us-data"), + "commit": "abc123def456", + }, + "gcs": { + "bucket": "policyengine-us-data", + "generations": { + "enhanced_cps_2024.h5": (1710203948123456), + "cps_2024.h5": 1710203948234567, + "states/AL.h5": 1710203948345678, + }, + }, + } + result = VersionManifest.from_dict(data) + + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + assert result.hf.repo == ("policyengine/policyengine-us-data") + assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 + assert result.gcs.bucket == "policyengine-us-data" + + def test_roundtrip(self, sample_manifest): + roundtripped = VersionManifest.from_dict(sample_manifest.to_dict()) + + assert roundtripped.version == (sample_manifest.version) + assert roundtripped.created_at == (sample_manifest.created_at) + assert roundtripped.hf.repo == (sample_manifest.hf.repo) + assert roundtripped.hf.commit == (sample_manifest.hf.commit) + assert roundtripped.gcs.bucket == (sample_manifest.gcs.bucket) + assert roundtripped.gcs.generations == (sample_manifest.gcs.generations) + + def test_without_hf(self, sample_generations): + manifest = VersionManifest( + version="1.72.3", + created_at="2026-03-10T14:30:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + ) + data = manifest.to_dict() + assert data["hf"] is None + + roundtripped = VersionManifest.from_dict(data) + assert roundtripped.hf is None + assert roundtripped.gcs.generations == (sample_generations) + + def test_special_operation_omitted_by_default(self, sample_manifest): + data = sample_manifest.to_dict() + assert "special_operation" not in data + assert "roll_back_version" not in data + + def test_special_operation_included_when_set( + self, sample_generations, sample_hf_info + ): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + special_operation="roll-back", + roll_back_version="1.70.1", + ) + data = manifest.to_dict() + assert data["special_operation"] == "roll-back" + assert data["roll_back_version"] == "1.70.1" + + def test_special_operation_roundtrip(self, sample_generations, sample_hf_info): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + special_operation="roll-back", + roll_back_version="1.70.1", + ) + roundtripped = VersionManifest.from_dict(manifest.to_dict()) + assert roundtripped.special_operation == ("roll-back") + assert roundtripped.roll_back_version == "1.70.1" + + def test_regular_manifest_has_no_special_operation( + self, + ): + data = { + "version": "1.72.3", + "created_at": "2026-03-10T14:30:00Z", + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 123}, + }, + } + result = VersionManifest.from_dict(data) + assert result.special_operation is None + assert result.roll_back_version is None + + def test_pipeline_run_id_omitted_by_default(self, sample_manifest): + data = sample_manifest.to_dict() + assert "pipeline_run_id" not in data + assert "diagnostics_path" not in data + + def test_pipeline_run_id_included_when_set( + self, sample_generations, sample_hf_info + ): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + pipeline_run_id="1.73.0_abc12345_20260310", + diagnostics_path=("calibration/runs/1.73.0_abc12345_20260310/diagnostics/"), + ) + data = manifest.to_dict() + assert data["pipeline_run_id"] == ("1.73.0_abc12345_20260310") + assert "diagnostics/" in data["diagnostics_path"] + + def test_pipeline_run_id_roundtrip(self, sample_generations, sample_hf_info): + manifest = VersionManifest( + version="1.73.0", + created_at="2026-03-10T15:00:00Z", + hf=sample_hf_info, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations=sample_generations, + ), + pipeline_run_id="1.73.0_abc12345_20260310", + diagnostics_path="calibration/runs/x/diag/", + ) + roundtripped = VersionManifest.from_dict(manifest.to_dict()) + assert roundtripped.pipeline_run_id == ("1.73.0_abc12345_20260310") + assert roundtripped.diagnostics_path == ("calibration/runs/x/diag/") + + +# -- VersionRegistry serialization tests --------------------------- + + +class TestVersionRegistrySerialization: + def test_to_dict(self, sample_registry): + result = sample_registry.to_dict() + + assert result["current"] == "1.72.3" + assert len(result["versions"]) == 1 + assert result["versions"][0]["version"] == "1.72.3" + + def test_from_dict(self, sample_manifest): + data = { + "current": "1.72.3", + "versions": [sample_manifest.to_dict()], + } + result = VersionRegistry.from_dict(data) + + assert result.current == "1.72.3" + assert len(result.versions) == 1 + assert result.versions[0].version == "1.72.3" + assert result.versions[0].hf.commit == ("abc123def456") + + def test_roundtrip(self, sample_registry): + roundtripped = VersionRegistry.from_dict(sample_registry.to_dict()) + assert roundtripped.current == (sample_registry.current) + assert len(roundtripped.versions) == len(sample_registry.versions) + assert roundtripped.versions[0].version == "1.72.3" + + def test_get_version(self, sample_registry): + result = sample_registry.get_version("1.72.3") + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + + def test_get_version_not_found(self, sample_registry): + with pytest.raises(ValueError, match="not found"): + sample_registry.get_version("9.9.9") + + def test_empty_registry(self): + registry = VersionRegistry() + assert registry.current == "" + assert registry.versions == [] + + data = registry.to_dict() + assert data == {"current": "", "versions": []} + + +# -- build_manifest tests ------------------------------------------ + + +class TestBuildManifest: + @patch(f"{_MOD}._get_gcs_bucket") + def test_structure(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob_names = [ + "file_a.h5", + "file_b.h5", + "file_c.h5", + ] + mock_bucket.get_blob.side_effect = [ + make_mock_blob(100), + make_mock_blob(200), + make_mock_blob(300), + ] + + result = build_manifest("1.72.3", blob_names) + + assert isinstance(result, VersionManifest) + assert result.version == "1.72.3" + assert result.created_at.endswith("Z") + assert result.gcs.generations == { + "file_a.h5": 100, + "file_b.h5": 200, + "file_c.h5": 300, + } + assert result.gcs.bucket == "policyengine-us-data" + assert result.hf is None + + @patch(f"{_MOD}._get_gcs_bucket") + def test_with_subdirectories(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob_names = [ + "states/AL.h5", + "districts/CA-01.h5", + ] + mock_bucket.get_blob.side_effect = [ + make_mock_blob(111), + make_mock_blob(222), + ] + + result = build_manifest("1.72.3", blob_names) + + assert "states/AL.h5" in result.gcs.generations + assert "districts/CA-01.h5" in result.gcs.generations + assert result.gcs.generations["states/AL.h5"] == 111 + assert result.gcs.generations["districts/CA-01.h5"] == 222 + + @patch(f"{_MOD}._get_gcs_bucket") + def test_with_hf_info( + self, + mock_get_bucket, + mock_bucket, + sample_hf_info, + ): + mock_get_bucket.return_value = mock_bucket + mock_bucket.get_blob.return_value = make_mock_blob(999) + + result = build_manifest( + "1.72.3", + ["file.h5"], + hf_info=sample_hf_info, + ) + + assert result.hf is not None + assert result.hf.commit == "abc123def456" + assert result.hf.repo == ("policyengine/policyengine-us-data") + + @patch(f"{_MOD}._get_gcs_bucket") + def test_missing_blob_raises(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + mock_bucket.get_blob.return_value = None + + with pytest.raises(ValueError, match="not found"): + build_manifest("1.72.3", ["missing.h5"]) + + +# -- upload_manifest tests ----------------------------------------- + + +class TestUploadManifest: + def _setup_empty_registry(self, bucket): + """Mock bucket with no existing registry.""" + written = {} + + def mock_blob(name): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.side_effect = NotFound("Not found") + written[name] = b + return b + b = MagicMock() + b.name = name + written[name] = b + return b + + bucket.blob.side_effect = mock_blob + return written + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_writes_registry_to_gcs( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + written = self._setup_empty_registry(mock_bucket) + + upload_manifest(sample_manifest) + + assert "version_manifest.json" in written + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 1 + assert registry_data["versions"][0]["version"] == "1.72.3" + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_includes_hf_commit( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + written = self._setup_empty_registry(mock_bucket) + + upload_manifest(sample_manifest) + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["versions"][0]["hf"]["commit"] == "abc123def456" + + @patch(f"{_MOD}._upload_registry_to_hf") + @patch(f"{_MOD}._get_gcs_bucket") + def test_appends_to_existing_registry( + self, + mock_get_bucket, + mock_hf, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + older = VersionManifest( + version="1.72.2", + created_at="2026-03-09T10:00:00Z", + hf=None, + gcs=GCSVersionInfo( + bucket="policyengine-us-data", + generations={"old.h5": 111}, + ), + ) + existing_registry = VersionRegistry(current="1.72.2", versions=[older]) + existing_json = json.dumps(existing_registry.to_dict()) + written = {} + + def mock_blob(name): + b = MagicMock() + b.name = name + b.download_as_text.return_value = existing_json + written[name] = b + return b + + mock_bucket.blob.side_effect = mock_blob + + upload_manifest(sample_manifest) + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.72.3" + assert len(registry_data["versions"]) == 2 + assert registry_data["versions"][0]["version"] == "1.72.3" + assert registry_data["versions"][1]["version"] == "1.72.2" + + @patch(f"{_MOD}.os") + @patch(f"{_MOD}.HfApi") + @patch(f"{_MOD}._get_gcs_bucket") + def test_always_uploads_to_hf( + self, + mock_get_bucket, + mock_hf_api_cls, + mock_os, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + mock_os.environ.get.return_value = "fake_token" + mock_os.unlink = MagicMock() + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + upload_manifest(sample_manifest) + + mock_api.upload_file.assert_called_once() + call_kwargs = mock_api.upload_file.call_args.kwargs + assert call_kwargs["path_in_repo"] == ("version_manifest.json") + assert call_kwargs["repo_id"] == ("policyengine/policyengine-us-data") + + +# -- get_current_version tests ------------------------------------- + + +class TestGetCurrentVersion: + @patch(f"{_MOD}._get_gcs_bucket") + def test_returns_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket + setup_bucket_with_registry(mock_bucket, sample_registry) + + result = get_current_version() + + assert result == "1.72.3" + mock_bucket.blob.assert_called_with("version_manifest.json") + + @patch(f"{_MOD}._get_gcs_bucket") + def test_no_registry_returns_none(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + result = get_current_version() + + assert result is None + + +# -- get_manifest tests --------------------------------------------- + + +class TestGetManifest: + @patch(f"{_MOD}._get_gcs_bucket") + def test_specific_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket + setup_bucket_with_registry(mock_bucket, sample_registry) + + result = get_manifest("1.72.3") + + assert isinstance(result, VersionManifest) + assert result.version == "1.72.3" + assert result.hf.commit == "abc123def456" + assert result.gcs.generations["enhanced_cps_2024.h5"] == 1710203948123456 + + @patch(f"{_MOD}._get_gcs_bucket") + def test_nonexistent_version( + self, + mock_get_bucket, + mock_bucket, + sample_registry, + ): + mock_get_bucket.return_value = mock_bucket + setup_bucket_with_registry(mock_bucket, sample_registry) + + with pytest.raises(ValueError, match="not found"): + get_manifest("9.9.9") + + @patch(f"{_MOD}._get_gcs_bucket") + def test_no_registry_raises(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + with pytest.raises(ValueError, match="not found"): + get_manifest("1.72.3") + + +# -- list_versions tests ------------------------------------------- + + +class TestListVersions: + @patch(f"{_MOD}._get_gcs_bucket") + def test_returns_sorted(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + v1 = VersionManifest( + version="1.72.1", + created_at="t1", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 1}), + ) + v2 = VersionManifest( + version="1.72.3", + created_at="t2", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 2}), + ) + v3 = VersionManifest( + version="1.72.2", + created_at="t3", + hf=None, + gcs=GCSVersionInfo(bucket="b", generations={"f.h5": 3}), + ) + registry = VersionRegistry(current="1.72.3", versions=[v2, v3, v1]) + setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions() + + assert result == [ + "1.72.1", + "1.72.2", + "1.72.3", + ] + + @patch(f"{_MOD}._get_gcs_bucket") + def test_empty(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + registry = VersionRegistry() + setup_bucket_with_registry(mock_bucket, registry) + + result = list_versions() + + assert result == [] + + +# -- download_versioned_file tests --------------------------------- + + +class TestDownloadVersionedFile: + @patch(f"{_MOD}._get_gcs_bucket") + def test_downloads_correct_generation( + self, + mock_get_bucket, + mock_bucket, + sample_manifest, + tmp_path, + ): + mock_get_bucket.return_value = mock_bucket + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + blob = MagicMock() + blob.download_as_text.return_value = registry_json + return blob + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + local_path = str(tmp_path / "AL.h5") + download_versioned_file( + "states/AL.h5", + "1.72.3", + local_path, + ) + + calls = mock_bucket.blob.call_args_list + gen_call = [ + c + for c in calls + if c + == call( + "states/AL.h5", + generation=1710203948345678, + ) + ] + assert len(gen_call) == 1 + + @patch(f"{_MOD}._get_gcs_bucket") + def test_file_not_in_manifest( + self, + mock_get_bucket, + mock_bucket, + sample_manifest, + tmp_path, + ): + mock_get_bucket.return_value = mock_bucket + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + setup_bucket_with_registry(mock_bucket, registry) + + with pytest.raises(ValueError, match="not found"): + download_versioned_file( + "nonexistent.h5", + "1.72.3", + str(tmp_path / "out.h5"), + ) + + +# -- rollback tests ------------------------------------------------- + + +class TestRollback: + @patch(f"{_MOD}.CommitOperationAdd") + @patch(f"{_MOD}.hf_hub_download") + @patch(f"{_MOD}.HfApi") + @patch(f"{_MOD}.os") + @patch(f"{_MOD}._get_gcs_bucket") + def test_creates_new_version_with_old_data( + self, + mock_get_bucket, + mock_os, + mock_hf_api_cls, + mock_hf_download, + mock_commit_op, + mock_bucket, + sample_manifest, + ): + mock_get_bucket.return_value = mock_bucket + mock_os.environ.get.return_value = "fake_token" + mock_os.path.join = lambda *args: "/".join(args) + mock_os.unlink = MagicMock() + + mock_api = MagicMock() + mock_hf_api_cls.return_value = mock_api + commit_info = MagicMock() + commit_info.oid = "new_commit_sha" + mock_api.create_commit.return_value = commit_info + + registry = VersionRegistry( + current="1.72.3", + versions=[sample_manifest], + ) + registry_json = json.dumps(registry.to_dict()) + written = {} + + def mock_blob(name, generation=None): + if name == "version_manifest.json": + b = MagicMock() + b.name = name + b.download_as_text.return_value = registry_json + written[name] = b + return b + blob = MagicMock() + blob.name = name + blob.generation = generation + return blob + + mock_bucket.blob.side_effect = mock_blob + + new_gen_counter = iter([50001, 50002, 50003]) + + def mock_get_blob(name): + blob = MagicMock() + blob.generation = next(new_gen_counter) + return blob + + mock_bucket.get_blob.side_effect = mock_get_blob + + result = rollback( + target_version="1.72.3", + new_version="1.73.0", + ) + + assert isinstance(result, VersionManifest) + assert result.version == "1.73.0" + assert result.special_operation == "roll-back" + assert result.roll_back_version == "1.72.3" + + assert mock_bucket.copy_blob.call_count == 3 + + blob = written["version_manifest.json"] + written_json = blob.upload_from_string.call_args[0][0] + registry_data = json.loads(written_json) + + assert registry_data["current"] == "1.73.0" + assert len(registry_data["versions"]) == 2 + assert registry_data["versions"][0]["version"] == "1.73.0" + assert registry_data["versions"][0]["special_operation"] == "roll-back" + + mock_api.create_commit.assert_called_once() + commit_msg = mock_api.create_commit.call_args.kwargs["commit_message"] + assert "1.72.3" in commit_msg + assert "1.73.0" in commit_msg + mock_api.create_tag.assert_called_once() + + @patch(f"{_MOD}._get_gcs_bucket") + def test_nonexistent_version(self, mock_get_bucket, mock_bucket): + mock_get_bucket.return_value = mock_bucket + blob = MagicMock() + blob.download_as_text.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = blob + + with pytest.raises(ValueError, match="not found"): + rollback( + target_version="9.9.9", + new_version="9.10.0", + ) + + +# -- Consumer API tests -------------------------------------------- + + +class TestGetDataManifest: + def setup_method(self): + import policyengine_us_data.utils.version_manifest as mod + + mod._cached_registry = None + + def teardown_method(self): + import policyengine_us_data.utils.version_manifest as mod + + mod._cached_registry = None + + @patch(f"{_MOD}.hf_hub_download") + def test_returns_registry(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": ("2026-03-10T14:30:00Z"), + "hf": { + "repo": ("policyengine/policyengine-us-data"), + "commit": "abc123", + }, + "gcs": { + "bucket": ("policyengine-us-data"), + "generations": {"file.h5": 12345}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + result = get_data_manifest() + + assert isinstance(result, VersionRegistry) + assert result.current == "1.72.3" + assert len(result.versions) == 1 + assert result.versions[0].hf.commit == "abc123" + mock_download.assert_called_once_with( + repo_id=("policyengine/policyengine-us-data"), + repo_type="model", + filename="version_manifest.json", + ) + + @patch(f"{_MOD}.hf_hub_download") + def test_caches_result(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": ("2026-03-10T14:30:00Z"), + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 1}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + first = get_data_manifest() + second = get_data_manifest() + + assert first is second + assert mock_download.call_count == 1 + + @patch(f"{_MOD}.hf_hub_download") + def test_get_data_version(self, mock_download, tmp_path): + registry_data = { + "current": "1.72.3", + "versions": [ + { + "version": "1.72.3", + "created_at": ("2026-03-10T14:30:00Z"), + "hf": None, + "gcs": { + "bucket": "b", + "generations": {"f.h5": 1}, + }, + }, + ], + } + registry_file = tmp_path / "version_manifest.json" + registry_file.write_text(json.dumps(registry_data)) + mock_download.return_value = str(registry_file) + + result = get_data_version() + + assert result == "1.72.3" diff --git a/policyengine_us_data/utils/huggingface.py b/policyengine_us_data/utils/huggingface.py index 9b1e48cb..c73a181a 100644 --- a/policyengine_us_data/utils/huggingface.py +++ b/policyengine_us_data/utils/huggingface.py @@ -81,11 +81,7 @@ def download_calibration_inputs( # but won't exist yet when running calibration from scratch optional_files = { "weights": f"calibration/{prefix}calibration_weights.npy", - "geography": f"calibration/{prefix}geography.npz", "run_config": (f"calibration/{prefix}unified_run_config.json"), - # Legacy artifacts (for backward compatibility) - "blocks": f"calibration/{prefix}stacked_blocks.npy", - "geo_labels": f"calibration/{prefix}geo_labels.json", } for key, hf_path in optional_files.items(): try: @@ -156,9 +152,6 @@ def download_calibration_logs( def upload_calibration_artifacts( weights_path: str = None, - blocks_path: str = None, - geo_labels_path: str = None, - geography_path: str = None, log_dir: str = None, repo: str = "policyengine/policyengine-us-data", prefix: str = "", @@ -167,9 +160,6 @@ def upload_calibration_artifacts( Args: weights_path: Path to calibration_weights.npy - blocks_path: Path to stacked_blocks.npy (legacy) - geo_labels_path: Path to geo_labels.json (legacy) - geography_path: Path to geography.npz log_dir: Directory containing log files (calibration_log.csv, unified_diagnostics.csv, unified_run_config.json) @@ -189,31 +179,6 @@ def upload_calibration_artifacts( ) ) - if geography_path and os.path.exists(geography_path): - operations.append( - CommitOperationAdd( - path_in_repo=(f"calibration/{prefix}geography.npz"), - path_or_fileobj=geography_path, - ) - ) - - # Legacy artifacts - if blocks_path and os.path.exists(blocks_path): - operations.append( - CommitOperationAdd( - path_in_repo=(f"calibration/{prefix}stacked_blocks.npy"), - path_or_fileobj=blocks_path, - ) - ) - - if geo_labels_path and os.path.exists(geo_labels_path): - operations.append( - CommitOperationAdd( - path_in_repo=(f"calibration/{prefix}geo_labels.json"), - path_or_fileobj=geo_labels_path, - ) - ) - if log_dir: # Upload run config to calibration/ root for artifact validation run_config_local = os.path.join(log_dir, f"{prefix}unified_run_config.json") diff --git a/policyengine_us_data/utils/takeup.py b/policyengine_us_data/utils/takeup.py index 5e49b20a..b8db8c90 100644 --- a/policyengine_us_data/utils/takeup.py +++ b/policyengine_us_data/utils/takeup.py @@ -22,90 +22,66 @@ "variable": "takes_up_snap_if_eligible", "entity": "spm_unit", "rate_key": "snap", + "target": "snap", }, { "variable": "takes_up_aca_if_eligible", "entity": "tax_unit", "rate_key": "aca", + "target": "aca_ptc", }, { "variable": "takes_up_dc_ptc", "entity": "tax_unit", "rate_key": "dc_ptc", + "target": "dc_property_tax_credit", }, { "variable": "takes_up_head_start_if_eligible", "entity": "person", "rate_key": "head_start", + "target": "head_start", }, { "variable": "takes_up_early_head_start_if_eligible", "entity": "person", "rate_key": "early_head_start", + "target": "early_head_start", }, { "variable": "takes_up_ssi_if_eligible", "entity": "person", "rate_key": "ssi", + "target": "ssi", }, { "variable": "would_file_taxes_voluntarily", "entity": "tax_unit", "rate_key": "voluntary_filing", + "target": None, }, { "variable": "takes_up_medicaid_if_eligible", "entity": "person", "rate_key": "medicaid", + "target": "medicaid", }, { "variable": "takes_up_tanf_if_eligible", "entity": "spm_unit", "rate_key": "tanf", + "target": "tanf", }, ] TAKEUP_AFFECTED_TARGETS: Dict[str, dict] = { - "snap": { - "takeup_var": "takes_up_snap_if_eligible", - "entity": "spm_unit", - "rate_key": "snap", - }, - "tanf": { - "takeup_var": "takes_up_tanf_if_eligible", - "entity": "spm_unit", - "rate_key": "tanf", - }, - "aca_ptc": { - "takeup_var": "takes_up_aca_if_eligible", - "entity": "tax_unit", - "rate_key": "aca", - }, - "ssi": { - "takeup_var": "takes_up_ssi_if_eligible", - "entity": "person", - "rate_key": "ssi", - }, - "medicaid": { - "takeup_var": "takes_up_medicaid_if_eligible", - "entity": "person", - "rate_key": "medicaid", - }, - "head_start": { - "takeup_var": "takes_up_head_start_if_eligible", - "entity": "person", - "rate_key": "head_start", - }, - "early_head_start": { - "takeup_var": "takes_up_early_head_start_if_eligible", - "entity": "person", - "rate_key": "early_head_start", - }, - "dc_property_tax_credit": { - "takeup_var": "takes_up_dc_ptc", - "entity": "tax_unit", - "rate_key": "dc_ptc", - }, + spec["target"]: { + "takeup_var": spec["variable"], + "entity": spec["entity"], + "rate_key": spec["rate_key"], + } + for spec in SIMPLE_TAKEUP_VARS + if spec.get("target") is not None } # FIPS -> 2-letter state code for Medicaid rate lookup @@ -182,34 +158,26 @@ def compute_block_takeup_for_entities( var_name: str, rate_or_dict, entity_blocks: np.ndarray, - entity_hh_ids: np.ndarray = None, - entity_clone_ids: np.ndarray = None, + entity_hh_ids: np.ndarray, + entity_clone_indices: np.ndarray, ) -> np.ndarray: - """Compute boolean takeup via block-level seeded draws. - - Each unique (block, household) pair gets its own seeded RNG, - producing reproducible draws regardless of how many households - share the same block across clones. + """Compute boolean takeup via clone-seeded draws. - When multiple clones share the same (block, hh_id), the draws - are generated once for a single clone's entity count and tiled - so every clone gets identical draws — matching the matrix - builder, which processes each clone independently. - - State FIPS for rate resolution is derived from the first two - characters of each block GEOID. + Each unique (hh_id, clone_idx) pair gets its own seeded RNG, + producing reproducible draws tied to the donor household and + independent across clones. The rate varies by state (derived + from the block GEOID). Args: var_name: Takeup variable name. rate_or_dict: Scalar rate or {state_code: rate} dict. - entity_blocks: Block GEOID per entity (str array). - entity_hh_ids: Household ID per entity (int array). - When provided, seeds per (block, household) for - clone-independent draws. - entity_clone_ids: Clone index per entity (int array). - When provided, draws are tiled across clones sharing - the same (block, hh_id) so each clone gets identical - takeup decisions. + entity_blocks: Block GEOID per entity (str array), + used only for state FIPS rate resolution. + entity_hh_ids: Original household ID per entity. + entity_clone_indices: Clone index per entity. For the + matrix builder (single clone), a scalar broadcast + via np.full. For the H5 builder (all clones), + a per-entity array. Returns: Boolean array of shape (n_entities,). @@ -218,35 +186,22 @@ def compute_block_takeup_for_entities( draws = np.zeros(n, dtype=np.float64) rates = np.ones(n, dtype=np.float64) + # Resolve rates from block state FIPS for block in np.unique(entity_blocks): if block == "": continue blk_mask = entity_blocks == block sf = int(str(block)[:2]) - rate = _resolve_rate(rate_or_dict, sf) - rates[blk_mask] = rate - - if entity_hh_ids is not None: - for hh_id in np.unique(entity_hh_ids[blk_mask]): - hh_mask = blk_mask & (entity_hh_ids == hh_id) - n_total = int(hh_mask.sum()) - rng = seeded_rng(var_name, salt=f"{block}:{int(hh_id)}") + rates[blk_mask] = _resolve_rate(rate_or_dict, sf) - if entity_clone_ids is not None and n_total > 1: - clone_ids = entity_clone_ids[hh_mask] - first_clone = clone_ids[0] - n_per_clone = int((clone_ids == first_clone).sum()) - if n_per_clone < n_total: - base_draws = rng.random(n_per_clone) - n_copies = n_total // n_per_clone - draws[hh_mask] = np.tile(base_draws, n_copies) - else: - draws[hh_mask] = rng.random(n_total) - else: - draws[hh_mask] = rng.random(n_total) - else: - rng = seeded_rng(var_name, salt=str(block)) - draws[blk_mask] = rng.random(int(blk_mask.sum())) + # Draw per (hh_id, clone_idx) pair + for hh_id in np.unique(entity_hh_ids): + hh_mask = entity_hh_ids == hh_id + for ci in np.unique(entity_clone_indices[hh_mask]): + ci_mask = hh_mask & (entity_clone_indices == ci) + n_ent = int(ci_mask.sum()) + rng = seeded_rng(var_name, salt=f"{int(hh_id)}:{int(ci)}") + draws[ci_mask] = rng.random(n_ent) return draws < rates @@ -255,13 +210,14 @@ def apply_block_takeup_to_arrays( hh_blocks: np.ndarray, hh_state_fips: np.ndarray, hh_ids: np.ndarray, + hh_clone_indices: np.ndarray, entity_hh_indices: Dict[str, np.ndarray], entity_counts: Dict[str, int], time_period: int, takeup_filter: List[str] = None, precomputed_rates: Optional[Dict[str, Any]] = None, ) -> Dict[str, np.ndarray]: - """Compute block-level takeup draws from raw arrays. + """Compute takeup draws from raw arrays. Works without a Microsimulation instance. For each takeup variable, maps entity-level arrays from household-level block/ @@ -271,7 +227,8 @@ def apply_block_takeup_to_arrays( Args: hh_blocks: Block GEOID per cloned household (str array). hh_state_fips: State FIPS per cloned household (int array). - hh_ids: Household ID per cloned household (int array). + hh_ids: Original household ID per cloned household. + hh_clone_indices: Clone index per cloned household. entity_hh_indices: {entity_key: array} mapping each entity instance to its household index. Keys: "person", "tax_unit", "spm_unit". @@ -304,7 +261,7 @@ def apply_block_takeup_to_arrays( ent_hh_idx = entity_hh_indices[entity] ent_blocks = hh_blocks[ent_hh_idx].astype(str) ent_hh_ids = hh_ids[ent_hh_idx] - ent_clone_ids = ent_hh_idx + ent_clone_indices = hh_clone_indices[ent_hh_idx] if precomputed_rates is not None and rate_key in precomputed_rates: rate_or_dict = precomputed_rates[rate_key] @@ -315,7 +272,7 @@ def apply_block_takeup_to_arrays( rate_or_dict, ent_blocks, ent_hh_ids, - entity_clone_ids=ent_clone_ids, + ent_clone_indices, ) result[var_name] = bools diff --git a/policyengine_us_data/utils/version_manifest.py b/policyengine_us_data/utils/version_manifest.py new file mode 100644 index 00000000..49ad8d5b --- /dev/null +++ b/policyengine_us_data/utils/version_manifest.py @@ -0,0 +1,568 @@ +""" +Version registry for semver-based dataset versioning. + +Provides typed structures and functions for versioned uploads, +downloads, and rollbacks across GCS and Hugging Face. All +versions are tracked in a single registry file +(version_manifest.json) on both backends. +""" + +import json +import logging +import os +import tempfile +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +import google.auth +from google.api_core.exceptions import NotFound +from google.cloud import storage +from huggingface_hub import ( + HfApi, + CommitOperationAdd, + hf_hub_download, +) + +# -- Configuration ------------------------------------------------- + +REGISTRY_BLOB = "version_manifest.json" +GCS_BUCKET_NAME = "policyengine-us-data" +HF_REPO_NAME = "policyengine/policyengine-us-data" +HF_REPO_TYPE = "model" + + +# -- Types --------------------------------------------------------- + + +@dataclass +class HFVersionInfo: + """Hugging Face backend location for a version.""" + + repo: str + commit: str + + def to_dict(self) -> dict[str, str]: + return {"repo": self.repo, "commit": self.commit} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "HFVersionInfo": + return cls(repo=data["repo"], commit=data["commit"]) + + +@dataclass +class GCSVersionInfo: + """GCS backend location for a version.""" + + bucket: str + generations: dict[str, int] + + def to_dict(self) -> dict[str, Any]: + return { + "bucket": self.bucket, + "generations": self.generations, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "GCSVersionInfo": + return cls( + bucket=data["bucket"], + generations=data["generations"], + ) + + +@dataclass +class VersionManifest: + """Single version entry tying semver to backend + identifiers. + + Consumers interact only with the semver version string. + HF commit SHAs and GCS generation numbers are internal + implementation details resolved by this manifest. + """ + + version: str + created_at: str + hf: Optional[HFVersionInfo] + gcs: GCSVersionInfo + special_operation: Optional[str] = None + roll_back_version: Optional[str] = None + pipeline_run_id: Optional[str] = None + diagnostics_path: Optional[str] = None + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = { + "version": self.version, + "created_at": self.created_at, + "hf": self.hf.to_dict() if self.hf else None, + "gcs": self.gcs.to_dict(), + } + if self.special_operation is not None: + result["special_operation"] = self.special_operation + if self.roll_back_version is not None: + result["roll_back_version"] = self.roll_back_version + if self.pipeline_run_id is not None: + result["pipeline_run_id"] = self.pipeline_run_id + if self.diagnostics_path is not None: + result["diagnostics_path"] = self.diagnostics_path + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VersionManifest": + hf_data = data.get("hf") + return cls( + version=data["version"], + created_at=data["created_at"], + hf=(HFVersionInfo.from_dict(hf_data) if hf_data else None), + gcs=GCSVersionInfo.from_dict(data["gcs"]), + special_operation=data.get("special_operation"), + roll_back_version=data.get("roll_back_version"), + pipeline_run_id=data.get("pipeline_run_id"), + diagnostics_path=data.get("diagnostics_path"), + ) + + +@dataclass +class VersionRegistry: + """Registry of all dataset versions. + + Contains a pointer to the current version and a list of + all version manifests (most recent first). + """ + + current: str = "" + versions: list[VersionManifest] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "current": self.current, + "versions": [v.to_dict() for v in self.versions], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VersionRegistry": + return cls( + current=data["current"], + versions=[VersionManifest.from_dict(v) for v in data["versions"]], + ) + + def get_version(self, version: str) -> VersionManifest: + """Look up a specific version entry. + + Args: + version: Semver version string. + + Returns: + The matching VersionManifest. + + Raises: + ValueError: If the version is not in the + registry. + """ + for v in self.versions: + if v.version == version: + return v + available = [v.version for v in self.versions[:10]] + raise ValueError( + f"Version '{version}' not found in registry. " + f"Available versions: {available}" + ) + + +# -- Internal helpers ---------------------------------------------- + + +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 string.""" + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + +def _get_gcs_bucket() -> storage.Bucket: + """Return an authenticated GCS bucket handle.""" + credentials, project_id = google.auth.default() + client = storage.Client(credentials=credentials, project=project_id) + return client.bucket(GCS_BUCKET_NAME) + + +def _read_registry_from_gcs( + bucket: storage.Bucket, +) -> VersionRegistry: + """Read the version registry from GCS. + + Returns an empty registry if no registry exists yet. + """ + blob = bucket.blob(REGISTRY_BLOB) + try: + content = blob.download_as_text() + except NotFound: + return VersionRegistry() + return VersionRegistry.from_dict(json.loads(content)) + + +def _upload_registry_to_gcs( + bucket: storage.Bucket, + registry: VersionRegistry, +) -> None: + """Write the version registry to GCS.""" + data = json.dumps(registry.to_dict(), indent=2) + blob = bucket.blob(REGISTRY_BLOB) + blob.upload_from_string(data, content_type="application/json") + logging.info(f"Uploaded registry to GCS (current={registry.current}).") + + +def _upload_registry_to_hf( + registry: VersionRegistry, +) -> None: + """Write the version registry to Hugging Face.""" + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + data = json.dumps(registry.to_dict(), indent=2) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write(data) + tmp_path = f.name + + try: + api.upload_file( + path_or_fileobj=tmp_path, + path_in_repo=REGISTRY_BLOB, + repo_id=HF_REPO_NAME, + repo_type=HF_REPO_TYPE, + token=token, + commit_message=(f"Update version registry (current={registry.current})"), + ) + logging.info(f"Uploaded {REGISTRY_BLOB} to HF repo {HF_REPO_NAME}.") + finally: + os.unlink(tmp_path) + + +def _restore_gcs_generations( + bucket: storage.Bucket, + old_generations: dict[str, int], +) -> dict[str, int]: + """Copy old GCS generation blobs to live paths. + + Args: + bucket: GCS bucket containing the blobs. + old_generations: Map of blob path to old generation + number. + + Returns: + Map of blob path to new generation number. + """ + new_generations: dict[str, int] = {} + for file_path, generation in old_generations.items(): + source_blob = bucket.blob(file_path, generation=generation) + bucket.copy_blob(source_blob, bucket, file_path) + restored_blob = bucket.get_blob(file_path) + new_generations[file_path] = restored_blob.generation + logging.info( + f"Restored {file_path}: generation " + f"{generation} -> {restored_blob.generation}." + ) + return new_generations + + +def _restore_hf_commit( + old_manifest: VersionManifest, + new_version: str, +) -> str: + """Re-upload old HF data as a new commit and tag it. + + Args: + old_manifest: The manifest of the version being + restored. + new_version: The new semver version string for + tagging. + + Returns: + The commit SHA of the new HF commit. + """ + token = os.environ.get("HUGGING_FACE_TOKEN") + api = HfApi() + target_version = old_manifest.version + + operations = [] + with tempfile.TemporaryDirectory() as tmpdir: + for file_path in old_manifest.gcs.generations: + hf_hub_download( + repo_id=old_manifest.hf.repo, + repo_type=HF_REPO_TYPE, + filename=file_path, + revision=old_manifest.hf.commit, + local_dir=tmpdir, + token=token, + ) + downloaded = os.path.join(tmpdir, file_path) + operations.append( + CommitOperationAdd( + path_in_repo=file_path, + path_or_fileobj=downloaded, + ) + ) + + commit_info = api.create_commit( + token=token, + repo_id=HF_REPO_NAME, + operations=operations, + repo_type=HF_REPO_TYPE, + commit_message=(f"Roll back to {target_version} as {new_version}"), + ) + + try: + api.create_tag( + token=token, + repo_id=HF_REPO_NAME, + tag=new_version, + revision=commit_info.oid, + repo_type=HF_REPO_TYPE, + ) + except Exception as e: + if "already exists" in str(e) or "409" in str(e): + logging.warning(f"Tag {new_version} already exists. Skipping tag creation.") + else: + raise + + return commit_info.oid + + +# -- Public API ---------------------------------------------------- + + +def build_manifest( + version: str, + blob_names: list[str], + hf_info: Optional[HFVersionInfo] = None, +) -> VersionManifest: + """Build a version manifest by reading generation + numbers from uploaded blobs. + + Args: + version: Semver version string. + blob_names: List of blob paths to include. + hf_info: Optional HF backend info to include. + + Returns: + A VersionManifest with generation numbers for + each blob. + """ + bucket = _get_gcs_bucket() + generations: dict[str, int] = {} + for name in blob_names: + blob = bucket.get_blob(name) + if blob is None: + raise ValueError( + f"Blob '{name}' not found in bucket '{bucket.name}' after upload." + ) + generations[name] = blob.generation + + return VersionManifest( + version=version, + created_at=_utc_now_iso(), + hf=hf_info, + gcs=GCSVersionInfo( + bucket=bucket.name, + generations=generations, + ), + ) + + +def upload_manifest( + manifest: VersionManifest, +) -> None: + """Append a version manifest to the registry and + upload to both GCS and HF. + + Reads the existing registry from GCS (or starts fresh), + prepends the new manifest, updates the current pointer, + and writes the registry to both backends. + + Args: + manifest: The version manifest to add. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + registry.versions.insert(0, manifest) + registry.current = manifest.version + _upload_registry_to_gcs(bucket, registry) + _upload_registry_to_hf(registry) + + +def get_current_version() -> Optional[str]: + """Get the current version from the registry. + + Returns: + The current semver version string, or None if no + registry exists. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + if not registry.current: + return None + return registry.current + + +def get_manifest(version: str) -> VersionManifest: + """Get the manifest for a specific version. + + Args: + version: Semver version string. + + Returns: + The deserialized VersionManifest. + + Raises: + ValueError: If the version is not in the registry. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + return registry.get_version(version) + + +def list_versions() -> list[str]: + """List all available versions. + + Returns: + Sorted list of semver version strings. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + return sorted(v.version for v in registry.versions) + + +def download_versioned_file( + file_path: str, + version: str, + local_path: str, +) -> str: + """Download a specific file at a specific version. + + Args: + file_path: Path of the file within the bucket. + version: Semver version string. + local_path: Local path to save the file to. + + Returns: + The local path where the file was saved. + + Raises: + ValueError: If the version or file is not found. + """ + bucket = _get_gcs_bucket() + registry = _read_registry_from_gcs(bucket) + manifest = registry.get_version(version) + + if file_path not in manifest.gcs.generations: + raise ValueError( + f"File '{file_path}' not found in manifest " + f"for version '{version}'. Available files: " + f"{list(manifest.gcs.generations.keys())[:10]}" + "..." + ) + + generation = manifest.gcs.generations[file_path] + blob = bucket.blob(file_path, generation=generation) + + Path(local_path).parent.mkdir(parents=True, exist_ok=True) + blob.download_to_filename(local_path) + + logging.info( + f"Downloaded {file_path} at version {version} " + f"(generation {generation}) to {local_path}." + ) + return local_path + + +def rollback( + target_version: str, + new_version: str, +) -> VersionManifest: + """Roll back by releasing a new version with old data. + + Treats rollback as a new release: data from + target_version is copied to the live paths (creating + new GCS generations), a new HF commit is created with + the old data, and a new manifest is published under + new_version with special_operation="roll-back". + + Args: + target_version: Semver version to roll back to. + new_version: New semver version to publish. + + Returns: + The new VersionManifest for the rollback release. + + Raises: + ValueError: If target_version is not in the + registry. + """ + bucket = _get_gcs_bucket() + old_manifest = _read_registry_from_gcs(bucket).get_version(target_version) + + new_gens = _restore_gcs_generations(bucket, old_manifest.gcs.generations) + hf_commit = ( + _restore_hf_commit(old_manifest, new_version) if old_manifest.hf else None + ) + + manifest = VersionManifest( + version=new_version, + created_at=_utc_now_iso(), + hf=(HFVersionInfo(repo=HF_REPO_NAME, commit=hf_commit) if hf_commit else None), + gcs=GCSVersionInfo( + bucket=GCS_BUCKET_NAME, + generations=new_gens, + ), + special_operation="roll-back", + roll_back_version=target_version, + ) + upload_manifest(manifest) + + logging.info( + f"Rolled back to {target_version} as new " + f"version {new_version}. " + f"Restored {len(new_gens)} files." + ) + return manifest + + +# -- Consumer API -------------------------------------------------- + +_cached_registry: Optional[VersionRegistry] = None + + +def get_data_manifest() -> VersionRegistry: + """Get the full version registry from HF. + + Fetches version_manifest.json from the Hugging Face + repo and returns it as a VersionRegistry. The result + is cached in memory after the first call. + + Returns: + The full VersionRegistry. + """ + global _cached_registry + if _cached_registry is not None: + return _cached_registry + + local_path = hf_hub_download( + repo_id=HF_REPO_NAME, + repo_type=HF_REPO_TYPE, + filename=REGISTRY_BLOB, + ) + with open(local_path) as f: + data = json.load(f) + + _cached_registry = VersionRegistry.from_dict(data) + return _cached_registry + + +def get_data_version() -> str: + """Get the current deployed data version string. + + Convenience wrapper around get_data_manifest(). + + Returns: + The current semver version string. + """ + return get_data_manifest().current