|
6 | 6 | import warnings |
7 | 7 | from collections import defaultdict |
8 | 8 | from collections.abc import Mapping |
| 9 | +from datetime import datetime |
| 10 | +from functools import total_ordering |
9 | 11 | from typing import Literal, Protocol, runtime_checkable |
10 | 12 |
|
11 | 13 | if sys.version_info < (3, 11): |
@@ -273,6 +275,123 @@ def from_defaults( |
273 | 275 | return cls(supplementary_facets, **kwargs[source_type]) |
274 | 276 |
|
275 | 277 |
|
| 278 | +@frozen |
| 279 | +@total_ordering |
| 280 | +class PartialDateTime: |
| 281 | + """ |
| 282 | + A partial datetime object that can be used to compare datetimes. |
| 283 | +
|
| 284 | + Only the specified fields are used for comparison. |
| 285 | + """ |
| 286 | + |
| 287 | + year: int | None = None |
| 288 | + month: int | None = None |
| 289 | + day: int | None = None |
| 290 | + hour: int | None = None |
| 291 | + minute: int | None = None |
| 292 | + second: int | None = None |
| 293 | + |
| 294 | + @property |
| 295 | + def _attrs(self) -> dict[str, int]: |
| 296 | + """The attributes that are set.""" |
| 297 | + return { |
| 298 | + a: v |
| 299 | + for a in self.__slots__ # type: ignore[attr-defined] |
| 300 | + if not a.startswith("_") and (v := getattr(self, a)) is not None |
| 301 | + } |
| 302 | + |
| 303 | + def __repr__(self) -> str: |
| 304 | + return f"{self.__class__.__name__}({', '.join(f'{a}={v}' for a, v in self._attrs.items())})" |
| 305 | + |
| 306 | + def __eq__(self, other: object) -> bool: |
| 307 | + if not isinstance(other, datetime): |
| 308 | + msg = ( |
| 309 | + f"Can only compare PartialDateTime with `datetime.datetime` " |
| 310 | + f"objects, got object {other} of type {type(other)}" |
| 311 | + ) |
| 312 | + raise TypeError(msg) |
| 313 | + |
| 314 | + for attr, value in self._attrs.items(): |
| 315 | + other_value = getattr(other, attr) |
| 316 | + if value != other_value: |
| 317 | + return False |
| 318 | + return True |
| 319 | + |
| 320 | + def __lt__(self, other: object) -> bool: |
| 321 | + if not isinstance(other, datetime): |
| 322 | + msg = ( |
| 323 | + f"Can only compare PartialDateTime with `datetime.datetime` " |
| 324 | + f"objects, got object {other} of type {type(other)}" |
| 325 | + ) |
| 326 | + raise TypeError(msg) |
| 327 | + |
| 328 | + for attr, value in self._attrs.items(): |
| 329 | + other_value = getattr(other, attr) |
| 330 | + if value != other_value: |
| 331 | + return value < other_value # type: ignore[no-any-return] |
| 332 | + return False |
| 333 | + |
| 334 | + |
| 335 | +@frozen |
| 336 | +class RequireTimerange: |
| 337 | + """ |
| 338 | + A constraint that requires datasets to have a specific timerange. |
| 339 | +
|
| 340 | + Specify the start and/or end of the required timerange using a precision |
| 341 | + that matches the frequency of the datasets. |
| 342 | +
|
| 343 | + For example, to ensure that datasets at monthly frequency cover the period |
| 344 | + from 2000 to 2010, use start=PartialDateTime(year=2000, month=1) and |
| 345 | + end=PartialDateTime(year=2010, month=12). |
| 346 | + """ |
| 347 | + |
| 348 | + group_by: tuple[str, ...] |
| 349 | + """ |
| 350 | + The fields to group the datasets by. Each group must cover the timerange |
| 351 | + to fulfill the constraint. |
| 352 | + """ |
| 353 | + |
| 354 | + start: PartialDateTime | None = None |
| 355 | + """ |
| 356 | + The start time of the required timerange. If None, no start time is required. |
| 357 | + """ |
| 358 | + |
| 359 | + end: PartialDateTime | None = None |
| 360 | + """ |
| 361 | + The end time of the required timerange. If None, no end time is required. |
| 362 | + """ |
| 363 | + |
| 364 | + def validate(self, group: pd.DataFrame) -> bool: |
| 365 | + """ |
| 366 | + Check that all subgroups of the group have a contiguous timerange. |
| 367 | + """ |
| 368 | + group = group.dropna(subset=["start_time", "end_time"]) |
| 369 | + for _, subgroup in group.groupby(list(self.group_by)): |
| 370 | + start = subgroup["start_time"].min() |
| 371 | + end = subgroup["end_time"].max() |
| 372 | + result = True |
| 373 | + if self.start is not None and start > self.start: |
| 374 | + logger.debug( |
| 375 | + f"Constraint {self.__class__.__name__} not satisfied " |
| 376 | + f"because start time {start} is after required start time " |
| 377 | + f"{self.start} for {', '.join(subgroup['path'])}" |
| 378 | + ) |
| 379 | + result = False |
| 380 | + if self.end is not None and end < self.end: |
| 381 | + logger.debug( |
| 382 | + f"Constraint {self.__class__.__name__} not satisfied " |
| 383 | + f"because end time {end} is before required end time " |
| 384 | + f"{self.end} for {', '.join(subgroup['path'])}" |
| 385 | + ) |
| 386 | + result = False |
| 387 | + if result: |
| 388 | + result = RequireContiguousTimerange(group_by=self.group_by).validate(subgroup) |
| 389 | + if not result: |
| 390 | + return False |
| 391 | + |
| 392 | + return True |
| 393 | + |
| 394 | + |
276 | 395 | @frozen |
277 | 396 | class RequireContiguousTimerange: |
278 | 397 | """ |
|
0 commit comments