From a9b56e4f4ebd89108ebfdb32b33e20f3807165ce Mon Sep 17 00:00:00 2001 From: Kenneth Lippold Date: Tue, 3 Mar 2026 16:07:57 -0800 Subject: [PATCH] Revert "Revert "367 aggregation"" --- domains/etl/aggregation.py | 291 ++++++++++++++ ..._task_type_and_nullable_data_connection.py | 28 ++ domains/etl/models/task.py | 11 +- domains/etl/services/task.py | 204 ++++++++-- domains/etl/tasks.py | 366 +++++++++++++++++- interfaces/api/schemas/task.py | 9 +- ...est_aggregation_failure_messaging_cases.py | 112 ++++++ tests/etl/services/test_task.py | 166 +++++++- 8 files changed, 1140 insertions(+), 47 deletions(-) create mode 100644 domains/etl/aggregation.py create mode 100644 domains/etl/migrations/0004_task_task_type_and_nullable_data_connection.py create mode 100644 tests/etl/services/test_aggregation_failure_messaging_cases.py diff --git a/domains/etl/aggregation.py b/domains/etl/aggregation.py new file mode 100644 index 00000000..2c85b893 --- /dev/null +++ b/domains/etl/aggregation.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, time, timedelta, timezone as dt_timezone, tzinfo +import math +import re +from bisect import bisect_left +from typing import Iterable +from zoneinfo import ZoneInfo + + +AGGREGATION_STATISTICS = { + "simple_mean", + "time_weighted_daily_mean", + "last_value_of_day", +} +AGGREGATION_TIMEZONE_MODES = {"fixedOffset", "daylightSavings"} +_FIXED_OFFSET_RE = re.compile(r"^([+-])(\d{2})(\d{2})$") + + +@dataclass(frozen=True) +class AggregationTransformation: + aggregation_statistic: str + timezone_mode: str + timezone: str + + +def _first_non_empty(mapping: dict, keys: Iterable[str]) -> str | None: + for key in keys: + value = mapping.get(key) + if value is None: + continue + if isinstance(value, str): + value = value.strip() + if not value: + continue + return value + return None + + +def _parse_fixed_offset(offset: str) -> tzinfo: + match = _FIXED_OFFSET_RE.fullmatch(offset) + if not match: + raise ValueError("fixedOffset timezone must match +/-HHMM") + + sign, hours_raw, minutes_raw = match.groups() + hours = int(hours_raw) + minutes = int(minutes_raw) + if minutes >= 60: + raise ValueError("fixedOffset timezone minutes must be between 00 and 59") + + offset_delta = timedelta(hours=hours, minutes=minutes) + if sign == "-": + offset_delta = -offset_delta + + return dt_timezone(offset_delta) + + +def timezone_info_for_transformation(transform: AggregationTransformation) -> tzinfo: + if transform.timezone_mode == "fixedOffset": + return _parse_fixed_offset(transform.timezone) + + if transform.timezone_mode == "daylightSavings": + try: + return ZoneInfo(transform.timezone) + except Exception as exc: # pragma: no cover - platform-specific internals + raise ValueError( + "daylightSavings timezone must be a valid IANA timezone" + ) from exc + + raise ValueError(f"Unsupported timezoneMode: {transform.timezone_mode}") + + +def normalize_aggregation_transformation(raw: dict) -> dict: + if not isinstance(raw, dict): + raise ValueError("Aggregation transformation must be an object") + + transform_type = raw.get("type") + if transform_type != "aggregation": + raise ValueError("Aggregation transformation must set type='aggregation'") + + aggregation_statistic = _first_non_empty( + raw, ("aggregationStatistic", "aggregation_statistic") + ) + if not isinstance(aggregation_statistic, str) or aggregation_statistic not in AGGREGATION_STATISTICS: + allowed = ", ".join(sorted(AGGREGATION_STATISTICS)) + raise ValueError(f"aggregationStatistic must be one of: {allowed}") + + timezone_mode = _first_non_empty(raw, ("timezoneMode", "timezone_mode")) + if not isinstance(timezone_mode, str) or timezone_mode not in AGGREGATION_TIMEZONE_MODES: + allowed = ", ".join(sorted(AGGREGATION_TIMEZONE_MODES)) + raise ValueError(f"timezoneMode must be one of: {allowed}") + + timezone_value = _first_non_empty(raw, ("timezone",)) + if not isinstance(timezone_value, str): + raise ValueError("timezone is required for aggregation transformations") + + normalized = { + "type": "aggregation", + "aggregationStatistic": aggregation_statistic, + "timezoneMode": timezone_mode, + "timezone": timezone_value, + } + + # Validate timezone now so malformed configs fail early. + timezone_info_for_transformation( + AggregationTransformation( + aggregation_statistic=aggregation_statistic, + timezone_mode=timezone_mode, + timezone=timezone_value, + ) + ) + + return normalized + + +def parse_aggregation_transformation(raw: dict) -> AggregationTransformation: + normalized = normalize_aggregation_transformation(raw) + return AggregationTransformation( + aggregation_statistic=normalized["aggregationStatistic"], + timezone_mode=normalized["timezoneMode"], + timezone=normalized["timezone"], + ) + + +def _local_midnight(timestamp_utc: datetime, tz: tzinfo) -> datetime: + local = timestamp_utc.astimezone(tz) + return datetime.combine(local.date(), time.min, tzinfo=tz) + + +def closed_window_end_utc(source_end_utc: datetime, transform: AggregationTransformation) -> datetime: + tz = timezone_info_for_transformation(transform) + return _local_midnight(source_end_utc, tz).astimezone(dt_timezone.utc) + + +def first_window_start_utc(source_begin_utc: datetime, transform: AggregationTransformation) -> datetime: + tz = timezone_info_for_transformation(transform) + return _local_midnight(source_begin_utc, tz).astimezone(dt_timezone.utc) + + +def next_window_start_utc(destination_end_utc: datetime, transform: AggregationTransformation) -> datetime: + tz = timezone_info_for_transformation(transform) + destination_local = destination_end_utc.astimezone(tz) + next_date = destination_local.date() + timedelta(days=1) + local_midnight = datetime.combine(next_date, time.min, tzinfo=tz) + return local_midnight.astimezone(dt_timezone.utc) + + +def iter_daily_windows_utc( + start_utc: datetime, + end_utc: datetime, + transform: AggregationTransformation, +): + tz = timezone_info_for_transformation(transform) + + current_local = _local_midnight(start_utc, tz) + end_local = _local_midnight(end_utc, tz) + + while current_local < end_local: + next_local = datetime.combine( + current_local.date() + timedelta(days=1), + time.min, + tzinfo=tz, + ) + yield ( + current_local.astimezone(dt_timezone.utc), + next_local.astimezone(dt_timezone.utc), + current_local.date(), + ) + current_local = next_local + + +def _boundary_value( + target: datetime, + timestamps: list[datetime], + values: list[float], + prev_idx: int | None, + next_idx: int | None, +) -> float | None: + prev = None + nxt = None + + if prev_idx is not None and 0 <= prev_idx < len(timestamps): + prev = (timestamps[prev_idx], values[prev_idx]) + if next_idx is not None and 0 <= next_idx < len(timestamps): + nxt = (timestamps[next_idx], values[next_idx]) + + if prev and prev[0] == target: + return prev[1] + if nxt and nxt[0] == target: + return nxt[1] + + if prev and nxt: + t0, v0 = prev + t1, v1 = nxt + span = (t1 - t0).total_seconds() + if span <= 0: + return v1 + ratio = (target - t0).total_seconds() / span + return v0 + ratio * (v1 - v0) + + if prev: + return prev[1] + if nxt: + return nxt[1] + + return None + + +def aggregate_daily_window( + timestamps: list[datetime], + values: list[float], + window_start_utc: datetime, + window_end_utc: datetime, + statistic: str, +) -> float | None: + if statistic not in AGGREGATION_STATISTICS: + raise ValueError(f"Unsupported aggregationStatistic '{statistic}'") + + if not timestamps or len(timestamps) != len(values): + return None + + if window_end_utc <= window_start_utc: + return None + + left = bisect_left(timestamps, window_start_utc) + right = bisect_left(timestamps, window_end_utc) + + # No observations in this day -> skip writing this day. + if left == right: + return None + + window_values = values[left:right] + + if statistic == "simple_mean": + return sum(window_values) / len(window_values) + + if statistic == "last_value_of_day": + return window_values[-1] + + # Time-weighted daily mean using trapezoidal integration over the daily window. + start_value = _boundary_value( + target=window_start_utc, + timestamps=timestamps, + values=values, + prev_idx=(left - 1) if left > 0 else None, + next_idx=left, + ) + end_value = _boundary_value( + target=window_end_utc, + timestamps=timestamps, + values=values, + prev_idx=(right - 1) if right > 0 else None, + next_idx=right if right < len(timestamps) else None, + ) + + if start_value is None or end_value is None: + return None + + area_points: list[tuple[datetime, float]] = [(window_start_utc, start_value)] + for idx in range(left, right): + ts = timestamps[idx] + val = values[idx] + if ts == window_start_utc: + area_points[0] = (ts, val) + continue + area_points.append((ts, val)) + + if area_points[-1][0] == window_end_utc: + area_points[-1] = (window_end_utc, end_value) + else: + area_points.append((window_end_utc, end_value)) + + total_area = 0.0 + for idx in range(1, len(area_points)): + t0, v0 = area_points[idx - 1] + t1, v1 = area_points[idx] + span = (t1 - t0).total_seconds() + if span <= 0: + continue + total_area += (v0 + v1) * 0.5 * span + + duration = (window_end_utc - window_start_utc).total_seconds() + if duration <= 0: + return None + + result = total_area / duration + if math.isnan(result) or math.isinf(result): + return None + + return result diff --git a/domains/etl/migrations/0004_task_task_type_and_nullable_data_connection.py b/domains/etl/migrations/0004_task_task_type_and_nullable_data_connection.py new file mode 100644 index 00000000..b36f0111 --- /dev/null +++ b/domains/etl/migrations/0004_task_task_type_and_nullable_data_connection.py @@ -0,0 +1,28 @@ +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("etl", "0003_remove_datasource_orchestration_system_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="task", + name="task_type", + field=models.CharField(default="ETL", max_length=32), + ), + migrations.AlterField( + model_name="task", + name="data_connection", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="tasks", + to="etl.dataconnection", + ), + ), + ] diff --git a/domains/etl/models/task.py b/domains/etl/models/task.py index c9e387b2..ea173e18 100644 --- a/domains/etl/models/task.py +++ b/domains/etl/models/task.py @@ -59,10 +59,17 @@ def visible(self, principal: Union["User", "APIKey"]): class Task(models.Model, PermissionChecker): id = models.UUIDField(primary_key=True, default=uuid6.uuid7, editable=False) name = models.CharField(max_length=255) + task_type = models.CharField(max_length=32, default="ETL") workspace = models.ForeignKey( "iam.Workspace", related_name="tasks", on_delete=models.CASCADE ) - data_connection = models.ForeignKey(DataConnection, on_delete=models.CASCADE, related_name="tasks") + data_connection = models.ForeignKey( + DataConnection, + on_delete=models.CASCADE, + related_name="tasks", + null=True, + blank=True, + ) orchestration_system = models.ForeignKey( OrchestrationSystem, on_delete=models.CASCADE, related_name="tasks" ) @@ -92,7 +99,7 @@ def get_principal_permissions( self, principal: Union["User", "APIKey", None] ) -> list[Literal["edit", "delete", "view"]]: permissions = self.check_object_permissions( - principal=principal, workspace=self.data_connection.workspace, resource_type="Task" + principal=principal, workspace=self.workspace, resource_type="Task" ) return permissions diff --git a/domains/etl/services/task.py b/domains/etl/services/task.py index 784aed36..46a0951f 100644 --- a/domains/etl/services/task.py +++ b/domains/etl/services/task.py @@ -18,7 +18,8 @@ TaskMappingPath, TaskRun, ) -from domains.sta.models import ThingFileAttachment +from domains.sta.models import ThingFileAttachment, Datastream +from domains.etl.aggregation import normalize_aggregation_transformation from interfaces.api.schemas import ( TaskFields, TaskPostBody, @@ -78,6 +79,7 @@ def build_task_response(task: Task, expand: bool = True) -> dict: response = { "id": task.id, "name": task.name, + "task_type": task.task_type, "schedule": { "start_time": task.periodic_task.start_time, "paused": task.paused, @@ -118,24 +120,26 @@ def build_task_response(task: Task, expand: bool = True) -> dict: "name": task.workspace.name, "is_private": task.workspace.is_private, } - response["data_connection"] = { - "id": task.data_connection.id, - "name": task.data_connection.name, - "data_connection_type": task.data_connection.data_connection_type, - "workspace_id": task.data_connection.workspace.id, - "extractor": { - "settings_type": task.data_connection.extractor_type, - "settings": task.data_connection.extractor_settings - } if task.data_connection.extractor_type else None, - "transformer": { - "settings_type": task.data_connection.transformer_type, - "settings": task.data_connection.transformer_settings - } if task.data_connection.transformer_type else None, - "loader": { - "settings_type": task.data_connection.loader_type, - "settings": task.data_connection.loader_settings - } if task.data_connection.loader_type else None - } + response["data_connection"] = ( + { + "id": task.data_connection.id, + "name": task.data_connection.name, + "data_connection_type": task.data_connection.data_connection_type, + "workspace_id": task.data_connection.workspace_id, + "extractor": { + "settings_type": task.data_connection.extractor_type, + "settings": task.data_connection.extractor_settings, + } if task.data_connection.extractor_type else None, + "transformer": { + "settings_type": task.data_connection.transformer_type, + "settings": task.data_connection.transformer_settings, + } if task.data_connection.transformer_type else None, + "loader": { + "settings_type": task.data_connection.loader_type, + "settings": task.data_connection.loader_settings, + } if task.data_connection.loader_type else None, + } if task.data_connection else None + ) response["orchestration_system"] = { "id": task.orchestration_system.id, "name": task.orchestration_system.name, @@ -199,6 +203,7 @@ def list( for field in [ "workspace_id", + "task_type", "data_connection_id", "orchestration_system_id", "orchestration_system__type", @@ -224,6 +229,7 @@ def list( if order_by: order_by_aliases = { + "type": "task_type", "orchestrationSystemType": "orchestration_system__type", "startTime": "periodic_task__start_time", "dataConnectionType": "data_connection__data_connection_type", @@ -231,11 +237,15 @@ def list( "dataConnectionTransformerType": "data_connection__transformer_type", "dataConnectionLoaderType": "data_connection__loader_type", } + order_by_aliases.update( + {f"-{key}": f"-{value}" for key, value in order_by_aliases.items()} + ) queryset = self.apply_ordering( queryset, order_by, - [order_by_aliases.get(field, field) for field in list(get_args(TaskOrderByFields))] + list(get_args(TaskOrderByFields)), + field_aliases=order_by_aliases, ) else: queryset = queryset.order_by("id") @@ -277,15 +287,24 @@ def create( principal=principal, workspace=workspace ): raise HttpError( - 403, "You do not have permission to create this ETL task" + 403, "You do not have permission to create this task" ) - data_connection = data_connection_service.get_data_connection_for_action( - principal=principal, uid=data.data_connection_id, action="edit", raise_400=True, expand_related=True - ) + task_type = data.task_type or "ETL" - if data_connection.workspace and data_connection.workspace_id != workspace.id: - raise HttpError(400, "Task and data connection must belong to the same workspace.") + data_connection = None + if task_type == "Aggregation": + if data.data_connection_id is not None: + raise HttpError(400, "Aggregation tasks cannot define a data connection.") + else: + if data.data_connection_id is None: + raise HttpError(400, "ETL tasks require a data connection.") + data_connection = data_connection_service.get_data_connection_for_action( + principal=principal, uid=data.data_connection_id, action="edit", raise_400=True, expand_related=True + ) + + if data_connection.workspace and data_connection.workspace_id != workspace.id: + raise HttpError(400, "Task and data connection must belong to the same workspace.") orchestration_system = orchestration_system_service.get_orchestration_system_for_action( principal=principal, uid=data.orchestration_system_id, action="view", raise_400=True @@ -298,6 +317,7 @@ def create( task = Task.objects.create( pk=data.id, name=data.name, + task_type=task_type, workspace=workspace, data_connection=data_connection, orchestration_system=orchestration_system, @@ -308,10 +328,12 @@ def create( except IntegrityError: raise HttpError(409, "The operation could not be completed due to a resource conflict.") - task = self.update_scheduling(task, data.schedule.dict()) + task = self.update_scheduling(task, data.schedule.dict() if data.schedule else None) task = self.update_mapping( task, - [mapping.dict() for mapping in data.mappings] if data.mappings else None, + [mapping.dict() for mapping in data.mappings] + if data.mappings is not None + else None, ) task.save() @@ -333,13 +355,24 @@ def update( exclude_unset=True, ) - if "data_connection_id" in task_data: - data_connection = data_connection_service.get_data_connection_for_action( - principal=principal, uid=data.data_connection_id, action="edit", raise_400=True, expand_related=True - ) + next_task_type = task_data.get("task_type", task.task_type) - if data_connection.workspace_id != task.workspace_id: - raise HttpError(400, f"Task and data connection must belong to the same workspace.") + if next_task_type == "Aggregation": + if "data_connection_id" in task_data and task_data["data_connection_id"] is not None: + raise HttpError(400, "Aggregation tasks cannot define a data connection.") + task_data["data_connection_id"] = None + else: + next_data_connection_id = task_data.get("data_connection_id", task.data_connection_id) + if not next_data_connection_id: + raise HttpError(400, "ETL tasks require a data connection.") + + if "data_connection_id" in task_data and task_data["data_connection_id"] is not None: + data_connection = data_connection_service.get_data_connection_for_action( + principal=principal, uid=data.data_connection_id, action="edit", raise_400=True, expand_related=True + ) + + if data_connection.workspace_id != task.workspace_id: + raise HttpError(400, "Task and data connection must belong to the same workspace.") if "orchestration_system_id" in task_data: orchestration_system = orchestration_system_service.get_orchestration_system_for_action( @@ -349,6 +382,8 @@ def update( if orchestration_system.workspace and orchestration_system.workspace_id != task.workspace_id: raise HttpError(400, "Task and orchestration system must belong to the same workspace.") + task.task_type = next_task_type + if "schedule" in task_data: task = self.update_scheduling(task, task_data["schedule"]) @@ -465,7 +500,7 @@ def update_scheduling(task: Task, schedule_data: dict | None = None): task.periodic_task.crontab.hour = hour task.periodic_task.crontab.day_of_month = day task.periodic_task.crontab.month_of_year = month - task.periodic_task.crontab.weekday = weekday + task.periodic_task.crontab.day_of_week = weekday task.periodic_task.crontab.save() else: crontab_schedule = CrontabSchedule.objects.create( @@ -557,6 +592,93 @@ def _normalize_transformation(transformation: dict) -> dict: return normalized + @staticmethod + def _validate_aggregation_mapping_constraints( + workspace_id: uuid.UUID, + mapping_data: List[dict], + ): + if len(mapping_data) < 1: + raise HttpError( + 400, + "Aggregation tasks must include at least one mapping.", + ) + + datastream_ids: set[uuid.UUID] = set() + for mapping in mapping_data: + try: + source_identifier = str(uuid.UUID(str(mapping["source_identifier"]))) + except (KeyError, ValueError, TypeError): + raise HttpError(400, "Aggregation mappings require a valid source datastream UUID.") + + paths = mapping.get("paths", []) or [] + if len(paths) != 1: + raise HttpError( + 400, + "Aggregation mappings currently support exactly one target path per source.", + ) + + mapping["source_identifier"] = source_identifier + datastream_ids.add(uuid.UUID(source_identifier)) + + path = paths[0] + try: + target_identifier = str(uuid.UUID(str(path["target_identifier"]))) + except (KeyError, ValueError, TypeError): + raise HttpError(400, "Aggregation mappings require a valid target datastream UUID.") + + path["target_identifier"] = target_identifier + datastream_ids.add(uuid.UUID(target_identifier)) + + transformations = path.get("data_transformations", []) or [] + if not isinstance(transformations, list) or len(transformations) != 1: + raise HttpError( + 400, + "Aggregation mappings require exactly one aggregation transformation per path.", + ) + + if not isinstance(transformations[0], dict): + raise HttpError(400, "Invalid aggregation data transformation payload.") + + try: + path["data_transformations"] = [ + normalize_aggregation_transformation(transformations[0]) + ] + except ValueError as exc: + raise HttpError(400, str(exc)) from exc + + existing_datastream_ids = set( + Datastream.objects.filter( + thing__workspace_id=workspace_id, + id__in=datastream_ids, + ).values_list("id", flat=True) + ) + missing = sorted(str(uid) for uid in (datastream_ids - existing_datastream_ids)) + if missing: + raise HttpError( + 400, + "Aggregation mapping datastreams must exist in the task workspace.", + ) + + @staticmethod + def _reject_aggregation_transformations(mapping_data: List[dict]): + for mapping in mapping_data: + for path in mapping.get("paths", []) or []: + transformations = path.get("data_transformations", []) or [] + if not isinstance(transformations, list): + raise HttpError( + 400, + "Path data_transformations must be an array of transformation objects", + ) + + for transformation in transformations: + if not isinstance(transformation, dict): + raise HttpError(400, "Invalid data transformation payload") + if transformation.get("type") == "aggregation": + raise HttpError( + 400, + "Aggregation transformations are only valid when task type is Aggregation.", + ) + @staticmethod def _thing_attachment_rating_curve_references( workspace_id: uuid.UUID, @@ -650,9 +772,15 @@ def update_mapping(task: Task, mapping_data: List[dict] | None = None): if mapping_data is None: return task - TaskService._validate_rating_curve_transformation_references( - workspace_id=task.workspace_id, mapping_data=mapping_data - ) + if task.task_type == "Aggregation": + TaskService._validate_aggregation_mapping_constraints( + workspace_id=task.workspace_id, mapping_data=mapping_data + ) + else: + TaskService._reject_aggregation_transformations(mapping_data) + TaskService._validate_rating_curve_transformation_references( + workspace_id=task.workspace_id, mapping_data=mapping_data + ) task.mappings.all().delete() diff --git a/domains/etl/tasks.py b/domains/etl/tasks.py index cddbc466..4a51bbe4 100644 --- a/domains/etl/tasks.py +++ b/domains/etl/tasks.py @@ -13,6 +13,9 @@ from django.db.utils import IntegrityError from django.core.management import call_command from domains.etl.models import Task, TaskRun +from domains.sta.models import Datastream, Observation +from domains.sta.services import ObservationService +from interfaces.api.schemas.observation import ObservationBulkPostBody from .loader import HydroServerInternalLoader, LoadSummary from .etl_errors import ( EtlUserFacingError, @@ -20,6 +23,15 @@ user_facing_error_from_validation_error, ) from .run_result_normalizer import normalize_task_run_result, task_transformer_raw +from .aggregation import ( + AggregationTransformation, + aggregate_daily_window, + closed_window_end_utc, + first_window_start_utc, + iter_daily_windows_utc, + next_window_start_utc, + parse_aggregation_transformation, +) from hydroserverpy.etl.factories import extractor_factory, transformer_factory from hydroserverpy.etl.etl_configuration import ( ExtractorConfig, @@ -37,6 +49,13 @@ class TaskRunContext: task_meta: dict[str, Any] = field(default_factory=dict) +@dataclass(frozen=True) +class AggregationMapping: + source_datastream_id: UUID + target_datastream_id: UUID + transformation: AggregationTransformation + + class TaskLogFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: path = (record.pathname or "").replace("\\", "/") @@ -99,6 +118,7 @@ def as_text(self) -> str: TASK_RUN_CONTEXT: dict[str, TaskRunContext] = {} +observation_service = ObservationService() @contextmanager @@ -310,6 +330,337 @@ def _validate_component_config( raise user_facing_error_from_validation_error(component, ve, raw=raw) from ve +def _parse_datastream_uuid(raw_value: Any, field_name: str) -> UUID: + try: + return UUID(str(raw_value)) + except (TypeError, ValueError) as exc: + raise EtlUserFacingError( + f"Aggregation mapping {field_name} must be a valid datastream UUID." + ) from exc + + +def _extract_aggregation_mappings(task: Task) -> list[AggregationMapping]: + task_mappings = list(task.mappings.all()) + if len(task_mappings) < 1: + raise EtlUserFacingError( + "Aggregation tasks must include at least one mapping." + ) + + mappings: list[AggregationMapping] = [] + for task_mapping in task_mappings: + paths = list(task_mapping.paths.all()) + if len(paths) != 1: + raise EtlUserFacingError( + "Aggregation mappings must include exactly one target path per source." + ) + + path = paths[0] + source_id = _parse_datastream_uuid( + task_mapping.source_identifier, "sourceIdentifier" + ) + target_id = _parse_datastream_uuid( + path.target_identifier, "targetIdentifier" + ) + + transformations = path.data_transformations or [] + if not isinstance(transformations, list) or len(transformations) != 1: + raise EtlUserFacingError( + "Aggregation mappings must include exactly one aggregation transformation." + ) + if not isinstance(transformations[0], dict): + raise EtlUserFacingError("Invalid aggregation transformation payload.") + + try: + transformation = parse_aggregation_transformation(transformations[0]) + except ValueError as exc: + raise EtlUserFacingError(str(exc)) from exc + + mappings.append( + AggregationMapping( + source_datastream_id=source_id, + target_datastream_id=target_id, + transformation=transformation, + ) + ) + + return mappings + + +def _fetch_observation_points( + source_datastream_id: UUID, + query_start_utc: datetime, + query_end_utc: datetime, +) -> tuple[list[datetime], list[float]]: + points = list( + Observation.objects.filter( + datastream_id=source_datastream_id, + phenomenon_time__gte=query_start_utc, + phenomenon_time__lt=query_end_utc, + ) + .order_by("phenomenon_time") + .values_list("phenomenon_time", "result") + ) + + previous_point = ( + Observation.objects.filter( + datastream_id=source_datastream_id, + phenomenon_time__lt=query_start_utc, + ) + .order_by("-phenomenon_time") + .values_list("phenomenon_time", "result") + .first() + ) + if previous_point: + points.insert(0, previous_point) + + next_point = ( + Observation.objects.filter( + datastream_id=source_datastream_id, + phenomenon_time__gte=query_end_utc, + ) + .order_by("phenomenon_time") + .values_list("phenomenon_time", "result") + .first() + ) + if next_point: + points.append(next_point) + + cleaned: list[tuple[datetime, float]] = [] + for phenomenon_time, result in points: + try: + result_float = float(result) + except (TypeError, ValueError): + continue + if not pd.notna(result_float): + continue + if cleaned and cleaned[-1][0] == phenomenon_time: + cleaned[-1] = (phenomenon_time, result_float) + else: + cleaned.append((phenomenon_time, result_float)) + + timestamps = [point[0] for point in cleaned] + values = [point[1] for point in cleaned] + return timestamps, values + + +def _load_aggregated_rows(task: Task, target_datastream_id: UUID, rows: list[list[Any]]): + chunk_size = 5000 + for offset in range(0, len(rows), chunk_size): + chunk = rows[offset:offset + chunk_size] + payload = ObservationBulkPostBody( + fields=["phenomenonTime", "result"], + data=chunk, + ) + observation_service.bulk_create( + principal=task.workspace.owner, + data=payload, + datastream_id=target_datastream_id, + mode="append", + ) + + +def _run_aggregation_task(task: Task, context: TaskRunContext) -> dict[str, Any]: + context.stage = "aggregate" + + if not task.workspace.owner: + raise EtlUserFacingError("Task workspace does not have an owner account.") + + mappings = _extract_aggregation_mappings(task) + if not mappings: + return _build_task_result( + "Aggregation task has no mappings. Nothing to do.", + context, + stage=context.stage, + ) + + datastream_ids = { + mapping.source_datastream_id for mapping in mappings + } | {mapping.target_datastream_id for mapping in mappings} + datastreams = Datastream.objects.filter( + id__in=datastream_ids, + thing__workspace_id=task.workspace_id, + ).only("id", "name", "phenomenon_begin_time", "phenomenon_end_time") + datastream_map = {datastream.id: datastream for datastream in datastreams} + + loaded_rows = 0 + loaded_mappings = 0 + loaded_days = 0 + mapping_summaries: list[dict[str, Any]] = [] + + for mapping in mappings: + source = datastream_map.get(mapping.source_datastream_id) + target = datastream_map.get(mapping.target_datastream_id) + + if not source or not target: + raise EtlUserFacingError( + "Aggregation source and target datastreams must exist in the task workspace." + ) + + source_end = source.phenomenon_end_time + if not source_end: + logging.info( + "Skipping mapping source=%s target=%s: source has no observations yet.", + source.id, + target.id, + ) + mapping_summaries.append( + { + "sourceDatastreamId": str(source.id), + "targetDatastreamId": str(target.id), + "status": "skipped", + "reason": "Source datastream has no observations.", + "rowsLoaded": 0, + "daysLoaded": 0, + } + ) + continue + + closed_end = closed_window_end_utc(source_end, mapping.transformation) + destination_end = target.phenomenon_end_time + + source_begin = source.phenomenon_begin_time + if not source_begin: + logging.info( + "Skipping mapping source=%s target=%s: source has no phenomenon_begin_time.", + source.id, + target.id, + ) + mapping_summaries.append( + { + "sourceDatastreamId": str(source.id), + "targetDatastreamId": str(target.id), + "status": "skipped", + "reason": "Source datastream has no observation history.", + "rowsLoaded": 0, + "daysLoaded": 0, + } + ) + continue + + query_start = first_window_start_utc(source_begin, mapping.transformation) + if destination_end is None: + start_window = query_start + else: + start_window = next_window_start_utc(destination_end, mapping.transformation) + + if start_window >= closed_end: + logging.info( + "Skipping mapping source=%s target=%s: no new closed daily windows.", + source.id, + target.id, + ) + mapping_summaries.append( + { + "sourceDatastreamId": str(source.id), + "targetDatastreamId": str(target.id), + "status": "up_to_date", + "reason": "No new closed daily windows.", + "rowsLoaded": 0, + "daysLoaded": 0, + } + ) + continue + + timestamps, values = _fetch_observation_points( + source_datastream_id=source.id, + query_start_utc=query_start, + query_end_utc=closed_end, + ) + if not timestamps: + logging.info( + "Skipping mapping source=%s target=%s: no source observations in query range.", + source.id, + target.id, + ) + mapping_summaries.append( + { + "sourceDatastreamId": str(source.id), + "targetDatastreamId": str(target.id), + "status": "skipped", + "reason": "No source observations available for aggregation.", + "rowsLoaded": 0, + "daysLoaded": 0, + } + ) + continue + + rows: list[list[Any]] = [] + for day_start, day_end, _ in iter_daily_windows_utc( + start_window, + closed_end, + mapping.transformation, + ): + value = aggregate_daily_window( + timestamps=timestamps, + values=values, + window_start_utc=day_start, + window_end_utc=day_end, + statistic=mapping.transformation.aggregation_statistic, + ) + if value is None: + continue + rows.append([day_start, float(value)]) + + if not rows: + mapping_summaries.append( + { + "sourceDatastreamId": str(source.id), + "targetDatastreamId": str(target.id), + "status": "up_to_date", + "reason": "No complete daily windows contained source observations.", + "rowsLoaded": 0, + "daysLoaded": 0, + } + ) + continue + + _load_aggregated_rows(task=task, target_datastream_id=target.id, rows=rows) + + loaded_rows += len(rows) + loaded_days += len(rows) + loaded_mappings += 1 + + logging.info( + "Aggregated %s day(s) for mapping source=%s target=%s statistic=%s.", + len(rows), + source.id, + target.id, + mapping.transformation.aggregation_statistic, + ) + mapping_summaries.append( + { + "sourceDatastreamId": str(source.id), + "targetDatastreamId": str(target.id), + "status": "loaded", + "rowsLoaded": len(rows), + "daysLoaded": len(rows), + "statistic": mapping.transformation.aggregation_statistic, + } + ) + + if loaded_rows == 0: + result = _build_task_result( + "No new closed daily windows were available for aggregation.", + context, + stage=context.stage, + ) + else: + result = _build_task_result( + f"Aggregated {loaded_days} day(s) and loaded {loaded_rows} observation(s) across {loaded_mappings} mapping(s).", + context, + stage=context.stage, + ) + + result["aggregation"] = { + "mappingsProcessed": len(mappings), + "mappingsLoaded": loaded_mappings, + "daysLoaded": loaded_days, + "rowsLoaded": loaded_rows, + "mappings": mapping_summaries, + } + return result + + @shared_task(bind=True, expires=10, name="etl.tasks.run_etl_task") def run_etl_task(self, task_id: str): """ @@ -323,7 +674,7 @@ def run_etl_task(self, task_id: str): with capture_task_logs(context): try: task = ( - Task.objects.select_related("data_connection") + Task.objects.select_related("data_connection", "workspace") .prefetch_related("mappings", "mappings__paths") .get(pk=UUID(task_id)) ) @@ -331,11 +682,20 @@ def run_etl_task(self, task_id: str): context.task_meta = { "id": str(task.id), "name": task.name, - "data_connection_id": str(task.data_connection_id), - "data_connection_name": task.data_connection.name, + "type": task.task_type, } + if task.data_connection_id: + context.task_meta["data_connection_id"] = str(task.data_connection_id) + context.task_meta["data_connection_name"] = task.data_connection.name context.stage = "setup" + if task.task_type == "Aggregation": + logging.info("Starting aggregation task") + return _run_aggregation_task(task, context) + + if not task.data_connection: + raise EtlUserFacingError("ETL tasks require a data connection.") + extractor_raw = { "type": task.data_connection.extractor_type, **(task.data_connection.extractor_settings or {}), diff --git a/interfaces/api/schemas/task.py b/interfaces/api/schemas/task.py index 5c61e067..a28a2259 100644 --- a/interfaces/api/schemas/task.py +++ b/interfaces/api/schemas/task.py @@ -12,6 +12,7 @@ _order_by_fields = ( "name", + "type", "orchestrationSystemType", "latestRunStatus", "latestRunStartedAt", @@ -37,6 +38,7 @@ class TaskQueryParameters(CollectionQueryParameters): workspace_id: list[uuid.UUID] = Query( [], description="Filter by workspace ID." ) + task_type: list[str] = Query([], description="Filter by task type.", alias="type") data_connection_id: list[uuid.UUID] = Query([], description="Filter by data connection ID.") orchestration_system_id: list[uuid.UUID | Literal["null"]] = Query( [], description="Filter by orchestration system ID." @@ -151,6 +153,7 @@ class TaskMappingPostBody(BasePostBody, TaskMappingFields): class TaskFields(Schema): name: str + task_type: Literal["ETL", "Aggregation"] = Field("ETL", alias="type") extractor_variables: dict[str, Any] = Field(default_factory=dict) transformer_variables: dict[str, Any] = Field(default_factory=dict) loader_variables: dict[str, Any] = Field(default_factory=dict) @@ -159,7 +162,7 @@ class TaskFields(Schema): class TaskSummaryResponse(BaseGetResponse, TaskFields): id: uuid.UUID workspace_id: uuid.UUID - data_connection_id: uuid.UUID + data_connection_id: uuid.UUID | None = None orchestration_system_id: uuid.UUID schedule: TaskScheduleResponse | None = None latest_run: TaskRunResponse | None = None @@ -169,7 +172,7 @@ class TaskSummaryResponse(BaseGetResponse, TaskFields): class TaskDetailResponse(BaseGetResponse, TaskFields): id: uuid.UUID workspace: WorkspaceSummaryResponse - data_connection: DataConnectionSummaryResponse + data_connection: DataConnectionSummaryResponse | None = None orchestration_system: OrchestrationSystemSummaryResponse schedule: TaskScheduleResponse | None = None latest_run: TaskRunResponse | None = None @@ -179,7 +182,7 @@ class TaskDetailResponse(BaseGetResponse, TaskFields): class TaskPostBody(BasePostBody, TaskFields): id: Optional[uuid.UUID] = None workspace_id: uuid.UUID - data_connection_id: uuid.UUID + data_connection_id: uuid.UUID | None = None orchestration_system_id: uuid.UUID schedule: TaskSchedulePostBody | None = None mappings: list[TaskMappingPostBody] diff --git a/tests/etl/services/test_aggregation_failure_messaging_cases.py b/tests/etl/services/test_aggregation_failure_messaging_cases.py new file mode 100644 index 00000000..b1be3f86 --- /dev/null +++ b/tests/etl/services/test_aggregation_failure_messaging_cases.py @@ -0,0 +1,112 @@ +import json +import uuid +from pathlib import Path + +import pytest +from ninja.errors import HttpError + +from domains.etl.models import Task, TaskMappingPath, TaskRun +from domains.etl.services import TaskService +from interfaces.api.schemas import TaskPostBody + + +task_service = TaskService() + + +def _load_cases() -> list[dict]: + case_file = Path(__file__).resolve().parents[3] / "tmp" / "aggregation-messaging-cases.json" + if not case_file.exists(): + pytest.skip( + "Missing tmp/aggregation-messaging-cases.json. Generate cases before running this test module.", + allow_module_level=True, + ) + return json.loads(case_file.read_text()) + + +ALL_CASES = _load_cases() +CREATE_VALIDATION_CASES = [case for case in ALL_CASES if case["phase"] == "create_validation"] +RUNTIME_CASES = [case for case in ALL_CASES if case["phase"] == "run_time"] + + +def _create_task_for_case(case: dict, get_principal) -> dict: + task_data = TaskPostBody.model_validate(case["taskCreateBody"]) + return task_service.create( + principal=get_principal("owner"), + data=task_data, + ) + + +def _mutate_runtime_case(task_id: uuid.UUID, slug: str): + task = Task.objects.get(pk=task_id) + mapping = task.mappings.first() + + if slug == "17-run-no-mappings-after-create": + task.mappings.all().delete() + return + + if mapping is None: + raise AssertionError(f"Expected at least one mapping for runtime case '{slug}'.") + + first_path = mapping.paths.first() + if first_path is None: + raise AssertionError(f"Expected at least one mapping path for runtime case '{slug}'.") + + if slug == "18-run-branched-mapping-after-create": + TaskMappingPath.objects.create( + task_mapping=mapping, + target_identifier=first_path.target_identifier, + data_transformations=first_path.data_transformations, + ) + return + + if slug == "19-run-invalid-target-identifier-after-create": + first_path.target_identifier = "not-a-uuid" + first_path.save(update_fields=["target_identifier"]) + return + + if slug == "20-run-target-datastream-deleted-after-create": + # Keep this non-destructive: emulate "missing target in workspace scope" + # with a valid UUID that is not present in the workspace. + first_path.target_identifier = str(uuid.uuid4()) + first_path.save(update_fields=["target_identifier"]) + return + + raise AssertionError(f"Unknown runtime case slug: {slug}") + + +def _latest_run_payload(task_id: uuid.UUID) -> dict: + latest_run = TaskRun.objects.filter(task_id=task_id).order_by("-started_at").first() + assert latest_run is not None + return { + "id": latest_run.id, + "status": latest_run.status, + "result": latest_run.result, + } + + +@pytest.mark.parametrize("case", CREATE_VALIDATION_CASES, ids=lambda case: case["slug"]) +def test_aggregation_create_validation_messages(case, get_principal): + with pytest.raises(HttpError) as exc_info: + _create_task_for_case(case, get_principal) + + assert exc_info.value.status_code == 400 + assert exc_info.value.message == case["expected"]["message"] + + +@pytest.mark.parametrize("case", RUNTIME_CASES, ids=lambda case: case["slug"]) +def test_aggregation_runtime_failure_messages(case, get_principal, settings): + settings.CELERY_ENABLED = False + created_task = _create_task_for_case(case, get_principal) + task_id = created_task["id"] + + _mutate_runtime_case(task_id, case["slug"]) + + run_result = task_service.run( + principal=get_principal("owner"), + task_id=task_id, + ) + if not run_result.get("status") or not isinstance(run_result.get("result"), dict): + run_result = _latest_run_payload(task_id) + + assert run_result["status"] == case["expected"]["status"] + assert run_result["result"]["message"] == case["expected"]["message"] diff --git a/tests/etl/services/test_task.py b/tests/etl/services/test_task.py index d6859dd6..d87178d6 100644 --- a/tests/etl/services/test_task.py +++ b/tests/etl/services/test_task.py @@ -58,7 +58,7 @@ # Test filtering ( "owner", - {"task_type": "SDL"}, + {"task_type": "ETL"}, ["Test ETL Task"], 7, ), @@ -272,6 +272,170 @@ def test_create_task( assert TaskDetailResponse.from_orm(task_create) +def test_create_aggregation_task_without_data_connection(get_principal): + task_data = TaskPostBody( + name="New Aggregation Task", + task_type="Aggregation", + workspace_id=uuid.UUID("b27c51a0-7374-462d-8a53-d97d47176c10"), + data_connection_id=None, + orchestration_system_id=uuid.UUID("019aead4-df4e-7a08-a609-dbc96df6befe"), + schedule=TaskSchedulePostBody( + paused=True, + crontab="* * * * *", + ), + mappings=[ + TaskMappingPostBody( + source_identifier="dd1f9293-ce29-4b6a-88e6-d65110d1be65", + paths=[ + TaskMappingPathPostBody( + target_identifier="1c9a797e-6fd8-4e99-b1ae-87ab4affc0a2", + data_transformations=[ + { + "type": "aggregation", + "aggregationStatistic": "simple_mean", + "timezoneMode": "fixedOffset", + "timezone": "-0700", + } + ], + ) + ], + ) + ], + ) + + task_create = task_service.create( + principal=get_principal("owner"), + data=task_data, + ) + assert task_create["task_type"] == "Aggregation" + assert task_create["data_connection"] is None + assert TaskDetailResponse.from_orm(task_create) + + +def test_create_etl_task_requires_data_connection(get_principal): + task_data = TaskPostBody( + name="New ETL Task Without Data Connection", + task_type="ETL", + workspace_id=uuid.UUID("b27c51a0-7374-462d-8a53-d97d47176c10"), + data_connection_id=None, + orchestration_system_id=uuid.UUID("019aead4-df4e-7a08-a609-dbc96df6befe"), + schedule=TaskSchedulePostBody( + paused=True, + crontab="* * * * *", + ), + mappings=[ + TaskMappingPostBody( + source_identifier="test", + paths=[TaskMappingPathPostBody(target_identifier="test")], + ) + ], + ) + + with pytest.raises(HttpError) as exc_info: + task_service.create( + principal=get_principal("owner"), + data=task_data, + ) + assert exc_info.value.status_code == 400 + assert exc_info.value.message == "ETL tasks require a data connection." + + +def test_create_aggregation_task_supports_multiple_source_target_mappings(get_principal): + aggregation_transform = { + "type": "aggregation", + "aggregationStatistic": "simple_mean", + "timezoneMode": "fixedOffset", + "timezone": "-0700", + } + + task_data = TaskPostBody( + name="Aggregation Task With Multiple Mappings", + task_type="Aggregation", + workspace_id=uuid.UUID("b27c51a0-7374-462d-8a53-d97d47176c10"), + data_connection_id=None, + orchestration_system_id=uuid.UUID("019aead4-df4e-7a08-a609-dbc96df6befe"), + schedule=TaskSchedulePostBody( + paused=True, + crontab="* * * * *", + ), + mappings=[ + TaskMappingPostBody( + source_identifier="dd1f9293-ce29-4b6a-88e6-d65110d1be65", + paths=[ + TaskMappingPathPostBody( + target_identifier="1c9a797e-6fd8-4e99-b1ae-87ab4affc0a2", + data_transformations=[aggregation_transform], + ) + ], + ), + TaskMappingPostBody( + source_identifier="42e08eea-27bb-4ea3-8ced-63acff0f3334", + paths=[ + TaskMappingPathPostBody( + target_identifier="9f96957b-ee20-4c7b-bf2b-673a0cda3a04", + data_transformations=[aggregation_transform], + ) + ], + ), + ], + ) + + task_create = task_service.create( + principal=get_principal("owner"), + data=task_data, + ) + assert task_create["task_type"] == "Aggregation" + assert len(task_create["mappings"]) == 2 + assert TaskDetailResponse.from_orm(task_create) + + +def test_create_aggregation_task_rejects_multiple_paths_per_mapping(get_principal): + aggregation_transform = { + "type": "aggregation", + "aggregationStatistic": "simple_mean", + "timezoneMode": "fixedOffset", + "timezone": "-0700", + } + + task_data = TaskPostBody( + name="Aggregation Task With Branched Mapping", + task_type="Aggregation", + workspace_id=uuid.UUID("b27c51a0-7374-462d-8a53-d97d47176c10"), + data_connection_id=None, + orchestration_system_id=uuid.UUID("019aead4-df4e-7a08-a609-dbc96df6befe"), + schedule=TaskSchedulePostBody( + paused=True, + crontab="* * * * *", + ), + mappings=[ + TaskMappingPostBody( + source_identifier="dd1f9293-ce29-4b6a-88e6-d65110d1be65", + paths=[ + TaskMappingPathPostBody( + target_identifier="1c9a797e-6fd8-4e99-b1ae-87ab4affc0a2", + data_transformations=[aggregation_transform], + ), + TaskMappingPathPostBody( + target_identifier="9f96957b-ee20-4c7b-bf2b-673a0cda3a04", + data_transformations=[aggregation_transform], + ), + ], + ), + ], + ) + + with pytest.raises(HttpError) as exc_info: + task_service.create( + principal=get_principal("owner"), + data=task_data, + ) + assert exc_info.value.status_code == 400 + assert ( + exc_info.value.message + == "Aggregation mappings currently support exactly one target path per source." + ) + + @pytest.mark.parametrize( "principal, task, message, error_code", [