diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cae4fa4..13d6ae8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11", "3.12", "3.13"] resolution: ["highest", "lowest-direct"] services: diff --git a/.github/workflows/mypy.yaml b/.github/workflows/ty.yaml similarity index 65% rename from .github/workflows/mypy.yaml rename to .github/workflows/ty.yaml index b1c4393..7fee17d 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/ty.yaml @@ -1,15 +1,15 @@ -name: mypy +name: ty on: [push, pull_request] jobs: static-analysis: - name: Python mypy + name: Python ty runs-on: ubuntu-latest steps: - name: Setup checkout uses: actions/checkout@master - name: Install uv uses: astral-sh/setup-uv@main - - name: mypy - run: uv run --extra dev mypy funlib/persistence tests + - name: ty + run: uv run --extra dev ty check src tests diff --git a/.gitignore b/.gitignore index 90f1bdd..0a8d3da 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,5 @@ dmypy.json .vscode/ *.sw[pmno] +daisy_logs +uv.lock diff --git a/README.md b/README.md index 82cb0af..82aaa0d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![tests](https://github.com/funkelab/funlib.persistence/actions/workflows/tests.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/tests.yaml) [![ruff](https://github.com/funkelab/funlib.persistence/actions/workflows/ruff.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/ruff.yaml) -[![mypy](https://github.com/funkelab/funlib.persistence/actions/workflows/mypy.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/mypy.yaml) +[![ty](https://github.com/funkelab/funlib.persistence/actions/workflows/ty.yaml/badge.svg)](https://github.com/funkelab/funlib.persistence/actions/workflows/ty.yaml) [![pypi](https://github.com/funkelab/funlib.persistence/actions/workflows/publish.yaml/badge.svg)](https://pypi.org/project/funlib.persistence/) # funlib.persistence diff --git a/funlib/persistence/arrays/lazy_ops.py b/funlib/persistence/arrays/lazy_ops.py deleted file mode 100644 index 2299eb6..0000000 --- a/funlib/persistence/arrays/lazy_ops.py +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Callable, Union - -from funlib.geometry import Roi - -LazyOp = Union[slice, Callable, Roi] diff --git a/funlib/persistence/graphs/pgsql_graph_database.py b/funlib/persistence/graphs/pgsql_graph_database.py deleted file mode 100644 index 1c37d41..0000000 --- a/funlib/persistence/graphs/pgsql_graph_database.py +++ /dev/null @@ -1,206 +0,0 @@ -import json -import logging -from collections.abc import Iterable -from typing import Any, Optional - -import psycopg2 - -from funlib.geometry import Roi - -from ..types import Vec -from .sql_graph_database import SQLGraphDataBase - -logger = logging.getLogger(__name__) - - -class PgSQLGraphDatabase(SQLGraphDataBase): - def __init__( - self, - position_attribute: str, - db_name: str, - db_host: str = "localhost", - db_user: Optional[str] = None, - db_password: Optional[str] = None, - db_port: Optional[int] = None, - mode: str = "r+", - directed: Optional[bool] = None, - total_roi: Optional[Roi] = None, - nodes_table: str = "nodes", - edges_table: str = "edges", - endpoint_names: Optional[list[str]] = None, - node_attrs: Optional[dict[str, type]] = None, - edge_attrs: Optional[dict[str, type]] = None, - ): - self.db_host = db_host - self.db_name = db_name - self.db_user = db_user - self.db_password = db_password - self.db_port = db_port - - connection = psycopg2.connect( - host=db_host, - database="postgres", - user=db_user, - password=db_password, - port=db_port, - ) - connection.autocommit = True - cur = connection.cursor() - try: - cur.execute(f"CREATE DATABASE {db_name}") - except psycopg2.errors.DuplicateDatabase: - # DB already exists, moving on... - connection.rollback() - pass - self.connection = psycopg2.connect( - host=db_host, - database=db_name, - user=db_user, - password=db_password, - port=db_port, - ) - # TODO: remove once tests pass: - # self.connection.autocommit = True - self.cur = self.connection.cursor() - - super().__init__( - mode=mode, - position_attribute=position_attribute, - directed=directed, - total_roi=total_roi, - nodes_table=nodes_table, - edges_table=edges_table, - endpoint_names=endpoint_names, - node_attrs=node_attrs, # type: ignore - edge_attrs=edge_attrs, # type: ignore - ) - - def _drop_edges(self) -> None: - logger.info("dropping edges table %s", self.edges_table_name) - self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}") - self._commit() - - def _drop_tables(self) -> None: - logger.info( - "dropping tables %s, %s", - self.nodes_table_name, - self.edges_table_name, - ) - - self.__exec(f"DROP TABLE IF EXISTS {self.nodes_table_name}") - self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}") - self.__exec("DROP TABLE IF EXISTS metadata") - self._commit() - - def _create_tables(self) -> None: - columns = self.node_attrs.keys() - types = [self.__sql_type(t) for t in self.node_attrs.values()] - column_types = [f"{c} {t}" for c, t in zip(columns, types)] - self.__exec( - f"CREATE TABLE IF NOT EXISTS " - f"{self.nodes_table_name}(" - "id BIGINT not null PRIMARY KEY, " - f"{', '.join(column_types)}" - ")" - ) - self.__exec( - f"CREATE INDEX IF NOT EXISTS pos_index ON " - f"{self.nodes_table_name}({self.position_attribute})" - ) - - columns = list(self.edge_attrs.keys()) # type: ignore - types = list([self.__sql_type(t) for t in self.edge_attrs.values()]) - column_types = [f"{c} {t}" for c, t in zip(columns, types)] - endpoint_names = self.endpoint_names - assert endpoint_names is not None - self.__exec( - f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" - f"{endpoint_names[0]} BIGINT not null, " # type: ignore - f"{endpoint_names[1]} BIGINT not null, " - f"{' '.join([c + ',' for c in column_types])}" - f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" - ")" - ) - - self._commit() - - def _store_metadata(self, metadata) -> None: - self.__exec("DROP TABLE IF EXISTS metadata") - self.__exec("CREATE TABLE metadata (value VARCHAR)") - self._insert_query( - "metadata", ["value"], [[json.dumps(metadata)]], fail_if_exists=True - ) - - def _read_metadata(self) -> Optional[dict[str, Any]]: - try: - self.__exec("SELECT value FROM metadata") - except psycopg2.errors.UndefinedTable: - self.connection.rollback() - return None - - result = self.cur.fetchone() - if result is not None: - metadata = result[0] - return json.loads(metadata) - - return None - - def _select_query(self, query) -> Iterable[Any]: - self.__exec(query) - return self.cur - - def _insert_query( - self, table, columns, values, fail_if_exists=False, commit=True - ) -> None: - values_str = ( - "VALUES (" - + "), (".join( - [", ".join([self.__sql_value(v) for v in value]) for value in values] - ) - + ")" - ) - # TODO: fail_if_exists is the default if UNIQUE was used to create the - # table, we need to update if fail_if_exists==False - insert_statement = f"INSERT INTO {table}({', '.join(columns)}) " + values_str - self.__exec(insert_statement) - - if commit: - self.connection.commit() - - def _update_query(self, query, commit=True) -> None: - self.__exec(query) - - if commit: - self.connection.commit() - - def _commit(self) -> None: - self.connection.commit() - - def __exec(self, query): - try: - return self.cur.execute(query) - except: - self.connection.rollback() - raise - - def __sql_value(self, value): - if isinstance(value, str): - return f"'{value}'" - if isinstance(value, Iterable): - return f"array[{','.join([self.__sql_value(v) for v in value])}]" - elif value is None: - return "NULL" - else: - return str(value) - - def __sql_type(self, type): - if isinstance(type, Vec): - return self.__sql_type(type.dtype) + f"[{type.size}]" - try: - return {bool: "BOOLEAN", int: "INTEGER", str: "VARCHAR", float: "REAL"}[ - type - ] - except ValueError: - raise NotImplementedError( - f"attributes of type {type} are not yet supported" - ) diff --git a/pyproject.toml b/pyproject.toml index 0e1f0e1..539f9d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,17 +13,13 @@ authors = [ ] dynamic = ['version'] -requires-python = ">=3.10" +requires-python = ">=3.11" classifiers = ["Programming Language :: Python :: 3"] keywords = [] dependencies = [ - "zarr>=2,<3", - # ImportError: cannot import name 'cbuffer_sizes' from 'numcodecs.blosc' - # We can pin zarr to >2.18.7 but then we have to drop python 3.10 - # pin numcodecs to avoid breaking change - "numcodecs>0.13,<0.16.0", - "iohub>=0.2.0b0", + "zarr>=3,<4", + "iohub>=0.3.0a5", "funlib.geometry>=0.3.0", "networkx>=3.0.0", "pymongo>=4.0.0", @@ -31,7 +27,7 @@ dependencies = [ "pydantic>=2.0.0", "dask>=2024.0.0", "toml>=0.10.0", - "psycopg2-binary>=2.9.5", + "psycopg2-binary>=2.9.11", ] [tool.setuptools.dynamic] @@ -40,10 +36,10 @@ version = { attr = "funlib.persistence.__version__" } [project.optional-dependencies] dev = [ "coverage>=7.7.1", - "mypy>=1.15.0", "pytest>=8.3.5", "pytest-mock>=3.14.0", "ruff>=0.11.2", + "ty>=0.0.16", "types-networkx", "types-psycopg2", "types-toml", @@ -56,10 +52,8 @@ lint.select = ["F", "W", "I001"] [tool.setuptools.package-data] "funlib.persistence" = ["py.typed"] -[tool.mypy] -explicit_package_bases = true - # # module specific overrides [[tool.mypy.overrides]] module = ["zarr.*", "iohub.*"] ignore_missing_imports = true + diff --git a/funlib/persistence/__init__.py b/src/funlib/persistence/__init__.py similarity index 87% rename from funlib/persistence/__init__.py rename to src/funlib/persistence/__init__.py index ec995b5..b893c3c 100644 --- a/funlib/persistence/__init__.py +++ b/src/funlib/persistence/__init__.py @@ -1,4 +1,4 @@ from .arrays import Array, open_ds, prepare_ds, open_ome_ds, prepare_ome_ds # noqa -__version__ = "0.6.1" +__version__ = "0.7.0" __version_info__ = tuple(int(i) for i in __version__.split(".")) diff --git a/funlib/persistence/arrays/__init__.py b/src/funlib/persistence/arrays/__init__.py similarity index 100% rename from funlib/persistence/arrays/__init__.py rename to src/funlib/persistence/arrays/__init__.py diff --git a/funlib/persistence/arrays/array.py b/src/funlib/persistence/arrays/array.py similarity index 96% rename from funlib/persistence/arrays/array.py rename to src/funlib/persistence/arrays/array.py index f19f2ab..9acf3f5 100644 --- a/funlib/persistence/arrays/array.py +++ b/src/funlib/persistence/arrays/array.py @@ -6,9 +6,8 @@ import dask.array as da import numpy as np from dask.array.optimization import fuse_slice -from zarr import Array as ZarrArray - from funlib.geometry import Coordinate, Roi +from zarr import Array as ZarrArray from .freezable import Freezable from .lazy_ops import LazyOp @@ -135,7 +134,7 @@ def attrs(self) -> dict: @property def chunk_shape(self) -> Coordinate: - return Coordinate(self.data.chunksize) + return Coordinate(self.data.chunksize) # ty: ignore[unresolved-attribute] def uncollapsed_dims(self, physical: bool = False) -> list[bool]: """ @@ -344,18 +343,21 @@ def __getitem__(self, key) -> np.ndarray: else: return self.data[key].compute() - def __setitem__(self, key, value: np.ndarray): - """Set the data of this array within the given ROI. + def __setitem__(self, key: Roi | slice | tuple, value: np.ndarray | float | int): + """Set the data of this array. Args: - key (`class:Roi`): + key (`class:Roi` or any numpy compatible key): - The ROI to write to. + The region to write to. Can be a `Roi` for world-unit indexing, + or any numpy-compatible key (e.g. ``np.s_[:]``, a slice, a tuple + of slices). - value (``ndarray``): + value (``ndarray`` or scalar): - The value to write. + The value to write. Can be a numpy array or a scalar that will + be broadcast. """ if self.is_writeable: @@ -474,7 +476,7 @@ def _is_slice(self, lazy_op: LazyOp, writeable: bool = False) -> bool: elif isinstance(lazy_op, list) and all([isinstance(a, int) for a in lazy_op]): return True elif isinstance(lazy_op, tuple) and all( - [self._is_slice(a, writeable) for a in lazy_op] + [self._is_slice(a, writeable) for a in lazy_op] # type: ignore[arg-type] # ty can't narrow parameterized tuple iteration ): return True elif ( @@ -508,7 +510,7 @@ def validate(self, strict: bool = False): ) def to_pixel_space( - self, world_loc: Roi | Coordinate | Sequence[int | float] + self, world_loc: Roi | Coordinate | Sequence[int | float] | np.ndarray ) -> Roi | Coordinate | np.ndarray: """Convert a point or roi in world space into the pixel space of this array. Works on sequences of floats by returning a numpy array that is not guaranteed @@ -539,7 +541,7 @@ def to_pixel_space( ) def to_world_space( - self, pixel_loc: Roi | Coordinate | Sequence[int | float] + self, pixel_loc: Roi | Coordinate | Sequence[int | float] | np.ndarray ) -> Roi | Coordinate: """Convert a point or roi from pixel space in this array to the world coordinate system defined by this array's roi and voxel size. diff --git a/funlib/persistence/arrays/datasets.py b/src/funlib/persistence/arrays/datasets.py similarity index 92% rename from funlib/persistence/arrays/datasets.py rename to src/funlib/persistence/arrays/datasets.py index fd324d3..9df0c24 100644 --- a/funlib/persistence/arrays/datasets.py +++ b/src/funlib/persistence/arrays/datasets.py @@ -1,30 +1,20 @@ import logging from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union import numpy as np import zarr -from numpy.typing import DTypeLike - from funlib.geometry import Coordinate +from numpy.typing import DTypeLike from .array import Array from .metadata import MetaDataFormat, get_default_metadata_format logger = logging.getLogger(__name__) - -class ArrayNotFoundError(Exception): - """Exception raised when an array is not found in the dataset.""" - - def __init__(self, message: str = "Array not found in the dataset"): - self.message = message - super().__init__(self.message) - - def open_ds( store, - mode: str = "r", + mode: Literal["r", "r+", "a", "w", "w-"] = "r", metadata_format: Optional[MetaDataFormat] = None, offset: Optional[Sequence[int]] = None, voxel_size: Optional[Sequence[int]] = None, @@ -107,10 +97,9 @@ def open_ds( else get_default_metadata_format() ) - try: - data = zarr.open(store, mode=mode, **kwargs) - except zarr.errors.PathNotFoundError: - raise ArrayNotFoundError(f"Nothing found at path {store}") + data = zarr.open(store, mode=mode, **kwargs) + if not isinstance(data, zarr.Array): + raise TypeError(f"Expected a zarr Array at {store}, got {type(data).__name__}") metadata = metadata_format.parse( data.shape, @@ -144,7 +133,7 @@ def prepare_ds( types: Optional[Sequence[str]] = None, chunk_shape: Optional[Sequence[int]] = None, dtype: DTypeLike = np.float32, - mode: str = "a", + mode: Literal["r", "r+", "a", "w", "w-"] = "a", custom_metadata: dict[str, Any] | None = None, **kwargs, ) -> Array: @@ -239,7 +228,7 @@ def prepare_ds( try: existing_array = open_ds(store, mode="r", **kwargs) - except ArrayNotFoundError: + except FileNotFoundError: existing_array = None if existing_array is not None: @@ -328,6 +317,10 @@ def prepare_ds( ) else: ds = zarr.open(store, mode=mode, **kwargs) + if not isinstance(ds, zarr.Array): + raise TypeError( + f"Expected a zarr Array at {store}, got {type(ds).__name__}" + ) return Array( ds, existing_metadata.offset, @@ -349,18 +342,14 @@ def prepare_ds( ) # create the dataset - try: - ds = zarr.open_array( - store=store, - shape=shape, - chunks=chunk_shape, - dtype=dtype, - dimension_separator="/", - mode=mode, - **kwargs, - ) - except zarr.errors.ArrayNotFoundError: - raise ArrayNotFoundError(f"Nothing found at path {store}") + ds = zarr.open_array( + store=store, + shape=shape, + chunks=chunk_shape, + dtype=dtype, + mode=mode, + **kwargs, + ) default_metadata_format = get_default_metadata_format() our_metadata = { diff --git a/funlib/persistence/arrays/freezable.py b/src/funlib/persistence/arrays/freezable.py similarity index 100% rename from funlib/persistence/arrays/freezable.py rename to src/funlib/persistence/arrays/freezable.py diff --git a/src/funlib/persistence/arrays/lazy_ops.py b/src/funlib/persistence/arrays/lazy_ops.py new file mode 100644 index 0000000..4e107be --- /dev/null +++ b/src/funlib/persistence/arrays/lazy_ops.py @@ -0,0 +1,6 @@ +from typing import Callable, Union + +import numpy as np +from funlib.geometry import Roi + +LazyOp = Union[slice, int, tuple[int | slice | list[int] | np.ndarray, ...], Callable, Roi] diff --git a/funlib/persistence/arrays/metadata.py b/src/funlib/persistence/arrays/metadata.py similarity index 88% rename from funlib/persistence/arrays/metadata.py rename to src/funlib/persistence/arrays/metadata.py index e3ed064..cf80b0f 100644 --- a/funlib/persistence/arrays/metadata.py +++ b/src/funlib/persistence/arrays/metadata.py @@ -1,13 +1,11 @@ import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any, Optional import toml -import zarr -from pydantic import BaseModel - from funlib.geometry import Coordinate +from pydantic import BaseModel def strip_channels( @@ -269,7 +267,7 @@ class Config: extra = "forbid" def fetch( - self, data: dict[str | int, Any], key: str + self, data: Mapping[str, Any], key: str ) -> Sequence[str | int | None] | None: """ Given a dictionary of attributes from e.g. zarr.open(...).attrs, fetch the value @@ -293,7 +291,7 @@ def fetch( keys = key.split("/") def recurse( - data: dict[str | int, Any] | list[Any], keys: list[str] + data: Any, keys: list[str] ) -> Sequence[str | int | None] | str | int | None: current_key: str | int current_key, *keys = keys @@ -305,7 +303,7 @@ def recurse( # base case if len(keys) == 0: # this key returns the data we want - if isinstance(data, (dict, zarr.attrs.Attributes)): + if isinstance(data, Mapping): return data.get(str(current_key), None) elif isinstance(data, list): assert isinstance(current_key, int), current_key @@ -331,45 +329,49 @@ def recurse( return recurse(data[current_key], keys) result = recurse(data, keys) - assert isinstance(result, Sequence) or result is None, result + assert (isinstance(result, Sequence) and not isinstance(result, (str, int))) or result is None, result return result def parse( self, - shape, - data: dict[str | int, Any], - offset=None, - voxel_size=None, - axis_names=None, - units=None, - types=None, - strict=False, + shape: Sequence[int], + data: Mapping[str, Any], + offset: Optional[Sequence[int]] = None, + voxel_size: Optional[Sequence[int]] = None, + axis_names: Optional[Sequence[str]] = None, + units: Optional[Sequence[str]] = None, + types: Optional[Sequence[str]] = None, + strict: bool = False, ) -> MetaData: - offset = offset if offset is not None else self.fetch(data, self.offset_attr) - voxel_size = ( + fetched_offset = offset if offset is not None else self.fetch(data, self.offset_attr) + fetched_voxel_size = ( voxel_size if voxel_size is not None else self.fetch(data, self.voxel_size_attr) ) - axis_names = ( + fetched_axis_names = ( axis_names if axis_names is not None else self.fetch(data, self.axis_names_attr) ) - units = units if units is not None else self.fetch(data, self.units_attr) - types = types if types is not None else self.fetch(data, self.types_attr) - if types is None and axis_names is not None: - types = [ - "channel" if name.endswith("^") else "space" for name in axis_names + fetched_units = units if units is not None else self.fetch(data, self.units_attr) + fetched_types = types if types is not None else self.fetch(data, self.types_attr) + if fetched_types is None and fetched_axis_names is not None: + fetched_types = [ + "channel" if str(name).endswith("^") else "space" + for name in fetched_axis_names ] + # fetch() returns Sequence[str | int | None] | None from untyped metadata. + # Some OME-Zarr metadata may have holes (None elements), so we pass + # the fetched values through as-is; MetaData.validate() checks at runtime. metadata = MetaData( shape=shape, - offset=offset, - voxel_size=voxel_size, - axis_names=axis_names, - units=units, - types=types, + offset=fetched_offset, # type: ignore[arg-type] + voxel_size=fetched_voxel_size, # type: ignore[arg-type] + axis_names=fetched_axis_names, # type: ignore[arg-type] + units=fetched_units, # type: ignore[arg-type] + types=fetched_types, # type: ignore[arg-type] strict=strict, ) diff --git a/funlib/persistence/arrays/ome_datasets.py b/src/funlib/persistence/arrays/ome_datasets.py similarity index 94% rename from funlib/persistence/arrays/ome_datasets.py rename to src/funlib/persistence/arrays/ome_datasets.py index c194b87..1912ade 100644 --- a/funlib/persistence/arrays/ome_datasets.py +++ b/src/funlib/persistence/arrays/ome_datasets.py @@ -1,13 +1,13 @@ import logging from collections.abc import Sequence from pathlib import Path +from typing import Literal +from funlib.geometry import Coordinate from iohub.ngff import TransformationMeta, open_ome_zarr from iohub.ngff.models import AxisMeta from numpy.typing import DTypeLike -from funlib.geometry import Coordinate - from .array import Array from .metadata import MetaData @@ -17,7 +17,7 @@ def open_ome_ds( store: Path, name: str, - mode: str = "r", + mode: Literal["r", "r+", "a", "w", "w-"] = "r", **kwargs, ) -> Array: """ @@ -171,7 +171,7 @@ def prepare_ome_ds( ) axis_metadata = [ - AxisMeta(name=n, type=t, unit=u) + AxisMeta(name=n, type=t, unit=u) # type: ignore[misc] for n, t, u in zip(metadata.axis_names, metadata.types, metadata.ome_units) ] @@ -180,8 +180,8 @@ def prepare_ome_ds( store, mode="w", layout="fov", axes=axis_metadata, channel_names=channel_names ) as ds: transforms = [ - TransformationMeta(type="scale", scale=metadata.ome_scale), - TransformationMeta(type="translation", translation=metadata.ome_translate), + TransformationMeta(type="scale", scale=list(metadata.ome_scale)), + TransformationMeta(type="translation", translation=list(metadata.ome_translate)), ] ds.create_zeros( diff --git a/funlib/persistence/arrays/utils.py b/src/funlib/persistence/arrays/utils.py similarity index 100% rename from funlib/persistence/arrays/utils.py rename to src/funlib/persistence/arrays/utils.py diff --git a/funlib/persistence/graphs/__init__.py b/src/funlib/persistence/graphs/__init__.py similarity index 100% rename from funlib/persistence/graphs/__init__.py rename to src/funlib/persistence/graphs/__init__.py diff --git a/funlib/persistence/graphs/graph_database.py b/src/funlib/persistence/graphs/graph_database.py similarity index 63% rename from funlib/persistence/graphs/graph_database.py rename to src/funlib/persistence/graphs/graph_database.py index acd474a..91dc65b 100644 --- a/funlib/persistence/graphs/graph_database.py +++ b/src/funlib/persistence/graphs/graph_database.py @@ -1,10 +1,9 @@ import logging from abc import ABC, abstractmethod -from typing import Optional - -from networkx import Graph +from typing import Any, Optional from funlib.geometry import Roi +from networkx import Graph from ..types import Vec @@ -59,6 +58,9 @@ def read_graph( read_edges: bool = True, node_attrs: Optional[list[str]] = None, edge_attrs: Optional[list[str]] = None, + nodes_filter: Optional[dict[str, Any]] = None, + edges_filter: Optional[dict[str, Any]] = None, + fetch_on_v: bool = False, ) -> Graph: """ Read a graph from the database for a given roi. @@ -81,6 +83,20 @@ def read_graph( If not ``None``, only read the given edge attributes. + nodes_filter (``dict[str, Any]`` or ``None``): + + If not ``None``, only read nodes matching these attribute values. + + edges_filter (``dict[str, Any]`` or ``None``): + + If not ``None``, only read edges matching these attribute values. + + fetch_on_v (``bool``): + + If ``True``, also fetch edges where the ``v`` endpoint matches + (i.e., either endpoint is in the ROI or node list). If ``False`` + (default), only fetch edges where ``u`` matches. + """ pass @@ -149,3 +165,53 @@ def write_attrs( Alias call to write_graph with write_nodes and write_edges set to False. """ pass + + @abstractmethod + def bulk_write_graph( + self, + graph: Graph, + roi: Optional[Roi] = None, + write_nodes: bool = True, + write_edges: bool = True, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + ) -> None: + """ + Fast bulk write of a graph. Mirrors ``write_graph`` but optimized + for large batch inserts. Does not support ``fail_if_exists`` or + ``delete``. + """ + pass + + @abstractmethod + def bulk_write_mode( + self, + worker: bool = False, + node_writes: bool = True, + edge_writes: bool = True, + ): + """Context manager that optimizes the database for bulk writes. + + Drops indexes and adjusts database settings for maximum write + throughput, then restores them on exit. + + Arguments: + + worker (``bool``): + + If ``False`` (default), drops and rebuilds indexes around the + block. Set to ``True`` for parallel workers whose orchestrator + manages indexes separately — only session-level performance + settings will be adjusted. + + node_writes (``bool``): + + If ``True`` (default), drop/rebuild node primary key and + position indexes. Ignored when ``worker=True``. + + edge_writes (``bool``): + + If ``True`` (default), drop/rebuild edge primary key index. + Ignored when ``worker=True``. + """ + pass diff --git a/src/funlib/persistence/graphs/pgsql_graph_database.py b/src/funlib/persistence/graphs/pgsql_graph_database.py new file mode 100644 index 0000000..2e0c227 --- /dev/null +++ b/src/funlib/persistence/graphs/pgsql_graph_database.py @@ -0,0 +1,360 @@ +import io +import json +import logging +from collections.abc import Iterable +from contextlib import contextmanager +from typing import Any, Optional + +import psycopg2 +from funlib.geometry import Roi +from psycopg2 import sql + +from ..types import Vec +from .sql_graph_database import AttributeType, SQLGraphDataBase + +logger = logging.getLogger(__name__) + + +class PgSQLGraphDatabase(SQLGraphDataBase): + def __init__( + self, + position_attribute: str, + db_name: str, + db_host: str = "localhost", + db_user: Optional[str] = None, + db_password: Optional[str] = None, + db_port: Optional[int] = None, + mode: str = "r+", + directed: Optional[bool] = None, + total_roi: Optional[Roi] = None, + nodes_table: str = "nodes", + edges_table: str = "edges", + endpoint_names: Optional[list[str]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, + ): + self.db_host = db_host + self.db_name = db_name + self.db_user = db_user + self.db_password = db_password + self.db_port = db_port + + connection = psycopg2.connect( + host=db_host, + database="postgres", + user=db_user, + password=db_password, + port=db_port, + ) + connection.autocommit = True + cur = connection.cursor() + try: + cur.execute(f"CREATE DATABASE {db_name}") + except psycopg2.errors.DuplicateDatabase: + # DB already exists, moving on... + connection.rollback() + pass + connection.close() + self.connection = psycopg2.connect( + host=db_host, + database=db_name, + user=db_user, + password=db_password, + port=db_port, + ) + self.cur = self.connection.cursor() + + super().__init__( + mode=mode, + position_attribute=position_attribute, + directed=directed, + total_roi=total_roi, + nodes_table=nodes_table, + edges_table=edges_table, + endpoint_names=endpoint_names, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ) + + def close(self): + if not self.connection.closed: + self.connection.close() + + def _drop_edges(self) -> None: + logger.info("dropping edges table %s", self.edges_table_name) + self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}") + self._commit() + + def _drop_tables(self) -> None: + logger.info( + "dropping tables %s, %s", + self.nodes_table_name, + self.edges_table_name, + ) + + self.__exec(f"DROP TABLE IF EXISTS {self.nodes_table_name}") + self.__exec(f"DROP TABLE IF EXISTS {self.edges_table_name}") + self.__exec("DROP TABLE IF EXISTS metadata") + self._commit() + + def _create_tables(self) -> None: + columns = list(self.node_attrs.keys()) + types = [self.__sql_type(t) for t in self.node_attrs.values()] + column_types = [f"{c} {t}" for c, t in zip(columns, types)] + self.__exec( + f"CREATE TABLE IF NOT EXISTS " + f"{self.nodes_table_name}(" + "id BIGINT not null PRIMARY KEY, " + f"{', '.join(column_types)}" + ")" + ) + self.__exec( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{self.nodes_table_name}({self.position_attribute})" + ) + + columns = list(self.edge_attrs.keys()) + types = list([self.__sql_type(t) for t in self.edge_attrs.values()]) + column_types = [f"{c} {t}" for c, t in zip(columns, types)] + endpoint_names = self.endpoint_names + assert endpoint_names is not None + self.__exec( + f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" + f"{endpoint_names[0]} BIGINT not null, " + f"{endpoint_names[1]} BIGINT not null, " + f"{' '.join([c + ',' for c in column_types])}" + f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" + ")" + ) + + self._commit() + + def _store_metadata(self, metadata) -> None: + self.__exec("DROP TABLE IF EXISTS metadata") + self.__exec("CREATE TABLE metadata (value VARCHAR)") + self._insert_query( + "metadata", ["value"], [[json.dumps(metadata)]], fail_if_exists=True + ) + + def _read_metadata(self) -> Optional[dict[str, Any]]: + try: + self.__exec("SELECT value FROM metadata") + except psycopg2.errors.UndefinedTable: + self.connection.rollback() + return None + + result = self.cur.fetchone() + if result is not None: + metadata = result[0] + return json.loads(metadata) + + return None + + def _select_query(self, query) -> Iterable[Any]: + self.__exec(query) + return self.cur + + def _insert_query( + self, table, columns, values, fail_if_exists=False, commit=True + ) -> None: + if not values: + return + + values_str = ( + "VALUES (" + + "), (".join( + [", ".join([self.__sql_value(v) for v in value]) for value in values] + ) + + ")" + ) + insert_statement = f"INSERT INTO {table}({', '.join(columns)}) " + values_str + if not fail_if_exists: + insert_statement += " ON CONFLICT DO NOTHING" + self.__exec(insert_statement) + + if commit: + self.connection.commit() + + def _update_query(self, query, commit=True) -> None: + self.__exec(query) + + if commit: + self.connection.commit() + + def _commit(self) -> None: + self.connection.commit() + + def __exec(self, query): + try: + return self.cur.execute(query) + except: + self.connection.rollback() + raise + + def __sql_value(self, value): + if isinstance(value, str): + return f"'{value}'" + if isinstance(value, Iterable): + return f"array[{','.join([self.__sql_value(v) for v in value])}]" + elif value is None: + return "NULL" + else: + return str(value) + + def __sql_type(self, type): + if isinstance(type, Vec): + return self.__sql_type(type.dtype) + f"[{type.size}]" + try: + return {bool: "BOOLEAN", int: "BIGINT", str: "VARCHAR", float: "REAL"}[type] + except ValueError: + raise NotImplementedError( + f"attributes of type {type} are not yet supported" + ) + + def print_summary(self, schema="public", limit=None): + self._commit() + + def fmt_bytes(n): + for unit in ["B", "KB", "MB", "GB", "TB", "PB"]: + if n < 1024 or unit == "PB": + return f"{n:.1f} {unit}" if unit != "B" else f"{int(n)} {unit}" + n /= 1024 + + # Basic DB info (read-only) + self.cur.execute("SELECT current_database(), current_user, version();") + db, user, version = self.cur.fetchone() + print(f"DB: {db} | User: {user}") + print(version.split("\n")[0]) + print("-" * 80) + + q = sql.SQL( + """ + SELECT + c.relname AS table_name, + c.relkind, + COALESCE(c.reltuples::bigint, 0) AS est_rows, + pg_total_relation_size(c.oid) AS total_bytes + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = %s + AND c.relkind IN ('r','p') -- r=table, p=partitioned table + ORDER BY pg_total_relation_size(c.oid) DESC, c.relname + """ + ) + if limit is not None: + q = q + sql.SQL(" LIMIT %s") + self.cur.execute(q, (schema, limit)) + else: + self.cur.execute(q, (schema,)) + rows = self.cur.fetchall() + if not rows: + print(f"No tables found in schema={schema!r}") + return + + # Pretty print + name_w = min(max(len(r[0]) for r in rows), 60) + header = f"{'table':<{name_w}} {'kind':<4} {'est_rows':>12} {'size':>10}" + print(header) + print("-" * len(header)) + + kind_map = {"r": "tbl", "p": "part"} + for name, relkind, est_rows, total_bytes in rows: + print( + f"{name:<{name_w}} {kind_map.get(relkind, relkind):<4} {est_rows:>12,} {fmt_bytes(total_bytes):>10}" + ) + + print("-" * 80) + print(f"Tables shown: {len(rows)} (schema={schema!r})") + + @contextmanager + def bulk_write_mode(self, worker=False, node_writes=True, edge_writes=True): + nodes = self.nodes_table_name + edges = self.edges_table_name + endpoint_names = self.endpoint_names + assert endpoint_names is not None + + if not worker: + if node_writes: + self.__exec("DROP INDEX IF EXISTS pos_index") + self.__exec( + f"ALTER TABLE {nodes} DROP CONSTRAINT IF EXISTS {nodes}_pkey" + ) + if edge_writes: + self.__exec( + f"ALTER TABLE {edges} DROP CONSTRAINT IF EXISTS {edges}_pkey" + ) + self._commit() + + self.__exec("SET synchronous_commit TO OFF") + self._commit() + + try: + yield + finally: + self.__exec("SET synchronous_commit TO ON") + self._commit() + + if not worker: + logger.info("Re-creating indexes and constraints...") + if node_writes: + self.__exec( + f"ALTER TABLE {nodes} " + f"ADD CONSTRAINT {nodes}_pkey PRIMARY KEY (id)" + ) + self.__exec( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{nodes}({self.position_attribute})" + ) + if edge_writes: + self.__exec( + f"ALTER TABLE {edges} " + f"ADD CONSTRAINT {edges}_pkey " + f"PRIMARY KEY ({endpoint_names[0]}, {endpoint_names[1]})" + ) + self._commit() + + def _bulk_insert(self, table, columns, rows) -> None: + def format_gen(): + for row in rows: + formatted = [] + for val in row: + if val is None: + formatted.append(r"\N") + elif isinstance(val, (list, tuple)): + formatted.append(f"{{{','.join(map(str, val))}}}") + else: + formatted.append(str(val)) + yield "\t".join(formatted) + "\n" + + self._stream_copy(table, columns, format_gen()) + self._commit() + + def _stream_copy(self, table_name, columns, data_generator): + """ + Consumes a generator of strings and sends them to Postgres via COPY. + Uses a chunked buffer to keep memory usage stable. + """ + # Tune this size (in bytes). 10MB - 50MB is usually a sweet spot. + BATCH_SIZE = 50 * 1024 * 1024 + + buffer = io.StringIO() + current_size = 0 + + # Helper to flush buffer to DB + def flush(): + buffer.seek(0) + self.cur.copy_from(buffer, table_name, columns=columns, null=r"\N") + buffer.truncate(0) + buffer.seek(0) + + for line in data_generator: + buffer.write(line) + current_size += len(line) + + if current_size >= BATCH_SIZE: + flush() + current_size = 0 + + # Flush remaining + if current_size > 0: + flush() diff --git a/funlib/persistence/graphs/sql_graph_database.py b/src/funlib/persistence/graphs/sql_graph_database.py similarity index 62% rename from funlib/persistence/graphs/sql_graph_database.py rename to src/funlib/persistence/graphs/sql_graph_database.py index cec7381..5cb511d 100644 --- a/funlib/persistence/graphs/sql_graph_database.py +++ b/src/funlib/persistence/graphs/sql_graph_database.py @@ -2,11 +2,10 @@ from abc import abstractmethod from typing import Any, Iterable, Optional +from funlib.geometry import Coordinate, Roi from networkx import DiGraph, Graph from networkx.classes.reportviews import NodeView, OutEdgeView -from funlib.geometry import Coordinate, Roi - from ..types import Vec, type_to_str from .graph_database import AttributeType, GraphDataBase @@ -86,20 +85,20 @@ def __init__( self.mode = mode if mode in self.read_modes: - self.position_attribute = position_attribute - self.directed = directed - self.total_roi = total_roi - self.nodes_table_name = nodes_table - self.edges_table_name = edges_table - self.endpoint_names = endpoint_names - self._node_attrs = node_attrs - self._edge_attrs = edge_attrs - self.ndims = None # to be read from metadata - metadata = self._read_metadata() if metadata is None: raise RuntimeError("metadata does not exist, can't open in read mode") - self.__load_metadata(metadata) + self.__load_metadata( + metadata, + position_attribute=position_attribute, + directed=directed, + total_roi=total_roi, + nodes_table=nodes_table, + edges_table=edges_table, + endpoint_names=endpoint_names, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ) if mode in self.create_modes: # this is where we populate default values for the DB creation @@ -113,7 +112,7 @@ def __init__( def get(value, default): return value if value is not None else default - self.position_attribute = get(position_attribute, "position") + self.position_attribute: str = get(position_attribute, "position") assert self.position_attribute in node_attrs, ( "No type information for position attribute " @@ -122,7 +121,7 @@ def get(value, default): position_type = node_attrs[self.position_attribute] if isinstance(position_type, Vec): - self.ndims = position_type.size + self.ndims: int = position_type.size assert self.ndims > 1, ( "Don't use Vecs of size 1 for the position, use the " "scalar type directly instead (i.e., 'float' instead of " @@ -132,13 +131,13 @@ def get(value, default): else: self.ndims = 1 - self.directed = get(directed, False) - self.total_roi = get( + self.directed: bool = get(directed, False) + self.total_roi: Roi = get( total_roi, Roi((None,) * self.ndims, (None,) * self.ndims) ) - self.nodes_table_name = get(nodes_table, "nodes") - self.edges_table_name = get(edges_table, "edges") - self.endpoint_names = get(endpoint_names, ["u", "v"]) + self.nodes_table_name: str = get(nodes_table, "nodes") + self.edges_table_name: str = get(edges_table, "edges") + self.endpoint_names: list[str] = get(endpoint_names, ["u", "v"]) self._node_attrs = node_attrs # no default, needs to be given self._edge_attrs = get(edge_attrs, {}) @@ -190,6 +189,17 @@ def _update_query(self, query, commit=True) -> None: def _commit(self) -> None: pass + @abstractmethod + def _bulk_insert(self, table, columns, rows) -> None: + """Insert rows using a backend-optimized bulk method. + + Args: + table: Table name. + columns: Column names. + rows: Iterable of row lists (Python values, not formatted strings). + """ + pass + def _node_attrs_to_columns(self, attrs): # default: each attribute maps to its own column return attrs @@ -214,6 +224,7 @@ def read_graph( edge_attrs: Optional[list[str]] = None, nodes_filter: Optional[dict[str, Any]] = None, edges_filter: Optional[dict[str, Any]] = None, + fetch_on_v: bool = False, ) -> Graph: graph: Graph if self.directed: @@ -230,10 +241,32 @@ def read_graph( graph.add_nodes_from(node_list) if read_edges: - edges = self.read_edges( - roi, nodes=nodes, read_attrs=edge_attrs, attr_filter=edges_filter - ) - u, v = self.endpoint_names # type: ignore + # When a nodes_filter is used, the filtered node set is narrower + # than the ROI. Fall back to the nodes path so edges only connect + # nodes that passed the filter. Otherwise use the faster ROI path. + if nodes_filter: + edges = self.read_edges( + nodes=nodes, + read_attrs=edge_attrs, + attr_filter=edges_filter, + fetch_on_v=fetch_on_v, + ) + else: + # We use ROI to query edges ro avoid serializing a list of + # node IDs into the SQL query. + + # A fully unbounded ROI (all None in shape) provides no + # filtering, so treat it as None to fetch all edges. + effective_roi = roi + if roi is not None and all(s is None for s in roi.shape): + effective_roi = None + edges = self.read_edges( + roi=effective_roi, + read_attrs=edge_attrs, + attr_filter=edges_filter, + fetch_on_v=fetch_on_v, + ) + u, v = self.endpoint_names try: edge_list = [(e[u], e[v], self.__remove_keys(e, [u, v])) for e in edges] except KeyError as e: @@ -289,6 +322,58 @@ def write_graph( delete=delete, ) + def bulk_write_graph( + self, + graph: Graph, + roi: Optional[Roi] = None, + write_nodes: bool = True, + write_edges: bool = True, + node_attrs: Optional[list[str]] = None, + edge_attrs: Optional[list[str]] = None, + ) -> None: + if write_nodes: + self.bulk_write_nodes(graph.nodes, roi=roi, attributes=node_attrs) + if write_edges: + self.bulk_write_edges( + graph.nodes, graph.edges, roi=roi, attributes=edge_attrs + ) + + def bulk_write_nodes(self, nodes, roi=None, attributes=None): + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + attrs = attributes if attributes is not None else list(self.node_attrs.keys()) + columns = ["id"] + list(attrs) + + def rows(): + for node_id, data in nodes.items(): + pos = self.__get_node_pos(data) + if roi is not None and not roi.contains(pos): + continue + yield [node_id] + [data.get(attr, None) for attr in attrs] + + self._bulk_insert(self.nodes_table_name, columns, rows()) + + def bulk_write_edges(self, nodes, edges, roi=None, attributes=None): + if self.mode == "r": + raise RuntimeError("Trying to write to read-only DB") + + u_name, v_name = self.endpoint_names + attrs = attributes if attributes is not None else list(self.edge_attrs.keys()) + columns = [u_name, v_name] + list(attrs) + + def rows(): + for (u, v), data in edges.items(): + if not self.directed: + u, v = min(u, v), max(u, v) + if roi is not None: + pos_u = self.__get_node_pos(nodes[u]) + if pos_u is None or not roi.contains(pos_u): + continue + yield [u, v] + [data.get(attr, None) for attr in attrs] + + self._bulk_insert(self.edges_table_name, columns, rows()) + @property def node_attrs(self) -> dict[str, AttributeType]: return self._node_attrs if self._node_attrs is not None else {} @@ -334,10 +419,6 @@ def read_nodes( ) ) - attr_filter = attr_filter if attr_filter is not None else {} - for k, v in attr_filter.items(): - select_statement += f" AND {k}={self.__convert_to_sql(v)}" - nodes = [ self._columns_to_node_attrs( {key: val for key, val in zip(read_columns, values)}, read_attrs @@ -365,49 +446,105 @@ def read_edges( nodes: Optional[list[dict[str, Any]]] = None, attr_filter: Optional[dict[str, Any]] = None, read_attrs: Optional[list[str]] = None, + fetch_on_v: bool = False, ) -> list[dict[str, Any]]: - """Returns a list of edges within roi.""" + """ + Returns a list of edges within roi, connected to provided nodes, or all edges. + + Args: + fetch_on_v: If True, also match edges where the v endpoint is in the + node list or ROI. If False (default), only match on u. - if nodes is None: - nodes = self.read_nodes(roi) + Raises: + ValueError: If both roi and nodes are provided. + """ - if len(nodes) == 0: - return [] + if roi is not None and nodes is not None: + raise ValueError( + "read_edges does not support both roi and nodes at the same time. " + "Pass one or the other." + ) endpoint_names = self.endpoint_names - assert endpoint_names is not None - node_ids = ", ".join([str(node["id"]) for node in nodes]) - node_condition = f"{endpoint_names[0]} IN ({node_ids})" # type: ignore + # 1. Determine the base SELECT statement and WHERE clause - logger.debug("Reading nodes in roi %s" % roi) - # TODO: AND vs OR here - desired_columns = ", ".join(endpoint_names + list(self.edge_attrs.keys())) # type: ignore - select_statement = ( - f"SELECT {desired_columns} FROM {self.edges_table_name} WHERE " - + node_condition - + ( - " AND " + self.__attr_query(attr_filter) - if attr_filter is not None and len(attr_filter) > 0 - else "" + # Columns to select from the edge table (T1) + edge_table_cols = endpoint_names + list(self.edge_attrs.keys()) + desired_columns = ", ".join(edge_table_cols) + + # Base query starts with selecting all columns from the edges table + select_statement = f"SELECT {desired_columns} FROM {self.edges_table_name}" + where_clauses = [] + using_join = False + + if nodes is not None: + # Case 1: Filter by explicit list of nodes + if len(nodes) == 0: + return [] + + node_ids = ", ".join([str(node["id"]) for node in nodes]) + if fetch_on_v: + where_clauses.append( + f"({endpoint_names[0]} IN ({node_ids})" + f" OR {endpoint_names[1]} IN ({node_ids}))" + ) + else: + where_clauses.append(f"{endpoint_names[0]} IN ({node_ids})") + + elif roi is not None: + # Case 2: Filter by ROI using INNER JOIN + using_join = True + node_id_column = "id" + + edge_cols = ", ".join([f"T1.{col}" for col in edge_table_cols]) + roi_condition = self.__roi_query(roi).replace("WHERE ", "") + + join_condition = f"T1.{endpoint_names[0]} = T2.{node_id_column}" + if fetch_on_v: + join_condition += f" OR T1.{endpoint_names[1]} = T2.{node_id_column}" + + select_statement = ( + f"SELECT DISTINCT {edge_cols} " + f"FROM {self.edges_table_name} AS T1 " + f"INNER JOIN {self.nodes_table_name} AS T2 " + f"ON {join_condition} " ) - ) + where_clauses.append(roi_condition) + + # Case 3: Both nodes and roi are None — fetch all edges + + # 2. Add Attribute Filter to WHERE clauses + if attr_filter is not None and len(attr_filter) > 0: + if using_join: + # Qualify each attribute with T1 for the JOIN case + parts = [ + f"T1.{k}={self.__convert_to_sql(v)}" for k, v in attr_filter.items() + ] + where_clauses.append(" AND ".join(parts)) + else: + where_clauses.append(self.__attr_query(attr_filter)) + + # 3. Finalize the SELECT statement + if len(where_clauses) > 0: + select_statement += " WHERE " + " AND ".join(where_clauses) + + logger.debug(f"Reading edges with query: {select_statement}") + + # 4. Execute Query and Process Results - edge_attrs = endpoint_names + ( # type: ignore + # Define the keys for the result dictionaries + all_edge_keys = endpoint_names + list(self.edge_attrs.keys()) + + # Define which keys to keep based on read_attrs + final_edge_keys = endpoint_names + ( list(self.edge_attrs.keys()) if read_attrs is None else read_attrs ) - attr_filter = attr_filter if attr_filter is not None else {} - for k, v in attr_filter.items(): - select_statement += f" AND {k}={self.__convert_to_sql(v)}" - edges = [ { key: val - for key, val in zip( - endpoint_names + list(self.edge_attrs.keys()), - values, # type: ignore - ) - if key in edge_attrs + for key, val in zip(all_edge_keys, values) + if key in final_edge_keys } for values in self._select_query(select_statement) ] @@ -490,8 +627,8 @@ def update_edges( if not roi.contains(pos_u): logger.debug( ( - f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," # type: ignore - + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" # type: ignore + f"Skipping edge with {self.endpoint_names[0]} {{}}, {self.endpoint_names[1]} {{}}," + + f"and data {{}} because {self.endpoint_names[0]} not in roi {{}}" ).format(u, v, data, roi) ) continue @@ -501,7 +638,7 @@ def update_edges( update_statement = ( f"UPDATE {self.edges_table_name} SET " f"{', '.join(setters)} WHERE " - f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" # type: ignore + f"{self.endpoint_names[0]}={u} AND {self.endpoint_names[1]}={v}" ) self._update_query(update_statement, commit=False) @@ -538,7 +675,9 @@ def write_nodes( logger.debug("No nodes to insert in %s", roi) return - self._insert_query(self.nodes_table_name, columns, values, fail_if_exists=True) + self._insert_query( + self.nodes_table_name, columns, values, fail_if_exists=fail_if_exists + ) def update_nodes( self, @@ -581,7 +720,6 @@ def update_nodes( def __create_metadata(self): """Sets the metadata in the meta collection to the provided values""" - metadata = { "position_attribute": self.position_attribute, "directed": self.directed, @@ -597,59 +735,64 @@ def __create_metadata(self): return metadata - def __load_metadata(self, metadata): + def __load_metadata( + self, + metadata, + position_attribute: Optional[str] = None, + directed: Optional[bool] = None, + total_roi: Optional[Roi] = None, + nodes_table: Optional[str] = None, + edges_table: Optional[str] = None, + endpoint_names: Optional[list[str]] = None, + node_attrs: Optional[dict[str, AttributeType]] = None, + edge_attrs: Optional[dict[str, AttributeType]] = None, + ): """Load the provided metadata into this object's attributes, check if - it is consistent with already populated fields.""" - - # simple attributes - for attr_name in [ - "position_attribute", - "directed", - "nodes_table_name", - "edges_table_name", - "endpoint_names", - "ndims", - ]: - if getattr(self, attr_name) is None: - setattr(self, attr_name, metadata[attr_name]) - else: - value = getattr(self, attr_name) - assert value == metadata[attr_name], ( - f"Attribute {attr_name} is already set to {value} for this " - "object, but disagrees with the stored metadata value of " - f"{metadata[attr_name]}" + user-provided overrides are consistent with stored metadata.""" + + # For each simple attribute, use metadata as the source of truth. + # If the user also provided a value, check consistency. + overrides: dict[str, Any] = { + "position_attribute": position_attribute, + "directed": directed, + "nodes_table_name": nodes_table, + "edges_table_name": edges_table, + "endpoint_names": endpoint_names, + "ndims": None, # ndims is never user-provided + } + for attr_name, override in overrides.items(): + stored = metadata[attr_name] + if override is not None: + assert override == stored, ( + f"Attribute {attr_name} was given as {override}, but " + f"disagrees with the stored metadata value of {stored}" ) - - # special attributes - - total_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) - if self.total_roi is None: - self.total_roi = total_roi - else: - assert self.total_roi == total_roi, ( - f"Attribute total_roi is already set to {self.total_roi} for " - "this object, but disagrees with the stored metadata value of " - f"{total_roi}" + setattr(self, attr_name, stored) + + # total_roi + stored_roi = Roi(metadata["total_roi_offset"], metadata["total_roi_shape"]) + if total_roi is not None: + assert total_roi == stored_roi, ( + f"Attribute total_roi was given as {total_roi}, but " + f"disagrees with the stored metadata value of {stored_roi}" ) - - node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} - edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} - if self._node_attrs is None: - self.node_attrs = node_attrs - else: - assert self.node_attrs == node_attrs, ( - f"Attribute node_attrs is already set to {self.node_attrs} for " - "this object, but disagrees with the stored metadata value of " - f"{node_attrs}" + self.total_roi = stored_roi + + # node_attrs / edge_attrs + stored_node_attrs = {k: eval(v) for k, v in metadata["node_attrs"].items()} + stored_edge_attrs = {k: eval(v) for k, v in metadata["edge_attrs"].items()} + if node_attrs is not None: + assert node_attrs == stored_node_attrs, ( + f"Attribute node_attrs was given as {node_attrs}, but " + f"disagrees with the stored metadata value of {stored_node_attrs}" ) - if self._edge_attrs is None: - self.edge_attrs = edge_attrs - else: - assert self.edge_attrs == edge_attrs, ( - f"Attribute edge_attrs is already set to {self.edge_attrs} for " - "this object, but disagrees with the stored metadata value of " - f"{edge_attrs}" + self._node_attrs = stored_node_attrs + if edge_attrs is not None: + assert edge_attrs == stored_edge_attrs, ( + f"Attribute edge_attrs was given as {edge_attrs}, but " + f"disagrees with the stored metadata value of {stored_edge_attrs}" ) + self._edge_attrs = stored_edge_attrs def __remove_keys(self, dictionary, keys): """Removes given keys from dictionary.""" @@ -658,7 +801,7 @@ def __remove_keys(self, dictionary, keys): def __get_node_pos(self, n: dict[str, Any]) -> Optional[Coordinate]: try: - return Coordinate(n[self.position_attribute]) # type: ignore + return Coordinate(n[self.position_attribute]) except KeyError: return None @@ -682,17 +825,21 @@ def __attr_query(self, attrs: dict[str, Any]) -> str: def __roi_query(self, roi: Roi) -> str: query = "WHERE " pos_attr = self.position_attribute - for dim in range(self.ndims): # type: ignore + for dim in range(self.ndims): if dim > 0: query += " AND " if roi.begin[dim] is not None and roi.end[dim] is not None: query += ( - f"{pos_attr}[{dim + 1}] BETWEEN {roi.begin[dim]} and {roi.end[dim]}" + f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" + f" AND {pos_attr}[{dim + 1}]<{roi.end[dim]}" ) elif roi.begin[dim] is not None: query += f"{pos_attr}[{dim + 1}]>={roi.begin[dim]}" - elif roi.begin[dim] is not None: + elif roi.end[dim] is not None: query += f"{pos_attr}[{dim + 1}]<{roi.end[dim]}" else: query = query[:-5] return query + + def print_summary(self): + raise ValueError("Not implemented for base SQLGraphDataBase") diff --git a/funlib/persistence/graphs/sqlite_graph_database.py b/src/funlib/persistence/graphs/sqlite_graph_database.py similarity index 69% rename from funlib/persistence/graphs/sqlite_graph_database.py rename to src/funlib/persistence/graphs/sqlite_graph_database.py index 9243677..a2c902b 100644 --- a/funlib/persistence/graphs/sqlite_graph_database.py +++ b/src/funlib/persistence/graphs/sqlite_graph_database.py @@ -2,6 +2,7 @@ import logging import re import sqlite3 +from contextlib import contextmanager from pathlib import Path from typing import Any, Optional @@ -47,6 +48,39 @@ def __init__( edge_attrs=edge_attrs, ) + def close(self): + self.con.close() + + @contextmanager + def bulk_write_mode(self, worker=False, node_writes=True, edge_writes=True): + prev_sync = self.cur.execute("PRAGMA synchronous").fetchone()[0] + self.cur.execute("PRAGMA synchronous=OFF") + self.cur.execute("PRAGMA journal_mode=WAL") + self.con.commit() + + if not worker and node_writes: + self.cur.execute("DROP INDEX IF EXISTS pos_index") + self.con.commit() + + try: + yield + finally: + self.con.commit() + + if not worker and node_writes: + if self.ndims > 1: + position_columns = self.node_array_columns[self.position_attribute] + else: + position_columns = [self.position_attribute] + self.cur.execute( + f"CREATE INDEX IF NOT EXISTS pos_index ON " + f"{self.nodes_table_name}({','.join(position_columns)})" + ) + self.con.commit() + + self.cur.execute(f"PRAGMA synchronous={prev_sync}") + self.con.commit() + @property def node_array_columns(self): if not self._node_array_columns: @@ -99,16 +133,16 @@ def _create_tables(self) -> None: f"{', '.join(node_columns)}" ")" ) - if self.ndims > 1: # type: ignore + if self.ndims > 1: position_columns = self.node_array_columns[self.position_attribute] else: - position_columns = self.position_attribute + position_columns = [self.position_attribute] self.cur.execute( f"CREATE INDEX IF NOT EXISTS pos_index ON {self.nodes_table_name}({','.join(position_columns)})" ) edge_columns = [ - f"{self.endpoint_names[0]} INTEGER not null", # type: ignore - f"{self.endpoint_names[1]} INTEGER not null", # type: ignore + f"{self.endpoint_names[0]} INTEGER not null", + f"{self.endpoint_names[1]} INTEGER not null", ] for attr in self.edge_attrs.keys(): if attr in self.edge_array_columns: @@ -118,7 +152,7 @@ def _create_tables(self) -> None: self.cur.execute( f"CREATE TABLE IF NOT EXISTS {self.edges_table_name}(" + f"{', '.join(edge_columns)}" - + f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" # type: ignore + + f", PRIMARY KEY ({self.endpoint_names[0]}, {self.endpoint_names[1]})" + ")" ) @@ -181,6 +215,42 @@ def _insert_query(self, table, columns, values, fail_if_exists=False, commit=Tru if commit: self.con.commit() + def _bulk_insert(self, table, columns, rows) -> None: + # Explode array columns for SQLite (same logic as _insert_query) + array_columns = ( + self.node_array_columns + if table == self.nodes_table_name + else self.edge_array_columns + ) + + exploded_columns: list[str] = [] + exploded_rows = [] + for row in rows: + exploded_cols = [] + exploded_vals = [] + for column, value in zip(columns, row): + if column in array_columns: + for c, v in zip(array_columns[column], value): + exploded_cols.append(c) + exploded_vals.append(v) + else: + exploded_cols.append(column) + exploded_vals.append(value) + if not exploded_columns: + exploded_columns = exploded_cols + exploded_rows.append(exploded_vals) + + if not exploded_rows: + return + + insert_statement = ( + f"INSERT OR IGNORE INTO {table} " + f"({', '.join(exploded_columns)}) " + f"VALUES ({', '.join(['?'] * len(exploded_columns))})" + ) + self.cur.executemany(insert_statement, exploded_rows) + self.con.commit() + def _update_query(self, query, commit=True): try: self.cur.execute(query) @@ -203,17 +273,18 @@ def _node_attrs_to_columns(self, attrs): columns.append(attr) return columns - def _columns_to_node_attrs(self, columns, query_attrs): - attrs = {} - for attr in query_attrs: + def _columns_to_node_attrs(self, columns, attrs): + result = {} + for attr in attrs: if attr in self.node_array_columns: value = tuple( - columns[f"{attr}_{d}"] for d in range(self.node_attrs[attr].size) + columns[f"{attr}_{d}"] + for d in range(len(self.node_array_columns[attr])) ) else: value = columns[attr] - attrs[attr] = value - return attrs + result[attr] = value + return result def _edge_attrs_to_columns(self, attrs): columns = [] @@ -225,14 +296,15 @@ def _edge_attrs_to_columns(self, attrs): columns.append(attr) return columns - def _columns_to_edge_attrs(self, columns, query_attrs): - attrs = {} - for attr in query_attrs: + def _columns_to_edge_attrs(self, columns, attrs): + result = {} + for attr in attrs: if attr in self.edge_array_columns: value = tuple( - columns[f"{attr}_{d}"] for d in range(self.edge_attrs[attr].size) + columns[f"{attr}_{d}"] + for d in range(len(self.edge_array_columns[attr])) ) else: value = columns[attr] - attrs[attr] = value - return attrs + result[attr] = value + return result diff --git a/funlib/persistence/py.typed b/src/funlib/persistence/py.typed similarity index 100% rename from funlib/persistence/py.typed rename to src/funlib/persistence/py.typed diff --git a/funlib/persistence/types.py b/src/funlib/persistence/types.py similarity index 100% rename from funlib/persistence/types.py rename to src/funlib/persistence/types.py diff --git a/tests/conftest.py b/tests/conftest.py index 95b3e0f..110b2c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ -from pathlib import Path +import os +from contextlib import contextmanager import psycopg2 import pytest @@ -7,12 +8,27 @@ from funlib.persistence.graphs import PgSQLGraphDatabase, SQLiteGraphDataBase -# Attempt to connect to the default database +def _psql_connect_kwargs(): + """Build psycopg2 connection kwargs from environment variables.""" + kwargs = {"dbname": "pytest"} + if os.environ.get("PGHOST"): + kwargs["host"] = os.environ["PGHOST"] + if os.environ.get("PGUSER"): + kwargs["user"] = os.environ["PGUSER"] + if os.environ.get("PGPASSWORD"): + kwargs["password"] = os.environ["PGPASSWORD"] + if os.environ.get("PGPORT"): + kwargs["port"] = int(os.environ["PGPORT"]) + return kwargs + + +# Attempt to connect to the server (using the default 'postgres' database +# which always exists, since the test database may not exist yet). def can_connect_to_psql(): try: - conn = psycopg2.connect( - dbname="pytest", - ) + kwargs = _psql_connect_kwargs() + kwargs["dbname"] = "postgres" + conn = psycopg2.connect(**kwargs) conn.close() return True except OperationalError: @@ -35,19 +51,14 @@ def can_connect_to_psql(): psql_param, ) ) -def provider_factory(request, tmpdir): - # provides a factory function to generate graph provider - # can provide either mongodb graph provider or file graph provider - # if file graph provider, will generate graph in a temporary directory - # to avoid artifacts - - tmpdir = Path(tmpdir) +def provider_factory(request, tmp_path): + @contextmanager def sqlite_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): - return SQLiteGraphDataBase( - tmpdir / "test_sqlite_graph.db", + provider = SQLiteGraphDataBase( + tmp_path / "test_sqlite_graph.db", position_attribute="position", mode=mode, directed=directed, @@ -55,19 +66,33 @@ def sqlite_provider_factory( node_attrs=node_attrs, edge_attrs=edge_attrs, ) + try: + yield provider + finally: + provider.close() + @contextmanager def psql_provider_factory( mode, directed=None, total_roi=None, node_attrs=None, edge_attrs=None ): - return PgSQLGraphDatabase( + connect_kwargs = _psql_connect_kwargs() + provider = PgSQLGraphDatabase( position_attribute="position", db_name="pytest", + db_host=connect_kwargs.get("host", "localhost"), + db_user=connect_kwargs.get("user"), + db_password=connect_kwargs.get("password"), + db_port=connect_kwargs.get("port"), mode=mode, directed=directed, total_roi=total_roi, node_attrs=node_attrs, edge_attrs=edge_attrs, ) + try: + yield provider + finally: + provider.close() if request.param == "sqlite": yield sqlite_provider_factory @@ -75,3 +100,8 @@ def psql_provider_factory( yield psql_provider_factory else: raise ValueError() + + +@pytest.fixture(params=["standard", "bulk"]) +def write_method(request): + return request.param diff --git a/tests/test_array.py b/tests/test_array.py index 28a1813..7f2d4dd 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1,8 +1,8 @@ import dask.array as da import numpy as np import pytest - from funlib.geometry import Coordinate, Roi + from funlib.persistence.arrays import Array @@ -423,6 +423,7 @@ def test_writeable(): assert a.axis_names == ["d0", "d1"] assert not a.is_writeable + def test_to_pixel_world_space_coordinate(): offset = Coordinate(1, -1, 2) shape = Coordinate(10, 10, 10) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index caf80e4..7cc6f45 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,8 +1,8 @@ import numpy as np import pytest - from funlib.geometry import Coordinate, Roi -from funlib.persistence.arrays.datasets import ArrayNotFoundError, open_ds, prepare_ds + +from funlib.persistence.arrays.datasets import open_ds, prepare_ds from funlib.persistence.arrays.metadata import MetaDataFormat stores = { @@ -16,8 +16,8 @@ @pytest.mark.parametrize("store", stores.keys()) -def test_metadata(tmpdir, store): - store = tmpdir / store +def test_metadata(tmp_path, store): + store = tmp_path / store # test prepare_ds creates array if it does not exist and mode is write array = prepare_ds( @@ -34,10 +34,10 @@ def test_metadata(tmpdir, store): @pytest.mark.parametrize("store", stores.keys()) @pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64]) -def test_helpers(tmpdir, store, dtype): +def test_helpers(tmp_path, store, dtype): shape = Coordinate(1, 1, 10, 20, 30) chunk_shape = Coordinate(2, 3, 10, 10, 10) - store = tmpdir / store + store = tmp_path / store metadata = MetaDataFormat().parse( shape, { @@ -50,7 +50,7 @@ def test_helpers(tmpdir, store, dtype): ) # test prepare_ds fails if array does not exist and mode is read - with pytest.raises(ArrayNotFoundError): + with pytest.raises(FileNotFoundError): prepare_ds( store, shape, @@ -218,9 +218,9 @@ def test_helpers(tmpdir, store, dtype): @pytest.mark.parametrize("store", stores.keys()) @pytest.mark.parametrize("dtype", [np.float32, np.uint8, np.uint64]) -def test_open_ds(tmpdir, store, dtype): +def test_open_ds(tmp_path, store, dtype): shape = Coordinate(1, 1, 10, 20, 30) - store = tmpdir / store + store = tmp_path / store metadata = MetaDataFormat().parse( shape, { @@ -233,7 +233,7 @@ def test_open_ds(tmpdir, store, dtype): ) # test open_ds fails if array does not exist and mode is read - with pytest.raises(ArrayNotFoundError): + with pytest.raises(FileNotFoundError): open_ds( store, offset=metadata.offset, diff --git a/tests/test_graph.py b/tests/test_graph.py index 17c4765..eef1e4e 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,19 +1,36 @@ import networkx as nx import pytest - from funlib.geometry import Roi + from funlib.persistence.types import Vec -def test_graph_filtering(provider_factory): - graph_writer = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool}, - edge_attrs={"selected": bool}, - ) - roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_writer[roi] +def _write_nodes(provider, nodes, write_method, **kwargs): + if write_method == "bulk": + provider.bulk_write_nodes(nodes, **kwargs) + else: + provider.write_nodes(nodes, **kwargs) + + +def _write_edges(provider, nodes, edges, write_method, **kwargs): + if write_method == "bulk": + provider.bulk_write_edges(nodes, edges, **kwargs) + else: + provider.write_edges(nodes, edges, **kwargs) + +def _write_graph(provider, graph, write_method, **kwargs): + if write_method == "bulk": + kwargs.pop("fail_if_exists", None) + kwargs.pop("delete", None) + provider.bulk_write_graph(graph, **kwargs) + else: + provider.write_graph(graph, **kwargs) + + +def test_graph_filtering(provider_factory, write_method): + roi = Roi((0, 0, 0), (10, 10, 10)) + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True) graph.add_node(42, position=(1, 1, 1), selected=False) graph.add_node(23, position=(5, 5, 5), selected=True) @@ -22,216 +39,207 @@ def test_graph_filtering(provider_factory): graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) - graph_writer.write_nodes(graph.nodes()) - graph_writer.write_edges(graph.nodes(), graph.edges()) - - graph_reader = provider_factory("r") - - filtered_nodes = graph_reader.read_nodes(roi, attr_filter={"selected": True}) - filtered_node_ids = [node["id"] for node in filtered_nodes] - expected_node_ids = [2, 23, 57] - assert expected_node_ids == filtered_node_ids - - filtered_edges = graph_reader.read_edges(roi, attr_filter={"selected": True}) - filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] - expected_edge_endpoints = [(57, 23), (2, 42)] - for u, v in expected_edge_endpoints: - assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints - - filtered_subgraph = graph_reader.read_graph( - roi, nodes_filter={"selected": True}, edges_filter={"selected": True} - ) - nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "position" in data - ] - assert expected_node_ids == nodes_with_position - assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) - for u, v in expected_edge_endpoints: - assert (u, v) in filtered_subgraph.edges() or ( - v, - u, - ) in filtered_subgraph.edges() - - -def test_graph_filtering_complex(provider_factory): - graph_provider = provider_factory( + with provider_factory( "w", - node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, - edge_attrs={"selected": bool, "a": int, "b": int}, - ) + node_attrs={"position": Vec(float, 3), "selected": bool}, + edge_attrs={"selected": bool}, + ) as graph_writer: + _write_nodes(graph_writer, graph.nodes(), write_method) + _write_edges(graph_writer, graph.nodes(), graph.edges(), write_method) + + with provider_factory("r") as graph_reader: + filtered_nodes = graph_reader.read_nodes(roi, attr_filter={"selected": True}) + filtered_node_ids = [node["id"] for node in filtered_nodes] + expected_node_ids = [2, 23, 57] + assert expected_node_ids == filtered_node_ids + + filtered_edges = graph_reader.read_edges(roi, attr_filter={"selected": True}) + filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] + expected_edge_endpoints = [(57, 23), (2, 42)] + for u, v in expected_edge_endpoints: + assert (u, v) in filtered_edge_endpoints or ( + v, + u, + ) in filtered_edge_endpoints + + filtered_subgraph = graph_reader.read_graph( + roi, nodes_filter={"selected": True}, edges_filter={"selected": True} + ) + nodes_with_position = [ + node + for node, data in filtered_subgraph.nodes(data=True) + if "position" in data + ] + assert expected_node_ids == nodes_with_position + assert len(filtered_subgraph.edges()) == len(expected_edge_endpoints) + for u, v in expected_edge_endpoints: + assert (u, v) in filtered_subgraph.edges() or ( + v, + u, + ) in filtered_subgraph.edges() + + +def test_graph_filtering_complex(provider_factory, write_method): roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") graph.add_node(57, position=(7, 7, 7), selected=True, test="test") - graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph_provider.write_nodes(graph.nodes()) - graph_provider.write_edges(graph.nodes(), graph.edges()) - - graph_provider = provider_factory("r") - - filtered_nodes = graph_provider.read_nodes( - roi, attr_filter={"selected": True, "test": "test"} - ) - filtered_node_ids = [node["id"] for node in filtered_nodes] - expected_node_ids = [2, 57] - assert expected_node_ids == filtered_node_ids - - filtered_edges = graph_provider.read_edges( - roi, attr_filter={"selected": True, "a": 100} - ) - filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] - expected_edge_endpoints = [(57, 23)] - for u, v in expected_edge_endpoints: - assert (u, v) in filtered_edge_endpoints or (v, u) in filtered_edge_endpoints - - filtered_subgraph = graph_provider.read_graph( - roi, - nodes_filter={"selected": True, "test": "test"}, - edges_filter={"selected": True, "a": 100}, - ) - nodes_with_position = [ - node for node, data in filtered_subgraph.nodes(data=True) if "position" in data - ] - assert expected_node_ids == nodes_with_position - assert len(filtered_subgraph.edges()) == 0 + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, + edge_attrs={"selected": bool, "a": int, "b": int}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) + + with provider_factory("r") as graph_provider: + filtered_nodes = graph_provider.read_nodes( + roi, attr_filter={"selected": True, "test": "test"} + ) + filtered_node_ids = [node["id"] for node in filtered_nodes] + expected_node_ids = [2, 57] + assert expected_node_ids == filtered_node_ids + + filtered_edges = graph_provider.read_edges( + roi, attr_filter={"selected": True, "a": 100} + ) + filtered_edge_endpoints = [(edge["u"], edge["v"]) for edge in filtered_edges] + expected_edge_endpoints = [(57, 23)] + for u, v in expected_edge_endpoints: + assert (u, v) in filtered_edge_endpoints or ( + v, + u, + ) in filtered_edge_endpoints + + filtered_subgraph = graph_provider.read_graph( + roi, + nodes_filter={"selected": True, "test": "test"}, + edges_filter={"selected": True, "a": 100}, + ) + nodes_with_position = [ + node + for node, data in filtered_subgraph.nodes(data=True) + if "position" in data + ] + assert expected_node_ids == nodes_with_position + assert len(filtered_subgraph.edges()) == 0 def test_graph_read_and_update_specific_attrs(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, - edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, - ) roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") graph.add_node(57, position=(7, 7, 7), selected=True, test="test") - graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph_provider.write_graph(graph) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, + edge_attrs={"selected": bool, "a": int, "b": int, "c": int}, + ) as graph_provider: + graph_provider.write_graph(graph) - graph_provider = provider_factory("r+") - limited_graph = graph_provider.read_graph( - roi, node_attrs=["selected"], edge_attrs=["c"] - ) + with provider_factory("r+") as graph_provider: + limited_graph = graph_provider.read_graph( + roi, node_attrs=["selected"], edge_attrs=["c"] + ) - for node, data in limited_graph.nodes(data=True): - assert "test" not in data - assert "selected" in data - data["selected"] = True + for node, data in limited_graph.nodes(data=True): + assert "test" not in data + assert "selected" in data + data["selected"] = True - for u, v, data in limited_graph.edges(data=True): - assert "a" not in data - assert "b" not in data - nx.set_edge_attributes(limited_graph, 5, "c") + for u, v, data in limited_graph.edges(data=True): + assert "a" not in data + assert "b" not in data + nx.set_edge_attributes(limited_graph, 5, "c") # type: ignore[call-overload] - try: - graph_provider.write_attrs( - limited_graph, edge_attrs=["c"], node_attrs=["selected"] - ) - except NotImplementedError: - pytest.xfail() + try: + graph_provider.write_attrs( + limited_graph, edge_attrs=["c"], node_attrs=["selected"] + ) + except NotImplementedError: + pytest.xfail() - updated_graph = graph_provider.read_graph(roi) + updated_graph = graph_provider.read_graph(roi) - for node, data in updated_graph.nodes(data=True): - assert data["selected"] + for node, data in updated_graph.nodes(data=True): + assert data["selected"] - for u, v, data in updated_graph.edges(data=True): - assert data["c"] == 5 + for u, v, data in updated_graph.edges(data=True): + assert data["c"] == 5 -def test_graph_read_unbounded_roi(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, - edge_attrs={"selected": bool, "a": int, "b": int}, - ) - roi = Roi((0, 0, 0), (10, 10, 10)) +def test_graph_read_unbounded_roi(provider_factory, write_method): unbounded_roi = Roi((None, None, None), (None, None, None)) - - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(2, 2, 2), selected=True, test="test") graph.add_node(42, position=(1, 1, 1), selected=False, test="test2") graph.add_node(23, position=(5, 5, 5), selected=True, test="test2") graph.add_node(57, position=(7, 7, 7), selected=True, test="test") - graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) - graph_provider.write_nodes( - graph.nodes(), - ) - graph_provider.write_edges( - graph.nodes(), - graph.edges(), - ) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "selected": bool, "test": str}, + edge_attrs={"selected": bool, "a": int, "b": int}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - graph_provider = provider_factory("r+") - limited_graph = graph_provider.read_graph( - unbounded_roi, node_attrs=["selected"], edge_attrs=["c"] - ) + with provider_factory("r+") as graph_provider: + limited_graph = graph_provider.read_graph( + unbounded_roi, node_attrs=["selected"], edge_attrs=["c"] + ) - seen = [] - for node, data in limited_graph.nodes(data=True): - assert "test" not in data - assert "selected" in data - data["selected"] = True - seen.append(node) + seen = [] + for node, data in limited_graph.nodes(data=True): + assert "test" not in data + assert "selected" in data + data["selected"] = True + seen.append(node) - assert sorted([2, 42, 23, 57]) == sorted(seen) + assert sorted([2, 42, 23, 57]) == sorted(seen) def test_graph_read_meta_values(provider_factory): roi = Roi((0, 0, 0), (10, 10, 10)) - provider_factory("w", True, roi, node_attrs={"position": Vec(float, 3)}) - graph_provider = provider_factory("r", None, None) - assert True == graph_provider.directed - assert roi == graph_provider.total_roi + with provider_factory("w", True, roi, node_attrs={"position": Vec(float, 3)}): + pass + with provider_factory("r", None, None) as graph_provider: + assert True == graph_provider.directed + assert roi == graph_provider.total_roi def test_graph_default_meta_values(provider_factory): - provider = provider_factory( + with provider_factory( "w", False, None, node_attrs={"position": Vec(float, 3)} - ) - assert False == provider.directed - assert provider.total_roi is None or provider.total_roi == Roi( - (None, None, None), (None, None, None) - ) - graph_provider = provider_factory("r", False, None) - assert False == graph_provider.directed - assert graph_provider.total_roi is None or graph_provider.total_roi == Roi( - (None, None, None), (None, None, None) - ) - - -def test_graph_io(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) + ) as provider: + assert False == provider.directed + assert provider.total_roi is None or provider.total_roi == Roi( + (None, None, None), (None, None, None) + ) + with provider_factory("r", False, None) as graph_provider: + assert False == graph_provider.directed + assert graph_provider.total_roi is None or graph_provider.total_roi == Roi( + (None, None, None), (None, None, None) + ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] +def test_graph_io(provider_factory, write_method): + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -240,16 +248,15 @@ def test_graph_io(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph_provider.write_nodes( - graph.nodes(), - ) - graph_provider.write_edges( - graph.nodes(), - graph.edges(), - ) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - graph_provider = provider_factory("r") - compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] + with provider_factory("r") as graph_provider: + compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] nodes = sorted(list(graph.nodes())) nodes.remove(2) # node 2 has no position and will not be queried @@ -264,16 +271,7 @@ def test_graph_io(provider_factory): def test_graph_fail_if_exists(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -282,24 +280,57 @@ def test_graph_fail_if_exists(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph_provider.write_graph(graph) - with pytest.raises(Exception): - graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) - with pytest.raises(Exception): - graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=True) - + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + graph_provider.write_graph(graph) + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges( + graph.nodes(), graph.edges(), fail_if_exists=True + ) + + +def test_graph_duplicate_insert_behavior(provider_factory): + """Test that fail_if_exists controls whether duplicate inserts raise.""" + roi = Roi((0, 0, 0), (10, 10, 10)) + graph = nx.Graph() + graph.add_node(2, position=(2, 2, 2), selected=True) + graph.add_node(42, position=(1, 1, 1), selected=False) + graph.add_edge(2, 42, selected=True) -def test_graph_fail_if_not_exists(provider_factory): - graph_provider = provider_factory( + with provider_factory( "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] + node_attrs={"position": Vec(float, 3), "selected": bool}, + edge_attrs={"selected": bool}, + ) as graph_provider: + # Initial write + graph_provider.write_nodes(graph.nodes()) + graph_provider.write_edges(graph.nodes(), graph.edges()) + + # fail_if_exists=True should raise on duplicate nodes and edges + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges( + graph.nodes(), graph.edges(), fail_if_exists=True + ) + + # fail_if_exists=False should silently ignore duplicates + graph_provider.write_nodes(graph.nodes(), fail_if_exists=False) + graph_provider.write_edges(graph.nodes(), graph.edges(), fail_if_exists=False) + + # Verify the original data is still intact + with provider_factory("r") as graph_provider: + result = graph_provider.read_graph(roi) + assert set(result.nodes()) == {2, 42} + assert len(result.edges()) == 1 + +def test_graph_fail_if_not_exists(provider_factory): + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -308,25 +339,20 @@ def test_graph_fail_if_not_exists(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - with pytest.raises(Exception): - graph_provider.write_nodes(graph.nodes(), fail_if_not_exists=True) - with pytest.raises(Exception): - graph_provider.write_edges( - graph.nodes(), graph.edges(), fail_if_not_exists=True - ) - - -def test_graph_write_attributes(provider_factory): - graph_provider = provider_factory( + with provider_factory( "w", - node_attrs={ - "position": Vec(int, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + with pytest.raises(Exception): + graph_provider.write_nodes(graph.nodes(), fail_if_not_exists=True) + with pytest.raises(Exception): + graph_provider.write_edges( + graph.nodes(), graph.edges(), fail_if_not_exists=True + ) + + +def test_graph_write_attributes(provider_factory, write_method): + graph = nx.Graph() graph.add_node(2, position=[0, 0, 0]) graph.add_node(42, position=[1, 1, 1]) graph.add_node(23, position=[5, 5, 5], swip="swap") @@ -335,17 +361,22 @@ def test_graph_write_attributes(provider_factory): graph.add_edge(57, 23) graph.add_edge(2, 42) - graph_provider.write_graph( - graph, write_nodes=True, write_edges=False, node_attrs=["position", "swip"] - ) - - graph_provider.write_edges( - graph.nodes(), - graph.edges(), - ) + with provider_factory( + "w", + node_attrs={"position": Vec(int, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_graph( + graph_provider, + graph, + write_method, + write_nodes=True, + write_edges=False, + node_attrs=["position", "swip"], + ) + _write_edges(graph_provider, graph.nodes(), graph.edges(), write_method) - graph_provider = provider_factory("r") - compare_graph = graph_provider[Roi((1, 1, 1), (10, 10, 10))] + with provider_factory("r") as graph_provider: + compare_graph = graph_provider[Roi((1, 1, 1), (10, 10, 10))] nodes = [] for node, data in graph.nodes(data=True): @@ -372,17 +403,8 @@ def test_graph_write_attributes(provider_factory): assert v1 == v2 -def test_graph_write_roi(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] - +def test_graph_write_roi(provider_factory, write_method): + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -392,10 +414,14 @@ def test_graph_write_roi(provider_factory): graph.add_edge(2, 42) write_roi = Roi((0, 0, 0), (6, 6, 6)) - graph_provider.write_graph(graph, write_roi) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_graph(graph_provider, graph, write_method, roi=write_roi) - graph_provider = provider_factory("r") - compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] + with provider_factory("r") as graph_provider: + compare_graph = graph_provider[Roi((1, 1, 1), (9, 9, 9))] nodes = sorted(list(graph.nodes())) nodes.remove(2) # node 2 has no position and will not be queried @@ -412,22 +438,19 @@ def test_graph_write_roi(provider_factory): def test_graph_connected_components(provider_factory): - graph_provider = provider_factory( + with provider_factory( "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + graph = graph_provider[Roi((0, 0, 0), (10, 10, 10))] + + graph.add_node(2, position=(0, 0, 0)) + graph.add_node(42, position=(1, 1, 1)) + graph.add_node(23, position=(5, 5, 5), swip="swap") + graph.add_node(57, position=(7, 7, 7), zap="zip") + graph.add_edge(57, 23) + graph.add_edge(2, 42) - graph.add_node(2, position=(0, 0, 0)) - graph.add_node(42, position=(1, 1, 1)) - graph.add_node(23, position=(5, 5, 5), swip="swap") - graph.add_node(57, position=(7, 7, 7), zap="zip") - graph.add_edge(57, 23) - graph.add_edge(2, 42) try: components = list(nx.connected_components(graph)) except NotImplementedError: @@ -449,19 +472,9 @@ def test_graph_connected_components(provider_factory): assert n2 == compare_n2 -def test_graph_has_edge(provider_factory): - graph_provider = provider_factory( - "w", - node_attrs={ - "position": Vec(float, 3), - "swip": str, - "zap": str, - }, - ) - +def test_graph_has_edge(provider_factory, write_method): roi = Roi((0, 0, 0), (10, 10, 10)) - graph = graph_provider[roi] - + graph = nx.Graph() graph.add_node(2, position=(0, 0, 0)) graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip="swap") @@ -470,7 +483,290 @@ def test_graph_has_edge(provider_factory): graph.add_edge(57, 23) write_roi = Roi((0, 0, 0), (6, 6, 6)) - graph_provider.write_nodes(graph.nodes(), roi=write_roi) - graph_provider.write_edges(graph.nodes(), graph.edges(), roi=write_roi) + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3), "swip": str, "zap": str}, + ) as graph_provider: + _write_nodes(graph_provider, graph.nodes(), write_method, roi=write_roi) + _write_edges( + graph_provider, graph.nodes(), graph.edges(), write_method, roi=write_roi + ) + assert graph_provider.has_edges(roi) + + +def test_read_edges_join_vs_in_clause(provider_factory, write_method): + """Benchmark: read_edges with JOIN (roi-only) vs IN clause (nodes list). + + Demonstrates that the JOIN path avoids serializing a large node ID list + into the SQL query, and lets the DB optimizer do the work instead. + """ + import time + from itertools import product + + size = 50 # 50^3 = 125,000 nodes + graph = nx.Graph() + for x, y, z in product(range(size), repeat=3): + node_id = x * size * size + y * size + z + graph.add_node(node_id, position=(x + 0.5, y + 0.5, z + 0.5)) + if x > 0: + graph.add_edge(node_id, (x - 1) * size * size + y * size + z) + if y > 0: + graph.add_edge(node_id, x * size * size + (y - 1) * size + z) + if z > 0: + graph.add_edge(node_id, x * size * size + y * size + (z - 1)) + + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + _write_graph(graph_provider, graph, write_method) + + with provider_factory("r") as graph_provider: + query_roi = Roi((10, 10, 10), (30, 30, 30)) + n_repeats = 5 + + # --- Old approach: read_nodes, then read_edges with nodes list --- + times_in_clause = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + nodes = graph_provider.read_nodes(query_roi) + edges_via_in = graph_provider.read_edges(nodes=nodes) + t1 = time.perf_counter() + times_in_clause.append(t1 - t0) + + # --- New approach: read_edges with roi (JOIN) --- + times_join = [] + for _ in range(n_repeats): + t0 = time.perf_counter() + edges_via_join = graph_provider.read_edges(roi=query_roi) + t1 = time.perf_counter() + times_join.append(t1 - t0) + + avg_in = sum(times_in_clause) / n_repeats + avg_join = sum(times_join) / n_repeats + + print(f"\n--- read_edges benchmark (roi covers {30**3:,} of {size**3:,} nodes) ---") + print(f"IN clause (2 queries): {avg_in * 1000:.1f} ms avg") + print(f"JOIN (1 query): {avg_join * 1000:.1f} ms avg") + print(f"Speedup: {avg_in / avg_join:.2f}x") + + # Both should return edges — just verify they're non-empty and reasonable + assert len(edges_via_in) > 0 + assert len(edges_via_join) > 0 + # JOIN finds edges where either endpoint is in ROI (superset of IN approach) + assert len(edges_via_join) == len(edges_via_in) + + +def test_read_edges_fetch_on_v(provider_factory, write_method): + """Test that fetch_on_v controls whether edges are matched on u only or both endpoints. + + Graph layout (1D for clarity, stored as 3D positions): + + Node 1 (pos 1) -- Edge(1,5) -- Node 5 (pos 5) + Node 2 (pos 2) -- Edge(2,8) -- Node 8 (pos 8) + Node 5 (pos 5) -- Edge(5,8) -- Node 8 (pos 8) + Node 8 (pos 8) -- Edge(8,9) -- Node 9 (pos 9) + + ROI = [0, 6) covers nodes {1, 2, 5}. + + Undirected edges are stored with u < v, so: + - Edge(1, 5): u=1 in ROI, v=5 in ROI + - Edge(2, 8): u=2 in ROI, v=8 outside ROI + - Edge(5, 8): u=5 in ROI, v=8 outside ROI + - Edge(8, 9): u=8 outside ROI, v=9 outside ROI + + fetch_on_v=False (default): only edges where u is in ROI -> {(1,5), (2,8), (5,8)} + fetch_on_v=True: edges where u OR v is in ROI -> {(1,5), (2,8), (5,8)} + (same here because u < v and all boundary-crossing edges have u inside) + + To properly test fetch_on_v, we need an edge where u is OUTSIDE the ROI + but v is INSIDE. With undirected u < v storage, this means a node with a + smaller ID outside the ROI connected to a node with a larger ID inside. + + So we add: Node 0 (pos 8) -- Edge(0, 5): u=0 outside ROI, v=5 in ROI. + """ + roi = Roi((0, 0, 0), (6, 6, 6)) + graph = nx.Graph() + # Nodes inside ROI (positions < 6) + graph.add_node(1, position=(1.0, 1.0, 1.0)) + graph.add_node(2, position=(2.0, 2.0, 2.0)) + graph.add_node(5, position=(5.0, 5.0, 5.0)) + # Nodes outside ROI (positions >= 6) + # Node 0 has ID < all ROI nodes but position outside ROI + graph.add_node(0, position=(8.0, 8.0, 8.0)) + graph.add_node(8, position=(8.0, 8.0, 8.0)) + graph.add_node(9, position=(9.0, 9.0, 9.0)) + # Edges: undirected, stored as u < v + graph.add_edge(1, 5) # both in ROI + graph.add_edge(2, 8) # u in ROI, v outside + graph.add_edge(5, 8) # u in ROI, v outside + graph.add_edge(8, 9) # both outside ROI + graph.add_edge(0, 5) # u=0 OUTSIDE ROI, v=5 INSIDE ROI (key test edge) + + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + _write_graph(graph_provider, graph, write_method) + + with provider_factory("r") as graph_provider: + + def edge_set(edges): + """Normalize edge list to set of sorted tuples for comparison.""" + return {(min(e["u"], e["v"]), max(e["u"], e["v"])) for e in edges} + + # --- Case 1: nodes passed explicitly --- + nodes_in_roi = graph_provider.read_nodes(roi) + node_ids_in_roi = {n["id"] for n in nodes_in_roi} + assert node_ids_in_roi == {1, 2, 5} + + edges_u_only = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=False) + edges_u_and_v = graph_provider.read_edges(nodes=nodes_in_roi, fetch_on_v=True) + + # fetch_on_v=False: only edges where u IN (1,2,5) + # (1,5), (2,8), (5,8) match; (0,5) does NOT match (u=0 not in list) + assert edge_set(edges_u_only) == {(1, 5), (2, 8), (5, 8)} - assert graph_provider.has_edges(roi) + # fetch_on_v=True: edges where u OR v IN (1,2,5) + # (0,5) now matches because v=5 is in the list + assert edge_set(edges_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} + + # --- Case 2: roi passed (JOIN path) --- + edges_roi_u_only = graph_provider.read_edges(roi=roi, fetch_on_v=False) + edges_roi_u_and_v = graph_provider.read_edges(roi=roi, fetch_on_v=True) + + # Same expected results as Case 1 + assert edge_set(edges_roi_u_only) == {(1, 5), (2, 8), (5, 8)} + assert edge_set(edges_roi_u_and_v) == {(0, 5), (1, 5), (2, 8), (5, 8)} + + # --- Case 3: via read_graph --- + graph_u_only = graph_provider.read_graph(roi, fetch_on_v=False) + graph_u_and_v = graph_provider.read_graph(roi, fetch_on_v=True) + + graph_edges_u_only = {tuple(sorted(e)) for e in graph_u_only.edges()} + graph_edges_u_and_v = {tuple(sorted(e)) for e in graph_u_and_v.edges()} + + assert graph_edges_u_only == {(1, 5), (2, 8), (5, 8)} + assert graph_edges_u_and_v == {(0, 5), (1, 5), (2, 8), (5, 8)} + + +def test_graph_roi_upper_bound_exclusive(provider_factory): + """Nodes at exactly the upper bound of the ROI must be excluded. + + ROI is half-open [begin, end). A node whose position equals end in any + dimension should NOT appear in read_nodes or read_graph results. + + Regression test for: https://github.com/funkelab/funlib.persistence/issues/XX + """ + roi = Roi((0, 0, 0), (10, 10, 10)) # [0, 10) in each dim + + graph = nx.Graph() + # Interior node — clearly inside + graph.add_node(1, position=(5.0, 5.0, 5.0)) + # Node exactly on lower bound — should be included + graph.add_node(2, position=(0.0, 0.0, 0.0)) + # Nodes exactly on upper bound — should be excluded + graph.add_node(3, position=(10.0, 5.0, 5.0)) # x == end + graph.add_node(4, position=(5.0, 10.0, 5.0)) # y == end + graph.add_node(5, position=(5.0, 5.0, 10.0)) # z == end + graph.add_node(6, position=(10.0, 10.0, 10.0)) # all dims == end + # Edge crossing the boundary (u inside, v on boundary) + graph.add_edge(1, 3) + + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as gp: + gp.write_graph(graph) + + with provider_factory("r") as gp: + # read_nodes: only nodes strictly inside [0, 10) + nodes = gp.read_nodes(roi) + node_ids = {n["id"] for n in nodes} + assert node_ids == {1, 2}, f"Expected {{1, 2}}, got {node_ids}" + + # read_graph: same node set, edge (1,3) should still appear + # because node 3 is pulled in as a bare node via the edge + result = gp.read_graph(roi) + result_node_ids = set(result.nodes()) + # Node 3 may appear as a bare node (no position) via the edge + assert 1 in result_node_ids + assert 2 in result_node_ids + # Nodes 4, 5, 6 have no edges to interior nodes — must not appear + assert 4 not in result_node_ids + assert 5 not in result_node_ids + assert 6 not in result_node_ids + + # Verify that nodes returned by read_nodes all have positions inside ROI + for node in nodes: + pos = node["position"] + for dim in range(3): + assert pos[dim] >= roi.begin[dim], ( + f"Node {node['id']} pos[{dim}]={pos[dim]} < roi.begin={roi.begin[dim]}" + ) + assert pos[dim] < roi.end[dim], ( + f"Node {node['id']} pos[{dim}]={pos[dim]} >= roi.end={roi.end[dim]}" + ) + + +def _build_grid_graph(size): + """Build a 3D grid graph with size^3 nodes and ~3*size^2*(size-1) edges.""" + from itertools import product + + graph = nx.Graph() + for x, y, z in product(range(size), repeat=3): + node_id = x * size * size + y * size + z + graph.add_node(node_id, position=(x + 0.5, y + 0.5, z + 0.5)) + if x > 0: + graph.add_edge(node_id, (x - 1) * size * size + y * size + z) + if y > 0: + graph.add_edge(node_id, x * size * size + (y - 1) * size + z) + if z > 0: + graph.add_edge(node_id, x * size * size + y * size + (z - 1)) + return graph + + +def test_bulk_write_benchmark(provider_factory): + """Benchmark: standard write_graph vs bulk_write_graph.""" + import time + + size = 30 # 30^3 = 27,000 nodes + graph = _build_grid_graph(size) + n_nodes = graph.number_of_nodes() + n_edges = graph.number_of_edges() + + # --- Standard write --- + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + t0 = time.perf_counter() + graph_provider.write_graph(graph) + t_standard = time.perf_counter() - t0 + + # Verify standard write + with provider_factory("r") as graph_reader: + result = graph_reader.read_graph() + assert result.number_of_nodes() == n_nodes + assert result.number_of_edges() == n_edges + + # --- Bulk write (recreate tables, with bulk_write_mode) --- + with provider_factory( + "w", + node_attrs={"position": Vec(float, 3)}, + ) as graph_provider: + with graph_provider.bulk_write_mode(): + t0 = time.perf_counter() + graph_provider.bulk_write_graph(graph) + t_bulk = time.perf_counter() - t0 + + # Verify bulk write + with provider_factory("r") as graph_reader: + result = graph_reader.read_graph() + assert result.number_of_nodes() == n_nodes + assert result.number_of_edges() == n_edges + + print(f"\n--- write benchmark ({n_nodes:,} nodes, {n_edges:,} edges) ---") + print(f"Standard: {t_standard * 1000:.1f} ms") + print(f"Bulk: {t_bulk * 1000:.1f} ms") + print(f"Speedup: {t_standard / t_bulk:.2f}x") diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 371e4d6..c433918 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -3,8 +3,8 @@ import pytest import zarr - from funlib.geometry import Coordinate + from funlib.persistence.arrays.datasets import prepare_ds from funlib.persistence.arrays.metadata import ( MetaDataFormat, @@ -91,7 +91,7 @@ def test_empty_metadata(): assert metadata.types == ["space", "space", "space", "space", "space"] -def test_default_metadata_format(tmpdir): +def test_default_metadata_format(tmp_path): set_default_metadata_format(metadata_formats["simple"]) metadata = metadata_formats["simple"].parse( (10, 2, 100, 100, 100), @@ -99,7 +99,7 @@ def test_default_metadata_format(tmpdir): ) prepare_ds( - tmpdir / "test.zarr/test", + tmp_path / "test.zarr/test", (10, 2, 100, 100, 100), offset=metadata.offset, voxel_size=metadata.voxel_size, @@ -110,7 +110,7 @@ def test_default_metadata_format(tmpdir): mode="w", ) - zarr_attrs = dict(**zarr.open(str(tmpdir / "test.zarr/test")).attrs) + zarr_attrs = dict(**zarr.open(tmp_path / "test.zarr/test").attrs) assert zarr_attrs["offset"] == [100, 200, 400] assert zarr_attrs["resolution"] == [1, 2, 3] assert zarr_attrs["extras/axes"] == ["sample^", "channel^", "t", "y", "x"]