From 1b9c4417bfc4bfa2928d1c8f5cc494080a73d236 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:16:03 +0000 Subject: [PATCH 01/38] Fix __roi_query: check roi.end instead of roi.begin for upper-bound-only case The elif branch handling an upper-bound-only ROI dimension was checking roi.begin[dim] (always False at that point) instead of roi.end[dim], causing the upper-bound constraint to be silently dropped. Co-Authored-By: Claude Opus 4.6 --- funlib/persistence/graphs/sql_graph_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index cec7381..34efa05 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -691,7 +691,7 @@ def __roi_query(self, roi: Roi) -> str: ) elif roi.begin[dim] is not None: query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" - elif roi.begin[dim] is not None: + elif roi.end[dim] is not None: query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" else: query = query[:-5] From d99372a85f2940ad2539f80bd20d0dcbc0f39269 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:17:45 +0000 Subject: [PATCH 02/38] Fix __roi_query: use half-open interval [begin, end) instead of BETWEEN Roi uses half-open intervals where end is exclusive. SQL BETWEEN is inclusive on both ends, so nodes exactly at roi.end were incorrectly included. Replace with >= begin AND < end. Co-Authored-By: Claude Opus 4.6 --- funlib/persistence/graphs/sql_graph_database.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 34efa05..eacc702 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -687,7 +687,8 @@ def __roi_query(self, roi: Roi) -> str: query += " AND " if roi.begin[dim] is not None and roi.end[dim] is not None: query += ( - f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}" + f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" + f" AND {pos_attr}[{dim + 1}]<{roi.end[dim]}" ) elif roi.begin[dim] is not None: query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" From bb33cb15fb83110584311513e6d26f00122211c3 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:18:48 +0000 Subject: [PATCH 03/38] Fix read_nodes: remove duplicate attr_filter application The __attr_query() call already generates WHERE conditions for all attr_filter entries. The subsequent for-loop over attr_filter appended the same conditions again, producing redundant SQL like "WHERE foo=1 AND foo=1". Co-Authored-By: Claude Opus 4.6 --- funlib/persistence/graphs/sql_graph_database.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index eacc702..b1f7f61 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -334,10 +334,6 @@ def read_nodes( ) ) - attr_filter = attr_filter if attr_filter is not None else {} - for k, v in attr_filter.items(): - select_statement += f" AND {k}={self.__convert_to_sql(v)}" - nodes = [ self._columns_to_node_attrs( {key: val for key, val in zip(read_columns, values)}, read_attrs From d8fe16b54f321c18ecde20a736e0b9152ee6b525 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:19:54 +0000 Subject: [PATCH 04/38] Fix read_edges: remove duplicate attr_filter application MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same issue as read_nodes — __attr_query() already generates the full filter clause, but the subsequent for-loop re-appended identical conditions. Co-Authored-By: Claude Opus 4.6 --- funlib/persistence/graphs/sql_graph_database.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index b1f7f61..5ea2128 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -392,9 +392,6 @@ def read_edges( edge_attrs = endpoint_names + ( # type: ignore list(self.edge_attrs.keys()) if read_attrs is None else read_attrs ) - attr_filter = attr_filter if attr_filter is not None else {} - for k, v in attr_filter.items(): - select_statement += f" AND {k}={self.__convert_to_sql(v)}" edges = [ { From d0b2ddecd34b5f888b6b4e6b1011b490539bb4b2 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:32:31 +0000 Subject: [PATCH 05/38] Fix write_nodes to pass through caller's fail_if_exists value write_nodes was hardcoding fail_if_exists=True in the _insert_query call, silently ignoring the caller's parameter. Duplicate node inserts with fail_if_exists=False would crash instead of being ignored. Add test_graph_duplicate_insert_behavior to verify both flags work correctly for nodes and edges. Co-Authored-By: Claude Opus 4.6 --- .../persistence/graphs/sql_graph_database.py | 694 ++++++++++++++++++ tests/test_graph.py | 35 + 2 files changed, 729 insertions(+) create mode 100644 src/funlib/persistence/graphs/sql_graph_database.py diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py new file mode 100644 index 0000000..a45bd93 --- /dev/null +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -0,0 +1,694 @@ +import logging +from abc import abstractmethod +from typing import Any, Iterable, Optional + +from networkx import DiGraph, Graph +from networkx.classes.reportviews import NodeView, OutEdgeView + +from funlib.geometry import Coordinate, Roi + +from ..types import Vec, type_to_str +from .graph_database import AttributeType, GraphDataBase + +logger = logging.getLogger(__name__) + + +class SQLGraphDataBase(GraphDataBase): + """Base class for SQL-based graph databases. + + Nodes must have a position attribute (set via argument + ``position_attribute``), which will be used for geometric slicing (see + ``__getitem__`` and ``read_graph``). + + Arguments: + + mode (``string``): + + Any of ``r`` (read-only), ``r+`` (read and allow modifications), + or ``w`` (create new database, overwrite if exists). + + position_attribute (``string``): + + The node attribute that contains position information. This will be + used for slicing subgraphs via ``__getitem__``. + + directed (``bool``): + + True if the graph is directed, false otherwise. If None, attempts + to read value from existing database. If not found, defaults to + false. + + nodes_table (``string``): + + The name of the nodes table. Defaults to ``nodes``. + + edges_table (``string``): + + The name of the edges table. Defaults to ``edges``. + + endpoint_names (``list`` or ``tuple`` with two elements): + + What names to use for the columns storing the start and end of an + edge. Default is ['u', 'v']. + + node_attrs (``list`` of ``str`` or None): + + The custom attributes to store on each node. + + edge_attrs (``list`` of ``str`` or None): + + The custom attributes to store on each edge. + """ + + read_modes = ["r", "r+"] + write_modes = ["r+", "w"] + create_modes = ["w"] + valid_modes = ["r", "r+", "w"] + + _node_attrs: Optional[dict[str, AttributeType]] = None + _edge_attrs: Optional[dict[str, AttributeType]] = None + + def __init__( + self, + mode: str = "r+", + position_attribute: Optional[str] = None, + directed: Optional[bool] = None, + total_roi: Optional[Roi] = None, + nodes_table: Optional[str] = None, + edges_table: Optional[str] = None, + endpoint_names: Optional[list[str]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, + ): + assert mode in self.valid_modes, ( + f"Mode '{mode}' not in allowed modes {self.valid_modes}" + ) + self.mode = mode + + if mode in self.read_modes: + self.position_attribute = position_attribute + self.directed = directed + self.total_roi = total_roi + self.nodes_table_name = nodes_table + self.edges_table_name = edges_table + self.endpoint_names = endpoint_names + self._node_attrs = node_attrs + self._edge_attrs = edge_attrs + self.ndims = None # to be read from metadata + + metadata = self._read_metadata() + if metadata is None: + raise RuntimeError("metadata does not exist, can't open in read mode") + self.__load_metadata(metadata) + + if mode in self.create_modes: + # this is where we populate default values for the DB creation + + assert node_attrs is not None, ( + "For DB creation (mode 'w'), node_attrs is a required " + "argument and needs to contain at least the type definition " + "for the position attribute" + ) + + def get(value, default): + return value if value is not None else default + + self.position_attribute = get(position_attribute, "position") + + assert self.position_attribute in node_attrs, ( + "No type information for position attribute " + f"'{self.position_attribute}' in 'node_attrs'" + ) + + position_type = node_attrs[self.position_attribute] + if isinstance(position_type, Vec): + self.ndims = position_type.size + assert self.ndims > 1, ( + "Don't use Vecs of size 1 for the position, use the " + "scalar type directly instead (i.e., 'float' instead of " + "'Vec(float, 1)'." + ) + # if ndims == 1, we know that we have a single scalar now + else: + self.ndims = 1 + + self.directed = get(directed, False) + self.total_roi = get( + total_roi, Roi((None,) * self.ndims, (None,) * self.ndims) + ) + self.nodes_table_name = get(nodes_table, "nodes") + self.edges_table_name = get(edges_table, "edges") + self.endpoint_names = get(endpoint_names, ["u", "v"]) + self._node_attrs = node_attrs # no default, needs to be given + self._edge_attrs = get(edge_attrs, {}) + + # delete previous DB, if exists + self._drop_tables() + + # create new DB + self._create_tables() + + # store metadata + metadata = self.__create_metadata() + self._store_metadata(metadata) + + @abstractmethod + def _drop_edges(self) -> None: + pass + + @abstractmethod + def _drop_tables(self) -> None: + pass + + @abstractmethod + def _create_tables(self) -> None: + pass + + @abstractmethod + def _store_metadata(self, metadata) -> None: + pass + + @abstractmethod + def _read_metadata(self) -> Optional[dict[str, Any]]: + pass + + @abstractmethod + def _select_query(self, query) -> Iterable[Any]: + pass + + @abstractmethod + def _insert_query( + self, table, columns, values, fail_if_exists=False, commit=True + ) -> None: + pass + + @abstractmethod + def _update_query(self, query, commit=True) -> None: + pass + + @abstractmethod + def _commit(self) -> None: + pass + + def _node_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_node_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + + def _edge_attrs_to_columns(self, attrs): + # default: each attribute maps to its own column + return attrs + + def _columns_to_edge_attrs(self, columns, attrs): + # default: each column maps to one attribute + return columns + + def read_graph( + self, + roi: Optional[Roi] = None, + read_edges: bool = True, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + nodes_filter: Optional[dict[str, Any]] = None, + edges_filter: Optional[dict[str, Any]] = None, + ) -> Graph: + graph: Graph + if self.directed: + graph = DiGraph() + else: + graph = Graph() + + nodes = self.read_nodes( + roi, + read_attrs=node_attrs, + attr_filter=nodes_filter, + ) + node_list = [(n["id"], self.__remove_keys(n, ["id"])) for n in nodes] + graph.add_nodes_from(node_list) + + if read_edges: + edges = self.read_edges( + roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter + ) + u, v = self.endpoint_names # type: ignore + try: + edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges] + except KeyError as e: + raise ValueError(edges[:5]) from e + graph.add_edges_from(edge_list) + return graph + + def write_attrs( + self, + graph: Graph, + roi: Optional[Roi] = None, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + ) -> None: + self.update_nodes( + nodes=graph.nodes, + roi=roi, + attributes=node_attrs, + ) + self.update_edges( + nodes=graph.nodes, + edges=graph.edges, + roi=roi, + attributes=edge_attrs, + ) + + def write_graph( + self, + graph: Graph, + roi: Optional[Roi] = None, + write_nodes: bool = True, + write_edges: bool = True, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + fail_if_exists: bool = False, + delete: bool = False, + ) -> None: + if write_nodes: + self.write_nodes( + graph.nodes, + roi, + attributes=node_attrs, + fail_if_exists=fail_if_exists, + delete=delete, + ) + if write_edges: + self.write_edges( + graph.nodes, + graph.edges, + roi, + attributes=edge_attrs, + fail_if_exists=fail_if_exists, + delete=delete, + ) + + @property + def node_attrs(self) -> dict[str, AttributeType]: + return self._node_attrs if self._node_attrs is not None else {} + + @node_attrs.setter + def node_attrs(self, value: dict[str, AttributeType]) -> None: + self._node_attrs = value + + @property + def edge_attrs(self) -> dict[str, AttributeType]: + return self._edge_attrs if self._edge_attrs is not None else {} + + @edge_attrs.setter + def edge_attrs(self, value: dict[str, AttributeType]) -> None: + self._edge_attrs = value + + def read_nodes( + self, + roi: Optional[Roi] = None, + attr_filter: Optional[dict[str, Any]] = None, + read_attrs: Optional[list[str]] = None, + join_collection: Optional[str] = None, + ) -> list[dict[str, Any]]: + """Return a list of nodes within roi.""" + + # attributes to read + read_attrs = list(self.node_attrs.keys()) if read_attrs is None else read_attrs + + # corresponding column naes + read_columns = ["id"] + self._node_attrs_to_columns(read_attrs) + read_attrs = ["id"] + read_attrs + read_attrs_query = ", ".join(read_columns) + + logger.debug("Reading nodes in roi %s" % roi) + select_statement = ( + f"SELECT {read_attrs_query} FROM {self.nodes_table_name} " + + (self.__roi_query(roi) if roi is not None else "") + + ( + f" {'WHERE' if roi is None else 'AND'} " + + self.__attr_query(attr_filter) + if attr_filter is not None and len(attr_filter) > 0 + else "" + ) + ) + + nodes = [ + self._columns_to_node_attrs( + {key: val for key, val in zip(read_columns, values)}, read_attrs + ) + for values in self._select_query(select_statement) + ] + + return nodes + + def num_nodes(self, roi: Roi) -> int: + """Return the number of nodes in the roi.""" + + # TODO: can be made more efficient + return len(self.read_nodes(roi)) + + def has_edges(self, roi: Roi) -> bool: + """Returns true if there is at least one edge in the roi.""" + + # TODO: can be made more efficient + return len(self.read_edges(roi)) > 0 + + def read_edges( + self, + roi: Optional[Roi] = None, + nodes: Optional[list[dict[str, Any]]] = None, + attr_filter: Optional[dict[str, Any]] = None, + read_attrs: Optional[list[str]] = None, + ) -> list[dict[str, Any]]: + """Returns a list of edges within roi.""" + + if nodes is None: + nodes = self.read_nodes(roi) + + if len(nodes) == 0: + return [] + + endpoint_names = self.endpoint_names + assert endpoint_names is not None + + node_ids = ", ".join([str(node["id"]) for node in nodes]) + node_condition = f"{endpoint_names[0]} IN ({node_ids})" # type: ignore + + logger.debug("Reading nodes in roi %s" % roi) + # TODO: AND vs OR here + desired_columns = ", ".join(endpoint_names + list(self.edge_attrs.keys())) # type: ignore + select_statement = ( + f"SELECT {desired_columns} FROM {self.edges_table_name} WHERE " + + node_condition + + ( + " AND " + self.__attr_query(attr_filter) + if attr_filter is not None and len(attr_filter) > 0 + else "" + ) + ) + + edge_attrs = endpoint_names + ( # type: ignore + list(self.edge_attrs.keys()) if read_attrs is None else read_attrs + ) + + edges = [ + { + key: val + for key, val in zip( + endpoint_names + list(self.edge_attrs.keys()), + values, # type: ignore + ) + if key in edge_attrs + } + for values in self._select_query(select_statement) + ] + + return edges + + def write_edges( + self, + nodes, + edges, + roi=None, + attributes=None, + fail_if_exists=False, + delete=False, + ): + if delete: + raise NotImplementedError("Delete not implemented for SQL graph database") + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + columns = self.endpoint_names + ( + list(self.edge_attrs.keys()) if attributes is None else attributes + ) + + if roi is None: + roi = Roi( + (None,) * self.ndims, + (None,) * self.ndims, + ) + + values = [] + for (u, v), data in edges.items(): + if not self.directed: + u, v = min(u, v), max(u, v) + pos_u = self.__get_node_pos(nodes[u]) + + if pos_u is None or not roi.contains(pos_u): + logger.debug( + ( + f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," + + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" + ).format(u, v, data, roi) + ) + continue + + edge_attributes = [u, v] + [ + data.get(attr, None) + for attr in (self.edge_attrs if attributes is None else attributes) + ] + values.append(edge_attributes) + + if len(values) == 0: + logger.debug("No edges to insert in %s", roi) + return + + self._insert_query( + self.edges_table_name, columns, values, fail_if_exists=fail_if_exists + ) + + def update_edges( + self, + nodes: NodeView, + edges: OutEdgeView, + roi=None, + attributes=None, + ): + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + logger.debug("Writing nodes in %s", roi) + + attrs = attributes if attributes is not None else [] + + for u, v, data in edges(data=True): + if not self.directed: + u, v = min(u, v), max(u, v) + if roi is not None: + pos_u = self.__get_node_pos(nodes[u]) + + if not roi.contains(pos_u): + logger.debug( + ( + f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," # type: ignore + + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" # type: ignore + ).format(u, v, data, roi) + ) + continue + + values = [data.get(attr) for attr in attrs] + setters = [f"{k}={v}" for k, v in zip(attrs, values)] + update_statement = ( + f"UPDATE {self.edges_table_name} SET " + f"{', '.join(setters)} WHERE " + f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" # type: ignore + ) + + self._update_query(update_statement, commit=False) + + self._commit() + + def write_nodes( + self, + nodes: NodeView, + roi=None, + attributes=None, + fail_if_exists=False, + delete=False, + ): + if delete: + raise NotImplementedError("Delete not implemented for SQL graph database") + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + logger.debug("Writing nodes in %s", roi) + + attrs = attributes if attributes is not None else list(self.node_attrs.keys()) + columns = ("id",) + tuple(attrs) + + values = [] + for node_id, data in nodes.items(): + data = data.copy() + pos = self.__get_node_pos(data) + if roi is not None and not roi.contains(pos): + continue + values.append([node_id] + [data.get(attr, None) for attr in attrs]) + + if len(values) == 0: + logger.debug("No nodes to insert in %s", roi) + return + + self._insert_query( + self.nodes_table_name, columns, values, fail_if_exists=fail_if_exists + ) + + def update_nodes( + self, + nodes: NodeView, + roi=None, + attributes=None, + ): + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + logger.debug("Writing nodes in %s", roi) + + attrs = attributes if attributes is not None else [] + + for node, data in nodes(data=True): + if roi is not None: + pos_u = self.__get_node_pos(data) + + if not roi.contains(pos_u): + logger.debug( + ("Skipping node {} because it is not in roi {}").format( + node, roi + ) + ) + continue + + values = [data.get(attr) for attr in attrs] + setters = [ + f"{k} = {self.__convert_to_sql(v)}" for k, v in zip(attrs, values) + ] + update_statement = ( + f"UPDATE {self.nodes_table_name} SET " + f"{', '.join(setters)} WHERE " + f"id={node}" + ) + + self._update_query(update_statement, commit=False) + + self._commit() + + def __create_metadata(self): + """Sets the metadata in the meta collection to the provided values""" + + metadata = { + "position_attribute": self.position_attribute, + "directed": self.directed, + "total_roi_offset": self.total_roi.offset, + "total_roi_shape": self.total_roi.shape, + "nodes_table_name": self.nodes_table_name, + "edges_table_name": self.edges_table_name, + "endpoint_names": self.endpoint_names, + "node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()}, + "edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()}, + "ndims": self.ndims, + } + + return metadata + + def __load_metadata(self, metadata): + """Load the provided metadata into this object's attributes, check if + it is consistent with already populated fields.""" + + # simple attributes + for attr_name in [ + "position_attribute", + "directed", + "nodes_table_name", + "edges_table_name", + "endpoint_names", + "ndims", + ]: + if getattr(self, attr_name) is None: + setattr(self, attr_name, metadata[attr_name]) + else: + value = getattr(self, attr_name) + assert value == metadata[attr_name], ( + f"Attribute {attr_name} is already set to {value} for this " + "object, but disagrees with the stored metadata value of " + f"{metadata[attr_name]}" + ) + + # special attributes + + total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) + if self.total_roi is None: + self.total_roi = total_roi + else: + assert self.total_roi == total_roi, ( + f"Attribute total_roi is already set to {self.total_roi} for " + "this object, but disagrees with the stored metadata value of " + f"{total_roi}" + ) + + node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} + edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} + if self._node_attrs is None: + self.node_attrs = node_attrs + else: + assert self.node_attrs == node_attrs, ( + f"Attribute node_attrs is already set to {self.node_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{node_attrs}" + ) + if self._edge_attrs is None: + self.edge_attrs = edge_attrs + else: + assert self.edge_attrs == edge_attrs, ( + f"Attribute edge_attrs is already set to {self.edge_attrs} for " + "this object, but disagrees with the stored metadata value of " + f"{edge_attrs}" + ) + + def __remove_keys(self, dictionary, keys): + """Removes given keys from dictionary.""" + + return {k: v for k, v in dictionary.items() if k not in keys} + + def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]: + try: + return Coordinate(n[self.position_attribute]) # type: ignore + except KeyError: + return None + + def __convert_to_sql(self, x: Any) -> str: + if isinstance(x, str): + return f"'{x}'" + elif x is None: + return "null" + elif isinstance(x, bool): + return f"{x}".lower() + else: + return str(x) + + def __attr_query(self, attrs: dict[str, Any]) -> str: + query = "" + for attr, value in attrs.items(): + query += f"{attr}={self.__convert_to_sql(value)} AND " + query = query[:-5] + return query + + def __roi_query(self, roi: Roi) -> str: + query = "WHERE " + pos_attr = self.position_attribute + for dim in range(self.ndims): # type: ignore + if dim > 0: + query += " AND " + if roi.begin[dim] is not None and roi.end[dim] is not None: + query += ( + f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" + f" AND {pos_attr}[{dim + 1}]<{roi.end[dim]}" + ) + elif roi.begin[dim] is not None: + query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" + elif roi.end[dim] is not None: + query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" + else: + query = query[:-5] + return query diff --git a/tests/test_graph.py b/tests/test_graph.py index 17c4765..f1000ea 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -289,6 +289,41 @@ def test_graph_fail_if_exists(provider_factory): graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True) +def test_graph_duplicate_insert_behavior(provider_factory): + """Test that fail_if_exists controls whether duplicate inserts raise.""" + graph_provider = provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool}, + edge_attrs={"selected": bool}, + ) + roi = Roi((0, 0, 0), (10, 10, 10)) + graph = graph_provider[roi] + + graph.add_node(2, position=(2, 2, 2), selected=True) + graph.add_node(42, position=(1, 1, 1), selected=False) + graph.add_edge(2, 42, selected=True) + + # Initial write + graph_provider.write_nodes(graph.nodes()) + graph_provider.write_edges(graph.nodes(), graph.edges()) + + # fail_if_exists=True should raise on duplicate nodes and edges + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True) + + # fail_if_exists=False should silently ignore duplicates + graph_provider.write_nodes(graph.nodes(), fail_if_exists=False) + graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=False) + + # Verify the original data is still intact + graph_provider = provider_factory("r") + result = graph_provider.read_graph(roi) + assert set(result.nodes()) == {2, 42} + assert len(result.edges()) == 1 + + def test_graph_fail_if_not_exists(provider_factory): graph_provider = provider_factory( "w", From 17f01b63201cfe4bafc3519a4eb02f88a33ca92d Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:38:28 +0000 Subject: [PATCH 06/38] fix fail_if_exists flag on write_nodes --- funlib/persistence/graphs/sql_graph_database.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/funlib/persistence/graphs/sql_graph_database.py b/funlib/persistence/graphs/sql_graph_database.py index 5ea2128..a45bd93 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/funlib/persistence/graphs/sql_graph_database.py @@ -531,7 +531,9 @@ def write_nodes( logger.debug("No nodes to insert in %s", roi) return - self._insert_query(self.nodes_table_name, columns, values, fail_if_exists=True) + self._insert_query( + self.nodes_table_name, columns, values, fail_if_exists=fail_if_exists + ) def update_nodes( self, From 1f4fa7b1f816a1a957e0460da50be7f295366f54 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:38:47 +0000 Subject: [PATCH 07/38] remove duplicate file --- .../persistence/graphs/sql_graph_database.py | 694 ------------------ 1 file changed, 694 deletions(-) delete mode 100644 src/funlib/persistence/graphs/sql_graph_database.py diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py deleted file mode 100644 index a45bd93..0000000 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ /dev/null @@ -1,694 +0,0 @@ -import logging -from abc import abstractmethod -from typing import Any, Iterable, Optional - -from networkx import DiGraph, Graph -from networkx.classes.reportviews import NodeView, OutEdgeView - -from funlib.geometry import Coordinate, Roi - -from ..types import Vec, type_to_str -from .graph_database import AttributeType, GraphDataBase - -logger = logging.getLogger(__name__) - - -class SQLGraphDataBase(GraphDataBase): - """Base class for SQL-based graph databases. - - Nodes must have a position attribute (set via argument - ``position_attribute``), which will be used for geometric slicing (see - ``__getitem__`` and ``read_graph``). - - Arguments: - - mode (``string``): - - Any of ``r`` (read-only), ``r+`` (read and allow modifications), - or ``w`` (create new database, overwrite if exists). - - position_attribute (``string``): - - The node attribute that contains position information. This will be - used for slicing subgraphs via ``__getitem__``. - - directed (``bool``): - - True if the graph is directed, false otherwise. If None, attempts - to read value from existing database. If not found, defaults to - false. - - nodes_table (``string``): - - The name of the nodes table. Defaults to ``nodes``. - - edges_table (``string``): - - The name of the edges table. Defaults to ``edges``. - - endpoint_names (``list`` or ``tuple`` with two elements): - - What names to use for the columns storing the start and end of an - edge. Default is ['u', 'v']. - - node_attrs (``list`` of ``str`` or None): - - The custom attributes to store on each node. - - edge_attrs (``list`` of ``str`` or None): - - The custom attributes to store on each edge. - """ - - read_modes = ["r", "r+"] - write_modes = ["r+", "w"] - create_modes = ["w"] - valid_modes = ["r", "r+", "w"] - - _node_attrs: Optional[dict[str, AttributeType]] = None - _edge_attrs: Optional[dict[str, AttributeType]] = None - - def __init__( - self, - mode: str = "r+", - position_attribute: Optional[str] = None, - directed: Optional[bool] = None, - total_roi: Optional[Roi] = None, - nodes_table: Optional[str] = None, - edges_table: Optional[str] = None, - endpoint_names: Optional[list[str]] = None, - node_attrs: Optional[dict[str, AttributeType]] = None, - edge_attrs: Optional[dict[str, AttributeType]] = None, - ): - assert mode in self.valid_modes, ( - f"Mode '{mode}' not in allowed modes {self.valid_modes}" - ) - self.mode = mode - - if mode in self.read_modes: - self.position_attribute = position_attribute - self.directed = directed - self.total_roi = total_roi - self.nodes_table_name = nodes_table - self.edges_table_name = edges_table - self.endpoint_names = endpoint_names - self._node_attrs = node_attrs - self._edge_attrs = edge_attrs - self.ndims = None # to be read from metadata - - metadata = self._read_metadata() - if metadata is None: - raise RuntimeError("metadata does not exist, can't open in read mode") - self.__load_metadata(metadata) - - if mode in self.create_modes: - # this is where we populate default values for the DB creation - - assert node_attrs is not None, ( - "For DB creation (mode 'w'), node_attrs is a required " - "argument and needs to contain at least the type definition " - "for the position attribute" - ) - - def get(value, default): - return value if value is not None else default - - self.position_attribute = get(position_attribute, "position") - - assert self.position_attribute in node_attrs, ( - "No type information for position attribute " - f"'{self.position_attribute}' in 'node_attrs'" - ) - - position_type = node_attrs[self.position_attribute] - if isinstance(position_type, Vec): - self.ndims = position_type.size - assert self.ndims > 1, ( - "Don't use Vecs of size 1 for the position, use the " - "scalar type directly instead (i.e., 'float' instead of " - "'Vec(float, 1)'." - ) - # if ndims == 1, we know that we have a single scalar now - else: - self.ndims = 1 - - self.directed = get(directed, False) - self.total_roi = get( - total_roi, Roi((None,) * self.ndims, (None,) * self.ndims) - ) - self.nodes_table_name = get(nodes_table, "nodes") - self.edges_table_name = get(edges_table, "edges") - self.endpoint_names = get(endpoint_names, ["u", "v"]) - self._node_attrs = node_attrs # no default, needs to be given - self._edge_attrs = get(edge_attrs, {}) - - # delete previous DB, if exists - self._drop_tables() - - # create new DB - self._create_tables() - - # store metadata - metadata = self.__create_metadata() - self._store_metadata(metadata) - - @abstractmethod - def _drop_edges(self) -> None: - pass - - @abstractmethod - def _drop_tables(self) -> None: - pass - - @abstractmethod - def _create_tables(self) -> None: - pass - - @abstractmethod - def _store_metadata(self, metadata) -> None: - pass - - @abstractmethod - def _read_metadata(self) -> Optional[dict[str, Any]]: - pass - - @abstractmethod - def _select_query(self, query) -> Iterable[Any]: - pass - - @abstractmethod - def _insert_query( - self, table, columns, values, fail_if_exists=False, commit=True - ) -> None: - pass - - @abstractmethod - def _update_query(self, query, commit=True) -> None: - pass - - @abstractmethod - def _commit(self) -> None: - pass - - def _node_attrs_to_columns(self, attrs): - # default: each attribute maps to its own column - return attrs - - def _columns_to_node_attrs(self, columns, attrs): - # default: each column maps to one attribute - return columns - - def _edge_attrs_to_columns(self, attrs): - # default: each attribute maps to its own column - return attrs - - def _columns_to_edge_attrs(self, columns, attrs): - # default: each column maps to one attribute - return columns - - def read_graph( - self, - roi: Optional[Roi] = None, - read_edges: bool = True, - node_attrs: Optional[list[str]] = None, - edge_attrs: Optional[list[str]] = None, - nodes_filter: Optional[dict[str, Any]] = None, - edges_filter: Optional[dict[str, Any]] = None, - ) -> Graph: - graph: Graph - if self.directed: - graph = DiGraph() - else: - graph = Graph() - - nodes = self.read_nodes( - roi, - read_attrs=node_attrs, - attr_filter=nodes_filter, - ) - node_list = [(n["id"], self.__remove_keys(n, ["id"])) for n in nodes] - graph.add_nodes_from(node_list) - - if read_edges: - edges = self.read_edges( - roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter - ) - u, v = self.endpoint_names # type: ignore - try: - edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges] - except KeyError as e: - raise ValueError(edges[:5]) from e - graph.add_edges_from(edge_list) - return graph - - def write_attrs( - self, - graph: Graph, - roi: Optional[Roi] = None, - node_attrs: Optional[list[str]] = None, - edge_attrs: Optional[list[str]] = None, - ) -> None: - self.update_nodes( - nodes=graph.nodes, - roi=roi, - attributes=node_attrs, - ) - self.update_edges( - nodes=graph.nodes, - edges=graph.edges, - roi=roi, - attributes=edge_attrs, - ) - - def write_graph( - self, - graph: Graph, - roi: Optional[Roi] = None, - write_nodes: bool = True, - write_edges: bool = True, - node_attrs: Optional[list[str]] = None, - edge_attrs: Optional[list[str]] = None, - fail_if_exists: bool = False, - delete: bool = False, - ) -> None: - if write_nodes: - self.write_nodes( - graph.nodes, - roi, - attributes=node_attrs, - fail_if_exists=fail_if_exists, - delete=delete, - ) - if write_edges: - self.write_edges( - graph.nodes, - graph.edges, - roi, - attributes=edge_attrs, - fail_if_exists=fail_if_exists, - delete=delete, - ) - - @property - def node_attrs(self) -> dict[str, AttributeType]: - return self._node_attrs if self._node_attrs is not None else {} - - @node_attrs.setter - def node_attrs(self, value: dict[str, AttributeType]) -> None: - self._node_attrs = value - - @property - def edge_attrs(self) -> dict[str, AttributeType]: - return self._edge_attrs if self._edge_attrs is not None else {} - - @edge_attrs.setter - def edge_attrs(self, value: dict[str, AttributeType]) -> None: - self._edge_attrs = value - - def read_nodes( - self, - roi: Optional[Roi] = None, - attr_filter: Optional[dict[str, Any]] = None, - read_attrs: Optional[list[str]] = None, - join_collection: Optional[str] = None, - ) -> list[dict[str, Any]]: - """Return a list of nodes within roi.""" - - # attributes to read - read_attrs = list(self.node_attrs.keys()) if read_attrs is None else read_attrs - - # corresponding column naes - read_columns = ["id"] + self._node_attrs_to_columns(read_attrs) - read_attrs = ["id"] + read_attrs - read_attrs_query = ", ".join(read_columns) - - logger.debug("Reading nodes in roi %s" % roi) - select_statement = ( - f"SELECT {read_attrs_query} FROM {self.nodes_table_name} " - + (self.__roi_query(roi) if roi is not None else "") - + ( - f" {'WHERE' if roi is None else 'AND'} " - + self.__attr_query(attr_filter) - if attr_filter is not None and len(attr_filter) > 0 - else "" - ) - ) - - nodes = [ - self._columns_to_node_attrs( - {key: val for key, val in zip(read_columns, values)}, read_attrs - ) - for values in self._select_query(select_statement) - ] - - return nodes - - def num_nodes(self, roi: Roi) -> int: - """Return the number of nodes in the roi.""" - - # TODO: can be made more efficient - return len(self.read_nodes(roi)) - - def has_edges(self, roi: Roi) -> bool: - """Returns true if there is at least one edge in the roi.""" - - # TODO: can be made more efficient - return len(self.read_edges(roi)) > 0 - - def read_edges( - self, - roi: Optional[Roi] = None, - nodes: Optional[list[dict[str, Any]]] = None, - attr_filter: Optional[dict[str, Any]] = None, - read_attrs: Optional[list[str]] = None, - ) -> list[dict[str, Any]]: - """Returns a list of edges within roi.""" - - if nodes is None: - nodes = self.read_nodes(roi) - - if len(nodes) == 0: - return [] - - endpoint_names = self.endpoint_names - assert endpoint_names is not None - - node_ids = ", ".join([str(node["id"]) for node in nodes]) - node_condition = f"{endpoint_names[0]} IN ({node_ids})" # type: ignore - - logger.debug("Reading nodes in roi %s" % roi) - # TODO: AND vs OR here - desired_columns = ", ".join(endpoint_names + list(self.edge_attrs.keys())) # type: ignore - select_statement = ( - f"SELECT {desired_columns} FROM {self.edges_table_name} WHERE " - + node_condition - + ( - " AND " + self.__attr_query(attr_filter) - if attr_filter is not None and len(attr_filter) > 0 - else "" - ) - ) - - edge_attrs = endpoint_names + ( # type: ignore - list(self.edge_attrs.keys()) if read_attrs is None else read_attrs - ) - - edges = [ - { - key: val - for key, val in zip( - endpoint_names + list(self.edge_attrs.keys()), - values, # type: ignore - ) - if key in edge_attrs - } - for values in self._select_query(select_statement) - ] - - return edges - - def write_edges( - self, - nodes, - edges, - roi=None, - attributes=None, - fail_if_exists=False, - delete=False, - ): - if delete: - raise NotImplementedError("Delete not implemented for SQL graph database") - if self.mode == "r": - raise RuntimeError("Trying to write to read-only DB") - - columns = self.endpoint_names + ( - list(self.edge_attrs.keys()) if attributes is None else attributes - ) - - if roi is None: - roi = Roi( - (None,) * self.ndims, - (None,) * self.ndims, - ) - - values = [] - for (u, v), data in edges.items(): - if not self.directed: - u, v = min(u, v), max(u, v) - pos_u = self.__get_node_pos(nodes[u]) - - if pos_u is None or not roi.contains(pos_u): - logger.debug( - ( - f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," - + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" - ).format(u, v, data, roi) - ) - continue - - edge_attributes = [u, v] + [ - data.get(attr, None) - for attr in (self.edge_attrs if attributes is None else attributes) - ] - values.append(edge_attributes) - - if len(values) == 0: - logger.debug("No edges to insert in %s", roi) - return - - self._insert_query( - self.edges_table_name, columns, values, fail_if_exists=fail_if_exists - ) - - def update_edges( - self, - nodes: NodeView, - edges: OutEdgeView, - roi=None, - attributes=None, - ): - if self.mode == "r": - raise RuntimeError("Trying to write to read-only DB") - - logger.debug("Writing nodes in %s", roi) - - attrs = attributes if attributes is not None else [] - - for u, v, data in edges(data=True): - if not self.directed: - u, v = min(u, v), max(u, v) - if roi is not None: - pos_u = self.__get_node_pos(nodes[u]) - - if not roi.contains(pos_u): - logger.debug( - ( - f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," # type: ignore - + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" # type: ignore - ).format(u, v, data, roi) - ) - continue - - values = [data.get(attr) for attr in attrs] - setters = [f"{k}={v}" for k, v in zip(attrs, values)] - update_statement = ( - f"UPDATE {self.edges_table_name} SET " - f"{', '.join(setters)} WHERE " - f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" # type: ignore - ) - - self._update_query(update_statement, commit=False) - - self._commit() - - def write_nodes( - self, - nodes: NodeView, - roi=None, - attributes=None, - fail_if_exists=False, - delete=False, - ): - if delete: - raise NotImplementedError("Delete not implemented for SQL graph database") - if self.mode == "r": - raise RuntimeError("Trying to write to read-only DB") - - logger.debug("Writing nodes in %s", roi) - - attrs = attributes if attributes is not None else list(self.node_attrs.keys()) - columns = ("id",) + tuple(attrs) - - values = [] - for node_id, data in nodes.items(): - data = data.copy() - pos = self.__get_node_pos(data) - if roi is not None and not roi.contains(pos): - continue - values.append([node_id] + [data.get(attr, None) for attr in attrs]) - - if len(values) == 0: - logger.debug("No nodes to insert in %s", roi) - return - - self._insert_query( - self.nodes_table_name, columns, values, fail_if_exists=fail_if_exists - ) - - def update_nodes( - self, - nodes: NodeView, - roi=None, - attributes=None, - ): - if self.mode == "r": - raise RuntimeError("Trying to write to read-only DB") - - logger.debug("Writing nodes in %s", roi) - - attrs = attributes if attributes is not None else [] - - for node, data in nodes(data=True): - if roi is not None: - pos_u = self.__get_node_pos(data) - - if not roi.contains(pos_u): - logger.debug( - ("Skipping node {} because it is not in roi {}").format( - node, roi - ) - ) - continue - - values = [data.get(attr) for attr in attrs] - setters = [ - f"{k} = {self.__convert_to_sql(v)}" for k, v in zip(attrs, values) - ] - update_statement = ( - f"UPDATE {self.nodes_table_name} SET " - f"{', '.join(setters)} WHERE " - f"id={node}" - ) - - self._update_query(update_statement, commit=False) - - self._commit() - - def __create_metadata(self): - """Sets the metadata in the meta collection to the provided values""" - - metadata = { - "position_attribute": self.position_attribute, - "directed": self.directed, - "total_roi_offset": self.total_roi.offset, - "total_roi_shape": self.total_roi.shape, - "nodes_table_name": self.nodes_table_name, - "edges_table_name": self.edges_table_name, - "endpoint_names": self.endpoint_names, - "node_attrs": {k: type_to_str(v) for k, v in self.node_attrs.items()}, - "edge_attrs": {k: type_to_str(v) for k, v in self.edge_attrs.items()}, - "ndims": self.ndims, - } - - return metadata - - def __load_metadata(self, metadata): - """Load the provided metadata into this object's attributes, check if - it is consistent with already populated fields.""" - - # simple attributes - for attr_name in [ - "position_attribute", - "directed", - "nodes_table_name", - "edges_table_name", - "endpoint_names", - "ndims", - ]: - if getattr(self, attr_name) is None: - setattr(self, attr_name, metadata[attr_name]) - else: - value = getattr(self, attr_name) - assert value == metadata[attr_name], ( - f"Attribute {attr_name} is already set to {value} for this " - "object, but disagrees with the stored metadata value of " - f"{metadata[attr_name]}" - ) - - # special attributes - - total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) - if self.total_roi is None: - self.total_roi = total_roi - else: - assert self.total_roi == total_roi, ( - f"Attribute total_roi is already set to {self.total_roi} for " - "this object, but disagrees with the stored metadata value of " - f"{total_roi}" - ) - - node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} - edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} - if self._node_attrs is None: - self.node_attrs = node_attrs - else: - assert self.node_attrs == node_attrs, ( - f"Attribute node_attrs is already set to {self.node_attrs} for " - "this object, but disagrees with the stored metadata value of " - f"{node_attrs}" - ) - if self._edge_attrs is None: - self.edge_attrs = edge_attrs - else: - assert self.edge_attrs == edge_attrs, ( - f"Attribute edge_attrs is already set to {self.edge_attrs} for " - "this object, but disagrees with the stored metadata value of " - f"{edge_attrs}" - ) - - def __remove_keys(self, dictionary, keys): - """Removes given keys from dictionary.""" - - return {k: v for k, v in dictionary.items() if k not in keys} - - def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]: - try: - return Coordinate(n[self.position_attribute]) # type: ignore - except KeyError: - return None - - def __convert_to_sql(self, x: Any) -> str: - if isinstance(x, str): - return f"'{x}'" - elif x is None: - return "null" - elif isinstance(x, bool): - return f"{x}".lower() - else: - return str(x) - - def __attr_query(self, attrs: dict[str, Any]) -> str: - query = "" - for attr, value in attrs.items(): - query += f"{attr}={self.__convert_to_sql(value)} AND " - query = query[:-5] - return query - - def __roi_query(self, roi: Roi) -> str: - query = "WHERE " - pos_attr = self.position_attribute - for dim in range(self.ndims): # type: ignore - if dim > 0: - query += " AND " - if roi.begin[dim] is not None and roi.end[dim] is not None: - query += ( - f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" - f" AND {pos_attr}[{dim + 1}]<{roi.end[dim]}" - ) - elif roi.begin[dim] is not None: - query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" - elif roi.end[dim] is not None: - query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" - else: - query = query[:-5] - return query From 9f635b1856c19332380232c4bcd31e407eb3a3bd Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:39:15 +0000 Subject: [PATCH 08/38] move code into a src directory to modernize this package --- {funlib => src/funlib}/persistence/__init__.py | 0 {funlib => src/funlib}/persistence/arrays/__init__.py | 0 {funlib => src/funlib}/persistence/arrays/array.py | 0 {funlib => src/funlib}/persistence/arrays/datasets.py | 0 {funlib => src/funlib}/persistence/arrays/freezable.py | 0 {funlib => src/funlib}/persistence/arrays/lazy_ops.py | 0 {funlib => src/funlib}/persistence/arrays/metadata.py | 0 {funlib => src/funlib}/persistence/arrays/ome_datasets.py | 0 {funlib => src/funlib}/persistence/arrays/utils.py | 0 {funlib => src/funlib}/persistence/graphs/__init__.py | 0 {funlib => src/funlib}/persistence/graphs/graph_database.py | 0 {funlib => src/funlib}/persistence/graphs/pgsql_graph_database.py | 0 {funlib => src/funlib}/persistence/graphs/sql_graph_database.py | 0 .../funlib}/persistence/graphs/sqlite_graph_database.py | 0 {funlib => src/funlib}/persistence/py.typed | 0 {funlib => src/funlib}/persistence/types.py | 0 16 files changed, 0 insertions(+), 0 deletions(-) rename {funlib => src/funlib}/persistence/__init__.py (100%) rename {funlib => src/funlib}/persistence/arrays/__init__.py (100%) rename {funlib => src/funlib}/persistence/arrays/array.py (100%) rename {funlib => src/funlib}/persistence/arrays/datasets.py (100%) rename {funlib => src/funlib}/persistence/arrays/freezable.py (100%) rename {funlib => src/funlib}/persistence/arrays/lazy_ops.py (100%) rename {funlib => src/funlib}/persistence/arrays/metadata.py (100%) rename {funlib => src/funlib}/persistence/arrays/ome_datasets.py (100%) rename {funlib => src/funlib}/persistence/arrays/utils.py (100%) rename {funlib => src/funlib}/persistence/graphs/__init__.py (100%) rename {funlib => src/funlib}/persistence/graphs/graph_database.py (100%) rename {funlib => src/funlib}/persistence/graphs/pgsql_graph_database.py (100%) rename {funlib => src/funlib}/persistence/graphs/sql_graph_database.py (100%) rename {funlib => src/funlib}/persistence/graphs/sqlite_graph_database.py (100%) rename {funlib => src/funlib}/persistence/py.typed (100%) rename {funlib => src/funlib}/persistence/types.py (100%) diff --git a/funlib/persistence/__init__.py b/src/funlib/persistence/__init__.py similarity index 100% rename from funlib/persistence/__init__.py rename to src/funlib/persistence/__init__.py diff --git a/funlib/persistence/arrays/__init__.py b/src/funlib/persistence/arrays/__init__.py similarity index 100% rename from funlib/persistence/arrays/__init__.py rename to src/funlib/persistence/arrays/__init__.py diff --git a/funlib/persistence/arrays/array.py b/src/funlib/persistence/arrays/array.py similarity index 100% rename from funlib/persistence/arrays/array.py rename to src/funlib/persistence/arrays/array.py diff --git a/funlib/persistence/arrays/datasets.py b/src/funlib/persistence/arrays/datasets.py similarity index 100% rename from funlib/persistence/arrays/datasets.py rename to src/funlib/persistence/arrays/datasets.py diff --git a/funlib/persistence/arrays/freezable.py b/src/funlib/persistence/arrays/freezable.py similarity index 100% rename from funlib/persistence/arrays/freezable.py rename to src/funlib/persistence/arrays/freezable.py diff --git a/funlib/persistence/arrays/lazy_ops.py b/src/funlib/persistence/arrays/lazy_ops.py similarity index 100% rename from funlib/persistence/arrays/lazy_ops.py rename to src/funlib/persistence/arrays/lazy_ops.py diff --git a/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py similarity index 100% rename from funlib/persistence/arrays/metadata.py rename to src/funlib/persistence/arrays/metadata.py diff --git a/funlib/persistence/arrays/ome_datasets.py b/src/funlib/persistence/arrays/ome_datasets.py similarity index 100% rename from funlib/persistence/arrays/ome_datasets.py rename to src/funlib/persistence/arrays/ome_datasets.py diff --git a/funlib/persistence/arrays/utils.py b/src/funlib/persistence/arrays/utils.py similarity index 100% rename from funlib/persistence/arrays/utils.py rename to src/funlib/persistence/arrays/utils.py diff --git a/funlib/persistence/graphs/__init__.py b/src/funlib/persistence/graphs/__init__.py similarity index 100% rename from funlib/persistence/graphs/__init__.py rename to src/funlib/persistence/graphs/__init__.py diff --git a/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py similarity index 100% rename from funlib/persistence/graphs/graph_database.py rename to src/funlib/persistence/graphs/graph_database.py diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py similarity index 100% rename from funlib/persistence/graphs/pgsql_graph_database.py rename to src/funlib/persistence/graphs/pgsql_graph_database.py diff --git a/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py similarity index 100% rename from funlib/persistence/graphs/sql_graph_database.py rename to src/funlib/persistence/graphs/sql_graph_database.py diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py similarity index 100% rename from funlib/persistence/graphs/sqlite_graph_database.py rename to src/funlib/persistence/graphs/sqlite_graph_database.py diff --git a/funlib/persistence/py.typed b/src/funlib/persistence/py.typed similarity index 100% rename from funlib/persistence/py.typed rename to src/funlib/persistence/py.typed diff --git a/funlib/persistence/types.py b/src/funlib/persistence/types.py similarity index 100% rename from funlib/persistence/types.py rename to src/funlib/persistence/types.py From 1e08cc965fb84943febb82383c9f960b67e28bef Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 00:40:01 +0000 Subject: [PATCH 09/38] update git ignore to ignore daisy logs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 90f1bdd..38427b2 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,4 @@ dmypy.json .vscode/ *.sw[pmno] +daisy_logs From 20ce72540d5295d0caa03746eec165cdf3a7a60f Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 16:20:03 +0000 Subject: [PATCH 10/38] ruff formatting and linting --- src/funlib/persistence/arrays/array.py | 3 +-- src/funlib/persistence/arrays/datasets.py | 3 +-- src/funlib/persistence/arrays/metadata.py | 3 +-- src/funlib/persistence/arrays/ome_datasets.py | 3 +-- src/funlib/persistence/graphs/graph_database.py | 3 +-- tests/test_array.py | 3 ++- tests/test_datasets.py | 2 +- tests/test_graph.py | 2 +- tests/test_metadata.py | 2 +- 9 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/funlib/persistence/arrays/array.py b/src/funlib/persistence/arrays/array.py index f19f2ab..13762b6 100644 --- a/src/funlib/persistence/arrays/array.py +++ b/src/funlib/persistence/arrays/array.py @@ -6,9 +6,8 @@ import dask.array as da import numpy as np from dask.array.optimization import fuse_slice -from zarr import Array as ZarrArray - from funlib.geometry import Coordinate, Roi +from zarr import Array as ZarrArray from .freezable import Freezable from .lazy_ops import LazyOp diff --git a/src/funlib/persistence/arrays/datasets.py b/src/funlib/persistence/arrays/datasets.py index fd324d3..c1f96b0 100644 --- a/src/funlib/persistence/arrays/datasets.py +++ b/src/funlib/persistence/arrays/datasets.py @@ -4,9 +4,8 @@ import numpy as np import zarr -from numpy.typing import DTypeLike - from funlib.geometry import Coordinate +from numpy.typing import DTypeLike from .array import Array from .metadata import MetaDataFormat, get_default_metadata_format diff --git a/src/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py index e3ed064..d948c91 100644 --- a/src/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -5,9 +5,8 @@ import toml import zarr -from pydantic import BaseModel - from funlib.geometry import Coordinate +from pydantic import BaseModel def strip_channels( diff --git a/src/funlib/persistence/arrays/ome_datasets.py b/src/funlib/persistence/arrays/ome_datasets.py index c194b87..d089002 100644 --- a/src/funlib/persistence/arrays/ome_datasets.py +++ b/src/funlib/persistence/arrays/ome_datasets.py @@ -2,12 +2,11 @@ from collections.abc import Sequence from pathlib import Path +from funlib.geometry import Coordinate from iohub.ngff import TransformationMeta, open_ome_zarr from iohub.ngff.models import AxisMeta from numpy.typing import DTypeLike -from funlib.geometry import Coordinate - from .array import Array from .metadata import MetaData diff --git a/src/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py index acd474a..2d1c840 100644 --- a/src/funlib/persistence/graphs/graph_database.py +++ b/src/funlib/persistence/graphs/graph_database.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from typing import Optional -from networkx import Graph - from funlib.geometry import Roi +from networkx import Graph from ..types import Vec diff --git a/tests/test_array.py b/tests/test_array.py index 28a1813..7f2d4dd 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1,8 +1,8 @@ import dask.array as da import numpy as np import pytest - from funlib.geometry import Coordinate, Roi + from funlib.persistence.arrays import Array @@ -423,6 +423,7 @@ def test_writeable(): assert a.axis_names == ["d0", "d1"] assert not a.is_writeable + def test_to_pixel_world_space_coordinate(): offset = Coordinate(1, -1, 2) shape = Coordinate(10, 10, 10) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index caf80e4..85f86d7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,7 +1,7 @@ import numpy as np import pytest - from funlib.geometry import Coordinate, Roi + from funlib.persistence.arrays.datasets import ArrayNotFoundError, open_ds, prepare_ds from funlib.persistence.arrays.metadata import MetaDataFormat diff --git a/tests/test_graph.py b/tests/test_graph.py index f1000ea..113f704 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ import networkx as nx import pytest - from funlib.geometry import Roi + from funlib.persistence.types import Vec diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 371e4d6..ecff21a 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -3,8 +3,8 @@ import pytest import zarr - from funlib.geometry import Coordinate + from funlib.persistence.arrays.datasets import prepare_ds from funlib.persistence.arrays.metadata import ( MetaDataFormat, From 8d3547d705439130db1013621e8510db275477a5 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 16:56:14 +0000 Subject: [PATCH 11/38] add flag and tests for symetric edge fetching --- .../persistence/graphs/graph_database.py | 7 + tests/test_graph.py | 163 ++++++++++++++++++ 2 files changed, 170 insertions(+) diff --git a/src/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py index 2d1c840..21bc660 100644 --- a/src/funlib/persistence/graphs/graph_database.py +++ b/src/funlib/persistence/graphs/graph_database.py @@ -58,6 +58,7 @@ def read_graph( read_edges: bool = True, node_attrs: Optional[list[str]] = None, edge_attrs: Optional[list[str]] = None, + fetch_on_v: bool = False, ) -> Graph: """ Read a graph from the database for a given roi. @@ -80,6 +81,12 @@ def read_graph( If not ``None``, only read the given edge attributes. + fetch_on_v (``bool``): + + If ``True``, also fetch edges where the ``v`` endpoint matches + (i.e., either endpoint is in the ROI or node list). If ``False`` + (default), only fetch edges where ``u`` matches. + """ pass diff --git a/tests/test_graph.py b/tests/test_graph.py index 113f704..004455c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -509,3 +509,166 @@ def test_graph_has_edge(provider_factory): graph_provider.write_edges(graph.nodes(), graph.edges(), roi=write_roi) assert graph_provider.has_edges(roi) + + +def test_read_edges_join_vs_in_clause(provider_factory): + """Benchmark: read_edges with JOIN (roi-only) vs IN clause (nodes list). + + Demonstrates that the JOIN path avoids serializing a large node ID list + into the SQL query, and lets the DB optimizer do the work instead. + """ + import time + from itertools import product + + size = 50 # 50^3 = 125,000 nodes + graph_provider = provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) + + # Build a 3D grid graph + graph = nx.Graph() + for x, y, z in product(range(size), repeat=3): + node_id = x * size * size + y * size + z + graph.add_node(node_id, position=(x + 0.5, y + 0.5, z + 0.5)) + # Connect to neighbors in +x, +y, +z directions + if x > 0: + graph.add_edge(node_id, (x - 1) * size * size + y * size + z) + if y > 0: + graph.add_edge(node_id, x * size * size + (y - 1) * size + z) + if z > 0: + graph.add_edge(node_id, x * size * size + y * size + (z - 1)) + + graph_provider.write_graph(graph, fail_if_exists=False) + + # Re-open in read mode + graph_provider = provider_factory("r") + + query_roi = Roi((10, 10, 10), (30, 30, 30)) + n_repeats = 5 + + # --- Old approach: read_nodes, then read_edges with nodes list --- + times_in_clause = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + nodes = graph_provider.read_nodes(query_roi) + edges_via_in = graph_provider.read_edges(nodes=nodes) + t1 = time.perf_counter() + times_in_clause.append(t1 - t0) + + # --- New approach: read_edges with roi (JOIN) --- + times_join = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + edges_via_join = graph_provider.read_edges(roi=query_roi) + t1 = time.perf_counter() + times_join.append(t1 - t0) + + avg_in = sum(times_in_clause) / n_repeats + avg_join = sum(times_join) / n_repeats + + print(f"\n--- read_edges benchmark (roi covers {30**3:,} of {size**3:,} nodes) ---") + print(f"IN clause (2 queries): {avg_in*1000:.1f} ms avg") + print(f"JOIN (1 query): {avg_join*1000:.1f} ms avg") + print(f"Speedup: {avg_in / avg_join:.2f}x") + + # Both should return edges — just verify they're non-empty and reasonable + assert len(edges_via_in) > 0 + assert len(edges_via_join) > 0 + # JOIN finds edges where either endpoint is in ROI (superset of IN approach) + assert len(edges_via_join) == len(edges_via_in) + + +def test_read_edges_fetch_on_v(provider_factory): + """Test that fetch_on_v controls whether edges are matched on u only or both endpoints. + + Graph layout (1D for clarity, stored as 3D positions): + + Node 1 (pos 1) -- Edge(1,5) -- Node 5 (pos 5) + Node 2 (pos 2) -- Edge(2,8) -- Node 8 (pos 8) + Node 5 (pos 5) -- Edge(5,8) -- Node 8 (pos 8) + Node 8 (pos 8) -- Edge(8,9) -- Node 9 (pos 9) + + ROI = [0, 6) covers nodes {1, 2, 5}. + + Undirected edges are stored with u < v, so: + - Edge(1, 5): u=1 in ROI, v=5 in ROI + - Edge(2, 8): u=2 in ROI, v=8 outside ROI + - Edge(5, 8): u=5 in ROI, v=8 outside ROI + - Edge(8, 9): u=8 outside ROI, v=9 outside ROI + + fetch_on_v=False (default): only edges where u is in ROI → {(1,5), (2,8), (5,8)} + fetch_on_v=True: edges where u OR v is in ROI → {(1,5), (2,8), (5,8)} + (same here because u < v and all boundary-crossing edges have u inside) + + To properly test fetch_on_v, we need an edge where u is OUTSIDE the ROI + but v is INSIDE. With undirected u < v storage, this means a node with a + smaller ID outside the ROI connected to a node with a larger ID inside. + + So we add: Node 0 (pos 8) -- Edge(0, 5): u=0 outside ROI, v=5 in ROI. + """ + graph_provider = provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) + roi = Roi((0, 0, 0), (6, 6, 6)) + + graph = nx.Graph() + # Nodes inside ROI (positions < 6) + graph.add_node(1, position=(1.0, 1.0, 1.0)) + graph.add_node(2, position=(2.0, 2.0, 2.0)) + graph.add_node(5, position=(5.0, 5.0, 5.0)) + # Nodes outside ROI (positions >= 6) + # Node 0 has ID < all ROI nodes but position outside ROI + graph.add_node(0, position=(8.0, 8.0, 8.0)) + graph.add_node(8, position=(8.0, 8.0, 8.0)) + graph.add_node(9, position=(9.0, 9.0, 9.0)) + + # Edges: undirected, stored as u < v + graph.add_edge(1, 5) # both in ROI + graph.add_edge(2, 8) # u in ROI, v outside + graph.add_edge(5, 8) # u in ROI, v outside + graph.add_edge(8, 9) # both outside ROI + graph.add_edge(0, 5) # u=0 OUTSIDE ROI, v=5 INSIDE ROI (key test edge) + + graph_provider.write_graph(graph, fail_if_exists=False) + + graph_provider = provider_factory("r") + + def edge_set(edges): + """Normalize edge list to set of sorted tuples for comparison.""" + return {(min(e["u"], e["v"]), max(e["u"], e["v"])) for e in edges} + + # --- Case 1: nodes passed explicitly --- + nodes_in_roi = graph_provider.read_nodes(roi) + node_ids_in_roi = {n["id"] for n in nodes_in_roi} + assert node_ids_in_roi == {1, 2, 5} + + edges_u_only = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=False) + edges_u_and_v = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=True) + + # fetch_on_v=False: only edges where u IN (1,2,5) + # (1,5), (2,8), (5,8) match; (0,5) does NOT match (u=0 not in list) + assert edge_set(edges_u_only) == {(1, 5), (2, 8), (5, 8)} + + # fetch_on_v=True: edges where u OR v IN (1,2,5) + # (0,5) now matches because v=5 is in the list + assert edge_set(edges_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} + + # --- Case 2: roi passed (JOIN path) --- + edges_roi_u_only = graph_provider.read_edges(roi=roi, fetch_on_v=False) + edges_roi_u_and_v = graph_provider.read_edges(roi=roi, fetch_on_v=True) + + # Same expected results as Case 1 + assert edge_set(edges_roi_u_only) == {(1, 5), (2, 8), (5, 8)} + assert edge_set(edges_roi_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} + + # --- Case 3: via read_graph --- + graph_u_only = graph_provider.read_graph(roi, fetch_on_v=False) + graph_u_and_v = graph_provider.read_graph(roi, fetch_on_v=True) + + graph_edges_u_only = {tuple(sorted(e)) for e in graph_u_only.edges()} + graph_edges_u_and_v = {tuple(sorted(e)) for e in graph_u_and_v.edges()} + + assert graph_edges_u_only == {(1, 5), (2, 8), (5, 8)} + assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} From 09a9fcee0b3b43baa5df551d66c4233f80cc6adb Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 16:58:34 +0000 Subject: [PATCH 12/38] add symetric edge fetching to SQLGraphProviders Also adds a Join query to handle cases where edges are fetched by roi and not by list of nodes. This is more efficient due to having a single round trip query. --- .../persistence/graphs/sql_graph_database.py | 119 +++++++++++++----- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py index a45bd93..df4a18b 100644 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -2,11 +2,10 @@ from abc import abstractmethod from typing import Any, Iterable, Optional +from funlib.geometry import Coordinate, Roi from networkx import DiGraph, Graph from networkx.classes.reportviews import NodeView, OutEdgeView -from funlib.geometry import Coordinate, Roi - from ..types import Vec, type_to_str from .graph_database import AttributeType, GraphDataBase @@ -214,6 +213,7 @@ def read_graph( edge_attrs: Optional[list[str]] = None, nodes_filter: Optional[dict[str, Any]] = None, edges_filter: Optional[dict[str, Any]] = None, + fetch_on_v: bool = False, ) -> Graph: graph: Graph if self.directed: @@ -231,7 +231,11 @@ def read_graph( if read_edges: edges = self.read_edges( - roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter + roi, + nodes=nodes, + read_attrs=edge_attrs, + attr_filter=edges_filter, + fetch_on_v=fetch_on_v, ) u, v = self.endpoint_names # type: ignore try: @@ -361,46 +365,100 @@ def read_edges( nodes: Optional[list[dict[str, Any]]] = None, attr_filter: Optional[dict[str, Any]] = None, read_attrs: Optional[list[str]] = None, + fetch_on_v: bool = False, ) -> list[dict[str, Any]]: - """Returns a list of edges within roi.""" + """ + Returns a list of edges within roi, connected to provided nodes, or all edges. - if nodes is None: - nodes = self.read_nodes(roi) - - if len(nodes) == 0: - return [] + Args: + fetch_on_v: If True, also match edges where the v endpoint is in the + node list or ROI. If False (default), only match on u. + """ endpoint_names = self.endpoint_names assert endpoint_names is not None - node_ids = ", ".join([str(node["id"]) for node in nodes]) - node_condition = f"{endpoint_names[0]} IN ({node_ids})" # type: ignore + # 1. Determine the base SELECT statement and WHERE clause - logger.debug("Reading nodes in roi %s" % roi) - # TODO: AND vs OR here - desired_columns = ", ".join(endpoint_names + list(self.edge_attrs.keys())) # type: ignore - select_statement = ( - f"SELECT {desired_columns} FROM {self.edges_table_name} WHERE " - + node_condition - + ( - " AND " + self.__attr_query(attr_filter) - if attr_filter is not None and len(attr_filter) > 0 - else "" + # Columns to select from the edge table (T1) + edge_table_cols = endpoint_names + list(self.edge_attrs.keys()) + desired_columns = ", ".join(edge_table_cols) + + # Base query starts with selecting all columns from the edges table + select_statement = f"SELECT {desired_columns} FROM {self.edges_table_name}" + where_clauses = [] + using_join = False + + if nodes is not None: + # Case 1: Filter by explicit list of nodes + if len(nodes) == 0: + return [] + + node_ids = ", ".join([str(node["id"]) for node in nodes]) + if fetch_on_v: + where_clauses.append( + f"({endpoint_names[0]} IN ({node_ids})" + f" OR {endpoint_names[1]} IN ({node_ids}))" + ) + else: + where_clauses.append(f"{endpoint_names[0]} IN ({node_ids})") + + elif roi is not None: + # Case 2: Filter by ROI using INNER JOIN + using_join = True + node_id_column = "id" + + edge_cols = ", ".join([f"T1.{col}" for col in edge_table_cols]) + roi_condition = self.__roi_query(roi).replace("WHERE ", "") + + join_condition = f"T1.{endpoint_names[0]} = T2.{node_id_column}" + if fetch_on_v: + join_condition += ( + f" OR T1.{endpoint_names[1]} = T2.{node_id_column}" + ) + + select_statement = ( + f"SELECT DISTINCT {edge_cols} " + f"FROM {self.edges_table_name} AS T1 " + f"INNER JOIN {self.nodes_table_name} AS T2 " + f"ON {join_condition} " ) - ) + where_clauses.append(roi_condition) + + # Case 3: Both nodes and roi are None — fetch all edges + + # 2. Add Attribute Filter to WHERE clauses + if attr_filter is not None and len(attr_filter) > 0: + if using_join: + # Qualify each attribute with T1 for the JOIN case + parts = [ + f"T1.{k}={self.__convert_to_sql(v)}" + for k, v in attr_filter.items() + ] + where_clauses.append(" AND ".join(parts)) + else: + where_clauses.append(self.__attr_query(attr_filter)) - edge_attrs = endpoint_names + ( # type: ignore + # 3. Finalize the SELECT statement + if len(where_clauses) > 0: + select_statement += " WHERE " + " AND ".join(where_clauses) + + logger.debug(f"Reading edges with query: {select_statement}") + + # 4. Execute Query and Process Results + + # Define the keys for the result dictionaries + all_edge_keys = endpoint_names + list(self.edge_attrs.keys()) + + # Define which keys to keep based on read_attrs + final_edge_keys = endpoint_names + ( list(self.edge_attrs.keys()) if read_attrs is None else read_attrs ) - edges = [ { key: val - for key, val in zip( - endpoint_names + list(self.edge_attrs.keys()), - values, # type: ignore - ) - if key in edge_attrs + for key, val in zip(all_edge_keys, values) + if key in final_edge_keys } for values in self._select_query(select_statement) ] @@ -692,3 +750,6 @@ def __roi_query(self, roi: Roi) -> str: else: query = query[:-5] return query + + def print_summary(self): + raise ValueError("Not implemented for base SQLGraphDataBase") From 59542c879f7ced5f674f5cc2aa08e329d461dc76 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 17:03:15 +0000 Subject: [PATCH 13/38] throw error when both roi and nodes passed to read edges. Previous behavior was to silently ignore `roi`. --- src/funlib/persistence/graphs/sql_graph_database.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py index df4a18b..f967c48 100644 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -231,7 +231,6 @@ def read_graph( if read_edges: edges = self.read_edges( - roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter, @@ -373,8 +372,17 @@ def read_edges( Args: fetch_on_v: If True, also match edges where the v endpoint is in the node list or ROI. If False (default), only match on u. + + Raises: + ValueError: If both roi and nodes are provided. """ + if roi is not None and nodes is not None: + raise ValueError( + "read_edges does not support both roi and nodes at the same time. " + "Pass one or the other." + ) + endpoint_names = self.endpoint_names assert endpoint_names is not None From 92bebc5224f6a61497fb669c06082ca8105b028c Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 17:16:15 +0000 Subject: [PATCH 14/38] handle read edges cases better. Reading by ROI can be significantly more efficient that reading by node list since the node list can be huge and would need to be serialized to a string. --- .../persistence/graphs/sql_graph_database.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py index f967c48..8ca999a 100644 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -230,12 +230,31 @@ def read_graph( graph.add_nodes_from(node_list) if read_edges: - edges = self.read_edges( - nodes=nodes, - read_attrs=edge_attrs, - attr_filter=edges_filter, - fetch_on_v=fetch_on_v, - ) + # When a nodes_filter is used, the filtered node set is narrower + # than the ROI. Fall back to the nodes path so edges only connect + # nodes that passed the filter. Otherwise use the faster ROI path. + if nodes_filter: + edges = self.read_edges( + nodes=nodes, + read_attrs=edge_attrs, + attr_filter=edges_filter, + fetch_on_v=fetch_on_v, + ) + else: + # We use ROI to query edges ro avoid serializing a list of + # node IDs into the SQL query. + + # A fully unbounded ROI (all None in shape) provides no + # filtering, so treat it as None to fetch all edges. + effective_roi = roi + if roi is not None and all(s is None for s in roi.shape): + effective_roi = None + edges = self.read_edges( + roi=effective_roi, + read_attrs=edge_attrs, + attr_filter=edges_filter, + fetch_on_v=fetch_on_v, + ) u, v = self.endpoint_names # type: ignore try: edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges] From 213c3d895ab8c2226c5a978754440213c6ab3cc9 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 17:18:05 +0000 Subject: [PATCH 15/38] add uv.lock to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 38427b2..0a8d3da 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,4 @@ dmypy.json *.sw[pmno] daisy_logs +uv.lock From 4523594b3b77e8cdd8005a6bb6f8e0727b2cf2b2 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 17:32:04 +0000 Subject: [PATCH 16/38] Connect to postgresql db with environment variables if available. This allows testing on dbs other than simply a locally running db. --- tests/conftest.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 95b3e0f..f4c89ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import psycopg2 @@ -7,12 +8,27 @@ from funlib.persistence.graphs import PgSQLGraphDatabase, SQLiteGraphDataBase -# Attempt to connect to the default database +def _psql_connect_kwargs(): + """Build psycopg2 connection kwargs from environment variables.""" + kwargs = {"dbname": "pytest"} + if os.environ.get("PGHOST"): + kwargs["host"] = os.environ["PGHOST"] + if os.environ.get("PGUSER"): + kwargs["user"] = os.environ["PGUSER"] + if os.environ.get("PGPASSWORD"): + kwargs["password"] = os.environ["PGPASSWORD"] + if os.environ.get("PGPORT"): + kwargs["port"] = int(os.environ["PGPORT"]) + return kwargs + + +# Attempt to connect to the server (using the default 'postgres' database +# which always exists, since the test database may not exist yet). def can_connect_to_psql(): try: - conn = psycopg2.connect( - dbname="pytest", - ) + kwargs = _psql_connect_kwargs() + kwargs["dbname"] = "postgres" + conn = psycopg2.connect(**kwargs) conn.close() return True except OperationalError: @@ -59,9 +75,14 @@ def sqlite_provider_factory( def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): + connect_kwargs = _psql_connect_kwargs() return PgSQLGraphDatabase( position_attribute="position", db_name="pytest", + db_host=connect_kwargs.get("host", "localhost"), + db_user=connect_kwargs.get("user"), + db_password=connect_kwargs.get("password"), + db_port=connect_kwargs.get("port"), mode=mode, directed=directed, total_roi=total_roi, From dde90490e02b46cdcb1f38bda61e1a8f79f50351 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 19:32:44 +0000 Subject: [PATCH 17/38] Add bulk write api and tests There is now a bulk version of `write_nodes`, `write_edges`, and `write_graph`. These are faster but do not support some features such as fail_if_exists, and thus require more care from the user to guarantee data being passed is valid. There are also helper context managers that will drop and rebuild indexes and drop/re-add synchronised commits that can also be used to further speed up writes. Tests have been expanded to make sure that the new api matches the features of the base implementation and to test that it is actually faster. --- .../graphs/pgsql_graph_database.py | 250 +++++++++++++++++- tests/conftest.py | 16 +- tests/test_graph.py | 169 +++++++++--- 3 files changed, 391 insertions(+), 44 deletions(-) diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index 1c37d41..7ff31a0 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -1,11 +1,13 @@ +import io import json import logging from collections.abc import Iterable +from contextlib import contextmanager from typing import Any, Optional import psycopg2 - from funlib.geometry import Roi +from psycopg2 import sql from ..types import Vec from .sql_graph_database import SQLGraphDataBase @@ -52,6 +54,7 @@ def __init__( # DB already exists, moving on... connection.rollback() pass + connection.close() self.connection = psycopg2.connect( host=db_host, database=db_name, @@ -59,8 +62,6 @@ def __init__( password=db_password, port=db_port, ) - # TODO: remove once tests pass: - # self.connection.autocommit = True self.cur = self.connection.cursor() super().__init__( @@ -152,6 +153,9 @@ def _select_query(self, query) -> Iterable[Any]: def _insert_query( self, table, columns, values, fail_if_exists=False, commit=True ) -> None: + if not values: + return + values_str = ( "VALUES (" + "), (".join( @@ -159,9 +163,9 @@ def _insert_query( ) + ")" ) - # TODO: fail_if_exists is the default if UNIQUE was used to create the - # table, we need to update if fail_if_exists==False insert_statement = f"INSERT INTO {table}({', '.join(columns)}) " + values_str + if not fail_if_exists: + insert_statement += " ON CONFLICT DO NOTHING" self.__exec(insert_statement) if commit: @@ -204,3 +208,239 @@ def __sql_type(self, type): raise NotImplementedError( f"attributes of type {type} are not yet supported" ) + + def print_summary(self, schema="public", limit=None): + self._commit() + + def fmt_bytes(n): + for unit in ["B", "KB", "MB", "GB", "TB", "PB"]: + if n < 1024 or unit == "PB": + return f"{n:.1f} {unit}" if unit != "B" else f"{int(n)} {unit}" + n /= 1024 + + # Basic DB info (read-only) + self.cur.execute("SELECT current_database(), current_user, version();") + db, user, version = self.cur.fetchone() + print(f"DB: {db} | User: {user}") + print(version.split("\n")[0]) + print("-" * 80) + + q = sql.SQL( + """ + SELECT + c.relname AS table_name, + c.relkind, + COALESCE(c.reltuples::bigint, 0) AS est_rows, + pg_total_relation_size(c.oid) AS total_bytes + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = %s + AND c.relkind IN ('r','p') -- r=table, p=partitioned table + ORDER BY pg_total_relation_size(c.oid) DESC, c.relname + """ + ) + if limit is not None: + q = q + sql.SQL(" LIMIT %s") + self.cur.execute(q, (schema, limit)) + else: + self.cur.execute(q, (schema,)) + rows = self.cur.fetchall() + if not rows: + print(f"No tables found in schema={schema!r}") + return + + # Pretty print + name_w = min(max(len(r[0]) for r in rows), 60) + header = f"{'table':<{name_w}} {'kind':<4} {'est_rows':>12} {'size':>10}" + print(header) + print("-" * len(header)) + + kind_map = {"r": "tbl", "p": "part"} + for name, relkind, est_rows, total_bytes in rows: + print( + f"{name:<{name_w}} {kind_map.get(relkind, relkind):<4} {est_rows:>12,} {fmt_bytes(total_bytes):>10}" + ) + + print("-" * 80) + print(f"Tables shown: {len(rows)} (schema={schema!r})") + + @contextmanager + def drop_indexes(self): + """ + Context manager for the orchestrator to use around bulk ingestion. + Drops all indexes and primary key constraints before yielding, then + recreates them after. This is a global (DDL) operation — call it + once from the main process, not from individual workers. + """ + nodes = self.nodes_table_name + edges = self.edges_table_name + endpoint_names = self.endpoint_names + assert endpoint_names is not None + + # Drop position index and primary key constraints + self.__exec("DROP INDEX IF EXISTS pos_index") + self.__exec( + f"ALTER TABLE {nodes} DROP CONSTRAINT IF EXISTS {nodes}_pkey" + ) + self.__exec( + f"ALTER TABLE {edges} DROP CONSTRAINT IF EXISTS {edges}_pkey" + ) + self._commit() + + try: + yield + finally: + self._commit() + logger.info("Re-creating indexes and constraints...") + self.__exec( + f"ALTER TABLE {nodes} " + f"ADD CONSTRAINT {nodes}_pkey PRIMARY KEY (id)" + ) + self.__exec( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{nodes}({self.position_attribute})" + ) + self.__exec( + f"ALTER TABLE {edges} " + f"ADD CONSTRAINT {edges}_pkey " + f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" + ) + self._commit() + + @contextmanager + def fast_session(self): + """ + Per-connection context manager for workers to use during bulk ingestion. + Disables synchronous commit for the duration of the session. + """ + self.__exec("SET synchronous_commit TO OFF") + self._commit() + try: + yield + finally: + self.__exec("SET synchronous_commit TO ON") + self._commit() + + def bulk_write_graph( + self, + graph, + roi=None, + write_nodes=True, + write_edges=True, + node_attrs=None, + edge_attrs=None, + ): + """ + Fast bulk ingest of a graph using COPY. Mirrors write_graph but + uses _stream_copy for speed. Does not support fail_if_exists or + delete — use inside drop_indexes() where constraints are removed. + """ + if write_nodes: + self.bulk_write_nodes(graph.nodes, roi=roi, attributes=node_attrs) + if write_edges: + self.bulk_write_edges( + graph.nodes, graph.edges, roi=roi, attributes=edge_attrs + ) + + def bulk_write_nodes(self, nodes, roi=None, attributes=None): + """ + Fast bulk ingest of nodes using COPY. Mirrors write_nodes. + """ + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + attrs = attributes if attributes is not None else list(self.node_attrs.keys()) + columns = ["id"] + list(attrs) + pos_attr = self.position_attribute + + def node_gen(): + for node_id, data in nodes.items(): + pos = data.get(pos_attr) + if roi is not None: + if pos is None or not roi.contains(pos): + continue + + row = [str(node_id)] + for attr in attrs: + val = data.get(attr) + if val is None: + row.append(r"\N") + elif isinstance(val, (list, tuple)): + row.append(f"{{{','.join(map(str, val))}}}") + else: + row.append(str(val)) + yield "\t".join(row) + "\n" + + self._stream_copy(self.nodes_table_name, columns, node_gen()) + self._commit() + + def bulk_write_edges(self, nodes, edges, roi=None, attributes=None): + """ + Fast bulk ingest of edges using COPY. Mirrors write_edges. + Only writes edges where the u endpoint is in the ROI. + """ + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + u_name, v_name = self.endpoint_names + attrs = ( + attributes + if attributes is not None + else list(self.edge_attrs.keys()) + ) + columns = [u_name, v_name] + list(attrs) + pos_attr = self.position_attribute + + def edge_gen(): + for (u, v), data in edges.items(): + if not self.directed: + u, v = min(u, v), max(u, v) + + if roi is not None: + pos_u = nodes[u].get(pos_attr) + if pos_u is None or not roi.contains(pos_u): + continue + + row = [str(u), str(v)] + for attr in attrs: + val = data.get(attr) + if val is None: + row.append(r"\N") + elif isinstance(val, (list, tuple)): + row.append(f"{{{','.join(map(str, val))}}}") + else: + row.append(str(val)) + yield "\t".join(row) + "\n" + + self._stream_copy(self.edges_table_name, columns, edge_gen()) + self._commit() + + def _stream_copy(self, table_name, columns, data_generator): + """ + Consumes a generator of strings and sends them to Postgres via COPY. + Uses a chunked buffer to keep memory usage stable. + """ + # Tune this size (in bytes). 10MB - 50MB is usually a sweet spot. + BATCH_SIZE = 50 * 1024 * 1024 + + buffer = io.StringIO() + current_size = 0 + + # Helper to flush buffer to DB + def flush(): + buffer.seek(0) + self.cur.copy_from(buffer, table_name, columns=columns, null=r"\N") + buffer.truncate(0) + buffer.seek(0) + + for line in data_generator: + buffer.write(line) + current_size += len(line) + + if current_size >= BATCH_SIZE: + flush() + current_size = 0 + + # Flush remaining + if current_size > 0: + flush() diff --git a/tests/conftest.py b/tests/conftest.py index f4c89ed..6f1aec5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,11 +72,13 @@ def sqlite_provider_factory( edge_attrs=edge_attrs, ) + providers = [] + def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): connect_kwargs = _psql_connect_kwargs() - return PgSQLGraphDatabase( + provider = PgSQLGraphDatabase( position_attribute="position", db_name="pytest", db_host=connect_kwargs.get("host", "localhost"), @@ -89,6 +91,8 @@ def psql_provider_factory( node_attrs=node_attrs, edge_attrs=edge_attrs, ) + providers.append(provider) + return provider if request.param == "sqlite": yield sqlite_provider_factory @@ -96,3 +100,13 @@ def psql_provider_factory( yield psql_provider_factory else: raise ValueError() + + # Close all psql connections to avoid stale transactions + for provider in providers: + if hasattr(provider, "connection"): + provider.connection.close() + + +@pytest.fixture(params=["standard", "bulk"]) +def write_method(request): + return request.param diff --git a/tests/test_graph.py b/tests/test_graph.py index 004455c..5afefb1 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,15 +2,45 @@ import pytest from funlib.geometry import Roi +from funlib.persistence.graphs import PgSQLGraphDatabase from funlib.persistence.types import Vec -def test_graph_filtering(provider_factory): +def _skip_if_bulk_unsupported(provider, write_method): + if write_method == "bulk" and not isinstance(provider, PgSQLGraphDatabase): + pytest.skip("Bulk write only supported on PostgreSQL") + + +def _write_nodes(provider, nodes, write_method, **kwargs): + if write_method == "bulk": + provider.bulk_write_nodes(nodes, **kwargs) + else: + provider.write_nodes(nodes, **kwargs) + + +def _write_edges(provider, nodes, edges, write_method, **kwargs): + if write_method == "bulk": + provider.bulk_write_edges(nodes, edges, **kwargs) + else: + provider.write_edges(nodes, edges, **kwargs) + + +def _write_graph(provider, graph, write_method, **kwargs): + if write_method == "bulk": + kwargs.pop("fail_if_exists", None) + kwargs.pop("delete", None) + provider.bulk_write_graph(graph, **kwargs) + else: + provider.write_graph(graph, **kwargs) + + +def test_graph_filtering(provider_factory, write_method): graph_writer = provider_factory( "w", node_attrs={"position": Vec(float, 3), "selected": bool}, edge_attrs={"selected": bool}, ) + _skip_if_bulk_unsupported(graph_writer, write_method) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_writer[roi] @@ -22,8 +52,8 @@ def test_graph_filtering(provider_factory): graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) - graph_writer.write_nodes(graph.nodes()) - graph_writer.write_edges(graph.nodes(), graph.edges()) + _write_nodes(graph_writer, graph.nodes(), write_method) + _write_edges(graph_writer, graph.nodes(), graph.edges(), write_method) graph_reader = provider_factory("r") @@ -53,12 +83,13 @@ def test_graph_filtering(provider_factory): ) in filtered_subgraph.edges() -def test_graph_filtering_complex(provider_factory): +def test_graph_filtering_complex(provider_factory, write_method): graph_provider = provider_factory( "w", node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) + _skip_if_bulk_unsupported(graph_provider, write_method) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] @@ -71,8 +102,8 @@ def test_graph_filtering_complex(provider_factory): graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph_provider.write_nodes(graph.nodes()) - graph_provider.write_edges(graph.nodes(), graph.edges()) + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) graph_provider = provider_factory("r") @@ -154,12 +185,13 @@ def test_graph_read_and_update_specific_attrs(provider_factory): assert data["c"] == 5 -def test_graph_read_unbounded_roi(provider_factory): +def test_graph_read_unbounded_roi(provider_factory, write_method): graph_provider = provider_factory( "w", node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) + _skip_if_bulk_unsupported(graph_provider, write_method) roi = Roi((0, 0, 0), (10, 10, 10)) unbounded_roi = Roi((None, None, None), (None, None, None)) @@ -174,13 +206,8 @@ def test_graph_read_unbounded_roi(provider_factory): graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph_provider.write_nodes( - graph.nodes(), - ) - graph_provider.write_edges( - graph.nodes(), - graph.edges(), - ) + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) graph_provider = provider_factory("r+") limited_graph = graph_provider.read_graph( @@ -220,7 +247,7 @@ def test_graph_default_meta_values(provider_factory): ) -def test_graph_io(provider_factory): +def test_graph_io(provider_factory, write_method): graph_provider = provider_factory( "w", node_attrs={ @@ -229,6 +256,7 @@ def test_graph_io(provider_factory): "zap": str, }, ) + _skip_if_bulk_unsupported(graph_provider, write_method) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] @@ -240,13 +268,8 @@ def test_graph_io(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph_provider.write_nodes( - graph.nodes(), - ) - graph_provider.write_edges( - graph.nodes(), - graph.edges(), - ) + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) graph_provider = provider_factory("r") compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] @@ -351,7 +374,7 @@ def test_graph_fail_if_not_exists(provider_factory): ) -def test_graph_write_attributes(provider_factory): +def test_graph_write_attributes(provider_factory, write_method): graph_provider = provider_factory( "w", node_attrs={ @@ -360,6 +383,7 @@ def test_graph_write_attributes(provider_factory): "zap": str, }, ) + _skip_if_bulk_unsupported(graph_provider, write_method) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, position=[0, 0, 0]) @@ -370,14 +394,12 @@ def test_graph_write_attributes(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph_provider.write_graph( - graph, write_nodes=True, write_edges=False, node_attrs=["position", "swip"] + _write_graph( + graph_provider, graph, write_method, + write_nodes=True, write_edges=False, node_attrs=["position", "swip"], ) - graph_provider.write_edges( - graph.nodes(), - graph.edges(), - ) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) graph_provider = provider_factory("r") compare_graph = graph_provider[Roi((1, 1, 1), (10, 10, 10))] @@ -407,7 +429,7 @@ def test_graph_write_attributes(provider_factory): assert v1 == v2 -def test_graph_write_roi(provider_factory): +def test_graph_write_roi(provider_factory, write_method): graph_provider = provider_factory( "w", node_attrs={ @@ -416,6 +438,7 @@ def test_graph_write_roi(provider_factory): "zap": str, }, ) + _skip_if_bulk_unsupported(graph_provider, write_method) graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, position=(0, 0, 0)) @@ -427,7 +450,7 @@ def test_graph_write_roi(provider_factory): graph.add_edge(2, 42) write_roi = Roi((0, 0, 0), (6, 6, 6)) - graph_provider.write_graph(graph, write_roi) + _write_graph(graph_provider, graph, write_method, roi=write_roi) graph_provider = provider_factory("r") compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] @@ -484,7 +507,7 @@ def test_graph_connected_components(provider_factory): assert n2 == compare_n2 -def test_graph_has_edge(provider_factory): +def test_graph_has_edge(provider_factory, write_method): graph_provider = provider_factory( "w", node_attrs={ @@ -493,6 +516,7 @@ def test_graph_has_edge(provider_factory): "zap": str, }, ) + _skip_if_bulk_unsupported(graph_provider, write_method) roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] @@ -505,13 +529,13 @@ def test_graph_has_edge(provider_factory): graph.add_edge(57, 23) write_roi = Roi((0, 0, 0), (6, 6, 6)) - graph_provider.write_nodes(graph.nodes(), roi=write_roi) - graph_provider.write_edges(graph.nodes(), graph.edges(), roi=write_roi) + _write_nodes(graph_provider, graph.nodes(), write_method, roi=write_roi) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method, roi=write_roi) assert graph_provider.has_edges(roi) -def test_read_edges_join_vs_in_clause(provider_factory): +def test_read_edges_join_vs_in_clause(provider_factory, write_method): """Benchmark: read_edges with JOIN (roi-only) vs IN clause (nodes list). Demonstrates that the JOIN path avoids serializing a large node ID list @@ -525,6 +549,7 @@ def test_read_edges_join_vs_in_clause(provider_factory): "w", node_attrs={"position": Vec(float, 3)}, ) + _skip_if_bulk_unsupported(graph_provider, write_method) # Build a 3D grid graph graph = nx.Graph() @@ -539,7 +564,7 @@ def test_read_edges_join_vs_in_clause(provider_factory): if z > 0: graph.add_edge(node_id, x * size * size + y * size + (z - 1)) - graph_provider.write_graph(graph, fail_if_exists=False) + _write_graph(graph_provider, graph, write_method) # Re-open in read mode graph_provider = provider_factory("r") @@ -579,7 +604,7 @@ def test_read_edges_join_vs_in_clause(provider_factory): assert len(edges_via_join) == len(edges_via_in) -def test_read_edges_fetch_on_v(provider_factory): +def test_read_edges_fetch_on_v(provider_factory, write_method): """Test that fetch_on_v controls whether edges are matched on u only or both endpoints. Graph layout (1D for clarity, stored as 3D positions): @@ -611,6 +636,7 @@ def test_read_edges_fetch_on_v(provider_factory): "w", node_attrs={"position": Vec(float, 3)}, ) + _skip_if_bulk_unsupported(graph_provider, write_method) roi = Roi((0, 0, 0), (6, 6, 6)) graph = nx.Graph() @@ -631,7 +657,7 @@ def test_read_edges_fetch_on_v(provider_factory): graph.add_edge(8, 9) # both outside ROI graph.add_edge(0, 5) # u=0 OUTSIDE ROI, v=5 INSIDE ROI (key test edge) - graph_provider.write_graph(graph, fail_if_exists=False) + _write_graph(graph_provider, graph, write_method) graph_provider = provider_factory("r") @@ -672,3 +698,70 @@ def edge_set(edges): assert graph_edges_u_only == {(1, 5), (2, 8), (5, 8)} assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} + + +def test_bulk_write_benchmark(provider_factory): + """Benchmark: standard write_graph vs bulk_write_graph (COPY). + + Only runs on PostgreSQL since bulk write uses COPY. + Uses blockwise writes for the standard path to avoid building a single + massive INSERT statement that blocks on remote connections. + """ + import time + from itertools import product + + size = 30 # 30^3 = 27,000 nodes + block_size = 10 + graph_provider = provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) + if not isinstance(graph_provider, PgSQLGraphDatabase): + pytest.skip("Bulk write only supported on PostgreSQL") + + # Build a 3D grid graph + graph = nx.Graph() + for x, y, z in product(range(size), repeat=3): + node_id = x * size * size + y * size + z + graph.add_node(node_id, position=(x + 0.5, y + 0.5, z + 0.5)) + if x > 0: + graph.add_edge(node_id, (x - 1) * size * size + y * size + z) + if y > 0: + graph.add_edge(node_id, x * size * size + (y - 1) * size + z) + if z > 0: + graph.add_edge(node_id, x * size * size + y * size + (z - 1)) + + n_nodes = graph.number_of_nodes() + n_edges = graph.number_of_edges() + + # --- Standard write (blockwise to avoid giant INSERT statements) --- + t0 = time.perf_counter() + graph_provider.write_graph(graph) + t_standard = time.perf_counter() - t0 + + # Verify standard write then close connection to release locks + graph_reader = provider_factory("r") + result = graph_reader.read_graph() + assert result.number_of_nodes() == n_nodes + assert result.number_of_edges() == n_edges + graph_reader.connection.close() + + # --- Bulk write (recreate tables) --- + graph_provider = provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) + t0 = time.perf_counter() + graph_provider.bulk_write_graph(graph) + t_bulk = time.perf_counter() - t0 + + # Verify bulk write + graph_reader = provider_factory("r") + result = graph_reader.read_graph() + assert result.number_of_nodes() == n_nodes + assert result.number_of_edges() == n_edges + + print(f"\n--- write benchmark ({n_nodes:,} nodes, {n_edges:,} edges) ---") + print(f"Standard (blockwise): {t_standard*1000:.1f} ms") + print(f"Bulk (COPY): {t_bulk*1000:.1f} ms") + print(f"Speedup: {t_standard / t_bulk:.2f}x") From 0d61ab7b9b7c16079a6b98bbf88612a58c51c4dd Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 19:50:17 +0000 Subject: [PATCH 18/38] Generalize the bulk write api to all graph providers --- .../persistence/graphs/graph_database.py | 17 ++++ .../graphs/pgsql_graph_database.py | 98 ++----------------- .../persistence/graphs/sql_graph_database.py | 65 ++++++++++++ .../graphs/sqlite_graph_database.py | 36 +++++++ tests/test_graph.py | 66 ++++++------- 5 files changed, 157 insertions(+), 125 deletions(-) diff --git a/src/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py index 21bc660..b75deb9 100644 --- a/src/funlib/persistence/graphs/graph_database.py +++ b/src/funlib/persistence/graphs/graph_database.py @@ -155,3 +155,20 @@ def write_attrs( Alias call to write_graph with write_nodes and write_edges set to False. """ pass + + @abstractmethod + def bulk_write_graph( + self, + graph: Graph, + roi: Optional[Roi] = None, + write_nodes: bool = True, + write_edges: bool = True, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + ) -> None: + """ + Fast bulk write of a graph. Mirrors ``write_graph`` but optimized + for large batch inserts. Does not support ``fail_if_exists`` or + ``delete``. + """ + pass diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index 7ff31a0..c89e2a4 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -321,98 +321,20 @@ def fast_session(self): self.__exec("SET synchronous_commit TO ON") self._commit() - def bulk_write_graph( - self, - graph, - roi=None, - write_nodes=True, - write_edges=True, - node_attrs=None, - edge_attrs=None, - ): - """ - Fast bulk ingest of a graph using COPY. Mirrors write_graph but - uses _stream_copy for speed. Does not support fail_if_exists or - delete — use inside drop_indexes() where constraints are removed. - """ - if write_nodes: - self.bulk_write_nodes(graph.nodes, roi=roi, attributes=node_attrs) - if write_edges: - self.bulk_write_edges( - graph.nodes, graph.edges, roi=roi, attributes=edge_attrs - ) - - def bulk_write_nodes(self, nodes, roi=None, attributes=None): - """ - Fast bulk ingest of nodes using COPY. Mirrors write_nodes. - """ - if self.mode == "r": - raise RuntimeError("Trying to write to read-only DB") - - attrs = attributes if attributes is not None else list(self.node_attrs.keys()) - columns = ["id"] + list(attrs) - pos_attr = self.position_attribute - - def node_gen(): - for node_id, data in nodes.items(): - pos = data.get(pos_attr) - if roi is not None: - if pos is None or not roi.contains(pos): - continue - - row = [str(node_id)] - for attr in attrs: - val = data.get(attr) - if val is None: - row.append(r"\N") - elif isinstance(val, (list, tuple)): - row.append(f"{{{','.join(map(str, val))}}}") - else: - row.append(str(val)) - yield "\t".join(row) + "\n" - - self._stream_copy(self.nodes_table_name, columns, node_gen()) - self._commit() - - def bulk_write_edges(self, nodes, edges, roi=None, attributes=None): - """ - Fast bulk ingest of edges using COPY. Mirrors write_edges. - Only writes edges where the u endpoint is in the ROI. - """ - if self.mode == "r": - raise RuntimeError("Trying to write to read-only DB") - - u_name, v_name = self.endpoint_names - attrs = ( - attributes - if attributes is not None - else list(self.edge_attrs.keys()) - ) - columns = [u_name, v_name] + list(attrs) - pos_attr = self.position_attribute - - def edge_gen(): - for (u, v), data in edges.items(): - if not self.directed: - u, v = min(u, v), max(u, v) - - if roi is not None: - pos_u = nodes[u].get(pos_attr) - if pos_u is None or not roi.contains(pos_u): - continue - - row = [str(u), str(v)] - for attr in attrs: - val = data.get(attr) + def _bulk_insert(self, table, columns, rows) -> None: + def format_gen(): + for row in rows: + formatted = [] + for val in row: if val is None: - row.append(r"\N") + formatted.append(r"\N") elif isinstance(val, (list, tuple)): - row.append(f"{{{','.join(map(str, val))}}}") + formatted.append(f"{{{','.join(map(str, val))}}}") else: - row.append(str(val)) - yield "\t".join(row) + "\n" + formatted.append(str(val)) + yield "\t".join(formatted) + "\n" - self._stream_copy(self.edges_table_name, columns, edge_gen()) + self._stream_copy(table, columns, format_gen()) self._commit() def _stream_copy(self, table_name, columns, data_generator): diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py index 8ca999a..67ef3d9 100644 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -189,6 +189,17 @@ def _update_query(self, query, commit=True) -> None: def _commit(self) -> None: pass + @abstractmethod + def _bulk_insert(self, table, columns, rows) -> None: + """Insert rows using a backend-optimized bulk method. + + Args: + table: Table name. + columns: Column names. + rows: Iterable of row lists (Python values, not formatted strings). + """ + pass + def _node_attrs_to_columns(self, attrs): # default: each attribute maps to its own column return attrs @@ -311,6 +322,60 @@ def write_graph( delete=delete, ) + def bulk_write_graph( + self, + graph: Graph, + roi: Optional[Roi] = None, + write_nodes: bool = True, + write_edges: bool = True, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + ) -> None: + if write_nodes: + self.bulk_write_nodes(graph.nodes, roi=roi, attributes=node_attrs) + if write_edges: + self.bulk_write_edges( + graph.nodes, graph.edges, roi=roi, attributes=edge_attrs + ) + + def bulk_write_nodes(self, nodes, roi=None, attributes=None): + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + attrs = attributes if attributes is not None else list(self.node_attrs.keys()) + columns = ["id"] + list(attrs) + + def rows(): + for node_id, data in nodes.items(): + pos = self.__get_node_pos(data) + if roi is not None and not roi.contains(pos): + continue + yield [node_id] + [data.get(attr, None) for attr in attrs] + + self._bulk_insert(self.nodes_table_name, columns, rows()) + + def bulk_write_edges(self, nodes, edges, roi=None, attributes=None): + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + u_name, v_name = self.endpoint_names + attrs = ( + attributes if attributes is not None else list(self.edge_attrs.keys()) + ) + columns = [u_name, v_name] + list(attrs) + + def rows(): + for (u, v), data in edges.items(): + if not self.directed: + u, v = min(u, v), max(u, v) + if roi is not None: + pos_u = self.__get_node_pos(nodes[u]) + if pos_u is None or not roi.contains(pos_u): + continue + yield [u, v] + [data.get(attr, None) for attr in attrs] + + self._bulk_insert(self.edges_table_name, columns, rows()) + @property def node_attrs(self) -> dict[str, AttributeType]: return self._node_attrs if self._node_attrs is not None else {} diff --git a/src/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py index 9243677..86edb1c 100644 --- a/src/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -181,6 +181,42 @@ def _insert_query(self, table, columns, values, fail_if_exists=False, commit=Tru if commit: self.con.commit() + def _bulk_insert(self, table, columns, rows) -> None: + # Explode array columns for SQLite (same logic as _insert_query) + array_columns = ( + self.node_array_columns + if table == self.nodes_table_name + else self.edge_array_columns + ) + + exploded_columns = None + exploded_rows = [] + for row in rows: + exploded_cols = [] + exploded_vals = [] + for column, value in zip(columns, row): + if column in array_columns: + for c, v in zip(array_columns[column], value): + exploded_cols.append(c) + exploded_vals.append(v) + else: + exploded_cols.append(column) + exploded_vals.append(value) + if exploded_columns is None: + exploded_columns = exploded_cols + exploded_rows.append(exploded_vals) + + if not exploded_rows: + return + + insert_statement = ( + f"INSERT OR IGNORE INTO {table} " + f"({', '.join(exploded_columns)}) " + f"VALUES ({', '.join(['?'] * len(exploded_columns))})" + ) + self.cur.executemany(insert_statement, exploded_rows) + self.con.commit() + def _update_query(self, query, commit=True): try: self.cur.execute(query) diff --git a/tests/test_graph.py b/tests/test_graph.py index 5afefb1..e30ca70 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -6,11 +6,6 @@ from funlib.persistence.types import Vec -def _skip_if_bulk_unsupported(provider, write_method): - if write_method == "bulk" and not isinstance(provider, PgSQLGraphDatabase): - pytest.skip("Bulk write only supported on PostgreSQL") - - def _write_nodes(provider, nodes, write_method, **kwargs): if write_method == "bulk": provider.bulk_write_nodes(nodes, **kwargs) @@ -40,7 +35,7 @@ def test_graph_filtering(provider_factory, write_method): node_attrs={"position": Vec(float, 3), "selected": bool}, edge_attrs={"selected": bool}, ) - _skip_if_bulk_unsupported(graph_writer, write_method) + roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_writer[roi] @@ -89,7 +84,7 @@ def test_graph_filtering_complex(provider_factory, write_method): node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] @@ -191,7 +186,7 @@ def test_graph_read_unbounded_roi(provider_factory, write_method): node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, edge_attrs={"selected": bool, "a": int, "b": int}, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + roi = Roi((0, 0, 0), (10, 10, 10)) unbounded_roi = Roi((None, None, None), (None, None, None)) @@ -256,7 +251,7 @@ def test_graph_io(provider_factory, write_method): "zap": str, }, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] @@ -383,7 +378,7 @@ def test_graph_write_attributes(provider_factory, write_method): "zap": str, }, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, position=[0, 0, 0]) @@ -438,7 +433,7 @@ def test_graph_write_roi(provider_factory, write_method): "zap": str, }, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, position=(0, 0, 0)) @@ -516,7 +511,7 @@ def test_graph_has_edge(provider_factory, write_method): "zap": str, }, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + roi = Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] @@ -549,7 +544,7 @@ def test_read_edges_join_vs_in_clause(provider_factory, write_method): "w", node_attrs={"position": Vec(float, 3)}, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + # Build a 3D grid graph graph = nx.Graph() @@ -636,7 +631,7 @@ def test_read_edges_fetch_on_v(provider_factory, write_method): "w", node_attrs={"position": Vec(float, 3)}, ) - _skip_if_bulk_unsupported(graph_provider, write_method) + roi = Roi((0, 0, 0), (6, 6, 6)) graph = nx.Graph() @@ -700,26 +695,10 @@ def edge_set(edges): assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} -def test_bulk_write_benchmark(provider_factory): - """Benchmark: standard write_graph vs bulk_write_graph (COPY). - - Only runs on PostgreSQL since bulk write uses COPY. - Uses blockwise writes for the standard path to avoid building a single - massive INSERT statement that blocks on remote connections. - """ - import time +def _build_grid_graph(size): + """Build a 3D grid graph with size^3 nodes and ~3*size^2*(size-1) edges.""" from itertools import product - size = 30 # 30^3 = 27,000 nodes - block_size = 10 - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3)}, - ) - if not isinstance(graph_provider, PgSQLGraphDatabase): - pytest.skip("Bulk write only supported on PostgreSQL") - - # Build a 3D grid graph graph = nx.Graph() for x, y, z in product(range(size), repeat=3): node_id = x * size * size + y * size + z @@ -730,16 +709,29 @@ def test_bulk_write_benchmark(provider_factory): graph.add_edge(node_id, x * size * size + (y - 1) * size + z) if z > 0: graph.add_edge(node_id, x * size * size + y * size + (z - 1)) + return graph + + +def test_bulk_write_benchmark(provider_factory): + """Benchmark: standard write_graph vs bulk_write_graph.""" + import time + + graph_provider = provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) + size = 30 # 30^3 = 27,000 nodes + graph = _build_grid_graph(size) n_nodes = graph.number_of_nodes() n_edges = graph.number_of_edges() - # --- Standard write (blockwise to avoid giant INSERT statements) --- + # --- Standard write --- t0 = time.perf_counter() graph_provider.write_graph(graph) t_standard = time.perf_counter() - t0 - # Verify standard write then close connection to release locks + # Verify then close connection to release locks graph_reader = provider_factory("r") result = graph_reader.read_graph() assert result.number_of_nodes() == n_nodes @@ -762,6 +754,6 @@ def test_bulk_write_benchmark(provider_factory): assert result.number_of_edges() == n_edges print(f"\n--- write benchmark ({n_nodes:,} nodes, {n_edges:,} edges) ---") - print(f"Standard (blockwise): {t_standard*1000:.1f} ms") - print(f"Bulk (COPY): {t_bulk*1000:.1f} ms") - print(f"Speedup: {t_standard / t_bulk:.2f}x") + print(f"Standard: {t_standard*1000:.1f} ms") + print(f"Bulk: {t_bulk*1000:.1f} ms") + print(f"Speedup: {t_standard / t_bulk:.2f}x") From 39a9f8049448c4d337038a326fb6825dc7100d35 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 20:02:59 +0000 Subject: [PATCH 19/38] make tests more reliable with a context manager to handle closing connections. Fixes a very frustraiting permanent hang that can occur due to unclosed postgresql connections --- .../graphs/pgsql_graph_database.py | 4 + .../graphs/sqlite_graph_database.py | 3 + tests/conftest.py | 27 +- tests/test_graph.py | 706 ++++++++---------- 4 files changed, 350 insertions(+), 390 deletions(-) diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index c89e2a4..0afff6b 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -76,6 +76,10 @@ def __init__( edge_attrs=edge_attrs, # type: ignore ) + def close(self): + if not self.connection.closed: + self.connection.close() + def _drop_edges(self) -> None: logger.info("dropping edges table %s", self.edges_table_name) self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}") diff --git a/src/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py index 86edb1c..64534bf 100644 --- a/src/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -47,6 +47,9 @@ def __init__( edge_attrs=edge_attrs, ) + def close(self): + self.con.close() + @property def node_array_columns(self): if not self._node_array_columns: diff --git a/tests/conftest.py b/tests/conftest.py index 6f1aec5..a1fd5c6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from contextlib import contextmanager from pathlib import Path import psycopg2 @@ -52,17 +53,13 @@ def can_connect_to_psql(): ) ) def provider_factory(request, tmpdir): - # provides a factory function to generate graph provider - # can provide either mongodb graph provider or file graph provider - # if file graph provider, will generate graph in a temporary directory - # to avoid artifacts - tmpdir = Path(tmpdir) + @contextmanager def sqlite_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): - return SQLiteGraphDataBase( + provider = SQLiteGraphDataBase( tmpdir / "test_sqlite_graph.db", position_attribute="position", mode=mode, @@ -71,9 +68,12 @@ def sqlite_provider_factory( node_attrs=node_attrs, edge_attrs=edge_attrs, ) + try: + yield provider + finally: + provider.close() - providers = [] - + @contextmanager def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): @@ -91,8 +91,10 @@ def psql_provider_factory( node_attrs=node_attrs, edge_attrs=edge_attrs, ) - providers.append(provider) - return provider + try: + yield provider + finally: + provider.close() if request.param == "sqlite": yield sqlite_provider_factory @@ -101,11 +103,6 @@ def psql_provider_factory( else: raise ValueError() - # Close all psql connections to avoid stale transactions - for provider in providers: - if hasattr(provider, "connection"): - provider.connection.close() - @pytest.fixture(params=["standard", "bulk"]) def write_method(request): diff --git a/tests/test_graph.py b/tests/test_graph.py index e30ca70..2030d3b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,7 +2,6 @@ import pytest from funlib.geometry import Roi -from funlib.persistence.graphs import PgSQLGraphDatabase from funlib.persistence.types import Vec @@ -30,15 +29,8 @@ def _write_graph(provider, graph, write_method, **kwargs): def test_graph_filtering(provider_factory, write_method): - graph_writer = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool}, - edge_attrs={"selected": bool}, - ) - roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_writer[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True) graph.add_node(42, position=(1, 1, 1), selected=False) graph.add_node(23, position=(5, 5, 5), selected=True) @@ -47,214 +39,210 @@ def test_graph_filtering(provider_factory, write_method): graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) - _write_nodes(graph_writer, graph.nodes(), write_method) - _write_edges(graph_writer, graph.nodes(), graph.edges(), write_method) - - graph_reader = provider_factory("r") - - filtered_nodes = graph_reader.read_nodes(roi, attr_filter={"selected": True}) - filtered_node_ids = [node["id"] for node in filtered_nodes] - expected_node_ids = [2, 23, 57] - assert expected_node_ids == filtered_node_ids - - filtered_edges = graph_reader.read_edges(roi, attr_filter={"selected": True}) - filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] - expected_edge_endpoints = [(57, 23), (2, 42)] - for u, v in expected_edge_endpoints: - assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints - - filtered_subgraph = graph_reader.read_graph( - roi, nodes_filter={"selected": True}, edges_filter={"selected": True} - ) - nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "position" in data - ] - assert expected_node_ids == nodes_with_position - assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) - for u, v in expected_edge_endpoints: - assert (u, v) in filtered_subgraph.edges() or ( - v, - u, - ) in filtered_subgraph.edges() + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool}, + edge_attrs={"selected": bool}, + ) as graph_writer: + _write_nodes(graph_writer, graph.nodes(), write_method) + _write_edges(graph_writer, graph.nodes(), graph.edges(), write_method) + + with provider_factory("r") as graph_reader: + filtered_nodes = graph_reader.read_nodes(roi, attr_filter={"selected": True}) + filtered_node_ids = [node["id"] for node in filtered_nodes] + expected_node_ids = [2, 23, 57] + assert expected_node_ids == filtered_node_ids + + filtered_edges = graph_reader.read_edges(roi, attr_filter={"selected": True}) + filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] + expected_edge_endpoints = [(57, 23), (2, 42)] + for u, v in expected_edge_endpoints: + assert (u, v) in filtered_edge_endpoints or ( + v, + u, + ) in filtered_edge_endpoints + + filtered_subgraph = graph_reader.read_graph( + roi, nodes_filter={"selected": True}, edges_filter={"selected": True} + ) + nodes_with_position = [ + node + for node, data in filtered_subgraph.nodes(data=True) + if "position" in data + ] + assert expected_node_ids == nodes_with_position + assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) + for u, v in expected_edge_endpoints: + assert (u, v) in filtered_subgraph.edges() or ( + v, + u, + ) in filtered_subgraph.edges() def test_graph_filtering_complex(provider_factory, write_method): - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, - edge_attrs={"selected": bool, "a": int, "b": int}, - ) - roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") graph.add_node(57, position=(7, 7, 7), selected=True, test="test") - graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - _write_nodes(graph_provider, graph.nodes(), write_method) - _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - - graph_provider = provider_factory("r") - - filtered_nodes = graph_provider.read_nodes( - roi, attr_filter={"selected": True, "test": "test"} - ) - filtered_node_ids = [node["id"] for node in filtered_nodes] - expected_node_ids = [2, 57] - assert expected_node_ids == filtered_node_ids - - filtered_edges = graph_provider.read_edges( - roi, attr_filter={"selected": True, "a": 100} - ) - filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] - expected_edge_endpoints = [(57, 23)] - for u, v in expected_edge_endpoints: - assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints - - filtered_subgraph = graph_provider.read_graph( - roi, - nodes_filter={"selected": True, "test": "test"}, - edges_filter={"selected": True, "a": 100}, - ) - nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "position" in data - ] - assert expected_node_ids == nodes_with_position - assert len(filtered_subgraph.edges()) == 0 + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, + edge_attrs={"selected": bool, "a": int, "b": int}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) + + with provider_factory("r") as graph_provider: + filtered_nodes = graph_provider.read_nodes( + roi, attr_filter={"selected": True, "test": "test"} + ) + filtered_node_ids = [node["id"] for node in filtered_nodes] + expected_node_ids = [2, 57] + assert expected_node_ids == filtered_node_ids + + filtered_edges = graph_provider.read_edges( + roi, attr_filter={"selected": True, "a": 100} + ) + filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] + expected_edge_endpoints = [(57, 23)] + for u, v in expected_edge_endpoints: + assert (u, v) in filtered_edge_endpoints or ( + v, + u, + ) in filtered_edge_endpoints + + filtered_subgraph = graph_provider.read_graph( + roi, + nodes_filter={"selected": True, "test": "test"}, + edges_filter={"selected": True, "a": 100}, + ) + nodes_with_position = [ + node + for node, data in filtered_subgraph.nodes(data=True) + if "position" in data + ] + assert expected_node_ids == nodes_with_position + assert len(filtered_subgraph.edges()) == 0 def test_graph_read_and_update_specific_attrs(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, - edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, - ) roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") graph.add_node(57, position=(7, 7, 7), selected=True, test="test") - graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph_provider.write_graph(graph) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, + edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, + ) as graph_provider: + graph_provider.write_graph(graph) - graph_provider = provider_factory("r+") - limited_graph = graph_provider.read_graph( - roi, node_attrs=["selected"], edge_attrs=["c"] - ) + with provider_factory("r+") as graph_provider: + limited_graph = graph_provider.read_graph( + roi, node_attrs=["selected"], edge_attrs=["c"] + ) - for node, data in limited_graph.nodes(data=True): - assert "test" not in data - assert "selected" in data - data["selected"] = True + for node, data in limited_graph.nodes(data=True): + assert "test" not in data + assert "selected" in data + data["selected"] = True - for u, v, data in limited_graph.edges(data=True): - assert "a" not in data - assert "b" not in data - nx.set_edge_attributes(limited_graph, 5, "c") + for u, v, data in limited_graph.edges(data=True): + assert "a" not in data + assert "b" not in data + nx.set_edge_attributes(limited_graph, 5, "c") - try: - graph_provider.write_attrs( - limited_graph, edge_attrs=["c"], node_attrs=["selected"] - ) - except NotImplementedError: - pytest.xfail() + try: + graph_provider.write_attrs( + limited_graph, edge_attrs=["c"], node_attrs=["selected"] + ) + except NotImplementedError: + pytest.xfail() - updated_graph = graph_provider.read_graph(roi) + updated_graph = graph_provider.read_graph(roi) - for node, data in updated_graph.nodes(data=True): - assert data["selected"] + for node, data in updated_graph.nodes(data=True): + assert data["selected"] - for u, v, data in updated_graph.edges(data=True): - assert data["c"] == 5 + for u, v, data in updated_graph.edges(data=True): + assert data["c"] == 5 def test_graph_read_unbounded_roi(provider_factory, write_method): - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, - edge_attrs={"selected": bool, "a": int, "b": int}, - ) - roi = Roi((0, 0, 0), (10, 10, 10)) unbounded_roi = Roi((None, None, None), (None, None, None)) - - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") graph.add_node(57, position=(7, 7, 7), selected=True, test="test") - graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - _write_nodes(graph_provider, graph.nodes(), write_method) - _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, + edge_attrs={"selected": bool, "a": int, "b": int}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - graph_provider = provider_factory("r+") - limited_graph = graph_provider.read_graph( - unbounded_roi, node_attrs=["selected"], edge_attrs=["c"] - ) + with provider_factory("r+") as graph_provider: + limited_graph = graph_provider.read_graph( + unbounded_roi, node_attrs=["selected"], edge_attrs=["c"] + ) - seen = [] - for node, data in limited_graph.nodes(data=True): - assert "test" not in data - assert "selected" in data - data["selected"] = True - seen.append(node) + seen = [] + for node, data in limited_graph.nodes(data=True): + assert "test" not in data + assert "selected" in data + data["selected"] = True + seen.append(node) - assert sorted([2, 42, 23, 57]) == sorted(seen) + assert sorted([2, 42, 23, 57]) == sorted(seen) def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - provider_factory("w", True, roi, node_attrs={"position": Vec(float, 3)}) - graph_provider = provider_factory("r", None, None) - assert True == graph_provider.directed - assert roi == graph_provider.total_roi + with provider_factory( + "w", True, roi, node_attrs={"position": Vec(float, 3)} + ): + pass + with provider_factory("r", None, None) as graph_provider: + assert True == graph_provider.directed + assert roi == graph_provider.total_roi def test_graph_default_meta_values(provider_factory): - provider = provider_factory( + with provider_factory( "w", False, None, node_attrs={"position": Vec(float, 3)} - ) - assert False == provider.directed - assert provider.total_roi is None or provider.total_roi == Roi( - (None, None, None), (None, None, None) - ) - graph_provider = provider_factory("r", False, None) - assert False == graph_provider.directed - assert graph_provider.total_roi is None or graph_provider.total_roi == Roi( - (None, None, None), (None, None, None) - ) + ) as provider: + assert False == provider.directed + assert provider.total_roi is None or provider.total_roi == Roi( + (None, None, None), (None, None, None) + ) + with provider_factory("r", False, None) as graph_provider: + assert False == graph_provider.directed + assert graph_provider.total_roi is None or graph_provider.total_roi == Roi( + (None, None, None), (None, None, None) + ) def test_graph_io(provider_factory, write_method): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - - - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -263,11 +251,15 @@ def test_graph_io(provider_factory, write_method): graph.add_edge(57, 23) graph.add_edge(2, 42) - _write_nodes(graph_provider, graph.nodes(), write_method) - _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - graph_provider = provider_factory("r") - compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] + with provider_factory("r") as graph_provider: + compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] nodes = sorted(list(graph.nodes())) nodes.remove(2) # node 2 has no position and will not be queried @@ -282,16 +274,7 @@ def test_graph_io(provider_factory, write_method): def test_graph_fail_if_exists(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -300,59 +283,57 @@ def test_graph_fail_if_exists(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph_provider.write_graph(graph) - with pytest.raises(Exception): - graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) - with pytest.raises(Exception): - graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + graph_provider.write_graph(graph) + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges( + graph.nodes(), graph.edges(), fail_if_exists=True + ) def test_graph_duplicate_insert_behavior(provider_factory): """Test that fail_if_exists controls whether duplicate inserts raise.""" - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool}, - edge_attrs={"selected": bool}, - ) roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True) graph.add_node(42, position=(1, 1, 1), selected=False) graph.add_edge(2, 42, selected=True) - # Initial write - graph_provider.write_nodes(graph.nodes()) - graph_provider.write_edges(graph.nodes(), graph.edges()) - - # fail_if_exists=True should raise on duplicate nodes and edges - with pytest.raises(Exception): - graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) - with pytest.raises(Exception): - graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True) - - # fail_if_exists=False should silently ignore duplicates - graph_provider.write_nodes(graph.nodes(), fail_if_exists=False) - graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=False) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool}, + edge_attrs={"selected": bool}, + ) as graph_provider: + # Initial write + graph_provider.write_nodes(graph.nodes()) + graph_provider.write_edges(graph.nodes(), graph.edges()) + + # fail_if_exists=True should raise on duplicate nodes and edges + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges( + graph.nodes(), graph.edges(), fail_if_exists=True + ) + + # fail_if_exists=False should silently ignore duplicates + graph_provider.write_nodes(graph.nodes(), fail_if_exists=False) + graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=False) # Verify the original data is still intact - graph_provider = provider_factory("r") - result = graph_provider.read_graph(roi) - assert set(result.nodes()) == {2, 42} - assert len(result.edges()) == 1 + with provider_factory("r") as graph_provider: + result = graph_provider.read_graph(roi) + assert set(result.nodes()) == {2, 42} + assert len(result.edges()) == 1 def test_graph_fail_if_not_exists(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -361,26 +342,20 @@ def test_graph_fail_if_not_exists(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - with pytest.raises(Exception): - graph_provider.write_nodes(graph.nodes(), fail_if_not_exists=True) - with pytest.raises(Exception): - graph_provider.write_edges( - graph.nodes(), graph.edges(), fail_if_not_exists=True - ) - - -def test_graph_write_attributes(provider_factory, write_method): - graph_provider = provider_factory( + with provider_factory( "w", - node_attrs={ - "position": Vec(int, 3), - "swip": str, - "zap": str, - }, - ) + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_not_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges( + graph.nodes(), graph.edges(), fail_if_not_exists=True + ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] +def test_graph_write_attributes(provider_factory, write_method): + graph = nx.Graph() graph.add_node(2, position=[0, 0, 0]) graph.add_node(42, position=[1, 1, 1]) graph.add_node(23, position=[5, 5, 5], swip="swap") @@ -389,15 +364,22 @@ def test_graph_write_attributes(provider_factory, write_method): graph.add_edge(57, 23) graph.add_edge(2, 42) - _write_graph( - graph_provider, graph, write_method, - write_nodes=True, write_edges=False, node_attrs=["position", "swip"], - ) - - _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) + with provider_factory( + "w", + node_attrs={"position": Vec(int, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_graph( + graph_provider, + graph, + write_method, + write_nodes=True, + write_edges=False, + node_attrs=["position", "swip"], + ) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - graph_provider = provider_factory("r") - compare_graph = graph_provider[Roi((1, 1, 1), (10, 10, 10))] + with provider_factory("r") as graph_provider: + compare_graph = graph_provider[Roi((1, 1, 1), (10, 10, 10))] nodes = [] for node, data in graph.nodes(data=True): @@ -425,17 +407,7 @@ def test_graph_write_attributes(provider_factory, write_method): def test_graph_write_roi(provider_factory, write_method): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -445,10 +417,14 @@ def test_graph_write_roi(provider_factory, write_method): graph.add_edge(2, 42) write_roi = Roi((0, 0, 0), (6, 6, 6)) - _write_graph(graph_provider, graph, write_method, roi=write_roi) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_graph(graph_provider, graph, write_method, roi=write_roi) - graph_provider = provider_factory("r") - compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] + with provider_factory("r") as graph_provider: + compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] nodes = sorted(list(graph.nodes())) nodes.remove(2) # node 2 has no position and will not be queried @@ -465,22 +441,19 @@ def test_graph_write_roi(provider_factory, write_method): def test_graph_connected_components(provider_factory): - graph_provider = provider_factory( + with provider_factory( "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] + + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") + graph.add_edge(57, 23) + graph.add_edge(2, 42) - graph.add_node(2, position=(0, 0, 0)) - graph.add_node(42, position=(1, 1, 1)) - graph.add_node(23, position=(5, 5, 5), swip="swap") - graph.add_node(57, position=(7, 7, 7), zap="zip") - graph.add_edge(57, 23) - graph.add_edge(2, 42) try: components = list(nx.connected_components(graph)) except NotImplementedError: @@ -503,19 +476,8 @@ def test_graph_connected_components(provider_factory): def test_graph_has_edge(provider_factory, write_method): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - - roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -524,10 +486,15 @@ def test_graph_has_edge(provider_factory, write_method): graph.add_edge(57, 23) write_roi = Roi((0, 0, 0), (6, 6, 6)) - _write_nodes(graph_provider, graph.nodes(), write_method, roi=write_roi) - _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method, roi=write_roi) - - assert graph_provider.has_edges(roi) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method, roi=write_roi) + _write_edges( + graph_provider, graph.nodes(), graph.edges(), write_method, roi=write_roi + ) + assert graph_provider.has_edges(roi) def test_read_edges_join_vs_in_clause(provider_factory, write_method): @@ -540,18 +507,10 @@ def test_read_edges_join_vs_in_clause(provider_factory, write_method): from itertools import product size = 50 # 50^3 = 125,000 nodes - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3)}, - ) - - - # Build a 3D grid graph graph = nx.Graph() for x, y, z in product(range(size), repeat=3): node_id = x * size * size + y * size + z graph.add_node(node_id, position=(x + 0.5, y + 0.5, z + 0.5)) - # Connect to neighbors in +x, +y, +z directions if x > 0: graph.add_edge(node_id, (x - 1) * size * size + y * size + z) if y > 0: @@ -559,30 +518,32 @@ def test_read_edges_join_vs_in_clause(provider_factory, write_method): if z > 0: graph.add_edge(node_id, x * size * size + y * size + (z - 1)) - _write_graph(graph_provider, graph, write_method) - - # Re-open in read mode - graph_provider = provider_factory("r") - - query_roi = Roi((10, 10, 10), (30, 30, 30)) - n_repeats = 5 - - # --- Old approach: read_nodes, then read_edges with nodes list --- - times_in_clause = [] - for _ in range(n_repeats): - t0 = time.perf_counter() - nodes = graph_provider.read_nodes(query_roi) - edges_via_in = graph_provider.read_edges(nodes=nodes) - t1 = time.perf_counter() - times_in_clause.append(t1 - t0) - - # --- New approach: read_edges with roi (JOIN) --- - times_join = [] - for _ in range(n_repeats): - t0 = time.perf_counter() - edges_via_join = graph_provider.read_edges(roi=query_roi) - t1 = time.perf_counter() - times_join.append(t1 - t0) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + _write_graph(graph_provider, graph, write_method) + + with provider_factory("r") as graph_provider: + query_roi = Roi((10, 10, 10), (30, 30, 30)) + n_repeats = 5 + + # --- Old approach: read_nodes, then read_edges with nodes list --- + times_in_clause = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + nodes = graph_provider.read_nodes(query_roi) + edges_via_in = graph_provider.read_edges(nodes=nodes) + t1 = time.perf_counter() + times_in_clause.append(t1 - t0) + + # --- New approach: read_edges with roi (JOIN) --- + times_join = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + edges_via_join = graph_provider.read_edges(roi=query_roi) + t1 = time.perf_counter() + times_join.append(t1 - t0) avg_in = sum(times_in_clause) / n_repeats avg_join = sum(times_join) / n_repeats @@ -617,8 +578,8 @@ def test_read_edges_fetch_on_v(provider_factory, write_method): - Edge(5, 8): u=5 in ROI, v=8 outside ROI - Edge(8, 9): u=8 outside ROI, v=9 outside ROI - fetch_on_v=False (default): only edges where u is in ROI → {(1,5), (2,8), (5,8)} - fetch_on_v=True: edges where u OR v is in ROI → {(1,5), (2,8), (5,8)} + fetch_on_v=False (default): only edges where u is in ROI -> {(1,5), (2,8), (5,8)} + fetch_on_v=True: edges where u OR v is in ROI -> {(1,5), (2,8), (5,8)} (same here because u < v and all boundary-crossing edges have u inside) To properly test fetch_on_v, we need an edge where u is OUTSIDE the ROI @@ -627,13 +588,7 @@ def test_read_edges_fetch_on_v(provider_factory, write_method): So we add: Node 0 (pos 8) -- Edge(0, 5): u=0 outside ROI, v=5 in ROI. """ - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3)}, - ) - roi = Roi((0, 0, 0), (6, 6, 6)) - graph = nx.Graph() # Nodes inside ROI (positions < 6) graph.add_node(1, position=(1.0, 1.0, 1.0)) @@ -644,7 +599,6 @@ def test_read_edges_fetch_on_v(provider_factory, write_method): graph.add_node(0, position=(8.0, 8.0, 8.0)) graph.add_node(8, position=(8.0, 8.0, 8.0)) graph.add_node(9, position=(9.0, 9.0, 9.0)) - # Edges: undirected, stored as u < v graph.add_edge(1, 5) # both in ROI graph.add_edge(2, 8) # u in ROI, v outside @@ -652,47 +606,51 @@ def test_read_edges_fetch_on_v(provider_factory, write_method): graph.add_edge(8, 9) # both outside ROI graph.add_edge(0, 5) # u=0 OUTSIDE ROI, v=5 INSIDE ROI (key test edge) - _write_graph(graph_provider, graph, write_method) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + _write_graph(graph_provider, graph, write_method) - graph_provider = provider_factory("r") + with provider_factory("r") as graph_provider: - def edge_set(edges): - """Normalize edge list to set of sorted tuples for comparison.""" - return {(min(e["u"], e["v"]), max(e["u"], e["v"])) for e in edges} + def edge_set(edges): + """Normalize edge list to set of sorted tuples for comparison.""" + return {(min(e["u"], e["v"]), max(e["u"], e["v"])) for e in edges} - # --- Case 1: nodes passed explicitly --- - nodes_in_roi = graph_provider.read_nodes(roi) - node_ids_in_roi = {n["id"] for n in nodes_in_roi} - assert node_ids_in_roi == {1, 2, 5} + # --- Case 1: nodes passed explicitly --- + nodes_in_roi = graph_provider.read_nodes(roi) + node_ids_in_roi = {n["id"] for n in nodes_in_roi} + assert node_ids_in_roi == {1, 2, 5} - edges_u_only = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=False) - edges_u_and_v = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=True) + edges_u_only = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=False) + edges_u_and_v = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=True) - # fetch_on_v=False: only edges where u IN (1,2,5) - # (1,5), (2,8), (5,8) match; (0,5) does NOT match (u=0 not in list) - assert edge_set(edges_u_only) == {(1, 5), (2, 8), (5, 8)} + # fetch_on_v=False: only edges where u IN (1,2,5) + # (1,5), (2,8), (5,8) match; (0,5) does NOT match (u=0 not in list) + assert edge_set(edges_u_only) == {(1, 5), (2, 8), (5, 8)} - # fetch_on_v=True: edges where u OR v IN (1,2,5) - # (0,5) now matches because v=5 is in the list - assert edge_set(edges_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} + # fetch_on_v=True: edges where u OR v IN (1,2,5) + # (0,5) now matches because v=5 is in the list + assert edge_set(edges_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} - # --- Case 2: roi passed (JOIN path) --- - edges_roi_u_only = graph_provider.read_edges(roi=roi, fetch_on_v=False) - edges_roi_u_and_v = graph_provider.read_edges(roi=roi, fetch_on_v=True) + # --- Case 2: roi passed (JOIN path) --- + edges_roi_u_only = graph_provider.read_edges(roi=roi, fetch_on_v=False) + edges_roi_u_and_v = graph_provider.read_edges(roi=roi, fetch_on_v=True) - # Same expected results as Case 1 - assert edge_set(edges_roi_u_only) == {(1, 5), (2, 8), (5, 8)} - assert edge_set(edges_roi_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} + # Same expected results as Case 1 + assert edge_set(edges_roi_u_only) == {(1, 5), (2, 8), (5, 8)} + assert edge_set(edges_roi_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} - # --- Case 3: via read_graph --- - graph_u_only = graph_provider.read_graph(roi, fetch_on_v=False) - graph_u_and_v = graph_provider.read_graph(roi, fetch_on_v=True) + # --- Case 3: via read_graph --- + graph_u_only = graph_provider.read_graph(roi, fetch_on_v=False) + graph_u_and_v = graph_provider.read_graph(roi, fetch_on_v=True) - graph_edges_u_only = {tuple(sorted(e)) for e in graph_u_only.edges()} - graph_edges_u_and_v = {tuple(sorted(e)) for e in graph_u_and_v.edges()} + graph_edges_u_only = {tuple(sorted(e)) for e in graph_u_only.edges()} + graph_edges_u_and_v = {tuple(sorted(e)) for e in graph_u_and_v.edges()} - assert graph_edges_u_only == {(1, 5), (2, 8), (5, 8)} - assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} + assert graph_edges_u_only == {(1, 5), (2, 8), (5, 8)} + assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} def _build_grid_graph(size): @@ -716,42 +674,40 @@ def test_bulk_write_benchmark(provider_factory): """Benchmark: standard write_graph vs bulk_write_graph.""" import time - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3)}, - ) - size = 30 # 30^3 = 27,000 nodes graph = _build_grid_graph(size) n_nodes = graph.number_of_nodes() n_edges = graph.number_of_edges() # --- Standard write --- - t0 = time.perf_counter() - graph_provider.write_graph(graph) - t_standard = time.perf_counter() - t0 + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + t0 = time.perf_counter() + graph_provider.write_graph(graph) + t_standard = time.perf_counter() - t0 - # Verify then close connection to release locks - graph_reader = provider_factory("r") - result = graph_reader.read_graph() - assert result.number_of_nodes() == n_nodes - assert result.number_of_edges() == n_edges - graph_reader.connection.close() + # Verify standard write + with provider_factory("r") as graph_reader: + result = graph_reader.read_graph() + assert result.number_of_nodes() == n_nodes + assert result.number_of_edges() == n_edges # --- Bulk write (recreate tables) --- - graph_provider = provider_factory( + with provider_factory( "w", node_attrs={"position": Vec(float, 3)}, - ) - t0 = time.perf_counter() - graph_provider.bulk_write_graph(graph) - t_bulk = time.perf_counter() - t0 + ) as graph_provider: + t0 = time.perf_counter() + graph_provider.bulk_write_graph(graph) + t_bulk = time.perf_counter() - t0 # Verify bulk write - graph_reader = provider_factory("r") - result = graph_reader.read_graph() - assert result.number_of_nodes() == n_nodes - assert result.number_of_edges() == n_edges + with provider_factory("r") as graph_reader: + result = graph_reader.read_graph() + assert result.number_of_nodes() == n_nodes + assert result.number_of_edges() == n_edges print(f"\n--- write benchmark ({n_nodes:,} nodes, {n_edges:,} edges) ---") print(f"Standard: {t_standard*1000:.1f} ms") From 02f9cec118186bbf88d01d9bf7cf02bc412e1e84 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 20:22:01 +0000 Subject: [PATCH 20/38] add and use a `bulk_write_mode` context manager for bulk writes --- .../persistence/graphs/graph_database.py | 33 +++++++++ .../graphs/pgsql_graph_database.py | 72 ++++++++----------- .../graphs/sqlite_graph_database.py | 33 +++++++++ tests/test_graph.py | 9 +-- 4 files changed, 102 insertions(+), 45 deletions(-) diff --git a/src/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py index b75deb9..64f9cb5 100644 --- a/src/funlib/persistence/graphs/graph_database.py +++ b/src/funlib/persistence/graphs/graph_database.py @@ -172,3 +172,36 @@ def bulk_write_graph( ``delete``. """ pass + + @abstractmethod + def bulk_write_mode( + self, + worker: bool = False, + node_writes: bool = True, + edge_writes: bool = True, + ): + """Context manager that optimizes the database for bulk writes. + + Drops indexes and adjusts database settings for maximum write + throughput, then restores them on exit. + + Arguments: + + worker (``bool``): + + If ``False`` (default), drops and rebuilds indexes around the + block. Set to ``True`` for parallel workers whose orchestrator + manages indexes separately — only session-level performance + settings will be adjusted. + + node_writes (``bool``): + + If ``True`` (default), drop/rebuild node primary key and + position indexes. Ignored when ``worker=True``. + + edge_writes (``bool``): + + If ``True`` (default), drop/rebuild edge primary key index. + Ignored when ``worker=True``. + """ + pass diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index 0afff6b..b0ffa59 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -269,62 +269,52 @@ def fmt_bytes(n): print(f"Tables shown: {len(rows)} (schema={schema!r})") @contextmanager - def drop_indexes(self): - """ - Context manager for the orchestrator to use around bulk ingestion. - Drops all indexes and primary key constraints before yielding, then - recreates them after. This is a global (DDL) operation — call it - once from the main process, not from individual workers. - """ + def bulk_write_mode(self, worker=False, node_writes=True, edge_writes=True): nodes = self.nodes_table_name edges = self.edges_table_name endpoint_names = self.endpoint_names assert endpoint_names is not None - # Drop position index and primary key constraints - self.__exec("DROP INDEX IF EXISTS pos_index") - self.__exec( - f"ALTER TABLE {nodes} DROP CONSTRAINT IF EXISTS {nodes}_pkey" - ) - self.__exec( - f"ALTER TABLE {edges} DROP CONSTRAINT IF EXISTS {edges}_pkey" - ) - self._commit() - - try: - yield - finally: - self._commit() - logger.info("Re-creating indexes and constraints...") - self.__exec( - f"ALTER TABLE {nodes} " - f"ADD CONSTRAINT {nodes}_pkey PRIMARY KEY (id)" - ) - self.__exec( - f"CREATE INDEX IF NOT EXISTS pos_index ON " - f"{nodes}({self.position_attribute})" - ) - self.__exec( - f"ALTER TABLE {edges} " - f"ADD CONSTRAINT {edges}_pkey " - f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" - ) + if not worker: + if node_writes: + self.__exec("DROP INDEX IF EXISTS pos_index") + self.__exec( + f"ALTER TABLE {nodes} DROP CONSTRAINT IF EXISTS {nodes}_pkey" + ) + if edge_writes: + self.__exec( + f"ALTER TABLE {edges} DROP CONSTRAINT IF EXISTS {edges}_pkey" + ) self._commit() - @contextmanager - def fast_session(self): - """ - Per-connection context manager for workers to use during bulk ingestion. - Disables synchronous commit for the duration of the session. - """ self.__exec("SET synchronous_commit TO OFF") self._commit() + try: yield finally: self.__exec("SET synchronous_commit TO ON") self._commit() + if not worker: + logger.info("Re-creating indexes and constraints...") + if node_writes: + self.__exec( + f"ALTER TABLE {nodes} " + f"ADD CONSTRAINT {nodes}_pkey PRIMARY KEY (id)" + ) + self.__exec( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{nodes}({self.position_attribute})" + ) + if edge_writes: + self.__exec( + f"ALTER TABLE {edges} " + f"ADD CONSTRAINT {edges}_pkey " + f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" + ) + self._commit() + def _bulk_insert(self, table, columns, rows) -> None: def format_gen(): for row in rows: diff --git a/src/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py index 64534bf..de5e11a 100644 --- a/src/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -2,6 +2,7 @@ import logging import re import sqlite3 +from contextlib import contextmanager from pathlib import Path from typing import Any, Optional @@ -50,6 +51,38 @@ def __init__( def close(self): self.con.close() + @contextmanager + def bulk_write_mode(self, worker=False, node_writes=True, edge_writes=True): + prev_sync = self.cur.execute("PRAGMA synchronous").fetchone()[0] + self.cur.execute("PRAGMA synchronous=OFF") + self.cur.execute("PRAGMA journal_mode=WAL") + self.con.commit() + + if not worker and node_writes: + self.cur.execute("DROP INDEX IF EXISTS pos_index") + self.con.commit() + + try: + yield + finally: + self.con.commit() + + if not worker and node_writes: + if self.ndims > 1: # type: ignore + position_columns = self.node_array_columns[ + self.position_attribute + ] + else: + position_columns = [self.position_attribute] + self.cur.execute( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{self.nodes_table_name}({','.join(position_columns)})" + ) + self.con.commit() + + self.cur.execute(f"PRAGMA synchronous={prev_sync}") + self.con.commit() + @property def node_array_columns(self): if not self._node_array_columns: diff --git a/tests/test_graph.py b/tests/test_graph.py index 2030d3b..4db168f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -694,14 +694,15 @@ def test_bulk_write_benchmark(provider_factory): assert result.number_of_nodes() == n_nodes assert result.number_of_edges() == n_edges - # --- Bulk write (recreate tables) --- + # --- Bulk write (recreate tables, with bulk_write_mode) --- with provider_factory( "w", node_attrs={"position": Vec(float, 3)}, ) as graph_provider: - t0 = time.perf_counter() - graph_provider.bulk_write_graph(graph) - t_bulk = time.perf_counter() - t0 + with graph_provider.bulk_write_mode(): + t0 = time.perf_counter() + graph_provider.bulk_write_graph(graph) + t_bulk = time.perf_counter() - t0 # Verify bulk write with provider_factory("r") as graph_reader: From 3cc47547caa03b563733d7b160637d566dd44dab Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 20:50:37 +0000 Subject: [PATCH 21/38] add test for node exclusion on upper bound of ROI --- tests/test_graph.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index 4db168f..1cb09b0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -653,6 +653,65 @@ def edge_set(edges): assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} +def test_graph_roi_upper_bound_exclusive(provider_factory): + """Nodes at exactly the upper bound of the ROI must be excluded. + + ROI is half-open [begin, end). A node whose position equals end in any + dimension should NOT appear in read_nodes or read_graph results. + + Regression test for: https://github.com/funkelab/funlib.persistence/issues/XX + """ + roi = Roi((0, 0, 0), (10, 10, 10)) # [0, 10) in each dim + + graph = nx.Graph() + # Interior node — clearly inside + graph.add_node(1, position=(5.0, 5.0, 5.0)) + # Node exactly on lower bound — should be included + graph.add_node(2, position=(0.0, 0.0, 0.0)) + # Nodes exactly on upper bound — should be excluded + graph.add_node(3, position=(10.0, 5.0, 5.0)) # x == end + graph.add_node(4, position=(5.0, 10.0, 5.0)) # y == end + graph.add_node(5, position=(5.0, 5.0, 10.0)) # z == end + graph.add_node(6, position=(10.0, 10.0, 10.0)) # all dims == end + # Edge crossing the boundary (u inside, v on boundary) + graph.add_edge(1, 3) + + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as gp: + gp.write_graph(graph) + + with provider_factory("r") as gp: + # read_nodes: only nodes strictly inside [0, 10) + nodes = gp.read_nodes(roi) + node_ids = {n["id"] for n in nodes} + assert node_ids == {1, 2}, f"Expected {{1, 2}}, got {node_ids}" + + # read_graph: same node set, edge (1,3) should still appear + # because node 3 is pulled in as a bare node via the edge + result = gp.read_graph(roi) + result_node_ids = set(result.nodes()) + # Node 3 may appear as a bare node (no position) via the edge + assert 1 in result_node_ids + assert 2 in result_node_ids + # Nodes 4, 5, 6 have no edges to interior nodes — must not appear + assert 4 not in result_node_ids + assert 5 not in result_node_ids + assert 6 not in result_node_ids + + # Verify that nodes returned by read_nodes all have positions inside ROI + for node in nodes: + pos = node["position"] + for dim in range(3): + assert pos[dim] >= roi.begin[dim], ( + f"Node {node['id']} pos[{dim}]={pos[dim]} < roi.begin={roi.begin[dim]}" + ) + assert pos[dim] < roi.end[dim], ( + f"Node {node['id']} pos[{dim}]={pos[dim]} >= roi.end={roi.end[dim]}" + ) + + def _build_grid_graph(size): """Build a 3D grid graph with size^3 nodes and ~3*size^2*(size-1) edges.""" from itertools import product From 5c5489cbab63692c3d3f2ad512895551b17a274c Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 20:54:26 +0000 Subject: [PATCH 22/38] ints should be stored as bigint. cost is small and makes dealing with large numbers like fragment ids much easier --- src/funlib/persistence/graphs/pgsql_graph_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index b0ffa59..b00cf23 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -205,7 +205,7 @@ def __sql_type(self, type): if isinstance(type, Vec): return self.__sql_type(type.dtype) + f"[{type.size}]" try: - return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[ + return {bool: "BOOLEAN", int: "BIGINT", str: "VARCHAR", float: "REAL"}[ type ] except ValueError: From 2cece20076dc7e50998b6f728599fedaf2c680b4 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 21:20:38 +0000 Subject: [PATCH 23/38] ruff formatting and linting --- .../persistence/graphs/pgsql_graph_database.py | 4 +--- src/funlib/persistence/graphs/sql_graph_database.py | 11 +++-------- .../persistence/graphs/sqlite_graph_database.py | 4 +--- tests/test_graph.py | 13 +++++-------- 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index b00cf23..8fe2453 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -205,9 +205,7 @@ def __sql_type(self, type): if isinstance(type, Vec): return self.__sql_type(type.dtype) + f"[{type.size}]" try: - return {bool: "BOOLEAN", int: "BIGINT", str: "VARCHAR", float: "REAL"}[ - type - ] + return {bool: "BOOLEAN", int: "BIGINT", str: "VARCHAR", float: "REAL"}[type] except ValueError: raise NotImplementedError( f"attributes of type {type} are not yet supported" diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py index 67ef3d9..3a12ca2 100644 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -359,9 +359,7 @@ def bulk_write_edges(self, nodes, edges, roi=None, attributes=None): raise RuntimeError("Trying to write to read-only DB") u_name, v_name = self.endpoint_names - attrs = ( - attributes if attributes is not None else list(self.edge_attrs.keys()) - ) + attrs = attributes if attributes is not None else list(self.edge_attrs.keys()) columns = [u_name, v_name] + list(attrs) def rows(): @@ -505,9 +503,7 @@ def read_edges( join_condition = f"T1.{endpoint_names[0]} = T2.{node_id_column}" if fetch_on_v: - join_condition += ( - f" OR T1.{endpoint_names[1]} = T2.{node_id_column}" - ) + join_condition += f" OR T1.{endpoint_names[1]} = T2.{node_id_column}" select_statement = ( f"SELECT DISTINCT {edge_cols} " @@ -524,8 +520,7 @@ def read_edges( if using_join: # Qualify each attribute with T1 for the JOIN case parts = [ - f"T1.{k}={self.__convert_to_sql(v)}" - for k, v in attr_filter.items() + f"T1.{k}={self.__convert_to_sql(v)}" for k, v in attr_filter.items() ] where_clauses.append(" AND ".join(parts)) else: diff --git a/src/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py index de5e11a..38b4080 100644 --- a/src/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -69,9 +69,7 @@ def bulk_write_mode(self, worker=False, node_writes=True, edge_writes=True): if not worker and node_writes: if self.ndims > 1: # type: ignore - position_columns = self.node_array_columns[ - self.position_attribute - ] + position_columns = self.node_array_columns[self.position_attribute] else: position_columns = [self.position_attribute] self.cur.execute( diff --git a/tests/test_graph.py b/tests/test_graph.py index 1cb09b0..9451c2c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -181,7 +181,6 @@ def test_graph_read_and_update_specific_attrs(provider_factory): def test_graph_read_unbounded_roi(provider_factory, write_method): - roi = Roi((0, 0, 0), (10, 10, 10)) unbounded_roi = Roi((None, None, None), (None, None, None)) graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") @@ -217,9 +216,7 @@ def test_graph_read_unbounded_roi(provider_factory, write_method): def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - with provider_factory( - "w", True, roi, node_attrs={"position": Vec(float, 3)} - ): + with provider_factory("w", True, roi, node_attrs={"position": Vec(float, 3)}): pass with provider_factory("r", None, None) as graph_provider: assert True == graph_provider.directed @@ -549,8 +546,8 @@ def test_read_edges_join_vs_in_clause(provider_factory, write_method): avg_join = sum(times_join) / n_repeats print(f"\n--- read_edges benchmark (roi covers {30**3:,} of {size**3:,} nodes) ---") - print(f"IN clause (2 queries): {avg_in*1000:.1f} ms avg") - print(f"JOIN (1 query): {avg_join*1000:.1f} ms avg") + print(f"IN clause (2 queries): {avg_in * 1000:.1f} ms avg") + print(f"JOIN (1 query): {avg_join * 1000:.1f} ms avg") print(f"Speedup: {avg_in / avg_join:.2f}x") # Both should return edges — just verify they're non-empty and reasonable @@ -770,6 +767,6 @@ def test_bulk_write_benchmark(provider_factory): assert result.number_of_edges() == n_edges print(f"\n--- write benchmark ({n_nodes:,} nodes, {n_edges:,} edges) ---") - print(f"Standard: {t_standard*1000:.1f} ms") - print(f"Bulk: {t_bulk*1000:.1f} ms") + print(f"Standard: {t_standard * 1000:.1f} ms") + print(f"Bulk: {t_bulk * 1000:.1f} ms") print(f"Speedup: {t_standard / t_bulk:.2f}x") From f655129bb8418592e160679682549cb2988231de Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 21:39:16 +0000 Subject: [PATCH 24/38] fix mypy errors --- src/funlib/persistence/graphs/graph_database.py | 12 +++++++++++- .../persistence/graphs/sqlite_graph_database.py | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py index 64f9cb5..91dc65b 100644 --- a/src/funlib/persistence/graphs/graph_database.py +++ b/src/funlib/persistence/graphs/graph_database.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional from funlib.geometry import Roi from networkx import Graph @@ -58,6 +58,8 @@ def read_graph( read_edges: bool = True, node_attrs: Optional[list[str]] = None, edge_attrs: Optional[list[str]] = None, + nodes_filter: Optional[dict[str, Any]] = None, + edges_filter: Optional[dict[str, Any]] = None, fetch_on_v: bool = False, ) -> Graph: """ @@ -81,6 +83,14 @@ def read_graph( If not ``None``, only read the given edge attributes. + nodes_filter (``dict[str, Any]`` or ``None``): + + If not ``None``, only read nodes matching these attribute values. + + edges_filter (``dict[str, Any]`` or ``None``): + + If not ``None``, only read edges matching these attribute values. + fetch_on_v (``bool``): If ``True``, also fetch edges where the ``v`` endpoint matches diff --git a/src/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py index 38b4080..24eef52 100644 --- a/src/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -223,7 +223,7 @@ def _bulk_insert(self, table, columns, rows) -> None: else self.edge_array_columns ) - exploded_columns = None + exploded_columns: list[str] = [] exploded_rows = [] for row in rows: exploded_cols = [] @@ -236,7 +236,7 @@ def _bulk_insert(self, table, columns, rows) -> None: else: exploded_cols.append(column) exploded_vals.append(value) - if exploded_columns is None: + if not exploded_columns: exploded_columns = exploded_cols exploded_rows.append(exploded_vals) From c8a43067a657434064e257c3b918b2722d746663 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 21:39:34 +0000 Subject: [PATCH 25/38] fix mypy workflow file --- .github/workflows/mypy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index b1c4393..3aeb886 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -12,4 +12,4 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@main - name: mypy - run: uv run --extra dev mypy funlib/persistence tests + run: uv run --extra dev mypy src tests From 8c4b9a2907113e6095e5996f726dfc2b0796669f Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 22:15:48 +0000 Subject: [PATCH 26/38] Pass ty checks - sql_graph_database.py: Refactored __init__ to declare instance attributes with non-optional type annotations (str, int, bool, Roi, list[str]). Refactored __load_metadata to take optional overrides as params and always assign from stored metadata. Removed all # type: ignore comments. - sqlite_graph_database.py: Removed # type: ignore comments, fixed position_columns = self.position_attribute bug (should be [self.position_attribute]), fixed parameter name mismatch (query_attrs -> attrs), replaced .size access with len() on already-computed array columns. - pgsql_graph_database.py: Fixed dict[str, type] -> dict[str, AttributeType] in __init__ signature, removed stale # type: ignore comments. --- .../graphs/pgsql_graph_database.py | 16 +- .../persistence/graphs/sql_graph_database.py | 151 +++++++++--------- .../graphs/sqlite_graph_database.py | 38 ++--- 3 files changed, 105 insertions(+), 100 deletions(-) diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py index 8fe2453..2e0c227 100644 --- a/src/funlib/persistence/graphs/pgsql_graph_database.py +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -10,7 +10,7 @@ from psycopg2 import sql from ..types import Vec -from .sql_graph_database import SQLGraphDataBase +from .sql_graph_database import AttributeType, SQLGraphDataBase logger = logging.getLogger(__name__) @@ -30,8 +30,8 @@ def __init__( nodes_table: str = "nodes", edges_table: str = "edges", endpoint_names: Optional[list[str]] = None, - node_attrs: Optional[dict[str, type]] = None, - edge_attrs: Optional[dict[str, type]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, ): self.db_host = db_host self.db_name = db_name @@ -72,8 +72,8 @@ def __init__( nodes_table=nodes_table, edges_table=edges_table, endpoint_names=endpoint_names, - node_attrs=node_attrs, # type: ignore - edge_attrs=edge_attrs, # type: ignore + node_attrs=node_attrs, + edge_attrs=edge_attrs, ) def close(self): @@ -98,7 +98,7 @@ def _drop_tables(self) -> None: self._commit() def _create_tables(self) -> None: - columns = self.node_attrs.keys() + columns = list(self.node_attrs.keys()) types = [self.__sql_type(t) for t in self.node_attrs.values()] column_types = [f"{c} {t}" for c, t in zip(columns, types)] self.__exec( @@ -113,14 +113,14 @@ def _create_tables(self) -> None: f"{self.nodes_table_name}({self.position_attribute})" ) - columns = list(self.edge_attrs.keys()) # type: ignore + columns = list(self.edge_attrs.keys()) types = list([self.__sql_type(t) for t in self.edge_attrs.values()]) column_types = [f"{c} {t}" for c, t in zip(columns, types)] endpoint_names = self.endpoint_names assert endpoint_names is not None self.__exec( f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" - f"{endpoint_names[0]} BIGINT not null, " # type: ignore + f"{endpoint_names[0]} BIGINT not null, " f"{endpoint_names[1]} BIGINT not null, " f"{' '.join([c + ',' for c in column_types])}" f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" diff --git a/src/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py index 3a12ca2..5cb511d 100644 --- a/src/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -85,20 +85,20 @@ def __init__( self.mode = mode if mode in self.read_modes: - self.position_attribute = position_attribute - self.directed = directed - self.total_roi = total_roi - self.nodes_table_name = nodes_table - self.edges_table_name = edges_table - self.endpoint_names = endpoint_names - self._node_attrs = node_attrs - self._edge_attrs = edge_attrs - self.ndims = None # to be read from metadata - metadata = self._read_metadata() if metadata is None: raise RuntimeError("metadata does not exist, can't open in read mode") - self.__load_metadata(metadata) + self.__load_metadata( + metadata, + position_attribute=position_attribute, + directed=directed, + total_roi=total_roi, + nodes_table=nodes_table, + edges_table=edges_table, + endpoint_names=endpoint_names, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ) if mode in self.create_modes: # this is where we populate default values for the DB creation @@ -112,7 +112,7 @@ def __init__( def get(value, default): return value if value is not None else default - self.position_attribute = get(position_attribute, "position") + self.position_attribute: str = get(position_attribute, "position") assert self.position_attribute in node_attrs, ( "No type information for position attribute " @@ -121,7 +121,7 @@ def get(value, default): position_type = node_attrs[self.position_attribute] if isinstance(position_type, Vec): - self.ndims = position_type.size + self.ndims: int = position_type.size assert self.ndims > 1, ( "Don't use Vecs of size 1 for the position, use the " "scalar type directly instead (i.e., 'float' instead of " @@ -131,13 +131,13 @@ def get(value, default): else: self.ndims = 1 - self.directed = get(directed, False) - self.total_roi = get( + self.directed: bool = get(directed, False) + self.total_roi: Roi = get( total_roi, Roi((None,) * self.ndims, (None,) * self.ndims) ) - self.nodes_table_name = get(nodes_table, "nodes") - self.edges_table_name = get(edges_table, "edges") - self.endpoint_names = get(endpoint_names, ["u", "v"]) + self.nodes_table_name: str = get(nodes_table, "nodes") + self.edges_table_name: str = get(edges_table, "edges") + self.endpoint_names: list[str] = get(endpoint_names, ["u", "v"]) self._node_attrs = node_attrs # no default, needs to be given self._edge_attrs = get(edge_attrs, {}) @@ -266,7 +266,7 @@ def read_graph( attr_filter=edges_filter, fetch_on_v=fetch_on_v, ) - u, v = self.endpoint_names # type: ignore + u, v = self.endpoint_names try: edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges] except KeyError as e: @@ -466,7 +466,6 @@ def read_edges( ) endpoint_names = self.endpoint_names - assert endpoint_names is not None # 1. Determine the base SELECT statement and WHERE clause @@ -628,8 +627,8 @@ def update_edges( if not roi.contains(pos_u): logger.debug( ( - f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," # type: ignore - + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" # type: ignore + f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," + + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" ).format(u, v, data, roi) ) continue @@ -639,7 +638,7 @@ def update_edges( update_statement = ( f"UPDATE {self.edges_table_name} SET " f"{', '.join(setters)} WHERE " - f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" # type: ignore + f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" ) self._update_query(update_statement, commit=False) @@ -721,7 +720,6 @@ def update_nodes( def __create_metadata(self): """Sets the metadata in the meta collection to the provided values""" - metadata = { "position_attribute": self.position_attribute, "directed": self.directed, @@ -737,59 +735,64 @@ def __create_metadata(self): return metadata - def __load_metadata(self, metadata): + def __load_metadata( + self, + metadata, + position_attribute: Optional[str] = None, + directed: Optional[bool] = None, + total_roi: Optional[Roi] = None, + nodes_table: Optional[str] = None, + edges_table: Optional[str] = None, + endpoint_names: Optional[list[str]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, + ): """Load the provided metadata into this object's attributes, check if - it is consistent with already populated fields.""" - - # simple attributes - for attr_name in [ - "position_attribute", - "directed", - "nodes_table_name", - "edges_table_name", - "endpoint_names", - "ndims", - ]: - if getattr(self, attr_name) is None: - setattr(self, attr_name, metadata[attr_name]) - else: - value = getattr(self, attr_name) - assert value == metadata[attr_name], ( - f"Attribute {attr_name} is already set to {value} for this " - "object, but disagrees with the stored metadata value of " - f"{metadata[attr_name]}" + user-provided overrides are consistent with stored metadata.""" + + # For each simple attribute, use metadata as the source of truth. + # If the user also provided a value, check consistency. + overrides: dict[str, Any] = { + "position_attribute": position_attribute, + "directed": directed, + "nodes_table_name": nodes_table, + "edges_table_name": edges_table, + "endpoint_names": endpoint_names, + "ndims": None, # ndims is never user-provided + } + for attr_name, override in overrides.items(): + stored = metadata[attr_name] + if override is not None: + assert override == stored, ( + f"Attribute {attr_name} was given as {override}, but " + f"disagrees with the stored metadata value of {stored}" ) - - # special attributes - - total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) - if self.total_roi is None: - self.total_roi = total_roi - else: - assert self.total_roi == total_roi, ( - f"Attribute total_roi is already set to {self.total_roi} for " - "this object, but disagrees with the stored metadata value of " - f"{total_roi}" + setattr(self, attr_name, stored) + + # total_roi + stored_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) + if total_roi is not None: + assert total_roi == stored_roi, ( + f"Attribute total_roi was given as {total_roi}, but " + f"disagrees with the stored metadata value of {stored_roi}" ) - - node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} - edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} - if self._node_attrs is None: - self.node_attrs = node_attrs - else: - assert self.node_attrs == node_attrs, ( - f"Attribute node_attrs is already set to {self.node_attrs} for " - "this object, but disagrees with the stored metadata value of " - f"{node_attrs}" + self.total_roi = stored_roi + + # node_attrs / edge_attrs + stored_node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} + stored_edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} + if node_attrs is not None: + assert node_attrs == stored_node_attrs, ( + f"Attribute node_attrs was given as {node_attrs}, but " + f"disagrees with the stored metadata value of {stored_node_attrs}" ) - if self._edge_attrs is None: - self.edge_attrs = edge_attrs - else: - assert self.edge_attrs == edge_attrs, ( - f"Attribute edge_attrs is already set to {self.edge_attrs} for " - "this object, but disagrees with the stored metadata value of " - f"{edge_attrs}" + self._node_attrs = stored_node_attrs + if edge_attrs is not None: + assert edge_attrs == stored_edge_attrs, ( + f"Attribute edge_attrs was given as {edge_attrs}, but " + f"disagrees with the stored metadata value of {stored_edge_attrs}" ) + self._edge_attrs = stored_edge_attrs def __remove_keys(self, dictionary, keys): """Removes given keys from dictionary.""" @@ -798,7 +801,7 @@ def __remove_keys(self, dictionary, keys): def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]: try: - return Coordinate(n[self.position_attribute]) # type: ignore + return Coordinate(n[self.position_attribute]) except KeyError: return None @@ -822,7 +825,7 @@ def __attr_query(self, attrs: dict[str, Any]) -> str: def __roi_query(self, roi: Roi) -> str: query = "WHERE " pos_attr = self.position_attribute - for dim in range(self.ndims): # type: ignore + for dim in range(self.ndims): if dim > 0: query += " AND " if roi.begin[dim] is not None and roi.end[dim] is not None: diff --git a/src/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py index 24eef52..a2c902b 100644 --- a/src/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -68,7 +68,7 @@ def bulk_write_mode(self, worker=False, node_writes=True, edge_writes=True): self.con.commit() if not worker and node_writes: - if self.ndims > 1: # type: ignore + if self.ndims > 1: position_columns = self.node_array_columns[self.position_attribute] else: position_columns = [self.position_attribute] @@ -133,16 +133,16 @@ def _create_tables(self) -> None: f"{', '.join(node_columns)}" ")" ) - if self.ndims > 1: # type: ignore + if self.ndims > 1: position_columns = self.node_array_columns[self.position_attribute] else: - position_columns = self.position_attribute + position_columns = [self.position_attribute] self.cur.execute( f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})" ) edge_columns = [ - f"{self.endpoint_names[0]} INTEGER not null", # type: ignore - f"{self.endpoint_names[1]} INTEGER not null", # type: ignore + f"{self.endpoint_names[0]} INTEGER not null", + f"{self.endpoint_names[1]} INTEGER not null", ] for attr in self.edge_attrs.keys(): if attr in self.edge_array_columns: @@ -152,7 +152,7 @@ def _create_tables(self) -> None: self.cur.execute( f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" + f"{', '.join(edge_columns)}" - + f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" # type: ignore + + f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" + ")" ) @@ -273,17 +273,18 @@ def _node_attrs_to_columns(self, attrs): columns.append(attr) return columns - def _columns_to_node_attrs(self, columns, query_attrs): - attrs = {} - for attr in query_attrs: + def _columns_to_node_attrs(self, columns, attrs): + result = {} + for attr in attrs: if attr in self.node_array_columns: value = tuple( - columns[f"{attr}_{d}"] for d in range(self.node_attrs[attr].size) + columns[f"{attr}_{d}"] + for d in range(len(self.node_array_columns[attr])) ) else: value = columns[attr] - attrs[attr] = value - return attrs + result[attr] = value + return result def _edge_attrs_to_columns(self, attrs): columns = [] @@ -295,14 +296,15 @@ def _edge_attrs_to_columns(self, attrs): columns.append(attr) return columns - def _columns_to_edge_attrs(self, columns, query_attrs): - attrs = {} - for attr in query_attrs: + def _columns_to_edge_attrs(self, columns, attrs): + result = {} + for attr in attrs: if attr in self.edge_array_columns: value = tuple( - columns[f"{attr}_{d}"] for d in range(self.edge_attrs[attr].size) + columns[f"{attr}_{d}"] + for d in range(len(self.edge_array_columns[attr])) ) else: value = columns[attr] - attrs[attr] = value - return attrs + result[attr] = value + return result From e2e8a0e1e52091395dd986da809c92293ea4fe14 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 22:47:38 +0000 Subject: [PATCH 27/38] pass ty linting on arrays --- src/funlib/persistence/arrays/array.py | 21 +++++---- src/funlib/persistence/arrays/lazy_ops.py | 3 +- src/funlib/persistence/arrays/metadata.py | 47 ++++++++++--------- src/funlib/persistence/arrays/ome_datasets.py | 13 ++--- tests/test_graph.py | 2 +- 5 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/funlib/persistence/arrays/array.py b/src/funlib/persistence/arrays/array.py index 13762b6..4187163 100644 --- a/src/funlib/persistence/arrays/array.py +++ b/src/funlib/persistence/arrays/array.py @@ -343,18 +343,21 @@ def __getitem__(self, key) -> np.ndarray: else: return self.data[key].compute() - def __setitem__(self, key, value: np.ndarray): - """Set the data of this array within the given ROI. + def __setitem__(self, key: Roi | slice | tuple, value: np.ndarray | float | int): + """Set the data of this array. Args: - key (`class:Roi`): + key (`class:Roi` or any numpy compatible key): - The ROI to write to. + The region to write to. Can be a `Roi` for world-unit indexing, + or any numpy-compatible key (e.g. ``np.s_[:]``, a slice, a tuple + of slices). - value (``ndarray``): + value (``ndarray`` or scalar): - The value to write. + The value to write. Can be a numpy array or a scalar that will + be broadcast. """ if self.is_writeable: @@ -473,7 +476,7 @@ def _is_slice(self, lazy_op: LazyOp, writeable: bool = False) -> bool: elif isinstance(lazy_op, list) and all([isinstance(a, int) for a in lazy_op]): return True elif isinstance(lazy_op, tuple) and all( - [self._is_slice(a, writeable) for a in lazy_op] + [self._is_slice(a, writeable) for a in lazy_op] # type: ignore[arg-type] # ty can't narrow parameterized tuple iteration ): return True elif ( @@ -507,7 +510,7 @@ def validate(self, strict: bool = False): ) def to_pixel_space( - self, world_loc: Roi | Coordinate | Sequence[int | float] + self, world_loc: Roi | Coordinate | Sequence[int | float] | np.ndarray ) -> Roi | Coordinate | np.ndarray: """Convert a point or roi in world space into the pixel space of this array. Works on sequences of floats by returning a numpy array that is not guaranteed @@ -538,7 +541,7 @@ def to_pixel_space( ) def to_world_space( - self, pixel_loc: Roi | Coordinate | Sequence[int | float] + self, pixel_loc: Roi | Coordinate | Sequence[int | float] | np.ndarray ) -> Roi | Coordinate: """Convert a point or roi from pixel space in this array to the world coordinate system defined by this array's roi and voxel size. diff --git a/src/funlib/persistence/arrays/lazy_ops.py b/src/funlib/persistence/arrays/lazy_ops.py index 2299eb6..4e107be 100644 --- a/src/funlib/persistence/arrays/lazy_ops.py +++ b/src/funlib/persistence/arrays/lazy_ops.py @@ -1,5 +1,6 @@ from typing import Callable, Union +import numpy as np from funlib.geometry import Roi -LazyOp = Union[slice, Callable, Roi] +LazyOp = Union[slice, int, tuple[int | slice | list[int] | np.ndarray, ...], Callable, Roi] diff --git a/src/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py index d948c91..520056c 100644 --- a/src/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -5,6 +5,7 @@ import toml import zarr +import zarr.attrs from funlib.geometry import Coordinate from pydantic import BaseModel @@ -330,45 +331,49 @@ def recurse( return recurse(data[current_key], keys) result = recurse(data, keys) - assert isinstance(result, Sequence) or result is None, result + assert (isinstance(result, Sequence) and not isinstance(result, (str, int))) or result is None, result return result def parse( self, - shape, + shape: Sequence[int], data: dict[str | int, Any], - offset=None, - voxel_size=None, - axis_names=None, - units=None, - types=None, - strict=False, + offset: Optional[Sequence[int]] = None, + voxel_size: Optional[Sequence[int]] = None, + axis_names: Optional[Sequence[str]] = None, + units: Optional[Sequence[str]] = None, + types: Optional[Sequence[str]] = None, + strict: bool = False, ) -> MetaData: - offset = offset if offset is not None else self.fetch(data, self.offset_attr) - voxel_size = ( + fetched_offset = offset if offset is not None else self.fetch(data, self.offset_attr) + fetched_voxel_size = ( voxel_size if voxel_size is not None else self.fetch(data, self.voxel_size_attr) ) - axis_names = ( + fetched_axis_names = ( axis_names if axis_names is not None else self.fetch(data, self.axis_names_attr) ) - units = units if units is not None else self.fetch(data, self.units_attr) - types = types if types is not None else self.fetch(data, self.types_attr) - if types is None and axis_names is not None: - types = [ - "channel" if name.endswith("^") else "space" for name in axis_names + fetched_units = units if units is not None else self.fetch(data, self.units_attr) + fetched_types = types if types is not None else self.fetch(data, self.types_attr) + if fetched_types is None and fetched_axis_names is not None: + fetched_types = [ + "channel" if str(name).endswith("^") else "space" + for name in fetched_axis_names ] + # fetch() returns Sequence[str | int | None] | None from untyped metadata. + # Some OME-Zarr metadata may have holes (None elements), so we pass + # the fetched values through as-is; MetaData.validate() checks at runtime. metadata = MetaData( shape=shape, - offset=offset, - voxel_size=voxel_size, - axis_names=axis_names, - units=units, - types=types, + offset=fetched_offset, # type: ignore[arg-type] + voxel_size=fetched_voxel_size, # type: ignore[arg-type] + axis_names=fetched_axis_names, # type: ignore[arg-type] + units=fetched_units, # type: ignore[arg-type] + types=fetched_types, # type: ignore[arg-type] strict=strict, ) diff --git a/src/funlib/persistence/arrays/ome_datasets.py b/src/funlib/persistence/arrays/ome_datasets.py index d089002..76c8d23 100644 --- a/src/funlib/persistence/arrays/ome_datasets.py +++ b/src/funlib/persistence/arrays/ome_datasets.py @@ -1,6 +1,7 @@ import logging from collections.abc import Sequence from pathlib import Path +from typing import Literal from funlib.geometry import Coordinate from iohub.ngff import TransformationMeta, open_ome_zarr @@ -16,7 +17,7 @@ def open_ome_ds( store: Path, name: str, - mode: str = "r", + mode: Literal["r", "r+", "a", "w", "w-"] = "r", **kwargs, ) -> Array: """ @@ -62,8 +63,8 @@ def open_ome_ds( metadata = MetaData( shape=dataset.shape, - offset=offset, - voxel_size=scale, + offset=offset, # type: ignore[arg-type] + voxel_size=scale, # type: ignore[arg-type] axis_names=axis_names, units=units, types=types, @@ -170,7 +171,7 @@ def prepare_ome_ds( ) axis_metadata = [ - AxisMeta(name=n, type=t, unit=u) + AxisMeta(name=n, type=t, unit=u) # type: ignore[misc] for n, t, u in zip(metadata.axis_names, metadata.types, metadata.ome_units) ] @@ -179,8 +180,8 @@ def prepare_ome_ds( store, mode="w", layout="fov", axes=axis_metadata, channel_names=channel_names ) as ds: transforms = [ - TransformationMeta(type="scale", scale=metadata.ome_scale), - TransformationMeta(type="translation", translation=metadata.ome_translate), + TransformationMeta(type="scale", scale=list(metadata.ome_scale)), + TransformationMeta(type="translation", translation=list(metadata.ome_translate)), ] ds.create_zeros( diff --git a/tests/test_graph.py b/tests/test_graph.py index 9451c2c..eef1e4e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -162,7 +162,7 @@ def test_graph_read_and_update_specific_attrs(provider_factory): for u, v, data in limited_graph.edges(data=True): assert "a" not in data assert "b" not in data - nx.set_edge_attributes(limited_graph, 5, "c") + nx.set_edge_attributes(limited_graph, 5, "c") # type: ignore[call-overload] try: graph_provider.write_attrs( From aacfbc8e4e8b9b1bcaca81163256b085536b4fca Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 22:49:53 +0000 Subject: [PATCH 28/38] switch from mypy to ty --- .github/workflows/{mypy.yaml => ty.yaml} | 8 ++++---- README.md | 2 +- pyproject.toml | 9 +++++---- 3 files changed, 10 insertions(+), 9 deletions(-) rename .github/workflows/{mypy.yaml => ty.yaml} (68%) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/ty.yaml similarity index 68% rename from .github/workflows/mypy.yaml rename to .github/workflows/ty.yaml index 3aeb886..7fee17d 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/ty.yaml @@ -1,15 +1,15 @@ -name: mypy +name: ty on: [push, pull_request] jobs: static-analysis: - name: Python mypy + name: Python ty runs-on: ubuntu-latest steps: - name: Setup checkout uses: actions/checkout@master - name: Install uv uses: astral-sh/setup-uv@main - - name: mypy - run: uv run --extra dev mypy src tests + - name: ty + run: uv run --extra dev ty check src tests diff --git a/README.md b/README.md index 82cb0af..82aaa0d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![tests](https://github.com/funkelab/funlib.persistence/actions/workflows/tests.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/tests.yaml) [![ruff](https://github.com/funkelab/funlib.persistence/actions/workflows/ruff.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/ruff.yaml) -[![mypy](https://github.com/funkelab/funlib.persistence/actions/workflows/mypy.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/mypy.yaml) +[![ty](https://github.com/funkelab/funlib.persistence/actions/workflows/ty.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/ty.yaml) [![pypi](https://github.com/funkelab/funlib.persistence/actions/workflows/publish.yaml/badge.svg)](https://pypi.org/project/funlib.persistence/) # funlib.persistence diff --git a/pyproject.toml b/pyproject.toml index 0e1f0e1..5a5766f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,10 +40,10 @@ version = { attr = "funlib.persistence.__version__" } [project.optional-dependencies] dev = [ "coverage>=7.7.1", - "mypy>=1.15.0", "pytest>=8.3.5", "pytest-mock>=3.14.0", "ruff>=0.11.2", + "ty>=0.0.16", "types-networkx", "types-psycopg2", "types-toml", @@ -56,10 +56,11 @@ lint.select = ["F", "W", "I001"] [tool.setuptools.package-data] "funlib.persistence" = ["py.typed"] -[tool.mypy] -explicit_package_bases = true - # # module specific overrides [[tool.mypy.overrides]] module = ["zarr.*", "iohub.*"] ignore_missing_imports = true + +[tool.ty.rules] +# dask and iohub have incomplete type stubs +possibly-missing-attribute = "warn" From ad1a0fc408a28bc8e8a7c82116e375ded629c680 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 14 May 2025 11:38:27 -0700 Subject: [PATCH 29/38] pass tests on zarr v3 --- pyproject.toml | 6 ++-- src/funlib/persistence/arrays/datasets.py | 36 +++++++---------------- src/funlib/persistence/arrays/metadata.py | 2 +- tests/test_datasets.py | 13 ++++---- tests/test_metadata.py | 2 +- 5 files changed, 21 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a5766f..d016e44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,17 +13,17 @@ authors = [ ] dynamic = ['version'] -requires-python = ">=3.10" +requires-python = ">=3.11" classifiers = ["Programming Language :: Python :: 3"] keywords = [] dependencies = [ - "zarr>=2,<3", + "zarr>=3,<4", # ImportError: cannot import name 'cbuffer_sizes' from 'numcodecs.blosc' # We can pin zarr to >2.18.7 but then we have to drop python 3.10 # pin numcodecs to avoid breaking change "numcodecs>0.13,<0.16.0", - "iohub>=0.2.0b0", + "iohub>=0.3.0a5", "funlib.geometry>=0.3.0", "networkx>=3.0.0", "pymongo>=4.0.0", diff --git a/src/funlib/persistence/arrays/datasets.py b/src/funlib/persistence/arrays/datasets.py index c1f96b0..aa64676 100644 --- a/src/funlib/persistence/arrays/datasets.py +++ b/src/funlib/persistence/arrays/datasets.py @@ -12,15 +12,6 @@ logger = logging.getLogger(__name__) - -class ArrayNotFoundError(Exception): - """Exception raised when an array is not found in the dataset.""" - - def __init__(self, message: str = "Array not found in the dataset"): - self.message = message - super().__init__(self.message) - - def open_ds( store, mode: str = "r", @@ -106,10 +97,7 @@ def open_ds( else get_default_metadata_format() ) - try: - data = zarr.open(store, mode=mode, **kwargs) - except zarr.errors.PathNotFoundError: - raise ArrayNotFoundError(f"Nothing found at path {store}") + data = zarr.open(store, mode=mode, **kwargs) metadata = metadata_format.parse( data.shape, @@ -238,7 +226,7 @@ def prepare_ds( try: existing_array = open_ds(store, mode="r", **kwargs) - except ArrayNotFoundError: + except FileNotFoundError: existing_array = None if existing_array is not None: @@ -348,18 +336,14 @@ def prepare_ds( ) # create the dataset - try: - ds = zarr.open_array( - store=store, - shape=shape, - chunks=chunk_shape, - dtype=dtype, - dimension_separator="/", - mode=mode, - **kwargs, - ) - except zarr.errors.ArrayNotFoundError: - raise ArrayNotFoundError(f"Nothing found at path {store}") + ds = zarr.open_array( + store=store, + shape=shape, + chunks=chunk_shape, + dtype=dtype, + mode=mode, + **kwargs, + ) default_metadata_format = get_default_metadata_format() our_metadata = { diff --git a/src/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py index 520056c..bb6ff1a 100644 --- a/src/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -305,7 +305,7 @@ def recurse( # base case if len(keys) == 0: # this key returns the data we want - if isinstance(data, (dict, zarr.attrs.Attributes)): + if isinstance(data, (dict, zarr.core.attributes.Attributes)): return data.get(str(current_key), None) elif isinstance(data, list): assert isinstance(current_key, int), current_key diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 85f86d7..9875a78 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,8 +1,7 @@ import numpy as np import pytest from funlib.geometry import Coordinate, Roi - -from funlib.persistence.arrays.datasets import ArrayNotFoundError, open_ds, prepare_ds +from funlib.persistence.arrays.datasets import open_ds, prepare_ds from funlib.persistence.arrays.metadata import MetaDataFormat stores = { @@ -17,7 +16,7 @@ @pytest.mark.parametrize("store", stores.keys()) def test_metadata(tmpdir, store): - store = tmpdir / store + store = str(tmpdir / store) # test prepare_ds creates array if it does not exist and mode is write array = prepare_ds( @@ -37,7 +36,7 @@ def test_metadata(tmpdir, store): def test_helpers(tmpdir, store, dtype): shape = Coordinate(1, 1, 10, 20, 30) chunk_shape = Coordinate(2, 3, 10, 10, 10) - store = tmpdir / store + store = str(tmpdir / store) metadata = MetaDataFormat().parse( shape, { @@ -50,7 +49,7 @@ def test_helpers(tmpdir, store, dtype): ) # test prepare_ds fails if array does not exist and mode is read - with pytest.raises(ArrayNotFoundError): + with pytest.raises(FileNotFoundError): prepare_ds( store, shape, @@ -220,7 +219,7 @@ def test_helpers(tmpdir, store, dtype): @pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64]) def test_open_ds(tmpdir, store, dtype): shape = Coordinate(1, 1, 10, 20, 30) - store = tmpdir / store + store = str(tmpdir / store) metadata = MetaDataFormat().parse( shape, { @@ -233,7 +232,7 @@ def test_open_ds(tmpdir, store, dtype): ) # test open_ds fails if array does not exist and mode is read - with pytest.raises(ArrayNotFoundError): + with pytest.raises(FileNotFoundError): open_ds( store, offset=metadata.offset, diff --git a/tests/test_metadata.py b/tests/test_metadata.py index ecff21a..cfb00ed 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -99,7 +99,7 @@ def test_default_metadata_format(tmpdir): ) prepare_ds( - tmpdir / "test.zarr/test", + str(tmpdir / "test.zarr/test"), (10, 2, 100, 100, 100), offset=metadata.offset, voxel_size=metadata.voxel_size, From e60b7efeefb9a39d84654354fbafdb23af83a672 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 14 May 2025 11:59:45 -0700 Subject: [PATCH 30/38] use `tmp_path` fixture over legacy `tmpdir` --- tests/conftest.py | 5 ++--- tests/test_datasets.py | 12 ++++++------ tests/test_metadata.py | 6 +++--- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a1fd5c6..36931b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,15 +52,14 @@ def can_connect_to_psql(): psql_param, ) ) -def provider_factory(request, tmpdir): - tmpdir = Path(tmpdir) +def provider_factory(request, tmp_path): @contextmanager def sqlite_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): provider = SQLiteGraphDataBase( - tmpdir / "test_sqlite_graph.db", + tmp_path / "test_sqlite_graph.db", position_attribute="position", mode=mode, directed=directed, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 9875a78..8cc8c2b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -15,8 +15,8 @@ @pytest.mark.parametrize("store", stores.keys()) -def test_metadata(tmpdir, store): - store = str(tmpdir / store) +def test_metadata(tmp_path, store): + store = tmp_path / store # test prepare_ds creates array if it does not exist and mode is write array = prepare_ds( @@ -33,10 +33,10 @@ def test_metadata(tmpdir, store): @pytest.mark.parametrize("store", stores.keys()) @pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64]) -def test_helpers(tmpdir, store, dtype): +def test_helpers(tmp_path, store, dtype): shape = Coordinate(1, 1, 10, 20, 30) chunk_shape = Coordinate(2, 3, 10, 10, 10) - store = str(tmpdir / store) + store = tmp_path / store metadata = MetaDataFormat().parse( shape, { @@ -217,9 +217,9 @@ def test_helpers(tmpdir, store, dtype): @pytest.mark.parametrize("store", stores.keys()) @pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64]) -def test_open_ds(tmpdir, store, dtype): +def test_open_ds(tmp_path, store, dtype): shape = Coordinate(1, 1, 10, 20, 30) - store = str(tmpdir / store) + store = tmp_path / store metadata = MetaDataFormat().parse( shape, { diff --git a/tests/test_metadata.py b/tests/test_metadata.py index cfb00ed..c433918 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -91,7 +91,7 @@ def test_empty_metadata(): assert metadata.types == ["space", "space", "space", "space", "space"] -def test_default_metadata_format(tmpdir): +def test_default_metadata_format(tmp_path): set_default_metadata_format(metadata_formats["simple"]) metadata = metadata_formats["simple"].parse( (10, 2, 100, 100, 100), @@ -99,7 +99,7 @@ def test_default_metadata_format(tmpdir): ) prepare_ds( - str(tmpdir / "test.zarr/test"), + tmp_path / "test.zarr/test", (10, 2, 100, 100, 100), offset=metadata.offset, voxel_size=metadata.voxel_size, @@ -110,7 +110,7 @@ def test_default_metadata_format(tmpdir): mode="w", ) - zarr_attrs = dict(**zarr.open(str(tmpdir / "test.zarr/test")).attrs) + zarr_attrs = dict(**zarr.open(tmp_path / "test.zarr/test").attrs) assert zarr_attrs["offset"] == [100, 200, 400] assert zarr_attrs["resolution"] == [1, 2, 3] assert zarr_attrs["extras/axes"] == ["sample^", "channel^", "t", "y", "x"] From 43333190e3f3d71d98b9e1f532eb75f720b83724 Mon Sep 17 00:00:00 2001 From: will Date: Wed, 11 Feb 2026 23:49:39 +0000 Subject: [PATCH 31/38] get rid of zarr.attrs import --- src/funlib/persistence/arrays/metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py index bb6ff1a..d47a26e 100644 --- a/src/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -5,7 +5,6 @@ import toml import zarr -import zarr.attrs from funlib.geometry import Coordinate from pydantic import BaseModel From baac4b4feea51c618501a37b55def5362c12d5df Mon Sep 17 00:00:00 2001 From: will Date: Thu, 12 Feb 2026 00:02:51 +0000 Subject: [PATCH 32/38] remove num_codecs pin --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d016e44..ada01ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,6 @@ keywords = [] dependencies = [ "zarr>=3,<4", - # ImportError: cannot import name 'cbuffer_sizes' from 'numcodecs.blosc' - # We can pin zarr to >2.18.7 but then we have to drop python 3.10 - # pin numcodecs to avoid breaking change - "numcodecs>0.13,<0.16.0", "iohub>=0.3.0a5", "funlib.geometry>=0.3.0", "networkx>=3.0.0", From 3c878f79f96caed85b66ffe75c4090319b23625e Mon Sep 17 00:00:00 2001 From: will Date: Thu, 12 Feb 2026 00:09:01 +0000 Subject: [PATCH 33/38] fix ty checks --- src/funlib/persistence/arrays/datasets.py | 6 +++--- src/funlib/persistence/arrays/metadata.py | 10 +++++----- src/funlib/persistence/arrays/ome_datasets.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/funlib/persistence/arrays/datasets.py b/src/funlib/persistence/arrays/datasets.py index aa64676..8fb91a2 100644 --- a/src/funlib/persistence/arrays/datasets.py +++ b/src/funlib/persistence/arrays/datasets.py @@ -1,6 +1,6 @@ import logging from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union import numpy as np import zarr @@ -14,7 +14,7 @@ def open_ds( store, - mode: str = "r", + mode: Literal["r", "r+", "a", "w", "w-"] = "r", metadata_format: Optional[MetaDataFormat] = None, offset: Optional[Sequence[int]] = None, voxel_size: Optional[Sequence[int]] = None, @@ -131,7 +131,7 @@ def prepare_ds( types: Optional[Sequence[str]] = None, chunk_shape: Optional[Sequence[int]] = None, dtype: DTypeLike = np.float32, - mode: str = "a", + mode: Literal["r", "r+", "a", "w", "w-"] = "a", custom_metadata: dict[str, Any] | None = None, **kwargs, ) -> Array: diff --git a/src/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py index d47a26e..c4980d6 100644 --- a/src/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -1,5 +1,5 @@ import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any, Optional @@ -268,7 +268,7 @@ class Config: extra = "forbid" def fetch( - self, data: dict[str | int, Any], key: str + self, data: Mapping[str, Any], key: str ) -> Sequence[str | int | None] | None: """ Given a dictionary of attributes from e.g. zarr.open(...).attrs, fetch the value @@ -292,7 +292,7 @@ def fetch( keys = key.split("/") def recurse( - data: dict[str | int, Any] | list[Any], keys: list[str] + data: Any, keys: list[str] ) -> Sequence[str | int | None] | str | int | None: current_key: str | int current_key, *keys = keys @@ -304,7 +304,7 @@ def recurse( # base case if len(keys) == 0: # this key returns the data we want - if isinstance(data, (dict, zarr.core.attributes.Attributes)): + if isinstance(data, Mapping): return data.get(str(current_key), None) elif isinstance(data, list): assert isinstance(current_key, int), current_key @@ -336,7 +336,7 @@ def recurse( def parse( self, shape: Sequence[int], - data: dict[str | int, Any], + data: Mapping[str, Any], offset: Optional[Sequence[int]] = None, voxel_size: Optional[Sequence[int]] = None, axis_names: Optional[Sequence[str]] = None, diff --git a/src/funlib/persistence/arrays/ome_datasets.py b/src/funlib/persistence/arrays/ome_datasets.py index 76c8d23..1912ade 100644 --- a/src/funlib/persistence/arrays/ome_datasets.py +++ b/src/funlib/persistence/arrays/ome_datasets.py @@ -63,8 +63,8 @@ def open_ome_ds( metadata = MetaData( shape=dataset.shape, - offset=offset, # type: ignore[arg-type] - voxel_size=scale, # type: ignore[arg-type] + offset=offset, + voxel_size=scale, axis_names=axis_names, units=units, types=types, From 1d339d2800664abb944914c02908712f4d890e35 Mon Sep 17 00:00:00 2001 From: William Patton Date: Fri, 13 Feb 2026 14:12:01 -0800 Subject: [PATCH 34/38] bump version and fix ruff formatting --- src/funlib/persistence/__init__.py | 2 +- src/funlib/persistence/arrays/metadata.py | 1 - tests/conftest.py | 1 - tests/test_datasets.py | 1 + 4 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/funlib/persistence/__init__.py b/src/funlib/persistence/__init__.py index ec995b5..b893c3c 100644 --- a/src/funlib/persistence/__init__.py +++ b/src/funlib/persistence/__init__.py @@ -1,4 +1,4 @@ from .arrays import Array, open_ds, prepare_ds, open_ome_ds, prepare_ome_ds # noqa -__version__ = "0.6.1" +__version__ = "0.7.0" __version_info__ = tuple(int(i) for i in __version__.split(".")) diff --git a/src/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py index c4980d6..cf80b0f 100644 --- a/src/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -4,7 +4,6 @@ from typing import Any, Optional import toml -import zarr from funlib.geometry import Coordinate from pydantic import BaseModel diff --git a/tests/conftest.py b/tests/conftest.py index 36931b9..110b2c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import os from contextlib import contextmanager -from pathlib import Path import psycopg2 import pytest diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8cc8c2b..7cc6f45 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,7 @@ import numpy as np import pytest from funlib.geometry import Coordinate, Roi + from funlib.persistence.arrays.datasets import open_ds, prepare_ds from funlib.persistence.arrays.metadata import MetaDataFormat From ab033542dd3589b923025c61815cddf1554adf2a Mon Sep 17 00:00:00 2001 From: William Patton Date: Fri, 13 Feb 2026 14:14:04 -0800 Subject: [PATCH 35/38] update test suite to drop old python versions and add newer versions --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cae4fa4..54cf757 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11", "3.12", "3.13", "3.14"] resolution: ["highest", "lowest-direct"] services: From 0e051954174aa072d8f867568767e4508b0f2c00 Mon Sep 17 00:00:00 2001 From: William Patton Date: Fri, 13 Feb 2026 14:55:21 -0800 Subject: [PATCH 36/38] resolve ty errors --- src/funlib/persistence/arrays/array.py | 2 +- src/funlib/persistence/arrays/datasets.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/funlib/persistence/arrays/array.py b/src/funlib/persistence/arrays/array.py index 4187163..9acf3f5 100644 --- a/src/funlib/persistence/arrays/array.py +++ b/src/funlib/persistence/arrays/array.py @@ -134,7 +134,7 @@ def attrs(self) -> dict: @property def chunk_shape(self) -> Coordinate: - return Coordinate(self.data.chunksize) + return Coordinate(self.data.chunksize) # ty: ignore[unresolved-attribute] def uncollapsed_dims(self, physical: bool = False) -> list[bool]: """ diff --git a/src/funlib/persistence/arrays/datasets.py b/src/funlib/persistence/arrays/datasets.py index 8fb91a2..9df0c24 100644 --- a/src/funlib/persistence/arrays/datasets.py +++ b/src/funlib/persistence/arrays/datasets.py @@ -98,6 +98,8 @@ def open_ds( ) data = zarr.open(store, mode=mode, **kwargs) + if not isinstance(data, zarr.Array): + raise TypeError(f"Expected a zarr Array at {store}, got {type(data).__name__}") metadata = metadata_format.parse( data.shape, @@ -315,6 +317,10 @@ def prepare_ds( ) else: ds = zarr.open(store, mode=mode, **kwargs) + if not isinstance(ds, zarr.Array): + raise TypeError( + f"Expected a zarr Array at {store}, got {type(ds).__name__}" + ) return Array( ds, existing_metadata.offset, From 06f20465d2016f4888632adfec0ed8e20db4ad3a Mon Sep 17 00:00:00 2001 From: William Patton Date: Fri, 13 Feb 2026 14:56:27 -0800 Subject: [PATCH 37/38] remove custom ty error handling --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ada01ce..67a8d15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,3 @@ lint.select = ["F", "W", "I001"] module = ["zarr.*", "iohub.*"] ignore_missing_imports = true -[tool.ty.rules] -# dask and iohub have incomplete type stubs -possibly-missing-attribute = "warn" From d5b3e7dc98634e3872d8694994578d8a0341d16a Mon Sep 17 00:00:00 2001 From: William Patton Date: Fri, 13 Feb 2026 15:11:41 -0800 Subject: [PATCH 38/38] update dependencies to pass python 3.13 tests --- .github/workflows/tests.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 54cf757..13d6ae8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.11", "3.12", "3.13", "3.14"] + python-version: ["3.11", "3.12", "3.13"] resolution: ["highest", "lowest-direct"] services: diff --git a/pyproject.toml b/pyproject.toml index 67a8d15..539f9d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "pydantic>=2.0.0", "dask>=2024.0.0", "toml>=0.10.0", - "psycopg2-binary>=2.9.5", + "psycopg2-binary>=2.9.11", ] [tool.setuptools.dynamic]