From 518d57ef6a939b16197f6cc409576fb8527b6bf3 Mon Sep 17 00:00:00 2001 From: Ueslei Santos Lima Date: Sun, 5 Apr 2026 13:47:28 +0200 Subject: [PATCH 1/6] feat(python-driver): add public API for connection pooling and model dict conversion Add two enhancements to the Python driver's public API: 1. Add configure_connection() function that registers AGE agtype adapters on an existing psycopg connection without creating a new one. This enables use with external connection pools (e.g. psycopg_pool) and managed PostgreSQL services where LOAD 'age' may be restricted. Also explicitly export AgeLoader and ClientCursor as public symbols in age/__init__.py. (#2369) 2. Add to_dict() methods to Vertex, Edge, and Path model classes for conversion to plain Python dicts. This enables direct JSON serialization with json.dumps() without requiring custom conversion logic. (#2371) - Vertex.to_dict() returns {id, label, properties} - Edge.to_dict() returns {id, label, start_id, end_id, properties} - Path.to_dict() returns a list of to_dict() results Closes #2369 Closes #2371 --- drivers/python/age/__init__.py | 1 + drivers/python/age/age.py | 37 +++++++++++++++++ drivers/python/age/models.py | 22 ++++++++++ drivers/python/test_age_py.py | 74 +++++++++++++++++++++++++++++++++- 4 files changed, 133 insertions(+), 1 deletion(-) diff --git a/drivers/python/age/__init__.py b/drivers/python/age/__init__.py index caee6a43c..2da5ef102 100644 --- a/drivers/python/age/__init__.py +++ b/drivers/python/age/__init__.py @@ -16,6 +16,7 @@ import psycopg.conninfo as conninfo from . import age from .age import * +from .age import AgeLoader, ClientCursor, configure_connection from .models import * from .builder import ResultHandler, DummyResultHandler, parseAgeValue, newResultHandler from . import VERSION diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index ae76bcf50..4021447fd 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -170,6 +170,43 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals if graphName != None: checkGraphCreated(conn, graphName) + +def configure_connection(conn, graph_name=None, skip_load=False): + """Register AGE agtype adapters on an existing connection. + + This enables use of AGE with externally-managed connections, such as + those from psycopg_pool.ConnectionPool. Unlike setUpAge(), this + function does not call LOAD 'age' by default, making it suitable for + managed PostgreSQL services where LOAD is restricted. + + Args: + conn: An existing psycopg connection. + graph_name: Optional graph name to check/create. + skip_load: If False (default), skip LOAD 'age'. Set to True to + include the LOAD command (equivalent to setUpAge behavior). + """ + with conn.cursor() as cursor: + if not skip_load: + cursor.execute("LOAD 'age';") + + cursor.execute("SET search_path = ag_catalog, '$user', public;") + + ag_info = TypeInfo.fetch(conn, 'agtype') + + if not ag_info: + raise AgeNotSet( + "AGE agtype type not found. Ensure the AGE extension is " + "installed and loaded in the current database. " + "Run CREATE EXTENSION age; first." + ) + + conn.adapters.register_loader(ag_info.oid, AgeLoader) + conn.adapters.register_loader(ag_info.array_oid, AgeLoader) + + if graph_name is not None: + checkGraphCreated(conn, graph_name) + + # Create the graph, if it does not exist def checkGraphCreated(conn:psycopg.connection, graphName:str): validate_graph_name(graphName) diff --git a/drivers/python/age/models.py b/drivers/python/age/models.py index 6d9095485..93a5e52a0 100644 --- a/drivers/python/age/models.py +++ b/drivers/python/age/models.py @@ -118,6 +118,12 @@ def toJson(self) -> str: return buf.getvalue() + def to_dict(self) -> list: + return [ + e.to_dict() if isinstance(e, AGObj) else e + for e in self.entities + ] + @@ -146,6 +152,13 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.toString() + def to_dict(self) -> dict: + return { + "id": self.id, + "label": self.label, + "properties": dict(self.properties) if self.properties else {}, + } + def toString(self) -> str: return nodeToString(self) @@ -186,6 +199,15 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.toString() + def to_dict(self) -> dict: + return { + "id": self.id, + "label": self.label, + "start_id": self.start_id, + "end_id": self.end_id, + "properties": dict(self.properties) if self.properties else {}, + } + def extraStrFormat(node, buf): if node.start_id != None: buf.write(", start_id:") diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index 12dd7bd55..fb1ea632c 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -15,7 +15,7 @@ import json import re -from age.models import Vertex +from age.models import Vertex, Edge, Path import unittest import unittest.mock import decimal @@ -171,6 +171,78 @@ def test_validate_column_quoting(self): self.assertEqual(_validate_column("my_col"), '"my_col" agtype') +class TestModelToDict(unittest.TestCase): + """Unit tests for Vertex/Edge/Path to_dict() — no DB required.""" + + def test_vertex_to_dict(self): + v = Vertex(id=123, label="Person", properties={"name": "Alice", "age": 30}) + d = v.to_dict() + self.assertEqual(d["id"], 123) + self.assertEqual(d["label"], "Person") + self.assertEqual(d["properties"], {"name": "Alice", "age": 30}) + # Verify it's a plain dict (JSON-serializable) + json_str = json.dumps(d) + self.assertIn("Alice", json_str) + + def test_vertex_to_dict_empty_properties(self): + v = Vertex(id=1, label="Empty", properties=None) + d = v.to_dict() + self.assertEqual(d["properties"], {}) + + def test_edge_to_dict(self): + e = Edge(id=456, label="KNOWS", properties={"since": 2020}) + e.start_id = 123 + e.end_id = 789 + d = e.to_dict() + self.assertEqual(d["id"], 456) + self.assertEqual(d["label"], "KNOWS") + self.assertEqual(d["start_id"], 123) + self.assertEqual(d["end_id"], 789) + self.assertEqual(d["properties"], {"since": 2020}) + json_str = json.dumps(d) + self.assertIn("KNOWS", json_str) + + def test_path_to_dict(self): + v1 = Vertex(id=1, label="A", properties={"name": "start"}) + e = Edge(id=10, label="r", properties={"w": 1}) + e.start_id = 1 + e.end_id = 2 + v2 = Vertex(id=2, label="B", properties={"name": "end"}) + p = Path([v1, e, v2]) + d = p.to_dict() + self.assertEqual(len(d), 3) + self.assertEqual(d[0]["label"], "A") + self.assertEqual(d[1]["label"], "r") + self.assertEqual(d[1]["start_id"], 1) + self.assertEqual(d[2]["label"], "B") + # Verify the whole path is JSON-serializable + json_str = json.dumps(d) + self.assertIn("start", json_str) + + def test_vertex_to_dict_is_plain_dict(self): + """to_dict() returns standard dict, not a model object.""" + v = Vertex(id=1, label="X", properties={"k": "v"}) + d = v.to_dict() + self.assertIsInstance(d, dict) + self.assertIsInstance(d["properties"], dict) + + +class TestPublicImports(unittest.TestCase): + """Verify that public API symbols are importable without type: ignore.""" + + def test_import_configure_connection(self): + from age import configure_connection + self.assertTrue(callable(configure_connection)) + + def test_import_age_loader(self): + from age import AgeLoader + self.assertIsNotNone(AgeLoader) + + def test_import_client_cursor(self): + from age import ClientCursor + self.assertIsNotNone(ClientCursor) + + class TestAgeBasic(unittest.TestCase): ag = None args: argparse.Namespace = argparse.Namespace( From c1ef604d30c7ca3effe63adc20a4f8dd0f59d129 Mon Sep 17 00:00:00 2001 From: Ueslei Santos Lima Date: Sun, 5 Apr 2026 15:18:59 +0200 Subject: [PATCH 2/6] Fix configure_connection: correct parameter semantics, add load_from_plugins - Replace confusing `skip_load` (double-negative) with `load` (positive boolean, default False). The default now correctly matches the intent: no LOAD by default for connection pool / managed PostgreSQL use cases. - Add `load_from_plugins` parameter for parity with setUpAge(). - Fix docstring to accurately describe parameter behavior. - Add 6 unit tests for configure_connection covering: default no-load, explicit load, load_from_plugins, search_path always set, adapter registration, and graph_name check delegation. Made-with: Cursor --- drivers/python/age/age.py | 29 ++++++++++---- drivers/python/test_age_py.py | 72 +++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 8 deletions(-) diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index 4021447fd..5fb0f2a6f 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -171,23 +171,36 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals checkGraphCreated(conn, graphName) -def configure_connection(conn, graph_name=None, skip_load=False): +def configure_connection(conn, graph_name=None, load=False, load_from_plugins=False): """Register AGE agtype adapters on an existing connection. This enables use of AGE with externally-managed connections, such as - those from psycopg_pool.ConnectionPool. Unlike setUpAge(), this - function does not call LOAD 'age' by default, making it suitable for - managed PostgreSQL services where LOAD is restricted. + those from psycopg_pool.ConnectionPool. By default the function does + **not** execute ``LOAD 'age'``, making it safe for managed PostgreSQL + services (Azure, AWS RDS) where the extension is pre-loaded via + ``shared_preload_libraries``. + + Performs: + - ``SET search_path`` to include ``ag_catalog`` + - Fetches agtype OIDs and registers ``AgeLoader`` + - Optionally loads the AGE extension (``load=True``) + - Optionally checks/creates the graph Args: conn: An existing psycopg connection. graph_name: Optional graph name to check/create. - skip_load: If False (default), skip LOAD 'age'. Set to True to - include the LOAD command (equivalent to setUpAge behavior). + load: If True, execute ``LOAD 'age'`` (or the plugins path). + Default False — suitable for environments where AGE is + already loaded. + load_from_plugins: If True (and ``load=True``), use + ``LOAD '$libdir/plugins/age'`` instead of ``LOAD 'age'``. """ with conn.cursor() as cursor: - if not skip_load: - cursor.execute("LOAD 'age';") + if load: + if load_from_plugins: + cursor.execute("LOAD '$libdir/plugins/age';") + else: + cursor.execute("LOAD 'age';") cursor.execute("SET search_path = ag_catalog, '$user', public;") diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index fb1ea632c..61ae8d029 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -243,6 +243,78 @@ def test_import_client_cursor(self): self.assertIsNotNone(ClientCursor) +class TestConfigureConnection(unittest.TestCase): + """Unit tests for configure_connection() — no DB required.""" + + def _make_mock_conn(self): + mock_conn = unittest.mock.MagicMock() + mock_cursor = unittest.mock.MagicMock() + mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False) + mock_conn.adapters = unittest.mock.MagicMock() + mock_type_info = unittest.mock.MagicMock() + mock_type_info.oid = 1 + mock_type_info.array_oid = 2 + return mock_conn, mock_cursor, mock_type_info + + def test_default_does_not_load(self): + """By default, configure_connection should NOT execute LOAD.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn) + executed = [str(c) for c in mock_cursor.execute.call_args_list] + for stmt in executed: + self.assertNotIn("LOAD", stmt, f"LOAD should not be called by default, got: {stmt}") + + def test_load_true_executes_load(self): + """When load=True, LOAD 'age' must be executed.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn, load=True) + executed = [str(c) for c in mock_cursor.execute.call_args_list] + load_calls = [s for s in executed if "LOAD" in s and "age" in s] + self.assertTrue(len(load_calls) > 0, "LOAD should be called when load=True") + + def test_load_from_plugins(self): + """When load=True and load_from_plugins=True, use plugins path.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn, load=True, load_from_plugins=True) + executed = [str(c) for c in mock_cursor.execute.call_args_list] + plugins_calls = [s for s in executed if "plugins" in s] + self.assertTrue(len(plugins_calls) > 0, "LOAD from plugins path should be called") + + def test_always_sets_search_path(self): + """search_path must always be set regardless of load parameter.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn) + executed = [str(c) for c in mock_cursor.execute.call_args_list] + search_path_calls = [s for s in executed if "search_path" in s] + self.assertTrue(len(search_path_calls) > 0, "search_path should always be set") + + def test_registers_agtype_adapters(self): + """AgeLoader must be registered for agtype OIDs.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.configure_connection(mock_conn) + mock_conn.adapters.register_loader.assert_any_call(1, age.age.AgeLoader) + mock_conn.adapters.register_loader.assert_any_call(2, age.age.AgeLoader) + + def test_graph_name_triggers_check(self): + """When graph_name is provided, checkGraphCreated must be called.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated") as mock_check: + age.age.configure_connection(mock_conn, graph_name="my_graph") + mock_check.assert_called_once_with(mock_conn, "my_graph") + + class TestAgeBasic(unittest.TestCase): ag = None args: argparse.Namespace = argparse.Namespace( From 1f7570b32ec761022237a191c7e6e3c01e1e0503 Mon Sep 17 00:00:00 2001 From: Ueslei Santos Lima Date: Mon, 6 Apr 2026 00:56:20 +0200 Subject: [PATCH 3/6] Address review feedback for configure_connection and to_dict - Move TypeInfo.fetch() inside cursor block so search_path change is visible regardless of transaction isolation mode - Raise ValueError when load_from_plugins=True but load=False - Add type annotations to configure_connection signature - Document shallow-copy semantics in Vertex/Edge to_dict() - Path.to_dict() uses str() fallback for non-AGObj entities to guarantee JSON-serializable output - Add test for AgeNotSet when TypeInfo.fetch returns None - Add test for load_from_plugins=True without load=True - Replace fragile string assertions with assert_called_with/assert_any_call Made-with: Cursor --- drivers/python/age/age.py | 19 +++++++++++++++++-- drivers/python/age/models.py | 14 +++++++++++++- drivers/python/test_age_py.py | 34 ++++++++++++++++++++++------------ 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index 5fb0f2a6f..2b2ceaf18 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -171,7 +171,12 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals checkGraphCreated(conn, graphName) -def configure_connection(conn, graph_name=None, load=False, load_from_plugins=False): +def configure_connection( + conn: psycopg.connection, + graph_name: str | None = None, + load: bool = False, + load_from_plugins: bool = False, +) -> None: """Register AGE agtype adapters on an existing connection. This enables use of AGE with externally-managed connections, such as @@ -194,7 +199,17 @@ def configure_connection(conn, graph_name=None, load=False, load_from_plugins=Fa already loaded. load_from_plugins: If True (and ``load=True``), use ``LOAD '$libdir/plugins/age'`` instead of ``LOAD 'age'``. + + Raises: + ValueError: If ``load_from_plugins=True`` but ``load=False``. + AgeNotSet: If the agtype type is not found in the database. """ + if load_from_plugins and not load: + raise ValueError( + "load_from_plugins=True requires load=True. " + "Set load=True to enable extension loading." + ) + with conn.cursor() as cursor: if load: if load_from_plugins: @@ -204,7 +219,7 @@ def configure_connection(conn, graph_name=None, load=False, load_from_plugins=Fa cursor.execute("SET search_path = ag_catalog, '$user', public;") - ag_info = TypeInfo.fetch(conn, 'agtype') + ag_info = TypeInfo.fetch(conn, 'agtype') if not ag_info: raise AgeNotSet( diff --git a/drivers/python/age/models.py b/drivers/python/age/models.py index 93a5e52a0..50fa1b26f 100644 --- a/drivers/python/age/models.py +++ b/drivers/python/age/models.py @@ -119,8 +119,10 @@ def toJson(self) -> str: return buf.getvalue() def to_dict(self) -> list: + # Non-AGObj elements (e.g. raw dicts/strings from malformed paths) + # are included as-is via str() to guarantee JSON-serializable output. return [ - e.to_dict() if isinstance(e, AGObj) else e + e.to_dict() if isinstance(e, AGObj) else str(e) for e in self.entities ] @@ -153,6 +155,11 @@ def __repr__(self) -> str: return self.toString() def to_dict(self) -> dict: + """Return a plain dict suitable for JSON serialization. + + Properties are shallow-copied; nested mutable values will share + references with the original Vertex. + """ return { "id": self.id, "label": self.label, @@ -200,6 +207,11 @@ def __repr__(self) -> str: return self.toString() def to_dict(self) -> dict: + """Return a plain dict suitable for JSON serialization. + + Properties are shallow-copied; nested mutable values will share + references with the original Edge. + """ return { "id": self.id, "label": self.label, diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index 61ae8d029..0578bedd8 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -263,9 +263,9 @@ def test_default_does_not_load(self): with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ unittest.mock.patch("age.age.checkGraphCreated"): age.age.configure_connection(mock_conn) - executed = [str(c) for c in mock_cursor.execute.call_args_list] - for stmt in executed: - self.assertNotIn("LOAD", stmt, f"LOAD should not be called by default, got: {stmt}") + mock_cursor.execute.assert_called_once_with( + "SET search_path = ag_catalog, '$user', public;" + ) def test_load_true_executes_load(self): """When load=True, LOAD 'age' must be executed.""" @@ -273,9 +273,7 @@ def test_load_true_executes_load(self): with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ unittest.mock.patch("age.age.checkGraphCreated"): age.age.configure_connection(mock_conn, load=True) - executed = [str(c) for c in mock_cursor.execute.call_args_list] - load_calls = [s for s in executed if "LOAD" in s and "age" in s] - self.assertTrue(len(load_calls) > 0, "LOAD should be called when load=True") + mock_cursor.execute.assert_any_call("LOAD 'age';") def test_load_from_plugins(self): """When load=True and load_from_plugins=True, use plugins path.""" @@ -283,9 +281,13 @@ def test_load_from_plugins(self): with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ unittest.mock.patch("age.age.checkGraphCreated"): age.age.configure_connection(mock_conn, load=True, load_from_plugins=True) - executed = [str(c) for c in mock_cursor.execute.call_args_list] - plugins_calls = [s for s in executed if "plugins" in s] - self.assertTrue(len(plugins_calls) > 0, "LOAD from plugins path should be called") + mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';") + + def test_load_from_plugins_without_load_raises(self): + """load_from_plugins=True without load=True must raise ValueError.""" + mock_conn, _, _ = self._make_mock_conn() + with self.assertRaises(ValueError): + age.age.configure_connection(mock_conn, load_from_plugins=True) def test_always_sets_search_path(self): """search_path must always be set regardless of load parameter.""" @@ -293,9 +295,9 @@ def test_always_sets_search_path(self): with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ unittest.mock.patch("age.age.checkGraphCreated"): age.age.configure_connection(mock_conn) - executed = [str(c) for c in mock_cursor.execute.call_args_list] - search_path_calls = [s for s in executed if "search_path" in s] - self.assertTrue(len(search_path_calls) > 0, "search_path should always be set") + mock_cursor.execute.assert_any_call( + "SET search_path = ag_catalog, '$user', public;" + ) def test_registers_agtype_adapters(self): """AgeLoader must be registered for agtype OIDs.""" @@ -314,6 +316,14 @@ def test_graph_name_triggers_check(self): age.age.configure_connection(mock_conn, graph_name="my_graph") mock_check.assert_called_once_with(mock_conn, "my_graph") + def test_age_not_set_when_type_info_is_none(self): + """AgeNotSet must be raised when TypeInfo.fetch returns None.""" + from age.exceptions import AgeNotSet + mock_conn, _, _ = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=None): + with self.assertRaises(AgeNotSet): + age.age.configure_connection(mock_conn) + class TestAgeBasic(unittest.TestCase): ag = None From 52f2e1c57408c521be2d759c3c6392d488f923a3 Mon Sep 17 00:00:00 2001 From: Ueslei Santos Lima Date: Mon, 6 Apr 2026 22:39:01 +0200 Subject: [PATCH 4/6] Fix Path.to_dict() to preserve JSON-native types, add tests to suite - Path.to_dict(): leave dict/list/str/int/float/bool/None unchanged instead of converting to str(); handle entities=None safely - Add TestModelToDict, TestPublicImports, TestConfigureConnection to the __main__ suite so they run via direct script execution Made-with: Cursor --- drivers/python/age/models.py | 18 ++++++++++++------ drivers/python/test_age_py.py | 5 +++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/drivers/python/age/models.py b/drivers/python/age/models.py index 50fa1b26f..62215c160 100644 --- a/drivers/python/age/models.py +++ b/drivers/python/age/models.py @@ -119,12 +119,18 @@ def toJson(self) -> str: return buf.getvalue() def to_dict(self) -> list: - # Non-AGObj elements (e.g. raw dicts/strings from malformed paths) - # are included as-is via str() to guarantee JSON-serializable output. - return [ - e.to_dict() if isinstance(e, AGObj) else str(e) - for e in self.entities - ] + # AGObj elements are recursively converted; JSON-native types + # (dict, list, str, int, float, bool, None) pass through unchanged. + # Non-serializable objects fall back to str() as a safety net. + result = [] + for e in (self.entities or []): + if isinstance(e, AGObj): + result.append(e.to_dict()) + elif isinstance(e, (dict, list, str, int, float, bool, type(None))): + result.append(e) + else: + result.append(str(e)) + return result diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index 0578bedd8..20ae6dec4 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -782,9 +782,14 @@ def testSerialization(self): args = parser.parse_args() suite = unittest.TestSuite() + # Unit tests (no DB required) loader = unittest.TestLoader() suite.addTests(loader.loadTestsFromTestCase(TestSetUpAge)) suite.addTests(loader.loadTestsFromTestCase(TestBuildCypher)) + suite.addTests(loader.loadTestsFromTestCase(TestModelToDict)) + suite.addTests(loader.loadTestsFromTestCase(TestPublicImports)) + suite.addTests(loader.loadTestsFromTestCase(TestConfigureConnection)) + # Integration tests (require DB) suite.addTest(TestAgeBasic("testExec")) suite.addTest(TestAgeBasic("testQuery")) suite.addTest(TestAgeBasic("testChangeData")) From 3e2318a0d9d09b4a442f43bc431eb48b9716b263 Mon Sep 17 00:00:00 2001 From: Ueslei Santos Lima Date: Fri, 10 Apr 2026 21:30:58 +0200 Subject: [PATCH 5/6] fix(python-driver): Python 3.9-safe hints and correct $user in search_path - Use Optional[str] for configure_connection graph_name (PEP 604 unions are invalid on Python 3.9). - Import Any/Optional from typing for annotations. - Quote $user in SET search_path; align unit test expectations. Made-with: Cursor --- drivers/python/age/age.py | 6 ++++-- drivers/python/test_age_py.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index 2b2ceaf18..a580936a3 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -14,6 +14,8 @@ # under the License. import re +from typing import Any, Optional + import psycopg from psycopg.types import TypeInfo from psycopg import sql @@ -173,7 +175,7 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals def configure_connection( conn: psycopg.connection, - graph_name: str | None = None, + graph_name: Optional[str] = None, load: bool = False, load_from_plugins: bool = False, ) -> None: @@ -217,7 +219,7 @@ def configure_connection( else: cursor.execute("LOAD 'age';") - cursor.execute("SET search_path = ag_catalog, '$user', public;") + cursor.execute('SET search_path = ag_catalog, "$user", public;') ag_info = TypeInfo.fetch(conn, 'agtype') diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index 20ae6dec4..5d8c2f29c 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -264,7 +264,7 @@ def test_default_does_not_load(self): unittest.mock.patch("age.age.checkGraphCreated"): age.age.configure_connection(mock_conn) mock_cursor.execute.assert_called_once_with( - "SET search_path = ag_catalog, '$user', public;" + 'SET search_path = ag_catalog, "$user", public;' ) def test_load_true_executes_load(self): @@ -296,7 +296,7 @@ def test_always_sets_search_path(self): unittest.mock.patch("age.age.checkGraphCreated"): age.age.configure_connection(mock_conn) mock_cursor.execute.assert_any_call( - "SET search_path = ag_catalog, '$user', public;" + 'SET search_path = ag_catalog, "$user", public;' ) def test_registers_agtype_adapters(self): From 9661b7a4c1c4cd799315326868b08362c2a3f2dc Mon Sep 17 00:00:00 2001 From: Ueslei Santos Lima Date: Mon, 20 Apr 2026 13:54:40 +0200 Subject: [PATCH 6/6] test(python-driver): add configure_connection + to_dict integration test Existing tests for the new public API are unit-only: - TestConfigureConnection mocks the psycopg connection, so it never proves that AgeLoader actually registers against real agtype OIDs. - TestModelToDict hand-constructs Vertex/Edge/Path via kwargs, so it never serialises objects produced by the ANTLR parser. Add a single TestAgeBasic.testConfigureConnection that: - opens a raw psycopg connection (bypassing age.connect()), - calls configure_connection(..., load=True) on it, - runs a Cypher CREATE/RETURN through the configured connection, - asserts the returned values are real Vertex/Edge instances and that their to_dict() output is JSON-serialisable with the expected label/start_id/end_id/properties shape, - repeats the round-trip for a Path returned by MATCH. This is the smallest test that proves the configure_connection + to_dict pipeline works end-to-end against a live AGE database. Made-with: Cursor --- drivers/python/test_age_py.py | 73 +++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index 5d8c2f29c..bb92f7d51 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -745,6 +745,78 @@ def testSerialization(self): self.assertFalse(as_str.endswith(", }}::VERTEX")) print("Vertex.toString() 'properties' field is formatted properly.") + def testConfigureConnection(self): + """Integration: configure_connection() on an externally-opened + connection must register agtype adapters so cypher queries return + real Vertex/Edge/Path objects, and to_dict() must round-trip those + parser-produced objects through json.dumps().""" + print("\n-------------------------------------------------------") + print("Test 8: configure_connection + to_dict end-to-end.....") + print("-------------------------------------------------------\n") + + import psycopg + + from age import configure_connection + + dsn = "host={host} port={port} dbname={database} user={user} password={password}".format( + **vars(self.args) + ) + # Deliberately bypass age.connect(): the whole point of + # configure_connection() is to enable AGE on a caller-managed + # connection (e.g. one obtained from psycopg_pool.ConnectionPool). + raw = psycopg.connect(dsn) + try: + configure_connection(raw, graph_name=self.args.graphName, load=True) + + graph = self.args.graphName + with raw.cursor() as cur: + cur.execute( + f"SELECT * FROM cypher('{graph}', $$ " + "CREATE (a:Person {name: 'Alice'})-[r:KNOWS {since: 2020}]->(b:Person {name: 'Bob'}) " + "RETURN a, r, b " + "$$) AS (a agtype, r agtype, b agtype);" + ) + row = cur.fetchone() + raw.commit() + + v_a, e, v_b = row + self.assertIsInstance(v_a, Vertex) + self.assertIsInstance(e, Edge) + self.assertIsInstance(v_b, Vertex) + self.assertEqual(v_a["name"], "Alice") + self.assertEqual(v_b["name"], "Bob") + self.assertEqual(e["since"], 2020) + + payload = { + "start": v_a.to_dict(), + "edge": e.to_dict(), + "end": v_b.to_dict(), + } + serialised = json.loads(json.dumps(payload)) + self.assertEqual(serialised["start"]["label"], "Person") + self.assertEqual(serialised["edge"]["label"], "KNOWS") + self.assertEqual(serialised["edge"]["start_id"], v_a.id) + self.assertEqual(serialised["edge"]["end_id"], v_b.id) + self.assertEqual(serialised["start"]["properties"]["name"], "Alice") + + with raw.cursor() as cur: + cur.execute( + f"SELECT * FROM cypher('{graph}', $$ " + "MATCH p=(:Person)-[:KNOWS]->(:Person) RETURN p " + "$$) AS (p agtype);" + ) + path = cur.fetchone()[0] + + self.assertIsInstance(path, Path) + path_dict = json.loads(json.dumps(path.to_dict())) + self.assertEqual(len(path_dict), 3) + self.assertEqual(path_dict[0]["label"], "Person") + self.assertEqual(path_dict[1]["label"], "KNOWS") + self.assertEqual(path_dict[2]["label"], "Person") + print("\nTest 8 Successful....") + finally: + raw.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -797,5 +869,6 @@ def testSerialization(self): suite.addTest(TestAgeBasic("testMultipleEdges")) suite.addTest(TestAgeBasic("testCollect")) suite.addTest(TestAgeBasic("testSerialization")) + suite.addTest(TestAgeBasic("testConfigureConnection")) TestAgeBasic.args = args unittest.TextTestRunner().run(suite)