diff --git a/src/dve/core_engine/backends/base/rules.py b/src/dve/core_engine/backends/base/rules.py
index b66b3ae..9b6b4fe 100644
--- a/src/dve/core_engine/backends/base/rules.py
+++ b/src/dve/core_engine/backends/base/rules.py
@@ -681,3 +681,13 @@ def read_parquet(self, path: URI, **kwargs) -> EntityType:
def write_parquet(self, entity: EntityType, target_location: URI, **kwargs) -> URI:
"""Method to write parquet files"""
raise NotImplementedError()
+
+ def filter_data_contract_record_rejections(
+ self,
+ working_directory: URI,
+ entity: EntityType,
+ entity_name: EntityName,
+ **kwargs,
+ ):
+ """Method to filter out record rejection errors from the data contract for a given entity"""
+ raise NotImplementedError()
diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py
index 627822b..786ef8f 100644
--- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py
+++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py
@@ -18,9 +18,10 @@
from pydantic import BaseModel
from typing_extensions import Annotated, get_args, get_origin, get_type_hints
+from dve.common.error_utils import get_feedback_errors_uri
from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
-from dve.core_engine.type_hints import URI
+from dve.core_engine.type_hints import URI, EntityName
from dve.parser.file_handling.service import LocalFilesystemImplementation, _get_implementation
@@ -100,7 +101,7 @@ def table_exists(connection: DuckDBPyConnection, table_name: str) -> bool:
def relation_is_empty(relation: DuckDBPyRelation) -> bool:
"""Check if a duckdb relation is empty"""
- if relation.limit(1).count("*"):
+ if relation.limit(1).shape[0] > 0:
return False
return True
@@ -256,6 +257,48 @@ def duckdb_write_parquet(cls):
return cls
+def _ddb_filter_contract_errors(
+ self,
+ working_directory: URI,
+ entity: DuckDBPyRelation,
+ entity_name: EntityName,
+) -> DuckDBPyRelation:
+ contract_error_location = get_feedback_errors_uri(working_directory, "data_contract")
+ if not Path(contract_error_location).exists():
+ return entity
+ relevant_record_rejection_codes_rel = (
+ self._connection.read_json(
+ contract_error_location,
+ columns={
+ "RecordIndex": "INTEGER",
+ "FailureType": "STRING",
+ "Status": "STRING",
+ "Entity": "STRING",
+ },
+ )
+ .filter(f"FailureType == 'record' AND Status != 'informational' AND Entity = '{entity_name}'") # pylint: disable=C0301
+ .select("RecordIndex")
+ .distinct()
+ .order("RecordIndex asc")
+ )
+
+ if relation_is_empty(relevant_record_rejection_codes_rel):
+ return entity
+
+ filtered_entity = entity.join(
+ relevant_record_rejection_codes_rel,
+ condition="__record_index__ == RecordIndex",
+ how="anti"
+ )
+ return filtered_entity
+
+
+def ddb_filter_contract_errors(cls):
+ """Class decorator to filter out records that failed casting and have record rejection scope"""
+ cls.filter_data_contract_record_rejections = _ddb_filter_contract_errors
+ return cls
+
+
@staticmethod # type: ignore
def _duckdb_get_entity_count(entity: DuckDBPyRelation) -> int:
"""Method to obtain entity count from a persisted parquet entity"""
diff --git a/src/dve/core_engine/backends/implementations/duckdb/rules.py b/src/dve/core_engine/backends/implementations/duckdb/rules.py
index debb8fe..dc73dad 100644
--- a/src/dve/core_engine/backends/implementations/duckdb/rules.py
+++ b/src/dve/core_engine/backends/implementations/duckdb/rules.py
@@ -22,6 +22,7 @@
from dve.core_engine.backends.exceptions import ConstraintError
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
DDBStruct,
+ ddb_filter_contract_errors,
duckdb_read_parquet,
duckdb_record_index,
duckdb_rel_to_dictionaries,
@@ -61,6 +62,7 @@
@duckdb_record_index
@duckdb_write_parquet
@duckdb_read_parquet
+@ddb_filter_contract_errors
class DuckDBStepImplementations(BaseStepImplementations[DuckDBPyRelation]):
"""An implementation of transformation steps in duckdb."""
diff --git a/src/dve/core_engine/backends/implementations/spark/rules.py b/src/dve/core_engine/backends/implementations/spark/rules.py
index 307e71a..66564ee 100644
--- a/src/dve/core_engine/backends/implementations/spark/rules.py
+++ b/src/dve/core_engine/backends/implementations/spark/rules.py
@@ -17,6 +17,7 @@
spark_read_parquet,
spark_record_index,
spark_write_parquet,
+ spark_filter_contract_errors,
)
from dve.core_engine.backends.implementations.spark.types import (
Joined,
@@ -53,6 +54,7 @@
@spark_record_index
@spark_write_parquet
@spark_read_parquet
+@spark_filter_contract_errors
class SparkStepImplementations(BaseStepImplementations[DataFrame]):
"""An implementation of transformation steps in Apache Spark."""
diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py
index ced985a..2c2fde4 100644
--- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py
+++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py
@@ -12,6 +12,7 @@
from dataclasses import dataclass, is_dataclass
from decimal import Decimal
from functools import wraps
+from pathlib import Path
from typing import Any, ClassVar, Optional, TypeVar, Union, overload
from delta.exceptions import ConcurrentAppendException, DeltaConcurrentModificationException
@@ -26,8 +27,9 @@
from typing_extensions import Annotated, Protocol, TypedDict, get_args, get_origin, get_type_hints
from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type
+from dve.common.error_utils import get_feedback_errors_uri
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
-from dve.core_engine.type_hints import URI
+from dve.core_engine.type_hints import URI, EntityName
# It would be really nice if there was a more parameterisable
# way of doing this.
@@ -365,6 +367,51 @@ def spark_write_parquet(cls):
return cls
+def _spark_filter_contract_errors(
+ self,
+ working_directory: URI,
+ entity: DataFrame,
+ entity_name: EntityName,
+) -> DataFrame:
+ contract_error_location = get_feedback_errors_uri(working_directory, "data_contract")
+ if not Path(contract_error_location).exists():
+ return entity
+
+ relevant_record_rejections_codes_df = (
+ self.spark_session.read.json(
+ path=contract_error_location,
+ schema=st.StructType([
+ st.StructField("RecordIndex", st.IntegerType()),
+ st.StructField("FailureType", st.StringType()),
+ st.StructField("Status", st.StringType()),
+ st.StructField("Entity", st.StringType()),
+ ]),
+ )
+ .filter(
+ (sf.col("FailureType") == sf.lit("record"))
+ & (sf.col("Status") != sf.lit("informational"))
+ & (sf.col("Entity") == sf.lit(entity_name))
+ )
+ .distinct()
+ .orderBy(sf.asc(sf.col("RecordIndex")))
+ # todo - ^^ possibly relook at join strat. Does this help? Over prescriptive?
+ )
+ if df_is_empty(relevant_record_rejections_codes_df):
+ return entity
+ filtered_entity = entity.join(
+ relevant_record_rejections_codes_df,
+ on=entity.__record_index__ == relevant_record_rejections_codes_df.RecordIndex,
+ how="anti",
+ )
+ return filtered_entity
+
+
+def spark_filter_contract_errors(cls):
+ """Class decorator to filter out records that failed casting and have record rejection scope"""
+ cls.filter_data_contract_record_rejections = _spark_filter_contract_errors
+ return cls
+
+
@staticmethod # type: ignore
def _spark_get_entity_count(entity: DataFrame) -> int:
"""Method to obtain entity count from a persisted parquet entity"""
diff --git a/src/dve/core_engine/models.py b/src/dve/core_engine/models.py
index 09fcbb3..f29889a 100644
--- a/src/dve/core_engine/models.py
+++ b/src/dve/core_engine/models.py
@@ -105,6 +105,8 @@ class SubmissionStatisticsRecord(AuditRecord):
record_count: Optional[int]
"""Count of records in the submitted file"""
+ number_submission_rejections: Optional[int]
+ """Number of submission rejections raised following validation"""
number_record_rejections: Optional[int]
"""Number of record rejections raised following validation"""
number_warnings: Optional[int]
diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py
index 00a0c51..1c32e87 100644
--- a/src/dve/pipeline/pipeline.py
+++ b/src/dve/pipeline/pipeline.py
@@ -379,9 +379,6 @@ def file_transformation_step(
failed.append((submission_info, submission_status))
else:
success.append((submission_info, submission_status))
- except AttributeError as exc:
- self._logger.error(f"File transformation raised exception: {exc}")
- raise exc
except PERMISSIBLE_EXCEPTIONS as exc:
self._logger.warning(
f"File transformation raised exception: {exc}. Will be retried later."
@@ -509,9 +506,6 @@ def data_contract_step(
submission_info: SubmissionInfo
submission_status: SubmissionStatus
submission_info, submission_status = future.result()
- except AttributeError as exc:
- self._logger.error(f"Data Contract raised exception: {exc}")
- raise exc
except PERMISSIBLE_EXCEPTIONS as exc:
self._logger.warning(
f"Data Contract raised exception: {exc}. Will be retried later."
@@ -616,8 +610,19 @@ def apply_business_rules( # pylint: disable=R0914
submission_status.processing_failed = True
for entity_name, entity in entity_manager.entities.items():
+ # Note BI filtering done within the apply_rules
+ self._logger.info(f"applying data contract filter to {entity_name}.")
+ if not entity_name.startswith("Original"):
+ filtered_entity = self._step_implementations.filter_data_contract_record_rejections(
+ working_directory,
+ entity,
+ entity_name,
+ )
+ else:
+ self._logger.info(f"Skipping {entity_name}. Marked original.")
+ filtered_entity = entity
projected = self._step_implementations.write_parquet( # type: ignore
- entity,
+ filtered_entity,
fh.joinuri(
self.processed_files_path,
submission_info.submission_id,
@@ -629,6 +634,7 @@ def apply_business_rules( # pylint: disable=R0914
projected
)
+ # todo - add to submission_status around records that have passed record validations/rejected
submission_status.number_of_records = self.get_entity_count(
entity=entity_manager.entities[
f"""Original{rules.global_variables.get(
@@ -682,9 +688,6 @@ def business_rule_step(
unsucessful_files.append((submission_info, submission_status)) # type: ignore
else:
successful_files.append((submission_info, submission_status)) # type: ignore
- except AttributeError as exc:
- self._logger.error(f"Business Rules raised exception: {exc}")
- raise exc
except PERMISSIBLE_EXCEPTIONS as exc:
self._logger.warning(
f"Business Rules raised exception: {exc}. Will be retried later."
@@ -758,10 +761,12 @@ def _get_error_dataframes(self, submission_id: str):
df = pl.DataFrame(errors, schema={key: pl.Utf8() for key in errors[0]}) # type: ignore
df = df.with_columns(
- pl.when(pl.col("Status") == pl.lit("error")) # type: ignore
+ pl.when(pl.col("Status") == pl.lit("informational"))
+ .then(pl.lit("Warning"))
+ .when(pl.col("FailureType") == pl.lit("submission")) # type: ignore
.then(pl.lit("Submission Failure")) # type: ignore
- .otherwise(pl.lit("Warning")) # type: ignore
- .alias("error_type")
+ .otherwise(pl.lit("Record Rejection")) # type: ignore
+ .alias("error_type") # type: ignore
)
df = df.select(
pl.col("Entity").alias("Table"), # type: ignore
@@ -823,7 +828,8 @@ def error_report(
sub_stats = SubmissionStatisticsRecord(
submission_id=submission_info.submission_id,
record_count=submission_status.number_of_records,
- number_record_rejections=err_types.get("Submission Failure", 0),
+ number_submission_rejections=err_types.get("Submission Failure", 0),
+ number_record_rejections=err_types.get("Record Rejection", 0),
number_warnings=err_types.get("Warning", 0),
)
@@ -835,7 +841,7 @@ def error_report(
summary_items = er.SummaryItems(
submission_status=submission_status,
summary_dict=summary_dict,
- row_headings=["Submission Failure", "Warning"],
+ row_headings=["Submission Failure", "Record Rejection", "Warning"],
)
workbook = er.ExcelFormat(
@@ -894,9 +900,6 @@ def error_report_step(
try:
submission_info, submission_status, submission_stats, feedback_uri = future.result()
reports.append((submission_info, submission_status, submission_stats, feedback_uri))
- except AttributeError as exc:
- self._logger.error(f"Error reports raised exception: {exc}")
- raise exc
except PERMISSIBLE_EXCEPTIONS as exc:
self._logger.warning(
f"Error reports raised exception: {exc}. Will be retried later."
diff --git a/src/dve/reporting/excel_report.py b/src/dve/reporting/excel_report.py
index 82aa510..9471c83 100644
--- a/src/dve/reporting/excel_report.py
+++ b/src/dve/reporting/excel_report.py
@@ -141,6 +141,11 @@ def _add_submission_info(self, status: str, summary: Worksheet):
for key, value in self.summary_dict.items():
summary.append(["", _key_renames.get(key, key), str(value)])
+ summary.append([
+ "",
+ "Total Number of Records Processed",
+ self.submission_status.number_of_records if self.submission_status.number_of_records else 0 # pylint: disable=C0301
+ ])
summary.append(["", ""])
diff --git a/tests/features/animals.feature b/tests/features/animals.feature
new file mode 100644
index 0000000..d68ddbf
--- /dev/null
+++ b/tests/features/animals.feature
@@ -0,0 +1,59 @@
+Feature: Pipeline tests using the animal dataset
+ Test record rejection and ensuring that records are correctly removed from the entity and that
+ the correct validation feedback is raised in the error report.
+
+ Scenario: Validate XML data with just record level rejections (duckdb)
+ Given I submit the animals file animals.xml for processing
+ And A duckdb pipeline is configured with schema file 'animals.dischema.json'
+ And I add initial audit entries for the submission
+ Then the latest audit record for the submission is marked with processing status file_transformation
+ When I run the file transformation phase
+ Then the animals entity is stored as a parquet after the file_transformation phase
+ And the latest audit record for the submission is marked with processing status data_contract
+ When I run the data contract phase
+ Then there are no record rejections from the data_contract phase
+ And the animals entity is stored as a parquet after the data_contract phase
+ And the latest audit record for the submission is marked with processing status business_rules
+ When I run the business rules phase
+ Then there are errors with the following details and associated error_count from the business_rules phase
+ | ErrorType | ErrorCode | error_count |
+ | record | ANE01 | 2 |
+ And The rules restrict "animals" to 3 qualifying records
+ When I run the error report phase
+ Then An error report is produced
+ And The statistics entry for the submission shows the following information
+ | parameter | value |
+ | record_count | 5 |
+ | number_record_rejections | 2 |
+ | number_warnings | 0 |
+
+ Scenario: Validate XML data with a mixture of error types in (duckdb)
+ Given I submit the animals file animals_mixture.xml for processing
+ And A duckdb pipeline is configured with schema file 'animals.dischema.json'
+ And I add initial audit entries for the submission
+ Then the latest audit record for the submission is marked with processing status file_transformation
+ When I run the file transformation phase
+ Then the animals entity is stored as a parquet after the file_transformation phase
+ And the latest audit record for the submission is marked with processing status data_contract
+ When I run the data contract phase
+ Then there are no record rejections from the data_contract phase
+ # Then there are errors with the following details and associated error_count from the data_contract phase
+ # | FailureType | Status | ErrorCode | error_count |
+ # | record | error | FieldBlank | 1 |
+ And the animals entity is stored as a parquet after the data_contract phase
+ And the latest audit record for the submission is marked with processing status business_rules
+ When I run the business rules phase
+ Then there are errors with the following details and associated error_count from the business_rules phase
+ | FailureType | Status | ErrorCode | error_count |
+ | record | error | ANE01 | 2 |
+ | submission | error | ANE02 | 1 |
+ | record | informational | ANE03 | 1 |
+ And The rules restrict "animals" to 5 qualifying records
+ When I run the error report phase
+ Then An error report is produced
+ And The statistics entry for the submission shows the following information
+ | parameter | value |
+ | record_count | 7 |
+ | number_submission_rejections | 1 |
+ | number_record_rejections | 2 |
+ | number_warnings | 1 |
diff --git a/tests/features/demographics.feature b/tests/features/demographics.feature
index aa59bfc..af4b62a 100644
--- a/tests/features/demographics.feature
+++ b/tests/features/demographics.feature
@@ -17,7 +17,7 @@ Feature: Pipeline tests using the ambsys dataset
And the demographics entity is stored as a parquet after the data_contract phase
And the latest audit record for the submission is marked with processing status business_rules
When I run the business rules phase
- Then The rules restrict "demographics" to 6 qualifying records
+ Then The rules restrict "demographics" to 2 qualifying records
And At least one row from "demographics" has generated error code "BAD_NHS"
And the demographics entity is stored as a parquet after the business_rules phase
And The entity "demographics" does not contain an entry for "FALSE" in column "NHS_Number_Valid"
@@ -43,7 +43,7 @@ Feature: Pipeline tests using the ambsys dataset
And the demographics entity is stored as a parquet after the data_contract phase
And the latest audit record for the submission is marked with processing status business_rules
When I run the business rules phase
- Then The rules restrict "demographics" to 6 qualifying records
+ Then The rules restrict "demographics" to 2 qualifying records
And At least one row from "demographics" has generated error code "BAD_NHS"
And the demographics entity is stored as a parquet after the business_rules phase
And The entity "demographics" does not contain an entry for "FALSE" in column "NHS_Number_Valid"
diff --git a/tests/features/movies.feature b/tests/features/movies.feature
index fa041ea..6916a4e 100644
--- a/tests/features/movies.feature
+++ b/tests/features/movies.feature
@@ -28,7 +28,7 @@ Feature: Pipeline tests using the movies dataset
And the movies entity is stored as a parquet after the data_contract phase
And the latest audit record for the submission is marked with processing status business_rules
When I run the business rules phase
- Then The rules restrict "movies" to 4 qualifying records
+ Then The rules restrict "movies" to 2 qualifying records
And there are errors with the following details and associated error_count from the business_rules phase
| ErrorCode | ErrorMessage | RecordIndex | error_count |
| LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 4 | 1 |
@@ -64,7 +64,7 @@ Feature: Pipeline tests using the movies dataset
And the movies entity is stored as a parquet after the data_contract phase
And the latest audit record for the submission is marked with processing status business_rules
When I run the business rules phase
- Then The rules restrict "movies" to 4 qualifying records
+ Then The rules restrict "movies" to 2 qualifying records
And there are errors with the following details and associated error_count from the business_rules phase
| ErrorCode | ErrorMessage | RecordIndex | error_count |
| LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 4 | 1 |
diff --git a/tests/features/steps/steps_pipeline.py b/tests/features/steps/steps_pipeline.py
index 55acadd..061bfe1 100644
--- a/tests/features/steps/steps_pipeline.py
+++ b/tests/features/steps/steps_pipeline.py
@@ -50,14 +50,14 @@ def setup_spark_pipeline(
rules_path = get_test_file_path(f"{dataset_id}/{schema_file_name}").resolve().as_uri()
return SparkDVEPipeline(
- processed_files_path=processing_path.as_uri(),
+ processed_files_path=processing_path.as_posix(),
audit_tables=SparkAuditingManager(
database="dve",
spark=spark,
),
job_run_id=12345,
rules_path=rules_path,
- submitted_files_path=processing_path.as_uri(),
+ submitted_files_path=processing_path.as_posix(),
spark=spark,
)
diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py
index 19e96e2..4a24960 100644
--- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py
+++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py
@@ -1,10 +1,15 @@
"""Test Duck DB helpers"""
+# pylint: disable=C0301,C0116
+
import datetime
+import json
+import os
import tempfile
from pathlib import Path
from typing import Any, List
+import polars as pl
import pytest
import pyspark.sql.types as pst
from duckdb import DuckDBPyRelation, DuckDBPyConnection
@@ -12,10 +17,13 @@
from pyspark.sql import Row, SparkSession
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
+ _ddb_filter_contract_errors,
_ddb_read_parquet,
duckdb_rel_to_dictionaries,
get_duckdb_cast_statement_from_annotation,
- get_duckdb_type_from_annotation)
+ get_duckdb_type_from_annotation,
+ relation_is_empty,
+)
@pytest.fixture
def casting_test_table(temp_ddb_conn):
@@ -51,8 +59,60 @@ def casting_test_table(temp_ddb_conn):
yield temp_ddb_conn
conn.sql("DROP TABLE IF EXISTS test_casting")
-
-
+
+
+@pytest.fixture
+def example_data_contract_error_codes(temp_ddb_conn):
+ _, con = temp_ddb_conn
+
+ test_df = pl.DataFrame([ # pylint: disable=W0612
+ {"id": "field1", "attr": 1, "__record_index__": 1,},
+ {"id": "field2", "attr": None, "__record_index__": 2,},
+ {"id": "field3", "attr": 2, "__record_index__": 3,},
+ {"id": "field4", "attr": None, "__record_index__": 4,},
+ ])
+ test_entity = con.sql("SELECT * FROM test_df")
+ error_contract_messages = [
+ {
+ "Entity": "test_entity",
+ "Key": "",
+ "FailureType": "record",
+ "Status": "error",
+ "ErrorType": "",
+ "ErrorLocation": "attr",
+ "ErrorMessage": "",
+ "ErrorCode": "",
+ "ReportingField": "attr",
+ "RecordIndex": 2,
+ "Value": "hello",
+ "Category": "Bad value"
+ },
+ {
+ "Entity": "test_entity",
+ "Key": "",
+ "FailureType": "record",
+ "Status": "error",
+ "ErrorType": "",
+ "ErrorLocation": "attr",
+ "ErrorMessage": "",
+ "ErrorCode": "",
+ "ReportingField": "attr",
+ "RecordIndex": 4,
+ "Value": "world",
+ "Category": "Bad value"
+ }
+ ]
+ with tempfile.TemporaryDirectory() as temp_dir_path:
+ os.mkdir(Path(temp_dir_path, "errors"))
+ temp_error_file = Path(temp_dir_path, "errors", "data_contract_errors.jsonl")
+ with open(temp_error_file, encoding="utf-8", mode="w") as tpf:
+ for error in error_contract_messages:
+ json.dump(error, tpf)
+ tpf.write("\n")
+
+ yield con, test_entity, temp_dir_path
+
+
class BasicModel(BaseModel):
str_field: str
@@ -176,4 +236,23 @@ def test_use_cast_statements(casting_test_table):
not dodgy_date_rec.get("basic_model",{}).get("date_field")
and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
)
-
+
+
+def test_ddb_filter_contract_errors(example_data_contract_error_codes): # pylint: disable=W0621
+ ddb_cnn, entity_rel, temp_dir = example_data_contract_error_codes
+ expected_df = pl.DataFrame([ # pylint: disable=W0612
+ {"id": "field1", "attr": 1, "__record_index__": 1,},
+ {"id": "field3", "attr": 2, "__record_index__": 3,},
+ ])
+ expected_rel = ddb_cnn.sql("SELECT * FROM expected_df")
+ result_rel = _ddb_filter_contract_errors(
+ TempConnection(ddb_cnn), temp_dir, entity_rel, "test_entity"
+ )
+ assert result_rel.pl().shape[0] == 2
+ assert expected_rel.join(result_rel, "__record_index__", "anti").pl().shape[0] == 0
+
+
+def test_relation_is_empty(temp_ddb_conn: DuckDBPyConnection):
+ _, con = temp_ddb_conn
+ rel = con.sql("SELECT 'abc' AS test").filter("test IS NULL")
+ assert relation_is_empty(rel)
diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py
index 7502673..8a0e45e 100644
--- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py
+++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py
@@ -1,9 +1,15 @@
"""Tests for UDF helpers."""
# pylint: disable=redefined-outer-name
+# pylint: disable=C0301,C0115,C0116
+
import datetime as dt
+import json
+import os
+import tempfile
from dataclasses import dataclass
from decimal import Decimal
+from pathlib import Path
from typing import Any, List, Optional, Union
from uuid import UUID
@@ -19,6 +25,7 @@
from dve.core_engine.backends.implementations.spark.spark_helpers import (
DecimalConfig,
create_udf,
+ _spark_filter_contract_errors,
get_spark_cast_statement_from_annotation,
get_type_from_annotation,
object_to_spark_literal,
@@ -42,9 +49,56 @@ def casting_dataframe(spark):
StructField("basic_model", bm_schema),
StructField("another_model", StructType([StructField("unique_id", StringType()), StructField("basic_models", ArrayType(bm_schema))]))])
yield spark.createDataFrame(data, schema=schema)
-
-
-
+
+
+@pytest.fixture
+def example_data_contract_error_codes(spark: SparkSession):
+ test_df = spark.createDataFrame([ # pylint: disable=W0612
+ {"id": "field1", "attr": 1, "__record_index__": 1,},
+ {"id": "field2", "attr": None, "__record_index__": 2,},
+ {"id": "field3", "attr": 2, "__record_index__": 3,},
+ {"id": "field4", "attr": None, "__record_index__": 4,},
+ ])
+ error_contract_messages = [
+ {
+ "Entity": "test_entity",
+ "Key": "",
+ "FailureType": "record",
+ "Status": "error",
+ "ErrorType": "",
+ "ErrorLocation": "attr",
+ "ErrorMessage": "",
+ "ErrorCode": "",
+ "ReportingField": "attr",
+ "RecordIndex": 2,
+ "Value": "hello",
+ "Category": "Bad value"
+ },
+ {
+ "Entity": "test_entity",
+ "Key": "",
+ "FailureType": "record",
+ "Status": "error",
+ "ErrorType": "",
+ "ErrorLocation": "attr",
+ "ErrorMessage": "",
+ "ErrorCode": "",
+ "ReportingField": "attr",
+ "RecordIndex": 4,
+ "Value": "world",
+ "Category": "Bad value"
+ }
+ ]
+ with tempfile.TemporaryDirectory() as temp_dir_path:
+ os.mkdir(Path(temp_dir_path, "errors"))
+ temp_error_file = Path(temp_dir_path, "errors", "data_contract_errors.jsonl")
+ with open(temp_error_file, encoding="utf-8", mode="w") as tpf:
+ for error in error_contract_messages:
+ json.dump(error, tpf)
+ tpf.write("\n")
+
+ yield test_df, temp_dir_path
+
class BasicModel(BaseModel):
str_field: str
@@ -264,4 +318,25 @@ def test_use_cast_statements(spark, casting_dataframe):
not dodgy_date_rec.get("basic_model",{}).get("date_field")
and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
)
- assert cast_df
\ No newline at end of file
+ assert cast_df
+
+
+class TempSparkSession:
+ def __init__(self, spark: SparkSession):
+ self.spark_session = spark
+
+
+def test_spark_filter_contract_errors(spark: SparkSession, example_data_contract_error_codes): # pylint: disable=W0621
+ entity_df, temp_dir = example_data_contract_error_codes
+ expected_df = spark.createDataFrame([ # pylint: disable=W0612
+ {"id": "field1", "attr": 1, "__record_index__": 1,},
+ {"id": "field3", "attr": 2, "__record_index__": 3,},
+ ])
+ result_df = _spark_filter_contract_errors(
+ TempSparkSession(spark),
+ temp_dir,
+ entity_df,
+ "test_entity"
+ )
+ assert result_df.count() == 2
+ assert expected_df.join(result_df, "__record_index__", "anti").count() == 0
diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py
index b3048a1..063ced7 100644
--- a/tests/test_pipeline/test_spark_pipeline.py
+++ b/tests/test_pipeline/test_spark_pipeline.py
@@ -439,7 +439,9 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out
("Dataset Id", "planets"),
("File Name", "doesnotmatter"),
("File Extension", "json"),
- ("Submission Failure", "2"),
+ ("Total Number of Records Processed", "9"),
+ ("Submission Failure", "0"),
+ ("Record Rejection", "2"),
("Warning", "0"),
]
@@ -455,7 +457,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out
[
OrderedDict(
**{
- "Type": "Submission Failure",
+ "Type": "Record Rejection",
"Group": "planets",
"Data Item Submission Name": "orbitalPeriod",
"Category": "Bad value",
@@ -465,7 +467,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out
),
OrderedDict(
**{
- "Type": "Submission Failure",
+ "Type": "Record Rejection",
"Group": "planets",
"Data Item Submission Name": "gravity",
"Category": "Bad value",
@@ -485,7 +487,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out
OrderedDict(
**{
"Group": "planets",
- "Type": "Submission Failure",
+ "Type": "Record Rejection",
"Error Code": "LONG_ORBIT",
"Data Item Submission Name": "orbitalPeriod",
"Errors and Warnings": "Planet has long orbital period",
@@ -498,7 +500,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out
OrderedDict(
**{
"Group": "planets",
- "Type": "Submission Failure",
+ "Type": "Record Rejection",
"Error Code": "STRONG_GRAVITY",
"Data Item Submission Name": "gravity",
"Errors and Warnings": "Planet has too strong gravity",
diff --git a/tests/testdata/animals/animals.dischema.json b/tests/testdata/animals/animals.dischema.json
new file mode 100644
index 0000000..0e1eda1
--- /dev/null
+++ b/tests/testdata/animals/animals.dischema.json
@@ -0,0 +1,54 @@
+{
+ "contract": {
+ "schemas": {},
+ "datasets": {
+ "animals": {
+ "fields": {
+ "name": "str",
+ "height": "float",
+ "weight": "float",
+ "region": "str"
+ },
+ "reader_config": {
+ ".xml": {
+ "reader": "DuckDBXMLStreamReader",
+ "kwargs": {
+ "record_tag": "animal",
+ "root_tag": "animals"
+ }
+ }
+ },
+ "mandatory_fields": [
+ "name"
+ ]
+ }
+ }
+ },
+ "transformations": {
+ "filters": [
+ {
+ "entity": "animals",
+ "name": "check_valid_region",
+ "expression": "lower(region) in ('africa', 'asia')",
+ "error_code": "ANE01",
+ "failure_message": "Record rejected - `{{ region }}` is not in a valid region."
+ },
+ {
+ "entity": "animals",
+ "name": "check_for_pets",
+ "expression": "lower(name) != 'human'",
+ "error_code": "ANE02",
+ "failure_message": "Submission Rejected - 'Human' is not a valid animal.",
+ "failure_type": "submission"
+ },
+ {
+ "entity": "animals",
+ "name": "check_valid_weight",
+ "expression": "weight > 0",
+ "error_code": "ANE03",
+ "failure_message": "Warning - `{{ weight }}` is below zero.",
+ "is_informational": true
+ }
+ ]
+ }
+}
\ No newline at end of file
diff --git a/tests/testdata/animals/animals.xml b/tests/testdata/animals/animals.xml
new file mode 100644
index 0000000..60bdcef
--- /dev/null
+++ b/tests/testdata/animals/animals.xml
@@ -0,0 +1,33 @@
+
+
+
+ African Elephant
+ 3.5
+ 6000.0
+ Africa
+
+
+ Bengal Tiger
+ 1.1
+ 260.0
+ Asia
+
+
+ Giraffe
+ 5.5
+ 1200.0
+ Africa
+
+
+ Polar Bear
+ 2.6
+ 900.0
+ Arctic
+
+
+ Blue Whale
+ 24.0
+ 180000.0
+ Oceans
+
+
diff --git a/tests/testdata/animals/animals_mixture.xml b/tests/testdata/animals/animals_mixture.xml
new file mode 100644
index 0000000..230f790
--- /dev/null
+++ b/tests/testdata/animals/animals_mixture.xml
@@ -0,0 +1,45 @@
+
+
+
+ African Elephant
+ 3.5
+ 6000.0
+ Africa
+
+
+ Bengal Tiger
+ 1.1
+ 260.0
+ Asia
+
+
+ Giraffe
+ 5.5
+ 1200.0
+ Africa
+
+
+ Polar Bear
+ 2.6
+ 900.0
+ Arctic
+
+
+ Blue Whale
+ 24.0
+ 180000.0
+ Oceans
+
+
+ Human
+ 1.7
+ 70.0
+ Africa
+
+
+ African Elephant
+ 3.5
+ -6000.0
+ Africa
+
+