diff --git a/pyproject.toml b/pyproject.toml index cd4ac761..c0aa3fa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ dependencies = [ "anyio>=3.5.0, <5", "distro>=1.7.0, <2", "sniffio", + "pandas==2.2.3", + "numpy==2.0.2", ] requires-python = ">= 3.8" classifiers = [ @@ -55,6 +57,8 @@ dev-dependencies = [ "importlib-metadata>=6.7.0", "rich>=13.7.1", "nest_asyncio==1.6.0", + "pandas==2.2.3", + "numpy==2.0.2", ] [tool.rye.scripts] diff --git a/requirements-dev.lock b/requirements-dev.lock index 83d02e00..19dcb392 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -102,3 +102,7 @@ virtualenv==20.24.5 # via nox zipp==3.17.0 # via importlib-metadata +pandas==2.2.3 + # via contextual-client +numpy==2.0.2 + # via contextual-client \ No newline at end of file diff --git a/requirements.lock b/requirements.lock index bc4698e1..3b833e41 100644 --- a/requirements.lock +++ b/requirements.lock @@ -43,3 +43,7 @@ typing-extensions==4.12.2 # via contextual-client # via pydantic # via pydantic-core +pandas==2.2.3 + # via contextual-client +numpy==2.0.2 + # via contextual-client \ No newline at end of file diff --git a/src/contextual/_response.py b/src/contextual/_response.py index 51fc249d..6d4a7bb1 100644 --- a/src/contextual/_response.py +++ b/src/contextual/_response.py @@ -1,6 +1,8 @@ from __future__ import annotations import os +import ast +import json import inspect import logging import datetime @@ -23,6 +25,7 @@ import anyio import httpx import pydantic +from pandas import DataFrame # type: ignore[import] from ._types import NoneType from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base @@ -479,6 +482,61 @@ class BinaryAPIResponse(APIResponse[bytes]): the API request, e.g. `.with_streaming_response.get_binary_response()` """ + def to_dataframe(self) -> DataFrame: + """Convert the response data to a pandas DataFrame. + + Note: This method requires the `pandas` library to be installed. + + Returns: + DataFrame: Processed evaluation data + """ + # Read the binary content + content = self.read() + + # Now decode the content + lines = content.decode("utf-8").strip().split("\n") + + # Parse each line and flatten the results + data = [] + for line in lines: + try: + entry = json.loads(line) + # Parse the results field directly from JSON + if 'results' in entry: + if isinstance(entry['results'], str): + # Try to handle string representations that are valid JSON + try: + results = json.loads(entry['results']) + except Exception as e: + # If not valid JSON, fall back to safer processing + results = ast.literal_eval(entry['results']) + else: + # Already a dictionary + results = entry['results'] + + # Remove the original results field + del entry['results'] + # Flatten the nested dictionary structure + if isinstance(results, dict): + for key, value in results.items(): # type: ignore + if isinstance(value, dict): + for subkey, subvalue in value.items(): # type: ignore + if isinstance(subvalue, dict): + # Handle one more level of nesting + for subsubkey, subsubvalue in subvalue.items(): # type: ignore + entry[f'{key}_{subkey}_{subsubkey}'] = subsubvalue + else: + entry[f'{key}_{subkey}'] = subvalue + else: + entry[key] = value + + data.append(entry) # type: ignore + except Exception as e: + log.error(f"Error processing line: {e}") + log.error(f"Problematic line: {line[:200]}...") # Print first 200 chars of the line + continue + return DataFrame(data) + def write_to_file( self, file: str | os.PathLike[str], diff --git a/tests/test_response.py b/tests/test_response.py index cedd75ba..d4cb409f 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -73,6 +73,24 @@ def test_response_parse_mismatched_basemodel(client: ContextualAI) -> None: response.parse(to=PydanticModel) +def test_response_binary_response_to_dataframe(client: ContextualAI) -> None: + response = BinaryAPIResponse( + raw=httpx.Response( + 200, + content=b'{"prompt": "What was Apple\'s total net sales for 2022?", "reference": "...", "response": "...", "guideline": "", "knowledge": "[]", "results": "{\'equivalence_score\': {\'score\': 0.0, \'metadata\': \\"The generated response does not provide any information about Apple\'s total net sales for 2022, whereas the reference response provides the specific figure.\\"}, \'factuality_v4.5_score\': {\'score\': 0.0, \'metadata\': {\'description\': \'There are claims but no knowledge so response is ungrounded.\'}}}", "status": "completed"}\r\n', + ), + client=client, + stream=False, + stream_cls=None, + cast_to=bytes, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + df = response.to_dataframe() + assert df.shape == (1, 10) + assert df["prompt"].astype(str).iloc[0] == "What was Apple's total net sales for 2022?" # type: ignore + assert df["equivalence_score_score"].astype(float).iloc[0] == 0.0 # type: ignore + + @pytest.mark.asyncio async def test_async_response_parse_mismatched_basemodel(async_client: AsyncContextualAI) -> None: response = AsyncAPIResponse(