Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
VALID_URL_SCHEMES = ["https"]
CLOUD_DEPLOYMENT_MODE = "cloud"
_HAS_LOGGED_FOR_SERIALIZATION_ERROR = False
_HAS_LOGGED_FOR_SERIALIZATION_FALLBACK = False


class AirbyteEntrypoint(object):
Expand Down Expand Up @@ -333,16 +334,28 @@ def set_up_secret_filter(config: TConfig, connection_specification: Mapping[str,
@staticmethod
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
global _HAS_LOGGED_FOR_SERIALIZATION_ERROR
global _HAS_LOGGED_FOR_SERIALIZATION_FALLBACK
serialized_message = AirbyteMessageSerializer.dump(airbyte_message)
try:
return orjson.dumps(serialized_message).decode()
except Exception as exception:
except Exception as orjson_error:
if not _HAS_LOGGED_FOR_SERIALIZATION_ERROR:
logger.warning(
f"There was an error during the serialization of an AirbyteMessage: `{exception}`. This might impact the sync performances."
"Record serialization fell back to slower method. Sync will continue with reduced performance."
)
logger.debug("orjson serialization error: %s", orjson_error)
_HAS_LOGGED_FOR_SERIALIZATION_ERROR = True
return json.dumps(serialized_message)
try:
return json.dumps(serialized_message)
except TypeError as json_error:
if not _HAS_LOGGED_FOR_SERIALIZATION_FALLBACK:
logger.warning(
"Record contains a value that could not be serialized to JSON. "
"The value was converted to a string representation."
)
logger.debug("json serialization error: %s", json_error)
_HAS_LOGGED_FOR_SERIALIZATION_FALLBACK = True
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a false positive. _HAS_LOGGED_FOR_SERIALIZATION_FALLBACK is already declared at module level on line 52 and is read/written inside the function via global _HAS_LOGGED_FOR_SERIALIZATION_FALLBACK. The suggested fix describes exactly what this PR already does.

return json.dumps(serialized_message, default=str)

@classmethod
def extract_state(cls, args: List[str]) -> Optional[Any]:
Expand Down
32 changes: 32 additions & 0 deletions unit_tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,10 @@ def test_handle_record_counts(
def test_given_serialization_error_using_orjson_then_fallback_on_json(
entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
):
# Reset global flags to avoid test pollution
entrypoint_module._HAS_LOGGED_FOR_SERIALIZATION_ERROR = False
entrypoint_module._HAS_LOGGED_FOR_SERIALIZATION_FALLBACK = False

parsed_args = Namespace(
command="read", config="config_path", state="statepath", catalog="catalogpath"
)
Expand All @@ -856,3 +860,31 @@ def test_given_serialization_error_using_orjson_then_fallback_on_json(
# There will be multiple messages here because the fixture `entrypoint` sets a control message. We only care about records here
record_messages = list(filter(lambda message: "RECORD" in message, messages))
assert len(record_messages) == 2


def test_given_non_json_serializable_type_then_fallback_with_default_str(
entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
):
"""Test that types which both orjson and json cannot serialize (like complex) are handled via default=str fallback."""
# Reset global flags to avoid test pollution
entrypoint_module._HAS_LOGGED_FOR_SERIALIZATION_ERROR = False
entrypoint_module._HAS_LOGGED_FOR_SERIALIZATION_FALLBACK = False

parsed_args = Namespace(
command="read", config="config_path", state="statepath", catalog="catalogpath"
)
record = AirbyteMessage(
record=AirbyteRecordMessage(stream="stream", data={"value": complex(1, 2)}, emitted_at=1),
type=Type.RECORD,
)
mocker.patch.object(MockSource, "read_state", return_value={})
mocker.patch.object(MockSource, "read_catalog", return_value={})
mocker.patch.object(MockSource, "read", return_value=[record])

messages = list(entrypoint.run(parsed_args))

record_messages = list(filter(lambda message: "RECORD" in message, messages))
assert len(record_messages) == 1
# Verify the complex value was converted to its string representation
parsed_record = orjson.loads(record_messages[0])
assert parsed_record["record"]["data"]["value"] == "(1+2j)"
Loading