diff --git a/ccflow/exttypes/narwhals.py b/ccflow/exttypes/narwhals.py index 5bb0890..56976fb 100644 --- a/ccflow/exttypes/narwhals.py +++ b/ccflow/exttypes/narwhals.py @@ -59,6 +59,13 @@ def validate_from_any(value: Any): if source_args and source_args[0] and source_args[0] is not Any: backend = source_args[0].__module__.split(".", 1)[0] + if "backend" in value: + if backend is None: + # backend in source args takes precedence + backend = value["backend"] + + value = value["data"] + try: try: value = nw.from_dict(value, backend=backend) @@ -89,7 +96,10 @@ def validate_from_any(value: Any): def serialize(value: Any): if isinstance(value, nw.DataFrame): - return value.to_dict(as_series=False) + return { + "data": value.to_dict(as_series=False), + "backend": value.implementation.value, + } else: raise ValueError("Cannot serialize a LazyFrame to JSON. Please use the collect() method to convert it to a DataFrame first.") diff --git a/ccflow/tests/result/test_narwhals.py b/ccflow/tests/result/test_narwhals.py index a45e8f6..1b22028 100644 --- a/ccflow/tests/result/test_narwhals.py +++ b/ccflow/tests/result/test_narwhals.py @@ -3,6 +3,7 @@ import narwhals.stable.v1 as nw import polars as pl import pytest +from polars.testing import assert_frame_equal from ccflow.exttypes.narwhals import ( DataFrameT, @@ -90,3 +91,10 @@ class MyNarwhalsResult(NarwhalsDataFrameResult): df = pl.DataFrame(data) result = MyNarwhalsResult(df=df) assert result.df.schema["d"] == nw.Float64() + + +def test_serialization(data): + df = pl.DataFrame(data) + result = NarwhalsDataFrameResult(df=df) + result2 = NarwhalsDataFrameResult.model_validate_json(result.model_dump_json()) + assert_frame_equal(result.df.to_native(), result2.df.to_native())