From cafe1b9ccf14104b64a3ab36a2ee223d82ac3bd8 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 28 May 2026 11:44:38 +0900 Subject: [PATCH 1/5] first trial --- src/tracksdata/_test/test_attrs.py | 114 +++++++++ src/tracksdata/attrs.py | 238 ++++++++++++++++-- src/tracksdata/graph/_base_graph.py | 16 +- src/tracksdata/graph/_graph_view.py | 4 +- src/tracksdata/graph/_rustworkx_graph.py | 60 +++-- src/tracksdata/graph/_sql_graph.py | 39 ++- src/tracksdata/graph/_test/test_subgraph.py | 92 +++++++ .../graph/filters/_indexed_filter.py | 4 +- 8 files changed, 507 insertions(+), 60 deletions(-) diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index 8b29bf7d..b4582643 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -9,6 +9,7 @@ from tracksdata.attrs import ( Attr, AttrComparison, + AttrFilter, EdgeAttr, NodeAttr, attr_comps_to_strs, @@ -650,3 +651,116 @@ def test_attr_is_in_accepts_numpy_arrays() -> None: evaluated = comp.to_attr().evaluate(df) assert evaluated.to_list() == [False, True, False] + + +# --------------------------------------------------------------------------- +# AttrFilter (compound boolean filters built from AttrComparisons) +# --------------------------------------------------------------------------- + + +def test_attr_filter_or_operator_returns_filter() -> None: + comp1 = NodeAttr("t") == 1 + comp2 = NodeAttr("t") == 2 + f = comp1 | comp2 + + assert isinstance(f, AttrFilter) + assert f.op == "or" + assert f.operands == [comp1, comp2] + assert f.columns == ["t"] + + +def test_attr_filter_xor_and_invert_operators() -> None: + comp1 = NodeAttr("a") > 0 + comp2 = NodeAttr("b") > 0 + xor_f = comp1 ^ comp2 + assert isinstance(xor_f, AttrFilter) + assert xor_f.op == "xor" + + not_f = ~comp1 + assert isinstance(not_f, AttrFilter) + assert not_f.op == "not" + assert not_f.operands == [comp1] + + +def test_attr_filter_and_operator_between_comparisons() -> None: + comp1 = NodeAttr("a") > 0 + comp2 = NodeAttr("b") < 1 + and_f = comp1 & comp2 + assert isinstance(and_f, AttrFilter) + assert and_f.op == "and" + + +def test_attr_filter_nested_composition() -> None: + f = (NodeAttr("a") > 0) & ((NodeAttr("b") == 1) | (NodeAttr("b") == 2)) + assert isinstance(f, AttrFilter) + assert f.op == "and" + assert isinstance(f.operands[1], AttrFilter) + assert f.operands[1].op == "or" + assert sorted({leaf.column for leaf in f.leaves()}) == ["a", "b"] + + +def test_attr_filter_mixed_node_and_edge_raises_on_split() -> None: + """A single compound that mixes node and edge attributes must error in split.""" + f = (NodeAttr("t") == 1) | (EdgeAttr("weight") > 0.5) + with pytest.raises(ValueError, match="cannot mix NodeAttr and EdgeAttr"): + split_attr_comps([f]) + + +def test_attr_filter_split_attr_comps_with_compounds() -> None: + node_f = (NodeAttr("t") == 1) | (NodeAttr("t") == 2) + edge_f = (EdgeAttr("w") > 0.5) | (EdgeAttr("w") < -0.5) + node_only = NodeAttr("label") == "A" + + nodes, edges = split_attr_comps([node_f, edge_f, node_only]) + assert nodes == [node_f, node_only] + assert edges == [edge_f] + + +def test_attr_filter_invalid_op_raises() -> None: + with pytest.raises(ValueError, match="Unknown logical operator"): + AttrFilter("nor", [NodeAttr("a") == 1, NodeAttr("a") == 2]) + + +def test_attr_filter_not_with_multiple_operands_raises() -> None: + with pytest.raises(ValueError, match="'not' filter requires exactly one operand"): + AttrFilter("not", [NodeAttr("a") == 1, NodeAttr("a") == 2]) + + +def test_attr_filter_or_with_single_operand_raises() -> None: + with pytest.raises(ValueError, match="'or' filter requires at least two operands"): + AttrFilter("or", [NodeAttr("a") == 1]) + + +def test_attr_filter_rejects_non_filter_operands() -> None: + with pytest.raises(TypeError, match="must be AttrComparison or AttrFilter"): + AttrFilter("or", [NodeAttr("a") == 1, 5]) + + +def test_attr_filter_polars_reduce_or() -> None: + df = pl.DataFrame({"t": [0, 1, 2, 3, 4]}) + f = (NodeAttr("t") == 1) | (NodeAttr("t") == 3) + expr = polars_reduce_attr_comps(df, [f], operator.and_) + result = df.select(expr).to_series() + assert result.to_list() == [False, True, False, True, False] + + +def test_attr_filter_polars_reduce_xor() -> None: + df = pl.DataFrame({"a": [0, 1, 0, 1], "b": [0, 0, 1, 1]}) + f = (NodeAttr("a") == 1) ^ (NodeAttr("b") == 1) + expr = polars_reduce_attr_comps(df, [f], operator.and_) + result = df.select(expr).to_series() + assert result.to_list() == [False, True, True, False] + + +def test_attr_filter_polars_reduce_not() -> None: + df = pl.DataFrame({"t": [0, 1, 2]}) + f = ~(NodeAttr("t") == 1) + expr = polars_reduce_attr_comps(df, [f], operator.and_) + result = df.select(expr).to_series() + assert result.to_list() == [True, False, True] + + +def test_attr_filter_attr_comps_to_strs_with_compound() -> None: + f = (NodeAttr("a") == 1) | (NodeAttr("b") == 2) + plain = NodeAttr("c") == 3 + assert attr_comps_to_strs([f, plain]) == ["a", "b", "c"] diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 60f82db8..d6e07512 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -12,6 +12,14 @@ graph.filter(NodeAttr("t") == 1).subgraph() ``` +Boolean combinations of comparisons can be expressed with `|` (or), `^` (xor), +`&` (and) and `~` (not). Comparisons passed as multiple positional arguments to +`filter()` are still implicitly AND-ed together. +```python +graph.filter((NodeAttr("t") == 1) | (NodeAttr("t") == 2)).subgraph() +graph.filter(~(NodeAttr("t") == 0)).subgraph() +``` + Or to create complex expression when solving the tracking problem: ```python NearestNeighborsSolver(-Attr("iou") * (-Attr("distance") / 30.0).exp()) @@ -32,9 +40,14 @@ ExprInput = Union[str, Scalar, "Attr", Expr, "AttrComparison"] MembershipExprInput = Sequence[Scalar] +# Logical operators supported by AttrFilter compounds. +_FILTER_LOGICAL_OPS = ("and", "or", "xor", "not") +FilterInput = Union["AttrComparison", "AttrFilter"] + __all__ = [ "AttrComparison", + "AttrFilter", "EdgeAttr", "NodeAttr", "attr_comps_to_strs", @@ -662,67 +675,244 @@ class EdgeAttr(Attr): """ -def split_attr_comps(attr_comps: Sequence[AttrComparison]) -> tuple[list[AttrComparison], list[AttrComparison]]: +_FILTER_OP_SYMBOLS = {"and": "&", "or": "|", "xor": "^", "not": "~"} + + +class AttrFilter: + """ + A compound boolean combination of [AttrComparison][tracksdata.attrs.AttrComparison] + (or nested `AttrFilter`) operands, used to express OR / XOR / AND / NOT + relationships when filtering nodes or edges in a graph. + + Use Python's bitwise operators on `AttrComparison` (or `AttrFilter`) + instances to build compounds: + + ```python + graph.filter((NodeAttr("t") == 1) | (NodeAttr("t") == 2)) + graph.filter(~(NodeAttr("t") == 0)) + graph.filter((EdgeAttr("w") > 0.5) ^ (EdgeAttr("w") < -0.5)) + ``` + + All leaves of a single `AttrFilter` must reference attributes of the same + kind (either all [NodeAttr][tracksdata.attrs.NodeAttr] or all + [EdgeAttr][tracksdata.attrs.EdgeAttr]). Mixing node and edge attributes + inside one compound is not supported because it would require joining the + node and edge tables in a way that conflicts with the existing AND-based + filter semantics. Top-level node/edge filters can still be combined via + positional arguments to `graph.filter()` (implicit AND). + + Parameters + ---------- + op : str + Logical operator, one of `"and"`, `"or"`, `"xor"`, `"not"`. + operands : Sequence[AttrComparison | AttrFilter] + Operands. `"not"` requires exactly one operand; the others require at + least two. + """ + + def __init__(self, op: str, operands: Sequence[FilterInput]) -> None: + if op not in _FILTER_LOGICAL_OPS: + raise ValueError(f"Unknown logical operator '{op}'. Expected one of {_FILTER_LOGICAL_OPS}.") + operands = list(operands) + for o in operands: + if not isinstance(o, AttrComparison | AttrFilter): + raise TypeError(f"AttrFilter operands must be AttrComparison or AttrFilter, got {type(o).__name__}.") + if op == "not": + if len(operands) != 1: + raise ValueError("'not' filter requires exactly one operand.") + else: + if len(operands) < 2: + raise ValueError(f"'{op}' filter requires at least two operands.") + self.op = op + self.operands = operands + + def __and__(self, other: FilterInput) -> "AttrFilter": + return AttrFilter("and", [self, other]) + + def __rand__(self, other: FilterInput) -> "AttrFilter": + return AttrFilter("and", [other, self]) + + def __or__(self, other: FilterInput) -> "AttrFilter": + return AttrFilter("or", [self, other]) + + def __ror__(self, other: FilterInput) -> "AttrFilter": + return AttrFilter("or", [other, self]) + + def __xor__(self, other: FilterInput) -> "AttrFilter": + return AttrFilter("xor", [self, other]) + + def __rxor__(self, other: FilterInput) -> "AttrFilter": + return AttrFilter("xor", [other, self]) + + def __invert__(self) -> "AttrFilter": + return AttrFilter("not", [self]) + + def leaves(self) -> list["AttrComparison"]: + """Flatten the filter tree to its leaf comparisons.""" + out: list[AttrComparison] = [] + for o in self.operands: + if isinstance(o, AttrFilter): + out.extend(o.leaves()) + else: + out.append(o) + return out + + @property + def columns(self) -> list[str]: + return list(dict.fromkeys(leaf.column for leaf in self.leaves())) + + def __repr__(self) -> str: + if self.op == "not": + return f"~{self.operands[0]!r}" + sep = f" {_FILTER_OP_SYMBOLS[self.op]} " + return "(" + sep.join(repr(o) for o in self.operands) + ")" + + +def _filter_attr_kind(f: FilterInput) -> type[Attr]: + """Return the leaf-attribute kind (NodeAttr / EdgeAttr) of a filter. + + Raises ValueError if the filter mixes node and edge attributes. + """ + if isinstance(f, AttrComparison): + if isinstance(f.attr, NodeAttr): + return NodeAttr + if isinstance(f.attr, EdgeAttr): + return EdgeAttr + raise ValueError(f"Expected comparisons of 'NodeAttr' or 'EdgeAttr' objects, got {type(f.attr)}") + + kinds = {_filter_attr_kind(o) for o in f.operands} + if len(kinds) > 1: + raise ValueError( + "A single AttrFilter compound cannot mix NodeAttr and EdgeAttr comparisons. " + "Combine node and edge filters via separate positional arguments to graph.filter()." + ) + return kinds.pop() + + +# --- AttrComparison boolean operator overrides -------------------------------- +# These run after _setup_ops() and replace the auto-generated `__and__`, +# `__or__`, `__xor__` on AttrComparison so that combining comparisons builds an +# AttrFilter compound instead of an Attr expression. The boolean methods on +# `Attr` itself are unchanged. + + +def _attr_comparison_logical(self: "AttrComparison", other: Any, op_name: str, py_op: Callable) -> Any: + if isinstance(other, AttrComparison | AttrFilter): + return AttrFilter(op_name, [self, other]) + return self._delegate_operator(other, py_op, reverse=False) + + +def _attr_comparison_r_logical(self: "AttrComparison", other: Any, op_name: str, py_op: Callable) -> Any: + if isinstance(other, AttrComparison | AttrFilter): + return AttrFilter(op_name, [other, self]) + return self._delegate_operator(other, py_op, reverse=True) + + +AttrComparison.__and__ = functools.partialmethod(_attr_comparison_logical, op_name="and", py_op=operator.and_) +AttrComparison.__rand__ = functools.partialmethod(_attr_comparison_r_logical, op_name="and", py_op=operator.and_) +AttrComparison.__or__ = functools.partialmethod(_attr_comparison_logical, op_name="or", py_op=operator.or_) +AttrComparison.__ror__ = functools.partialmethod(_attr_comparison_r_logical, op_name="or", py_op=operator.or_) +AttrComparison.__xor__ = functools.partialmethod(_attr_comparison_logical, op_name="xor", py_op=operator.xor) +AttrComparison.__rxor__ = functools.partialmethod(_attr_comparison_r_logical, op_name="xor", py_op=operator.xor) + + +def _attr_comparison_invert(self: "AttrComparison") -> "AttrFilter": + return AttrFilter("not", [self]) + + +AttrComparison.__invert__ = _attr_comparison_invert + + +def split_attr_comps( + attr_comps: Sequence[FilterInput], +) -> tuple[list[FilterInput], list[FilterInput]]: """ - Split a list of attribute comparisons into node and edge attribute comparisons. + Split a list of attribute comparisons (or compound filters) into node and + edge groups based on the kind of their leaf comparisons. Parameters ---------- - attr_comps : Sequence[AttrComparison] - The attribute comparisons to split. + attr_comps : Sequence[AttrComparison | AttrFilter] + The attribute comparisons or compound filters to split. Returns ------- - tuple[list[AttrComparison], list[AttrComparison]] - A tuple of lists of node and edge attribute comparisons. + tuple[list[AttrComparison | AttrFilter], list[AttrComparison | AttrFilter]] + A tuple of lists of node and edge filters. """ - node_attr_comps = [] - edge_attr_comps = [] + node_attr_comps: list[FilterInput] = [] + edge_attr_comps: list[FilterInput] = [] for attr_comp in attr_comps: - if isinstance(attr_comp.attr, NodeAttr): + kind = _filter_attr_kind(attr_comp) + if kind is NodeAttr: node_attr_comps.append(attr_comp) - elif isinstance(attr_comp.attr, EdgeAttr): - edge_attr_comps.append(attr_comp) else: - raise ValueError(f"Expected comparisons of 'NodeAttr' or 'EdgeAttr' objects, got {type(attr_comp.attr)}") + edge_attr_comps.append(attr_comp) return node_attr_comps, edge_attr_comps -def attr_comps_to_strs(attr_comps: Sequence[AttrComparison]) -> list[str]: +def attr_comps_to_strs(attr_comps: Sequence[FilterInput]) -> list[str]: """ - Convert a list of attribute comparisons to a list of strings. + Convert a list of attribute comparisons (or compound filters) to a list of + column names involved in them. Parameters ---------- - attr_comps : Sequence[AttrComparison] - The attribute comparisons to convert to strings. + attr_comps : Sequence[AttrComparison | AttrFilter] + The filters to extract column names from. Returns ------- list[str] - The attribute comparisons as strings. + The column names referenced by the filters, deduplicated while + preserving order. """ - return [str(attr_comp.column) for attr_comp in attr_comps] + out: list[str] = [] + for attr_comp in attr_comps: + if isinstance(attr_comp, AttrFilter): + out.extend(attr_comp.columns) + else: + out.append(str(attr_comp.column)) + return list(dict.fromkeys(out)) + + +def _polars_filter_expr(f: FilterInput, df: pl.DataFrame) -> pl.Expr | pl.Series: + """Translate a single AttrComparison/AttrFilter to a polars expression.""" + if isinstance(f, AttrComparison): + return f.op(df[str(f.column)], f.other) + + if f.op == "not": + return ~_polars_filter_expr(f.operands[0], df) + + child_exprs = [_polars_filter_expr(o, df) for o in f.operands] + if f.op == "and": + return functools.reduce(operator.and_, child_exprs) + if f.op == "or": + return functools.reduce(operator.or_, child_exprs) + # xor + return functools.reduce(operator.xor, child_exprs) def polars_reduce_attr_comps( df: pl.DataFrame, - attr_comps: Sequence[AttrComparison], + attr_comps: Sequence[FilterInput], reduce_op: Callable[[Expr, Expr], Expr], ) -> pl.Expr: """ - Reduce a list of attribute comparisons into a single polars expression. + Reduce a list of attribute comparisons (or compound filters) into a single + polars expression, combined with `reduce_op` at the top level (AND-ed by + default in callers). Parameters ---------- df : pl.DataFrame The dataframe to reduce the attribute comparisons on. - attr_comps : Sequence[AttrComparison] - The attribute comparisons to reduce. + attr_comps : Sequence[AttrComparison | AttrFilter] + The filters to reduce. reduce_op : Callable[[Expr, Expr], Expr] - The operation to reduce the attribute comparisons with. + The operation to reduce the top-level filters with. Returns ------- @@ -733,4 +923,4 @@ def polars_reduce_attr_comps( # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") - return pl.reduce(reduce_op, [attr_comp.op(df[str(attr_comp.column)], attr_comp.other) for attr_comp in attr_comps]) + return pl.reduce(reduce_op, [_polars_filter_expr(f, df) for f in attr_comps]) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 6067b71c..ebed4617 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -15,7 +15,7 @@ from psygnal import Signal from zarr.storage import StoreLike -from tracksdata.attrs import AttrComparison, NodeAttr +from tracksdata.attrs import FilterInput, NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.utils._cache import cache_method from tracksdata.utils._dtypes import ( @@ -598,8 +598,8 @@ def predecessors( def _validate_subgraph_args( self, node_ids: Sequence[int] | None = None, - node_attr_comps: list[AttrComparison] | None = None, - edge_attr_comps: list[AttrComparison] | None = None, + node_attr_comps: list[FilterInput] | None = None, + edge_attr_comps: list[FilterInput] | None = None, ) -> None: if node_ids is None and not node_attr_comps and not edge_attr_comps: raise ValueError("Either node IDs or one of the attributes' comparisons must be provided") @@ -619,7 +619,7 @@ def edge_ids(self) -> list[int]: @abc.abstractmethod def filter( self, - *attr_filters: AttrComparison, + *attr_filters: FilterInput, node_ids: Sequence[int] | None = None, include_targets: bool = False, include_sources: bool = False, @@ -627,10 +627,14 @@ def filter( """ Creates a filter object that can be used to create a subgraph or query ids and attributes. + Multiple positional filters are implicitly AND-ed together. Each filter + can itself be a compound `AttrFilter` built from `AttrComparison`s using + `&`, `|`, `^`, `~` (e.g. `(NodeAttr("t") == 1) | (NodeAttr("t") == 2)`). + Parameters ---------- - *attr_filters : AttrComparison - The attributes to filter the nodes by. + *attr_filters : AttrComparison | AttrFilter + The attribute filters to apply. Positional args are AND-ed. node_ids : Sequence[int] | None The IDs of the nodes to include in the filter. If None, all nodes are used. diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index a9350258..79068003 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -5,7 +5,7 @@ import polars as pl import rustworkx as rx -from tracksdata.attrs import AttrComparison +from tracksdata.attrs import FilterInput from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.graph._mapped_graph_mixin import MappedGraphMixin @@ -237,7 +237,7 @@ def overlaps( def filter( self, - *attr_filters: AttrComparison, + *attr_filters: FilterInput, node_ids: Sequence[int] | None = None, include_targets: bool = False, include_sources: bool = False, diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 707d32fd..9387fe5d 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -7,7 +7,7 @@ import polars as pl import rustworkx as rx -from tracksdata.attrs import AttrComparison, split_attr_comps +from tracksdata.attrs import AttrComparison, FilterInput, split_attr_comps from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.graph._mapped_graph_mixin import MappedGraphMixin @@ -23,26 +23,32 @@ def _pop_time_eq( - attrs: Sequence[AttrComparison], -) -> tuple[list[AttrComparison], int | None]: + attrs: Sequence[FilterInput], +) -> tuple[list[FilterInput], int | None]: """ - Pop the time equality filter from a list of attribute filters. - If multiple time equality filters are found, an error is raised. + Pop the top-level time equality filter from a list of attribute filters. + Compound (AttrFilter) entries are left untouched even if they reference + the time column. If multiple time equality filters are found at the top + level, an error is raised. Parameters ---------- - attrs : Sequence[AttrComparison] + attrs : Sequence[AttrComparison | AttrFilter] The attribute filters to pop the time equality filter from. Returns ------- - tuple[list[AttrComparison], int | None] + tuple[list[AttrComparison | AttrFilter], int | None] The attribute filters without the time equality filter and the time value. """ - out_attrs = [] + out_attrs: list[FilterInput] = [] time = None for attr_comp in attrs: - if str(attr_comp.column) == DEFAULT_ATTR_KEYS.T and attr_comp.op == operator.eq: + if ( + isinstance(attr_comp, AttrComparison) + and str(attr_comp.column) == DEFAULT_ATTR_KEYS.T + and attr_comp.op == operator.eq + ): if time is not None: raise ValueError(f"Multiple '{DEFAULT_ATTR_KEYS.T}' equality filters are not allowed\n {attrs}") time = int(attr_comp.other) @@ -82,16 +88,36 @@ def _list_to_pl_series(key: str, values: list[Any], schema: AttrSchema) -> pl.Se return s +def _eval_filter( + f: FilterInput, + attrs: dict[str, Any], + schema: dict[str, AttrSchema], +) -> bool: + """Evaluate a single comparison or compound filter against an attrs dict.""" + if isinstance(f, AttrComparison): + value = attrs.get(f.column, schema[f.column].default_value) + return bool(f.op(value, f.other)) + + if f.op == "and": + return all(_eval_filter(o, attrs, schema) for o in f.operands) + if f.op == "or": + return any(_eval_filter(o, attrs, schema) for o in f.operands) + if f.op == "xor": + truthy_count = sum(1 for o in f.operands if _eval_filter(o, attrs, schema)) + return truthy_count % 2 == 1 + # not + return not _eval_filter(f.operands[0], attrs, schema) + + def _create_filter_func( - attr_comps: Sequence[AttrComparison], + attr_comps: Sequence[FilterInput], schema: dict[str, AttrSchema], ) -> Callable[[dict[str, Any]], bool]: LOG.info(f"Creating filter function for {attr_comps}") def _filter(attrs: dict[str, Any]) -> bool: - for attr_op in attr_comps: - value = attrs.get(attr_op.column, schema[attr_op.column].default_value) - if not attr_op.op(value, attr_op.other): + for f in attr_comps: + if not _eval_filter(f, attrs, schema): return False return True @@ -101,7 +127,7 @@ def _filter(attrs: dict[str, Any]) -> bool: class RXFilter(BaseFilter): def __init__( self, - *attr_comps: AttrComparison, + *attr_comps: FilterInput, graph: "RustWorkXGraph", node_ids: Sequence[int] | None = None, include_targets: bool = False, @@ -459,7 +485,7 @@ def rx_graph(self) -> rx.PyDiGraph: def filter( self, - *attr_filters: AttrComparison, + *attr_filters: FilterInput, node_ids: Sequence[int] | None = None, include_targets: bool = False, include_sources: bool = False, @@ -888,7 +914,7 @@ def predecessors( def _filter_nodes_by_attrs( self, - *attrs: AttrComparison, + *attrs: FilterInput, node_ids: Sequence[int] | None = None, ) -> list[int]: """ @@ -896,7 +922,7 @@ def _filter_nodes_by_attrs( Parameters ---------- - *attrs : AttrComparison + *attrs : AttrComparison | AttrFilter The attributes to filter by, for example: node_ids : list[int] | None The IDs of the nodes to include in the filter. diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index de1ebff7..01d748bf 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,4 +1,5 @@ import binascii +import functools import re from collections.abc import Callable, Sequence from enum import Enum @@ -15,7 +16,7 @@ from sqlalchemy.orm.query import Query from sqlalchemy.sql.type_api import TypeEngine -from tracksdata.attrs import AttrComparison, split_attr_comps +from tracksdata.attrs import AttrComparison, FilterInput, split_attr_comps from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.graph.filters._base_filter import BaseFilter @@ -58,13 +59,35 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() +def _to_sql_clause(f: FilterInput, table: type[DeclarativeBase]) -> Any: + """Translate an AttrComparison or AttrFilter into a SQLAlchemy clause.""" + if isinstance(f, AttrComparison): + return f.op(getattr(table, str(f.column)), f.other) + + if f.op == "not": + return sa.not_(_to_sql_clause(f.operands[0], table)) + + clauses = [_to_sql_clause(o, table) for o in f.operands] + if f.op == "and": + return sa.and_(*clauses) + if f.op == "or": + return sa.or_(*clauses) + # xor: reduce pairwise via (a OR b) AND NOT (a AND b) + return functools.reduce( + lambda a, b: sa.and_(sa.or_(a, b), sa.not_(sa.and_(a, b))), + clauses, + ) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], - attr_filters: list[AttrComparison], + attr_filters: Sequence[FilterInput], ) -> sa.Select: """ - Filter a query by a list of attribute filters. + Filter a query by a list of attribute filters (AND-ed together at the top + level). Each filter may itself be a compound AttrFilter combining + AttrComparisons with OR / AND / XOR / NOT. Parameters ---------- @@ -72,7 +95,7 @@ def _filter_query( The query to filter. table : type[DeclarativeBase] The table to filter. - attr_filters : list[AttrComparison] + attr_filters : Sequence[AttrComparison | AttrFilter] The attribute filters to apply. Returns @@ -81,16 +104,14 @@ def _filter_query( The filtered query. """ LOG.info("Filter query:\n%s", attr_filters) - query = query.filter( - *[attr_filter.op(getattr(table, str(attr_filter.column)), attr_filter.other) for attr_filter in attr_filters] - ) + query = query.filter(*[_to_sql_clause(f, table) for f in attr_filters]) return query class SQLFilter(BaseFilter): def __init__( self, - *attr_filters: AttrComparison, + *attr_filters: FilterInput, graph: "SQLGraph", node_ids: Sequence[int] | None = None, include_targets: bool = False, @@ -728,7 +749,7 @@ def _update_max_id_per_time(self) -> None: def filter( self, - *attr_filters: AttrComparison, + *attr_filters: FilterInput, node_ids: Sequence[int] | None = None, include_targets: bool = False, include_sources: bool = False, diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index f38e54c1..dba8c128 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -741,6 +741,98 @@ def test_homemorphism(graph_backend: BaseGraph) -> None: assert same_graph.edge_ids() == graph_with_data.edge_ids() +@parametrize_subgraph_tests +def test_filter_nodes_with_or_attr_filter( + graph_backend: BaseGraph, + use_subgraph: bool, +) -> None: + """OR-combined node filter selects the union of matching nodes.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph) + node_attrs = graph_with_data.node_attrs() + + nodes = graph_with_data.filter((NodeAttr("t") == 1) | (NodeAttr("t") == 3)).node_ids() + expected = node_attrs.filter(pl.col("t").is_in([1, 3]))[DEFAULT_ATTR_KEYS.NODE_ID].to_list() + assert set(nodes) == set(expected) + + +@parametrize_subgraph_tests +def test_filter_nodes_with_not_attr_filter( + graph_backend: BaseGraph, + use_subgraph: bool, +) -> None: + """NOT (inverted) node filter selects the complement of matching nodes.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph) + node_attrs = graph_with_data.node_attrs() + + nodes = graph_with_data.filter(~(NodeAttr("label") == "A")).node_ids() + expected = node_attrs.filter(pl.col("label") != "A")[DEFAULT_ATTR_KEYS.NODE_ID].to_list() + assert set(nodes) == set(expected) + + +@parametrize_subgraph_tests +def test_filter_nodes_with_xor_attr_filter( + graph_backend: BaseGraph, + use_subgraph: bool, +) -> None: + """XOR node filter selects nodes matching exactly one of the conditions.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph) + node_attrs = graph_with_data.node_attrs() + + nodes = graph_with_data.filter((NodeAttr("t") == 2) ^ (NodeAttr("label") == "A")).node_ids() + expected = node_attrs.filter((pl.col("t") == 2) ^ (pl.col("label") == "A"))[DEFAULT_ATTR_KEYS.NODE_ID].to_list() + assert set(nodes) == set(expected) + + +@parametrize_subgraph_tests +def test_filter_nodes_with_nested_compound( + graph_backend: BaseGraph, + use_subgraph: bool, +) -> None: + """Nested AND/OR filter trees evaluate correctly.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph) + node_attrs = graph_with_data.node_attrs() + + nodes = graph_with_data.filter( + (NodeAttr("label") == "A") & ((NodeAttr("t") == 1) | (NodeAttr("t") == 3)) + ).node_ids() + expected = node_attrs.filter((pl.col("label") == "A") & (pl.col("t").is_in([1, 3])))[ + DEFAULT_ATTR_KEYS.NODE_ID + ].to_list() + assert set(nodes) == set(expected) + + +def test_filter_edges_with_or_attr_filter(graph_backend: BaseGraph) -> None: + """OR-combined edge filter selects the union of matching edges.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph=False) + edge_attrs = graph_with_data.edge_attrs() + + edge_filter = graph_with_data.filter((EdgeAttr("weight") < 0.4) | (EdgeAttr("weight") > 0.8)) + selected_edges = edge_filter.edge_attrs()[DEFAULT_ATTR_KEYS.EDGE_ID].to_list() + expected = edge_attrs.filter((pl.col("weight") < 0.4) | (pl.col("weight") > 0.8))[ + DEFAULT_ATTR_KEYS.EDGE_ID + ].to_list() + assert set(selected_edges) == set(expected) + + +def test_filter_subgraph_with_or_attr_filter(graph_backend: BaseGraph) -> None: + """Building a subgraph from a compound (OR) filter yields the expected nodes/edges.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph=False) + node_attrs = graph_with_data.node_attrs() + + sub = graph_with_data.filter((NodeAttr("t") == 1) | (NodeAttr("t") == 2)).subgraph() + expected = node_attrs.filter(pl.col("t").is_in([1, 2]))[DEFAULT_ATTR_KEYS.NODE_ID].to_list() + assert set(sub.node_ids()) == set(expected) + + +def test_filter_compound_mixed_node_and_edge_raises(graph_backend: BaseGraph) -> None: + """A single compound filter cannot mix node and edge attributes.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph=False) + + bad_filter = (NodeAttr("t") == 1) | (EdgeAttr("weight") > 0.5) + with pytest.raises(ValueError, match="cannot mix NodeAttr and EdgeAttr"): + graph_with_data.filter(bad_filter).node_ids() + + @parametrize_subgraph_tests def test_subgraph_overlaps_basic(graph_backend: BaseGraph, use_subgraph: bool) -> None: """Test basic overlap functionality in subgraphs.""" diff --git a/src/tracksdata/graph/filters/_indexed_filter.py b/src/tracksdata/graph/filters/_indexed_filter.py index 7c43cd5a..a3fb4beb 100644 --- a/src/tracksdata/graph/filters/_indexed_filter.py +++ b/src/tracksdata/graph/filters/_indexed_filter.py @@ -3,7 +3,7 @@ import polars as pl -from tracksdata.attrs import AttrComparison +from tracksdata.attrs import FilterInput from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._rustworkx_graph import ( IndexedRXGraph, @@ -21,7 +21,7 @@ class IndexRXFilter(RXFilter): def __init__( self, - *attr_comps: AttrComparison, + *attr_comps: FilterInput, graph: "GraphView | IndexedRXGraph", node_ids: Sequence[int] | None = None, include_targets: bool = False, From f0ee8735641dc16ce24db2bfa0c9f6b07a5a71b4 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sat, 30 May 2026 22:13:21 -0700 Subject: [PATCH 2/5] cleanup --- src/tracksdata/_test/test_attrs.py | 14 +++++ src/tracksdata/attrs.py | 90 ++++++++++++++++-------------- 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index b4582643..7bce651f 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -690,6 +690,20 @@ def test_attr_filter_and_operator_between_comparisons() -> None: assert and_f.op == "and" +@pytest.mark.parametrize( + "op", + [operator.and_, operator.or_, operator.xor], +) +def test_attr_filter_logical_op_with_non_filter_raises(op: Callable) -> None: + """Combining a comparison with a non-filter operand is not meaningful.""" + comp = NodeAttr("a") > 0 + with pytest.raises(TypeError, match="Boolean operators on comparisons"): + op(comp, 5) + # reversed operand order goes through the reflected operator + with pytest.raises(TypeError, match="Boolean operators on comparisons"): + op(5, comp) + + def test_attr_filter_nested_composition() -> None: f = (NodeAttr("a") > 0) & ((NodeAttr("b") == 1) | (NodeAttr("b") == 2)) assert isinstance(f, AttrFilter) diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index d6e07512..53aafee5 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -172,7 +172,7 @@ def __getattr__(self, attr: str) -> Any: def _delegate_operator(self, other: ExprInput, op: Callable[[Expr, Expr], Expr], reverse: bool = False) -> "Attr": return self.to_attr()._delegate_operator(other, op, reverse) - # Binary operators + # Arithmetic operators (auto-generated by `_setup_ops`, return Attr) def __add__(self, other: ExprInput) -> "Attr": ... def __sub__(self, other: ExprInput) -> "Attr": ... def __mul__(self, other: ExprInput) -> "Attr": ... @@ -180,11 +180,8 @@ def __truediv__(self, other: ExprInput) -> "Attr": ... def __floordiv__(self, other: ExprInput) -> "Attr": ... def __mod__(self, other: ExprInput) -> "Attr": ... def __pow__(self, other: ExprInput) -> "Attr": ... - def __and__(self, other: ExprInput) -> "Attr": ... - def __or__(self, other: ExprInput) -> "Attr": ... - def __xor__(self, other: ExprInput) -> "Attr": ... - # Reverse operators + # Reverse arithmetic operators def __radd__(self, other: Scalar) -> "Attr": ... def __rsub__(self, other: Scalar) -> "Attr": ... def __rmul__(self, other: Scalar) -> "Attr": ... @@ -192,9 +189,41 @@ def __rtruediv__(self, other: Scalar) -> "Attr": ... def __rfloordiv__(self, other: Scalar) -> "Attr": ... def __rmod__(self, other: Scalar) -> "Attr": ... def __rpow__(self, other: Scalar) -> "Attr": ... - def __rand__(self, other: Scalar) -> "Attr": ... - def __ror__(self, other: Scalar) -> "Attr": ... - def __rxor__(self, other: Scalar) -> "Attr": ... + + # Logical operators combine comparisons into an AttrFilter compound. + # `AttrFilter` is defined later in the module; the references below resolve + # at call time, so the forward reference is fine. + def _logical_op(self, op_name: str, other: Any, reverse: bool = False) -> "AttrFilter": + if not isinstance(other, AttrComparison | AttrFilter): + symbol = _FILTER_OP_SYMBOLS[op_name] + raise TypeError( + f"Cannot apply '{symbol}' between an AttrComparison and {type(other).__name__}. " + "Boolean operators on comparisons combine them into a filter; both operands " + "must be an AttrComparison or AttrFilter." + ) + operands = [other, self] if reverse else [self, other] + return AttrFilter(op_name, operands) + + def __and__(self, other: FilterInput) -> "AttrFilter": + return self._logical_op("and", other) + + def __rand__(self, other: FilterInput) -> "AttrFilter": + return self._logical_op("and", other, reverse=True) + + def __or__(self, other: FilterInput) -> "AttrFilter": + return self._logical_op("or", other) + + def __ror__(self, other: FilterInput) -> "AttrFilter": + return self._logical_op("or", other, reverse=True) + + def __xor__(self, other: FilterInput) -> "AttrFilter": + return self._logical_op("xor", other) + + def __rxor__(self, other: FilterInput) -> "AttrFilter": + return self._logical_op("xor", other, reverse=True) + + def __invert__(self) -> "AttrFilter": + return AttrFilter("not", [self]) # Comparison operators (always return Attr) def __eq__(self, other: ExprInput) -> "Attr": ... @@ -613,6 +642,7 @@ def _setup_ops() -> None: """ Setup the operator methods for the AttrExpr class. """ + # Arithmetic operators: generated for both Attr and AttrComparison. bin_ops = { "add": operator.add, "sub": operator.sub, @@ -621,6 +651,12 @@ def _setup_ops() -> None: "floordiv": operator.floordiv, "mod": operator.mod, "pow": operator.pow, + } + + # Logical operators: generated only for Attr (bitwise on the polars expr). + # AttrComparison defines its own `& | ^ ~` in the class body to build + # AttrFilter compounds, so they are intentionally excluded here. + logical_ops = { "and": operator.and_, "or": operator.or_, "xor": operator.xor, @@ -635,9 +671,11 @@ def _setup_ops() -> None: "ge": operator.ge, } - for op_name, op_func in bin_ops.items(): + for op_name, op_func in (bin_ops | logical_ops).items(): _add_operator(Attr, f"__{op_name}__", op_func, reverse=False) _add_operator(Attr, f"__r{op_name}__", op_func, reverse=True) + + for op_name, op_func in bin_ops.items(): _add_operator(AttrComparison, f"__{op_name}__", op_func, reverse=False) _add_operator(AttrComparison, f"__r{op_name}__", op_func, reverse=True) @@ -789,40 +827,6 @@ def _filter_attr_kind(f: FilterInput) -> type[Attr]: return kinds.pop() -# --- AttrComparison boolean operator overrides -------------------------------- -# These run after _setup_ops() and replace the auto-generated `__and__`, -# `__or__`, `__xor__` on AttrComparison so that combining comparisons builds an -# AttrFilter compound instead of an Attr expression. The boolean methods on -# `Attr` itself are unchanged. - - -def _attr_comparison_logical(self: "AttrComparison", other: Any, op_name: str, py_op: Callable) -> Any: - if isinstance(other, AttrComparison | AttrFilter): - return AttrFilter(op_name, [self, other]) - return self._delegate_operator(other, py_op, reverse=False) - - -def _attr_comparison_r_logical(self: "AttrComparison", other: Any, op_name: str, py_op: Callable) -> Any: - if isinstance(other, AttrComparison | AttrFilter): - return AttrFilter(op_name, [other, self]) - return self._delegate_operator(other, py_op, reverse=True) - - -AttrComparison.__and__ = functools.partialmethod(_attr_comparison_logical, op_name="and", py_op=operator.and_) -AttrComparison.__rand__ = functools.partialmethod(_attr_comparison_r_logical, op_name="and", py_op=operator.and_) -AttrComparison.__or__ = functools.partialmethod(_attr_comparison_logical, op_name="or", py_op=operator.or_) -AttrComparison.__ror__ = functools.partialmethod(_attr_comparison_r_logical, op_name="or", py_op=operator.or_) -AttrComparison.__xor__ = functools.partialmethod(_attr_comparison_logical, op_name="xor", py_op=operator.xor) -AttrComparison.__rxor__ = functools.partialmethod(_attr_comparison_r_logical, op_name="xor", py_op=operator.xor) - - -def _attr_comparison_invert(self: "AttrComparison") -> "AttrFilter": - return AttrFilter("not", [self]) - - -AttrComparison.__invert__ = _attr_comparison_invert - - def split_attr_comps( attr_comps: Sequence[FilterInput], ) -> tuple[list[FilterInput], list[FilterInput]]: From 79c1d5961b3fbef2caca6548e5b0b2d876b2f384 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sat, 30 May 2026 22:29:22 -0700 Subject: [PATCH 3/5] ignored devcontainer --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b652a88f..2695912e 100644 --- a/.gitignore +++ b/.gitignore @@ -198,3 +198,4 @@ src/tracksdata/__about__.py # Claude .claude/ +.devcontainer/ From a4afbcfe5ddbb8dc5f1a85b2e8b296ce1d1f1c6b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 31 May 2026 09:22:34 -0700 Subject: [PATCH 4/5] merge to attr --- src/tracksdata/_filter.py | 57 ++ src/tracksdata/_test/test_attrs.py | 326 ++++---- src/tracksdata/attrs.py | 781 ++++++++------------ src/tracksdata/graph/_base_graph.py | 6 +- src/tracksdata/graph/_rustworkx_graph.py | 65 +- src/tracksdata/graph/_sql_graph.py | 36 +- src/tracksdata/graph/_test/test_subgraph.py | 9 +- 7 files changed, 595 insertions(+), 685 deletions(-) create mode 100644 src/tracksdata/_filter.py diff --git a/src/tracksdata/_filter.py b/src/tracksdata/_filter.py new file mode 100644 index 00000000..a6c3c4a9 --- /dev/null +++ b/src/tracksdata/_filter.py @@ -0,0 +1,57 @@ +"""Internal filter AST shared by graph backends. + +Each `Attr` carries an optional `_FilterNode` that records the structured form +of a boolean filter expression (leaf comparison or compound). Backends walk +this AST to translate filters into SQL clauses, polars predicates, or Python +dict checks. + +The AST is intentionally minimal: leaf comparisons hold `(column, op, other)` +plus the originating `Attr` subclass so node/edge dispatch survives, and +compound nodes hold a logical op and a list of children. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from tracksdata.attrs import Attr + + +CompoundOp = Literal["and", "or", "xor", "not"] + + +@dataclass(frozen=True) +class _FilterLeaf: + """A single column comparison: ``column op other``. + + `kind` is the originating `Attr` subclass (`NodeAttr` / `EdgeAttr` / `Attr`) + used by backend dispatch to decide which graph table the filter targets. + """ + + column: str + op: Callable + other: object + kind: type[Attr] + + +@dataclass(frozen=True) +class _FilterCompound: + """A boolean combination of filter nodes.""" + + op: CompoundOp + operands: tuple[_FilterNode, ...] + + +_FilterNode = _FilterLeaf | _FilterCompound + + +def walk_leaves(node: _FilterNode) -> Iterable[_FilterLeaf]: + """Yield all leaf comparisons under `node` in left-to-right order.""" + if isinstance(node, _FilterLeaf): + yield node + return + for child in node.operands: + yield from walk_leaves(child) diff --git a/src/tracksdata/_test/test_attrs.py b/src/tracksdata/_test/test_attrs.py index 7bce651f..0a0723e6 100644 --- a/src/tracksdata/_test/test_attrs.py +++ b/src/tracksdata/_test/test_attrs.py @@ -6,10 +6,9 @@ import polars as pl import pytest +from tracksdata._filter import _FilterCompound, _FilterLeaf from tracksdata.attrs import ( Attr, - AttrComparison, - AttrFilter, EdgeAttr, NodeAttr, attr_comps_to_strs, @@ -312,110 +311,101 @@ def test_duplicated_columns() -> None: def test_attr_reverse_comparison() -> None: - """Test basic initialization of AttrComparison.""" + """`5 == attr` is dispatched via Python's symmetric `__eq__` fallback.""" attr = Attr("test_column") comp = 5 == attr # reversed on purpose - assert comp.attr == attr - assert comp.column == "test_column" - assert comp.op == operator.eq - assert comp.other == 5 + leaf = comp._filter + assert isinstance(leaf, _FilterLeaf) + assert leaf.column == "test_column" + assert leaf.op == operator.eq + assert leaf.other == 5 def test_attr_numpy_comparison() -> None: - """Test basic initialization of AttrComparison.""" + """numpy scalars are cast to Python scalars in the leaf.""" attr = Attr("test_column") comp = attr == np.asarray(5) - assert comp.attr == attr - assert comp.column == "test_column" - assert comp.op == operator.eq - assert comp.other == 5 + leaf = comp._filter + assert isinstance(leaf, _FilterLeaf) + assert leaf.column == "test_column" + assert leaf.op == operator.eq + assert leaf.other == 5 + assert type(leaf.other) is int # native, not np.int64 def test_attr_comparison_repr() -> None: - """Test string representation of AttrComparison.""" + """Filter-shaped Attrs render as `Kind(col) op value`.""" attr = Attr("test_column") comp = attr > 10 assert repr(comp) == "Attr(test_column) > 10" -def test_attr_comparison_to_attr() -> None: - """Test converting AttrComparison back to Attr.""" +def test_attr_comparison_evaluates_as_boolean_series() -> None: + """A filter-shaped Attr evaluates directly as a polars boolean expression.""" df = pl.DataFrame({"test_column": [1, 2, 3, 4, 5]}) - attr = Attr("test_column") - comp = attr > 3 - - converted_attr = comp.to_attr() - result = converted_attr.evaluate(df) + comp = Attr("test_column") > 3 + result = comp.evaluate(df) assert result.to_list() == [False, False, False, True, True] def test_attr_comparison_getattr_delegation() -> None: - """Test that AttrComparison delegates attribute access to its Attr representation.""" - attr = Attr("test_column") - comp = attr == 5 + """Filter-shaped Attrs expose the usual `Attr` properties.""" + comp = Attr("test_column") == 5 - # Test that we can access Attr methods through AttrComparison assert comp.columns == ["test_column"] assert comp.expr_columns == ["test_column"] def test_attr_comparison_operator_delegation() -> None: - """Test that AttrComparison delegates operators to its Attr representation.""" + """Arithmetic on a filter-shaped Attr produces a non-filter numeric Attr.""" df = pl.DataFrame({"test_column": [1, 2, 3, 4, 5]}) - attr = Attr("test_column") - comp = attr > 3 + comp = Attr("test_column") > 3 - # Test that we can use operators on AttrComparison result_attr = comp + 10 + assert result_attr._filter is None # arithmetic clears the leaf result = result_attr.evaluate(df) - - # Should be (test_column > 3) + 10 expected = [(x > 3) + 10 for x in df["test_column"]] assert result.to_list() == expected -def test_attr_comparison_init_with_infinity_attr() -> None: - """Test that AttrComparison raises error when attr has infinity.""" - # Create an attr with infinity +def test_comparison_on_infinity_attr_raises() -> None: + """Comparing an Attr that carries infinity tracking is meaningless and raises.""" attr_with_inf = Attr("test") * math.inf - with pytest.raises(ValueError, match="Comparison operators are not supported for expressions with infinity"): - AttrComparison(attr_with_inf, operator.eq, 5) + _ = attr_with_inf == 5 -def test_attr_comparison_init_with_attr_other() -> None: - """Test that AttrComparison raises error when comparing two Attr objects.""" +def test_comparison_between_two_attrs_is_not_a_filter() -> None: + """`attr1 == attr2` is a polars boolean expression, not a pushdown filter.""" attr1 = Attr("col1") attr2 = Attr("col2") + comp = attr1 == attr2 + assert comp._filter is None - with pytest.raises(ValueError, match="Does not support comparison between expressions"): - AttrComparison(attr1, operator.eq, attr2) - -def test_attr_comparison_init_with_empty_columns() -> None: - """Test that AttrComparison raises error for empty expressions.""" - # Create an attr with no columns (literal) +def test_comparison_on_literal_attr_is_not_a_filter() -> None: + """Comparison on an empty-column attr falls back to a non-filter Attr.""" attr_no_cols = Attr(5) - - with pytest.raises(ValueError, match="Comparison operators are not supported for empty expressions"): - AttrComparison(attr_no_cols, operator.eq, 10) + comp = attr_no_cols == 10 + assert comp._filter is None -def test_attr_comparison_init_with_multiple_columns() -> None: - """Test that AttrComparison raises error for multiple columns.""" - # Create an attr with multiple columns +def test_comparison_on_multi_column_attr_is_not_a_filter() -> None: + """Comparison on a multi-column attr falls back to a non-filter Attr.""" attr_multi_cols = Attr("col1") + Attr("col2") - - with pytest.raises(ValueError, match="Comparison operators are not supported for multiple columns"): - AttrComparison(attr_multi_cols, operator.eq, 10) + comp = attr_multi_cols == 10 + assert comp._filter is None + # graph.filter() would later reject it via split_attr_comps: + with pytest.raises(ValueError, match="Expected a filter-shaped Attr"): + split_attr_comps([comp]) def test_attr_comparison_comparison_operators() -> None: - """Test all comparison operators with AttrComparison.""" + """All comparison operators produce filter-shaped Attrs that evaluate correctly.""" df = pl.DataFrame({"test_column": [1, 2, 3, 4, 5]}) attr = Attr("test_column") @@ -430,46 +420,37 @@ def test_attr_comparison_comparison_operators() -> None: for op, other, expected in test_cases: comp = op(attr, other) - result = comp.to_attr().evaluate(df) + assert isinstance(comp._filter, _FilterLeaf) + result = comp.evaluate(df) assert result.to_list() == expected def test_attr_comparison_binary_operators() -> None: - """Test binary operators with AttrComparison.""" + """Arithmetic on a leaf-filter Attr works as scalar polars math.""" df = pl.DataFrame({"test_column": [1, 2, 3, 4, 5]}) - attr = Attr("test_column") - comp = AttrComparison(attr, operator.gt, 3) + comp = Attr._leaf("test_column", operator.gt, 3) - # Test addition result = comp + 10 - result_series = result.evaluate(df) expected = [(x > 3) + 10 for x in df["test_column"]] - assert result_series.to_list() == expected + assert result.evaluate(df).to_list() == expected - # Test multiplication result = comp * 2 - result_series = result.evaluate(df) expected = [(x > 3) * 2 for x in df["test_column"]] - assert result_series.to_list() == expected + assert result.evaluate(df).to_list() == expected def test_attr_comparison_reverse_operators() -> None: - """Test reverse operators with AttrComparison.""" + """Reverse arithmetic on a leaf-filter Attr also works.""" df = pl.DataFrame({"test_column": [1, 2, 3, 4, 5]}) - attr = Attr("test_column") - comp = AttrComparison(attr, operator.gt, 3) + comp = Attr._leaf("test_column", operator.gt, 3) - # Test reverse addition result = 10 + comp - result_series = result.evaluate(df) expected = [10 + (x > 3) for x in df["test_column"]] - assert result_series.to_list() == expected + assert result.evaluate(df).to_list() == expected - # Test reverse multiplication result = 2 * comp - result_series = result.evaluate(df) expected = [2 * (x > 3) for x in df["test_column"]] - assert result_series.to_list() == expected + assert result.evaluate(df).to_list() == expected def test_split_attr_comps() -> None: @@ -489,10 +470,10 @@ def test_split_attr_comps() -> None: assert len(node_comps) == 2 assert len(edge_comps) == 2 - assert node_comps[0].column == "node_col1" - assert node_comps[1].column == "node_col2" - assert edge_comps[0].column == "edge_col1" - assert edge_comps[1].column == "edge_col2" + assert node_comps[0]._filter.column == "node_col1" + assert node_comps[1]._filter.column == "node_col2" + assert edge_comps[0]._filter.column == "edge_col1" + assert edge_comps[1]._filter.column == "edge_col2" def test_split_attr_comps_empty() -> None: @@ -504,24 +485,22 @@ def test_split_attr_comps_empty() -> None: def test_split_attr_comps_only_node() -> None: """Test splitting only node attribute comparisons.""" - node_attr = NodeAttr("node_col") - node_comp = AttrComparison(node_attr, operator.eq, 1) + node_comp = NodeAttr("node_col") == 1 node_comps, edge_comps = split_attr_comps([node_comp]) assert len(node_comps) == 1 assert len(edge_comps) == 0 - assert node_comps[0].column == "node_col" + assert node_comps[0]._filter.column == "node_col" def test_split_attr_comps_only_edge() -> None: """Test splitting only edge attribute comparisons.""" - edge_attr = EdgeAttr("edge_col") - edge_comp = AttrComparison(edge_attr, operator.gt, 5) + edge_comp = EdgeAttr("edge_col") > 5 node_comps, edge_comps = split_attr_comps([edge_comp]) assert len(node_comps) == 0 assert len(edge_comps) == 1 - assert edge_comps[0].column == "edge_col" + assert edge_comps[0]._filter.column == "edge_col" def test_split_attr_comps_invalid_type() -> None: @@ -633,8 +612,8 @@ def test_attr_is_in_creates_membership_expression() -> None: df = pl.DataFrame({"col": [1, 2, 3, 4]}) comp = Attr("col").is_in([1, 3, 4]) - assert isinstance(comp, AttrComparison) - evaluated = comp.to_attr().evaluate(df) + assert isinstance(comp._filter, _FilterLeaf) + evaluated = comp.evaluate(df) assert evaluated.to_list() == [True, False, True, True] @@ -649,78 +628,90 @@ def test_attr_is_in_accepts_numpy_arrays() -> None: df = pl.DataFrame({"col": [5, 6, 7]}) comp = Attr("col").is_in(np.array([6, 8], dtype=np.int64)) - evaluated = comp.to_attr().evaluate(df) + evaluated = comp.evaluate(df) assert evaluated.to_list() == [False, True, False] # --------------------------------------------------------------------------- -# AttrFilter (compound boolean filters built from AttrComparisons) +# Compound boolean filters (`& | ^ ~` on filter-shaped Attrs) # --------------------------------------------------------------------------- -def test_attr_filter_or_operator_returns_filter() -> None: +def test_filter_or_operator_returns_compound() -> None: comp1 = NodeAttr("t") == 1 comp2 = NodeAttr("t") == 2 f = comp1 | comp2 - assert isinstance(f, AttrFilter) - assert f.op == "or" - assert f.operands == [comp1, comp2] - assert f.columns == ["t"] + assert isinstance(f._filter, _FilterCompound) + assert f._filter.op == "or" + assert f._filter.operands == (comp1._filter, comp2._filter) + assert attr_comps_to_strs([f]) == ["t"] -def test_attr_filter_xor_and_invert_operators() -> None: +def test_filter_xor_and_invert_operators() -> None: comp1 = NodeAttr("a") > 0 comp2 = NodeAttr("b") > 0 xor_f = comp1 ^ comp2 - assert isinstance(xor_f, AttrFilter) - assert xor_f.op == "xor" + assert isinstance(xor_f._filter, _FilterCompound) + assert xor_f._filter.op == "xor" not_f = ~comp1 - assert isinstance(not_f, AttrFilter) - assert not_f.op == "not" - assert not_f.operands == [comp1] + assert isinstance(not_f._filter, _FilterCompound) + assert not_f._filter.op == "not" + assert not_f._filter.operands == (comp1._filter,) -def test_attr_filter_and_operator_between_comparisons() -> None: - comp1 = NodeAttr("a") > 0 - comp2 = NodeAttr("b") < 1 - and_f = comp1 & comp2 - assert isinstance(and_f, AttrFilter) - assert and_f.op == "and" +def test_filter_and_operator_between_comparisons() -> None: + and_f = (NodeAttr("a") > 0) & (NodeAttr("b") < 1) + assert isinstance(and_f._filter, _FilterCompound) + assert and_f._filter.op == "and" @pytest.mark.parametrize( "op", [operator.and_, operator.or_, operator.xor], ) -def test_attr_filter_logical_op_with_non_filter_raises(op: Callable) -> None: - """Combining a comparison with a non-filter operand is not meaningful.""" +def test_filter_logical_op_with_non_filter_raises(op: Callable) -> None: + """Combining a filter-shaped Attr with a non-filter operand raises.""" comp = NodeAttr("a") > 0 - with pytest.raises(TypeError, match="Boolean operators on comparisons"): + with pytest.raises(TypeError, match="Cannot apply"): op(comp, 5) # reversed operand order goes through the reflected operator - with pytest.raises(TypeError, match="Boolean operators on comparisons"): + with pytest.raises(TypeError, match="Cannot apply"): op(5, comp) -def test_attr_filter_nested_composition() -> None: +def test_filter_nested_composition() -> None: f = (NodeAttr("a") > 0) & ((NodeAttr("b") == 1) | (NodeAttr("b") == 2)) - assert isinstance(f, AttrFilter) - assert f.op == "and" - assert isinstance(f.operands[1], AttrFilter) - assert f.operands[1].op == "or" - assert sorted({leaf.column for leaf in f.leaves()}) == ["a", "b"] + assert isinstance(f._filter, _FilterCompound) + assert f._filter.op == "and" + assert isinstance(f._filter.operands[1], _FilterCompound) + assert f._filter.operands[1].op == "or" + leaf_columns = sorted( + { + leaf.column + for leaf in [op for op in f._filter.operands if isinstance(op, _FilterLeaf)] + + [op for op in f._filter.operands[1].operands if isinstance(op, _FilterLeaf)] + } + ) + assert leaf_columns == ["a", "b"] + +def test_filter_auto_flattens_associative_ops() -> None: + """`(a | b) | c` should produce a single 3-operand `or` compound.""" + f = (NodeAttr("t") == 1) | (NodeAttr("t") == 2) | (NodeAttr("t") == 3) + assert isinstance(f._filter, _FilterCompound) + assert f._filter.op == "or" + assert len(f._filter.operands) == 3 -def test_attr_filter_mixed_node_and_edge_raises_on_split() -> None: - """A single compound that mixes node and edge attributes must error in split.""" - f = (NodeAttr("t") == 1) | (EdgeAttr("weight") > 0.5) - with pytest.raises(ValueError, match="cannot mix NodeAttr and EdgeAttr"): - split_attr_comps([f]) +def test_filter_mixed_node_and_edge_raises_on_compound() -> None: + """Mixing NodeAttr and EdgeAttr in one compound now raises at construction.""" + with pytest.raises(ValueError, match="Cannot combine NodeAttr and EdgeAttr"): + _ = (NodeAttr("t") == 1) | (EdgeAttr("weight") > 0.5) -def test_attr_filter_split_attr_comps_with_compounds() -> None: + +def test_filter_split_attr_comps_with_compounds() -> None: node_f = (NodeAttr("t") == 1) | (NodeAttr("t") == 2) edge_f = (EdgeAttr("w") > 0.5) | (EdgeAttr("w") < -0.5) node_only = NodeAttr("label") == "A" @@ -730,27 +721,7 @@ def test_attr_filter_split_attr_comps_with_compounds() -> None: assert edges == [edge_f] -def test_attr_filter_invalid_op_raises() -> None: - with pytest.raises(ValueError, match="Unknown logical operator"): - AttrFilter("nor", [NodeAttr("a") == 1, NodeAttr("a") == 2]) - - -def test_attr_filter_not_with_multiple_operands_raises() -> None: - with pytest.raises(ValueError, match="'not' filter requires exactly one operand"): - AttrFilter("not", [NodeAttr("a") == 1, NodeAttr("a") == 2]) - - -def test_attr_filter_or_with_single_operand_raises() -> None: - with pytest.raises(ValueError, match="'or' filter requires at least two operands"): - AttrFilter("or", [NodeAttr("a") == 1]) - - -def test_attr_filter_rejects_non_filter_operands() -> None: - with pytest.raises(TypeError, match="must be AttrComparison or AttrFilter"): - AttrFilter("or", [NodeAttr("a") == 1, 5]) - - -def test_attr_filter_polars_reduce_or() -> None: +def test_filter_polars_reduce_or() -> None: df = pl.DataFrame({"t": [0, 1, 2, 3, 4]}) f = (NodeAttr("t") == 1) | (NodeAttr("t") == 3) expr = polars_reduce_attr_comps(df, [f], operator.and_) @@ -758,7 +729,7 @@ def test_attr_filter_polars_reduce_or() -> None: assert result.to_list() == [False, True, False, True, False] -def test_attr_filter_polars_reduce_xor() -> None: +def test_filter_polars_reduce_xor() -> None: df = pl.DataFrame({"a": [0, 1, 0, 1], "b": [0, 0, 1, 1]}) f = (NodeAttr("a") == 1) ^ (NodeAttr("b") == 1) expr = polars_reduce_attr_comps(df, [f], operator.and_) @@ -766,7 +737,7 @@ def test_attr_filter_polars_reduce_xor() -> None: assert result.to_list() == [False, True, True, False] -def test_attr_filter_polars_reduce_not() -> None: +def test_filter_polars_reduce_not() -> None: df = pl.DataFrame({"t": [0, 1, 2]}) f = ~(NodeAttr("t") == 1) expr = polars_reduce_attr_comps(df, [f], operator.and_) @@ -774,7 +745,76 @@ def test_attr_filter_polars_reduce_not() -> None: assert result.to_list() == [True, False, True] -def test_attr_filter_attr_comps_to_strs_with_compound() -> None: +def test_filter_attr_comps_to_strs_with_compound() -> None: f = (NodeAttr("a") == 1) | (NodeAttr("b") == 2) plain = NodeAttr("c") == 3 assert attr_comps_to_strs([f, plain]) == ["a", "b", "c"] + + +# --------------------------------------------------------------------------- +# Kind preservation (NodeAttr/EdgeAttr survive arithmetic and method delegation) +# --------------------------------------------------------------------------- + + +def test_node_attr_kind_survives_arithmetic_with_scalar() -> None: + assert type(NodeAttr("t") + 5) is NodeAttr + assert type(5 + NodeAttr("t")) is NodeAttr + assert type(NodeAttr("t") * 2) is NodeAttr + assert type(-NodeAttr("t")) is NodeAttr + + +def test_edge_attr_kind_survives_arithmetic_with_scalar() -> None: + assert type(EdgeAttr("w") - 1) is EdgeAttr + assert type(EdgeAttr("w") / 2) is EdgeAttr + assert type(abs(EdgeAttr("w"))) is EdgeAttr + + +def test_node_attr_kind_survives_method_delegation() -> None: + assert type(NodeAttr("t").log()) is NodeAttr + assert type(NodeAttr("t").alias("x")) is NodeAttr + + +def test_same_kind_binary_op_preserves_kind() -> None: + assert type(NodeAttr("a") + NodeAttr("b")) is NodeAttr + assert type(EdgeAttr("a") * EdgeAttr("b")) is EdgeAttr + + +def test_base_attr_defers_to_specific_kind() -> None: + assert type(Attr("a") + NodeAttr("b")) is NodeAttr + assert type(EdgeAttr("a") - Attr("b")) is EdgeAttr + + +def test_mixed_node_edge_arithmetic_raises() -> None: + with pytest.raises(ValueError, match="Cannot combine NodeAttr and EdgeAttr"): + NodeAttr("a") + EdgeAttr("b") + with pytest.raises(ValueError, match="Cannot combine EdgeAttr and NodeAttr"): + EdgeAttr("a") * NodeAttr("b") + + +def test_kind_preserved_through_comparison_filter() -> None: + """`(NodeAttr("t") + 5) == 0` used to fail kind detection; now must split as a node filter.""" + comp = (NodeAttr("t") + 5) == 0 + nodes, edges = split_attr_comps([comp]) + assert len(nodes) == 1 and len(edges) == 0 + + +def test_neg_swaps_infinity_trackers() -> None: + expr = NodeAttr("x") * math.inf + neg = -expr + assert len(neg.inf_exprs) == 0 + assert len(neg.neg_inf_exprs) == 1 + assert neg.neg_inf_exprs[0].columns == ["x"] + + +def test_invert_and_abs_propagate_infinity() -> None: + df = pl.DataFrame({"x": [True, False, True]}) + expr = NodeAttr("x") * math.inf + inverted = ~expr + assert len(inverted.inf_exprs) == 1 + assert inverted.inf_exprs[0].columns == ["x"] + + abs_expr = abs(NodeAttr("y") * math.inf) + assert len(abs_expr.inf_exprs) == 1 + assert abs_expr.inf_exprs[0].columns == ["y"] + # Smoke-check: clean expression still evaluates (drop side-effect) + _ = df diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 53aafee5..d1ff0a8c 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -36,18 +36,20 @@ import polars as pl from polars import DataFrame, Expr, Series +from tracksdata._filter import _FilterCompound, _FilterLeaf, _FilterNode, walk_leaves + Scalar = int | float | str | bool | complex | np.number -ExprInput = Union[str, Scalar, "Attr", Expr, "AttrComparison"] -MembershipExprInput = Sequence[Scalar] +ExprInput = Union[str, Scalar, "Attr", Expr] +MembershipExprInput = Sequence[Scalar] | np.ndarray -# Logical operators supported by AttrFilter compounds. -_FILTER_LOGICAL_OPS = ("and", "or", "xor", "not") -FilterInput = Union["AttrComparison", "AttrFilter"] +# A filter-shaped `Attr` (one whose `_filter` is set). Backends introspect +# `Attr._filter` to translate compound boolean filters into SQL / polars / +# Python-dict predicates. +FilterInput = "Attr" __all__ = [ - "AttrComparison", - "AttrFilter", + "Attr", "EdgeAttr", "NodeAttr", "attr_comps_to_strs", @@ -88,8 +90,16 @@ def _is_in_op(lhs: Any, values: MembershipExprInput) -> Any: } +_FILTER_OP_SYMBOLS = {"and": "&", "or": "|", "xor": "^", "not": "~"} +_BOOLEAN_OP_FUNCS: dict[str, Callable] = { + "and": operator.and_, + "or": operator.or_, + "xor": operator.xor, +} + + def _is_membership_expr_input(x: Any) -> TypeGuard[MembershipExprInput]: - if isinstance(x, Attr | AttrComparison | pl.Expr): + if isinstance(x, Attr | pl.Expr): return False if isinstance(x, Scalar): return False @@ -98,146 +108,26 @@ def _is_membership_expr_input(x: Any) -> TypeGuard[MembershipExprInput]: return isinstance(x, Sequence) -class AttrComparison: - """ - Class to store a comparison between an [Attr][tracksdata.attrs.Attr] and a value - (a sequence of values for `is_in`). - It's mainly used for filtering. - Complex expression are transformed back to [Attr][tracksdata.attrs.Attr] objects - which can be used to evaluate the expression on a DataFrame. - - Parameters - ---------- - attr : Attr - The attribute to compare. - op : Callable - The operator to use for the comparison. - other : ExprInput | MembershipExprInput - The value to compare the attribute to. - """ - - def __init__(self, attr: "Attr", op: Callable, other: ExprInput | MembershipExprInput) -> None: - is_membership_expr = _is_membership_expr_input(other) - if is_membership_expr and op != _is_in_op: - raise ValueError( - f"Membership values can only be used with the 'is_in' method. Found '{_OPS_MATH_SYMBOLS[op]}'." - ) - elif not is_membership_expr and op == _is_in_op: - raise ValueError( - f"Cannot use 'is_in' method with non-membership values. Found '{other}' of type {type(other)}." - ) - - if attr.has_inf(): - raise ValueError("Comparison operators are not supported for expressions with infinity.") - - if isinstance(other, Attr): - raise ValueError(f"Does not support comparison between expressions. Found {other} and {attr}.") - - columns = attr.expr_columns - - if len(columns) == 0: - raise ValueError("Comparison operators are not supported for empty expressions.") - - elif len(columns) > 1: - raise ValueError(f"Comparison operators are not supported for multiple columns. Found {columns}.") - - self.attr = attr - self.column = columns[0] - self.op = op - - # casting numpy scalars to python scalars - # numpy scalars are problematic for sqlalchemy - if is_membership_expr: - if isinstance(other, np.ndarray): - other = other.tolist() - else: - other = list(other) - elif isinstance(other, np.ndarray): - other = other.item() - self.other = other - - def __repr__(self) -> str: - return f"{type(self.attr).__name__}({self.column}) {_OPS_MATH_SYMBOLS[self.op]} {self.other}" - - def to_attr(self) -> "Attr": - """ - Transform the comparison back to an [Attr][tracksdata.attrs.Attr] object. - This is useful for evaluating the expression on a DataFrame. - """ - return Attr(self.op(pl.col(self.column), self.other)) - - def __getattr__(self, attr: str) -> Any: - return getattr(self.to_attr(), attr) - - def _delegate_operator(self, other: ExprInput, op: Callable[[Expr, Expr], Expr], reverse: bool = False) -> "Attr": - return self.to_attr()._delegate_operator(other, op, reverse) - - # Arithmetic operators (auto-generated by `_setup_ops`, return Attr) - def __add__(self, other: ExprInput) -> "Attr": ... - def __sub__(self, other: ExprInput) -> "Attr": ... - def __mul__(self, other: ExprInput) -> "Attr": ... - def __truediv__(self, other: ExprInput) -> "Attr": ... - def __floordiv__(self, other: ExprInput) -> "Attr": ... - def __mod__(self, other: ExprInput) -> "Attr": ... - def __pow__(self, other: ExprInput) -> "Attr": ... - - # Reverse arithmetic operators - def __radd__(self, other: Scalar) -> "Attr": ... - def __rsub__(self, other: Scalar) -> "Attr": ... - def __rmul__(self, other: Scalar) -> "Attr": ... - def __rtruediv__(self, other: Scalar) -> "Attr": ... - def __rfloordiv__(self, other: Scalar) -> "Attr": ... - def __rmod__(self, other: Scalar) -> "Attr": ... - def __rpow__(self, other: Scalar) -> "Attr": ... - - # Logical operators combine comparisons into an AttrFilter compound. - # `AttrFilter` is defined later in the module; the references below resolve - # at call time, so the forward reference is fine. - def _logical_op(self, op_name: str, other: Any, reverse: bool = False) -> "AttrFilter": - if not isinstance(other, AttrComparison | AttrFilter): - symbol = _FILTER_OP_SYMBOLS[op_name] - raise TypeError( - f"Cannot apply '{symbol}' between an AttrComparison and {type(other).__name__}. " - "Boolean operators on comparisons combine them into a filter; both operands " - "must be an AttrComparison or AttrFilter." - ) - operands = [other, self] if reverse else [self, other] - return AttrFilter(op_name, operands) - - def __and__(self, other: FilterInput) -> "AttrFilter": - return self._logical_op("and", other) - - def __rand__(self, other: FilterInput) -> "AttrFilter": - return self._logical_op("and", other, reverse=True) - - def __or__(self, other: FilterInput) -> "AttrFilter": - return self._logical_op("or", other) +def _cast_membership(values: MembershipExprInput) -> list: + if isinstance(values, np.ndarray): + return values.tolist() + return list(values) - def __ror__(self, other: FilterInput) -> "AttrFilter": - return self._logical_op("or", other, reverse=True) - def __xor__(self, other: FilterInput) -> "AttrFilter": - return self._logical_op("xor", other) +def _cast_scalar(value: Any) -> Any: + # numpy scalars are problematic for sqlalchemy; unwrap to Python types + if isinstance(value, np.ndarray): + return value.item() + return value - def __rxor__(self, other: FilterInput) -> "AttrFilter": - return self._logical_op("xor", other, reverse=True) - def __invert__(self) -> "AttrFilter": - return AttrFilter("not", [self]) - - # Comparison operators (always return Attr) - def __eq__(self, other: ExprInput) -> "Attr": ... - def __req__(self, other: ExprInput) -> "Attr": ... - def __ne__(self, other: ExprInput) -> "Attr": ... - def __rne__(self, other: ExprInput) -> "Attr": ... - def __lt__(self, other: ExprInput) -> "Attr": ... - def __rlt__(self, other: ExprInput) -> "Attr": ... - def __le__(self, other: ExprInput) -> "Attr": ... - def __rle__(self, other: ExprInput) -> "Attr": ... - def __gt__(self, other: ExprInput) -> "Attr": ... - def __rgt__(self, other: ExprInput) -> "Attr": ... - def __ge__(self, other: ExprInput) -> "Attr": ... - def __rge__(self, other: ExprInput) -> "Attr": ... +def _filter_repr(node: _FilterNode) -> str: + if isinstance(node, _FilterLeaf): + return f"{node.kind.__name__}({node.column}) {_OPS_MATH_SYMBOLS[node.op]} {node.other}" + if node.op == "not": + return f"~{_filter_repr(node.operands[0])}" + sep = f" {_FILTER_OP_SYMBOLS[node.op]} " + return "(" + sep.join(_filter_repr(o) for o in node.operands) + ")" class Attr: @@ -260,62 +150,107 @@ class Attr: """ expr: Expr + _filter: _FilterNode | None def __init__(self, value: ExprInput) -> None: - self._inf_exprs = [] # expressions multiplied by +inf - self._neg_inf_exprs = [] # expressions multiplied by -inf + self._inf_exprs: list[Attr] = [] # expressions multiplied by +inf + self._neg_inf_exprs: list[Attr] = [] # expressions multiplied by -inf + self._filter = None if isinstance(value, str): self.expr = pl.col(value) elif isinstance(value, Attr): self.expr = value.expr - # Copy infinity tracking from the other AttrExpr + # Copy infinity tracking; intentionally do NOT copy `_filter` — wrapping + # an Attr in another Attr is for rebinding/aliasing, not duplicating + # filter identity (operators set `_filter` on the new instance). self._inf_exprs = value.inf_exprs self._neg_inf_exprs = value.neg_inf_exprs - elif isinstance(value, AttrComparison): - attr = value.to_attr() - self.expr = attr.expr - self._inf_exprs = attr.inf_exprs - self._neg_inf_exprs = attr.neg_inf_exprs elif isinstance(value, Expr): self.expr = value else: self.expr = pl.lit(value) + @classmethod + def _leaf( + cls, + column: str, + op: Callable, + other: Any, + kind: type["Attr"] | None = None, + ) -> "Attr": + """Construct an `Attr` representing a single leaf filter `column op other`. + + Provided for tests and internal helpers that need to build filter nodes + without going through Python's operator dispatch. + """ + is_membership = _is_membership_expr_input(other) + if is_membership and op is not _is_in_op: + raise ValueError( + f"Membership values can only be used with the 'is_in' method. Found '{_OPS_MATH_SYMBOLS[op]}'." + ) + if not is_membership and op is _is_in_op: + raise ValueError( + f"Cannot use 'is_in' method with non-membership values. Found '{other}' of type {type(other)}." + ) + leaf_kind = kind if kind is not None else cls + other_cast = _cast_membership(other) if is_membership else _cast_scalar(other) + + if is_membership: + expr = pl.col(column).is_in(other_cast) + else: + expr = op(pl.col(column), other_cast) + + result = leaf_kind(expr) + result._filter = _FilterLeaf(column=column, op=op, other=other_cast, kind=leaf_kind) + return result + def _wrap(self, expr: ExprInput) -> Union["Attr", Any]: if isinstance(expr, Expr): - result = Attr(expr) + result = type(self)(expr) # Propagate infinity tracking result._inf_exprs = self._inf_exprs.copy() result._neg_inf_exprs = self._neg_inf_exprs.copy() return result return expr - def _delegate_operator(self, other: ExprInput, op: Callable[[Expr, Expr], Expr], reverse: bool = False) -> "Attr": + def _result_kind(self, other: "ExprInput") -> type["Attr"]: + """Pick the result class for a binary op so NodeAttr/EdgeAttr is preserved. + + The base `Attr` defers to a more specific operand; two specific kinds + must match — mixing `NodeAttr` and `EdgeAttr` in one expression raises + because they target different graph tables. """ - Delegate the operator to the expression. + self_kind = type(self) + if not isinstance(other, Attr): + return self_kind + other_kind = type(other) + if self_kind is Attr: + return other_kind + if other_kind is Attr or other_kind is self_kind: + return self_kind + raise ValueError( + f"Cannot combine {self_kind.__name__} and {other_kind.__name__} " + "in a single expression — they target different graph tables." + ) - Parameters - ---------- - other : ExprInput - The other expression to delegate the operator to. - op : Callable[[Expr, Expr], Expr] - The operator to delegate. - reverse : bool, optional - Whether the operator is reversed. + def _delegate_operator(self, other: ExprInput, op: Callable[[Expr, Expr], Expr], reverse: bool = False) -> "Attr": + """ + Delegate a binary numeric/bitwise operator to the polars expression. - Returns - ------- - Attr - The result of the operator. + Arithmetic and pure-bitwise operations always clear `_filter` (the result + is no longer a filter-shaped Attr), so callers that need to combine + filter compounds must go through `_delegate_boolean_operator` instead. """ + cls = self._result_kind(other) + # Special handling for multiplication with infinity if op == operator.mul: # Check if we're multiplying with infinity scalar # In both reverse and non-reverse cases, 'other' is the infinity value # and 'self' is the AttrExpr we want to track if isinstance(other, int | float) and math.isinf(other): - result = Attr(pl.lit(0)) # Clean expression is zero (infinity term removed) + result = cls(pl.lit(0)) # Clean expression is zero (infinity term removed) # Copy existing infinity tracking result._inf_exprs = self._inf_exprs.copy() @@ -332,7 +267,7 @@ def _delegate_operator(self, other: ExprInput, op: Callable[[Expr, Expr], Expr], # Regular operation - no infinity involved left = Attr(other).expr if reverse else self.expr right = self.expr if reverse else Attr(other).expr - result = Attr(op(left, right)) + result = cls(op(left, right)) # Combine infinity tracking from both operands if isinstance(other, Attr): @@ -354,47 +289,80 @@ def _delegate_operator(self, other: ExprInput, op: Callable[[Expr, Expr], Expr], return result - def _delegate_comparison_operator( - self, - other: ExprInput, - op: Callable, - reverse: bool = False, - ) -> "AttrComparison | Attr": + def _delegate_comparison_operator(self, other: ExprInput, op: Callable) -> "Attr": """ - Simplified version of `_delegate_operator` for comparison operators. - [AttrComparison][tracksdata.attrs.AttrComparison] has a limited scope and - it's mainly used for filtering. - If creating an [AttrComparison][tracksdata.attrs.AttrComparison] object is - not possible, it will return an [Attr][tracksdata.attrs.Attr] object. + Build a leaf-filter `Attr` for `self other` when possible. + + If `other` is itself an `Attr`, the result is a non-filter Attr that + evaluates as a polars boolean expression. If `self` has infinity + tracking, comparison is rejected as semantically meaningless. + Multi-column / literal LHS also falls back to a non-filter result — + such filters can't be pushed down to SQL and must be evaluated by + polars only. + """ + if self.has_inf(): + raise ValueError("Comparison operators are not supported for expressions with infinity.") - Parameters - ---------- - other : ExprInput - The other expression to delegate the operator to. - op : Callable - The operator to delegate. - reverse : bool, optional - Whether the operator is reversed. + if isinstance(other, Attr): + return self._delegate_operator(other, op) - Returns - ------- - AttrComparison | Attr - The result of the operator. + columns = self.expr_columns + if len(columns) != 1: + # Can't form a leaf — fall back to a non-filter Attr. + return self._delegate_operator(other, op) + + other_cast = _cast_scalar(other) + expr = op(self.expr, other_cast) + result = type(self)(expr) + result._filter = _FilterLeaf(column=columns[0], op=op, other=other_cast, kind=type(self)) + return result + + def _delegate_boolean_operator(self, other: "ExprInput", op_name: str, reverse: bool = False) -> "Attr": + """ + Combine two `Attr`s with a boolean op (`& | ^`). + + If both have `_filter` set, build a compound filter (auto-flattening + nested same-op compounds). If neither has `_filter`, fall through to + plain bitwise polars evaluation. Mixing a filter-shaped Attr with a + non-filter operand raises: implicit pushdown loss is too easy to miss. """ + op_func = _BOOLEAN_OP_FUNCS[op_name] + self_has = self._filter is not None + other_has = isinstance(other, Attr) and other._filter is not None + + if self_has != other_has: + symbol = _FILTER_OP_SYMBOLS[op_name] + raise TypeError( + f"Cannot apply '{symbol}' between a filter-shaped Attr and a non-filter operand. " + "Both operands must be filter-shaped (built from comparisons) or both non-filter." + ) + + if not self_has: + # Neither has filter structure — pure bitwise op, no `_filter` on result. + return self._delegate_operator(other, op_func, reverse=reverse) + + # Both have filter — combine into a compound, auto-flattening associative ops. + cls = self._result_kind(other) + first, second = (other, self) if reverse else (self, other) + operands: list[_FilterNode] = [] + for op_attr in (first, second): + f = op_attr._filter + if isinstance(f, _FilterCompound) and f.op == op_name: + operands.extend(f.operands) + else: + operands.append(f) + if reverse: - lhs = Attr(other) - rhs = self + expr = op_func(other.expr, self.expr) else: - lhs = self - rhs = other + expr = op_func(self.expr, other.expr) - if isinstance(other, Attr): - return self._delegate_operator(other, op, reverse=False) - - return AttrComparison(lhs, op, rhs) + result = cls(expr) + result._filter = _FilterCompound(op_name, tuple(operands)) + return result def alias(self, name: str) -> "Attr": - result = Attr(self.expr.alias(name)) + result = type(self)(self.expr.alias(name)) result._inf_exprs = self._inf_exprs.copy() result._neg_inf_exprs = self._neg_inf_exprs.copy() return result @@ -437,7 +405,7 @@ def expr_columns(self) -> list[str]: @property def inf_columns(self) -> list[str]: """Get the names of columns multiplied by positive infinity.""" - columns = [] + columns: list[str] = [] for attr_expr in self._inf_exprs: columns.extend(attr_expr.columns) return list(dict.fromkeys(columns)) @@ -445,7 +413,7 @@ def inf_columns(self) -> list[str]: @property def neg_inf_columns(self) -> list[str]: """Get the names of columns multiplied by negative infinity.""" - columns = [] + columns: list[str] = [] for attr_expr in self._neg_inf_exprs: columns.extend(attr_expr.columns) return list(dict.fromkeys(columns)) @@ -453,53 +421,65 @@ def neg_inf_columns(self) -> list[str]: def has_inf(self) -> bool: """ Check if any column in the expression is multiplied by infinity or negative infinity. - - Returns - ------- - bool - True if any column is multiplied by infinity, False otherwise. """ return self.has_pos_inf() or self.has_neg_inf() def has_pos_inf(self) -> bool: - """ - Check if any column in the expression is multiplied by positive infinity. - """ return len(self._inf_exprs) > 0 def has_neg_inf(self) -> bool: - """ - Check if any column in the expression is multiplied by negative infinity. - """ return len(self._neg_inf_exprs) > 0 - def is_in(self, values: MembershipExprInput) -> "AttrComparison": + def is_in(self, values: MembershipExprInput) -> "Attr": """ - Create a membership comparison between the attribute and a collection of literals. + Create a membership filter `self in values`. + + Returns a filter-shaped `Attr` suitable for `graph.filter()` and for + composition with `&`, `|`, `^`, `~`. Parameters ---------- - values : Iterable[Scalar] | Sequence[Scalar] | np.ndarray | Series + values : Iterable[Scalar] | Sequence[Scalar] | np.ndarray Values the attribute should belong to. - - Returns - ------- - AttrComparison - A comparison suitable for filtering across all graph backends. """ - return AttrComparison(self, _is_in_op, values) + if not _is_membership_expr_input(values): + raise ValueError( + f"Cannot use 'is_in' method with non-membership values. Found '{values}' of type {type(values)}." + ) + if self.has_inf(): + raise ValueError("Comparison operators are not supported for expressions with infinity.") + columns = self.expr_columns + if len(columns) != 1: + raise ValueError(f"'is_in' is only supported for single-column expressions. Found columns {columns}.") + values_cast = _cast_membership(values) + expr = self.expr.is_in(values_cast) + result = type(self)(expr) + result._filter = _FilterLeaf(column=columns[0], op=_is_in_op, other=values_cast, kind=type(self)) + return result def __invert__(self) -> "Attr": - return Attr(~self.expr) + result = type(self)(~self.expr) + result._inf_exprs = self._inf_exprs.copy() + result._neg_inf_exprs = self._neg_inf_exprs.copy() + if self._filter is not None: + result._filter = _FilterCompound("not", (self._filter,)) + return result def __neg__(self) -> "Attr": - return Attr(-self.expr) + result = type(self)(-self.expr) + # `-(x * inf)` is `x * -inf`: swap positive and negative trackers. + result._inf_exprs = self._neg_inf_exprs.copy() + result._neg_inf_exprs = self._inf_exprs.copy() + return result def __pos__(self) -> "Attr": - return Attr(+self.expr) + return self def __abs__(self) -> "Attr": - return Attr(abs(self.expr)) + result = type(self)(abs(self.expr)) + result._inf_exprs = self._inf_exprs.copy() + result._neg_inf_exprs = self._neg_inf_exprs.copy() + return result def __getattr__(self, attr: str) -> Any: # Don't delegate our internal attributes to the expr @@ -518,9 +498,13 @@ def _wrapped(*args, **kwargs): return expr_attr def __repr__(self) -> str: + if self._filter is not None: + return _filter_repr(self._filter) + # Non-filter Attrs always render as `Attr()` regardless of subclass — + # the kind is meaningful for filter dispatch, not for arbitrary expressions. return f"Attr({self.expr})" - # Binary operators + # Binary arithmetic operators (auto-generated by `_setup_ops`) def __add__(self, other: ExprInput) -> "Attr": ... def __sub__(self, other: ExprInput) -> "Attr": ... def __mul__(self, other: ExprInput) -> "Attr": ... @@ -528,11 +512,13 @@ def __truediv__(self, other: ExprInput) -> "Attr": ... def __floordiv__(self, other: ExprInput) -> "Attr": ... def __mod__(self, other: ExprInput) -> "Attr": ... def __pow__(self, other: ExprInput) -> "Attr": ... + + # Boolean / bitwise operators (auto-generated by `_setup_ops`) def __and__(self, other: ExprInput) -> "Attr": ... def __or__(self, other: ExprInput) -> "Attr": ... def __xor__(self, other: ExprInput) -> "Attr": ... - # Reverse operators + # Reverse arithmetic operators def __radd__(self, other: Scalar) -> "Attr": ... def __rsub__(self, other: Scalar) -> "Attr": ... def __rmul__(self, other: Scalar) -> "Attr": ... @@ -544,106 +530,54 @@ def __rand__(self, other: Scalar) -> "Attr": ... def __ror__(self, other: Scalar) -> "Attr": ... def __rxor__(self, other: Scalar) -> "Attr": ... - # Comparison operators with overloads + # Comparison operators with overloads (auto-generated by `_setup_ops`). + # No reflected `__r{eq,ne,lt,le,gt,ge}__` — Python uses the symmetric / opposite + # operator on the swapped operand instead, so those dunders are never invoked. @overload def __eq__(self, other: "Attr") -> "Attr": ... @overload - def __eq__(self, other: Scalar) -> "AttrComparison": ... - def __eq__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - @overload - def __req__(self, other: "Attr") -> "Attr": ... - @overload - def __req__(self, other: Scalar) -> "AttrComparison": ... - def __req__(self, other: ExprInput) -> "Attr | AttrComparison": ... + def __eq__(self, other: Scalar) -> "Attr": ... + def __eq__(self, other: ExprInput) -> "Attr": ... @overload def __ne__(self, other: "Attr") -> "Attr": ... @overload - def __ne__(self, other: Scalar) -> "AttrComparison": ... - def __ne__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - @overload - def __rne__(self, other: "Attr") -> "Attr": ... - @overload - def __rne__(self, other: Scalar) -> "AttrComparison": ... - def __rne__(self, other: ExprInput) -> "Attr | AttrComparison": ... + def __ne__(self, other: Scalar) -> "Attr": ... + def __ne__(self, other: ExprInput) -> "Attr": ... @overload def __lt__(self, other: "Attr") -> "Attr": ... @overload - def __lt__(self, other: Scalar) -> "AttrComparison": ... - def __lt__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - @overload - def __rlt__(self, other: "Attr") -> "Attr": ... - @overload - def __rlt__(self, other: Scalar) -> "AttrComparison": ... - def __rlt__(self, other: ExprInput) -> "Attr | AttrComparison": ... + def __lt__(self, other: Scalar) -> "Attr": ... + def __lt__(self, other: ExprInput) -> "Attr": ... @overload def __le__(self, other: "Attr") -> "Attr": ... @overload - def __le__(self, other: Scalar) -> "AttrComparison": ... - def __le__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - @overload - def __rle__(self, other: "Attr") -> "Attr": ... - @overload - def __rle__(self, other: Scalar) -> "AttrComparison": ... - def __rle__(self, other: ExprInput) -> "Attr | AttrComparison": ... + def __le__(self, other: Scalar) -> "Attr": ... + def __le__(self, other: ExprInput) -> "Attr": ... @overload def __gt__(self, other: "Attr") -> "Attr": ... @overload - def __gt__(self, other: Scalar) -> "AttrComparison": ... - def __gt__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - @overload - def __rgt__(self, other: "Attr") -> "Attr": ... - @overload - def __rgt__(self, other: Scalar) -> "AttrComparison": ... - def __rgt__(self, other: ExprInput) -> "Attr | AttrComparison": ... + def __gt__(self, other: Scalar) -> "Attr": ... + def __gt__(self, other: ExprInput) -> "Attr": ... @overload def __ge__(self, other: "Attr") -> "Attr": ... @overload - def __ge__(self, other: Scalar) -> "AttrComparison": ... - def __ge__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - @overload - def __rge__(self, other: "Attr") -> "Attr": ... - @overload - def __rge__(self, other: Scalar) -> "AttrComparison": ... - def __rge__(self, other: ExprInput) -> "Attr | AttrComparison": ... - - -# Auto-generate operator methods using functools.partialmethod -def _add_operator( - cls: type[Attr] | type[AttrComparison], - name: str, - op: Callable, - reverse: bool = False, -) -> None: - method = functools.partialmethod(cls._delegate_operator, op=op, reverse=reverse) - setattr(cls, name, method) - - -def _add_comparison_operator( - name: str, - op: Callable, - reverse: bool = False, -) -> None: - method = functools.partialmethod(Attr._delegate_comparison_operator, op=op, reverse=reverse) - setattr(Attr, name, method) + def __ge__(self, other: Scalar) -> "Attr": ... + def __ge__(self, other: ExprInput) -> "Attr": ... def _setup_ops() -> None: + """Auto-generate dunder methods on `Attr` from operator tables. + + Arithmetic ops use `_delegate_operator` (clears `_filter`); comparison ops + use `_delegate_comparison_operator` (sets `_filter` leaf when possible); + boolean ops use `_delegate_boolean_operator` (builds compounds). """ - Setup the operator methods for the AttrExpr class. - """ - # Arithmetic operators: generated for both Attr and AttrComparison. - bin_ops = { + arith_ops = { "add": operator.add, "sub": operator.sub, "mul": operator.mul, @@ -652,16 +586,7 @@ def _setup_ops() -> None: "mod": operator.mod, "pow": operator.pow, } - - # Logical operators: generated only for Attr (bitwise on the polars expr). - # AttrComparison defines its own `& | ^ ~` in the class body to build - # AttrFilter compounds, so they are intentionally excluded here. - logical_ops = { - "and": operator.and_, - "or": operator.or_, - "xor": operator.xor, - } - + bool_ops = ("and", "or", "xor") comp_ops = { "eq": operator.eq, "ne": operator.ne, @@ -671,21 +596,20 @@ def _setup_ops() -> None: "ge": operator.ge, } - for op_name, op_func in (bin_ops | logical_ops).items(): - _add_operator(Attr, f"__{op_name}__", op_func, reverse=False) - _add_operator(Attr, f"__r{op_name}__", op_func, reverse=True) + for name, func in arith_ops.items(): + setattr(Attr, f"__{name}__", functools.partialmethod(Attr._delegate_operator, op=func, reverse=False)) + setattr(Attr, f"__r{name}__", functools.partialmethod(Attr._delegate_operator, op=func, reverse=True)) - for op_name, op_func in bin_ops.items(): - _add_operator(AttrComparison, f"__{op_name}__", op_func, reverse=False) - _add_operator(AttrComparison, f"__r{op_name}__", op_func, reverse=True) - - for op_name, op_func in comp_ops.items(): - _add_comparison_operator(f"__{op_name}__", op_func, reverse=False) - _add_comparison_operator(f"__r{op_name}__", op_func, reverse=True) + for name in bool_ops: + setattr( + Attr, f"__{name}__", functools.partialmethod(Attr._delegate_boolean_operator, op_name=name, reverse=False) + ) + setattr( + Attr, f"__r{name}__", functools.partialmethod(Attr._delegate_boolean_operator, op_name=name, reverse=True) + ) - # attrr_comparision uses normal delegate_operator - _add_operator(AttrComparison, f"__{op_name}__", op_func, reverse=False) - _add_operator(AttrComparison, f"__r{op_name}__", op_func, reverse=True) + for name, func in comp_ops.items(): + setattr(Attr, f"__{name}__", functools.partialmethod(Attr._delegate_comparison_operator, op=func)) _setup_ops() @@ -713,218 +637,91 @@ class EdgeAttr(Attr): """ -_FILTER_OP_SYMBOLS = {"and": "&", "or": "|", "xor": "^", "not": "~"} - - -class AttrFilter: - """ - A compound boolean combination of [AttrComparison][tracksdata.attrs.AttrComparison] - (or nested `AttrFilter`) operands, used to express OR / XOR / AND / NOT - relationships when filtering nodes or edges in a graph. - - Use Python's bitwise operators on `AttrComparison` (or `AttrFilter`) - instances to build compounds: - - ```python - graph.filter((NodeAttr("t") == 1) | (NodeAttr("t") == 2)) - graph.filter(~(NodeAttr("t") == 0)) - graph.filter((EdgeAttr("w") > 0.5) ^ (EdgeAttr("w") < -0.5)) - ``` - - All leaves of a single `AttrFilter` must reference attributes of the same - kind (either all [NodeAttr][tracksdata.attrs.NodeAttr] or all - [EdgeAttr][tracksdata.attrs.EdgeAttr]). Mixing node and edge attributes - inside one compound is not supported because it would require joining the - node and edge tables in a way that conflicts with the existing AND-based - filter semantics. Top-level node/edge filters can still be combined via - positional arguments to `graph.filter()` (implicit AND). +def _filter_attr_kind(node: _FilterNode) -> type[Attr]: + """Return the leaf-attribute kind (`NodeAttr` / `EdgeAttr` / `Attr`) of a filter node. - Parameters - ---------- - op : str - Logical operator, one of `"and"`, `"or"`, `"xor"`, `"not"`. - operands : Sequence[AttrComparison | AttrFilter] - Operands. `"not"` requires exactly one operand; the others require at - least two. + Raises `ValueError` if the filter mixes `NodeAttr` and `EdgeAttr` leaves. + The base `Attr` kind defers to any more specific kind present. """ - - def __init__(self, op: str, operands: Sequence[FilterInput]) -> None: - if op not in _FILTER_LOGICAL_OPS: - raise ValueError(f"Unknown logical operator '{op}'. Expected one of {_FILTER_LOGICAL_OPS}.") - operands = list(operands) - for o in operands: - if not isinstance(o, AttrComparison | AttrFilter): - raise TypeError(f"AttrFilter operands must be AttrComparison or AttrFilter, got {type(o).__name__}.") - if op == "not": - if len(operands) != 1: - raise ValueError("'not' filter requires exactly one operand.") - else: - if len(operands) < 2: - raise ValueError(f"'{op}' filter requires at least two operands.") - self.op = op - self.operands = operands - - def __and__(self, other: FilterInput) -> "AttrFilter": - return AttrFilter("and", [self, other]) - - def __rand__(self, other: FilterInput) -> "AttrFilter": - return AttrFilter("and", [other, self]) - - def __or__(self, other: FilterInput) -> "AttrFilter": - return AttrFilter("or", [self, other]) - - def __ror__(self, other: FilterInput) -> "AttrFilter": - return AttrFilter("or", [other, self]) - - def __xor__(self, other: FilterInput) -> "AttrFilter": - return AttrFilter("xor", [self, other]) - - def __rxor__(self, other: FilterInput) -> "AttrFilter": - return AttrFilter("xor", [other, self]) - - def __invert__(self) -> "AttrFilter": - return AttrFilter("not", [self]) - - def leaves(self) -> list["AttrComparison"]: - """Flatten the filter tree to its leaf comparisons.""" - out: list[AttrComparison] = [] - for o in self.operands: - if isinstance(o, AttrFilter): - out.extend(o.leaves()) - else: - out.append(o) - return out - - @property - def columns(self) -> list[str]: - return list(dict.fromkeys(leaf.column for leaf in self.leaves())) - - def __repr__(self) -> str: - if self.op == "not": - return f"~{self.operands[0]!r}" - sep = f" {_FILTER_OP_SYMBOLS[self.op]} " - return "(" + sep.join(repr(o) for o in self.operands) + ")" - - -def _filter_attr_kind(f: FilterInput) -> type[Attr]: - """Return the leaf-attribute kind (NodeAttr / EdgeAttr) of a filter. - - Raises ValueError if the filter mixes node and edge attributes. - """ - if isinstance(f, AttrComparison): - if isinstance(f.attr, NodeAttr): - return NodeAttr - if isinstance(f.attr, EdgeAttr): - return EdgeAttr - raise ValueError(f"Expected comparisons of 'NodeAttr' or 'EdgeAttr' objects, got {type(f.attr)}") - - kinds = {_filter_attr_kind(o) for o in f.operands} - if len(kinds) > 1: + kinds = {leaf.kind for leaf in walk_leaves(node)} + specific = {k for k in kinds if k is not Attr} + if len(specific) > 1: raise ValueError( - "A single AttrFilter compound cannot mix NodeAttr and EdgeAttr comparisons. " + "A single compound filter cannot mix NodeAttr and EdgeAttr comparisons. " "Combine node and edge filters via separate positional arguments to graph.filter()." ) - return kinds.pop() + return specific.pop() if specific else Attr def split_attr_comps( - attr_comps: Sequence[FilterInput], -) -> tuple[list[FilterInput], list[FilterInput]]: + attr_comps: Sequence["Attr"], +) -> tuple[list["Attr"], list["Attr"]]: """ - Split a list of attribute comparisons (or compound filters) into node and - edge groups based on the kind of their leaf comparisons. + Split a list of filter-shaped Attrs into node and edge groups based on the + kind of their leaf comparisons. Parameters ---------- - attr_comps : Sequence[AttrComparison | AttrFilter] - The attribute comparisons or compound filters to split. + attr_comps : Sequence[Attr] + The filter-shaped Attrs to split. Each must have `_filter` set (i.e. + be built from comparisons + boolean ops). Returns ------- - tuple[list[AttrComparison | AttrFilter], list[AttrComparison | AttrFilter]] + tuple[list[Attr], list[Attr]] A tuple of lists of node and edge filters. """ - node_attr_comps: list[FilterInput] = [] - edge_attr_comps: list[FilterInput] = [] + node_attr_comps: list[Attr] = [] + edge_attr_comps: list[Attr] = [] for attr_comp in attr_comps: - kind = _filter_attr_kind(attr_comp) + if not isinstance(attr_comp, Attr) or attr_comp._filter is None: + raise ValueError(f"Expected a filter-shaped Attr (built from comparisons), got {type(attr_comp).__name__}.") + kind = _filter_attr_kind(attr_comp._filter) if kind is NodeAttr: node_attr_comps.append(attr_comp) - else: + elif kind is EdgeAttr: edge_attr_comps.append(attr_comp) + else: + raise ValueError(f"Expected comparisons of 'NodeAttr' or 'EdgeAttr' objects, got {kind.__name__}.") return node_attr_comps, edge_attr_comps -def attr_comps_to_strs(attr_comps: Sequence[FilterInput]) -> list[str]: +def attr_comps_to_strs(attr_comps: Sequence["Attr"]) -> list[str]: """ - Convert a list of attribute comparisons (or compound filters) to a list of - column names involved in them. - - Parameters - ---------- - attr_comps : Sequence[AttrComparison | AttrFilter] - The filters to extract column names from. - - Returns - ------- - list[str] - The column names referenced by the filters, deduplicated while - preserving order. + Convert a list of filter-shaped Attrs to the list of column names they + reference, deduplicated while preserving order. """ out: list[str] = [] for attr_comp in attr_comps: - if isinstance(attr_comp, AttrFilter): - out.extend(attr_comp.columns) - else: - out.append(str(attr_comp.column)) + if attr_comp._filter is None: + continue + for leaf in walk_leaves(attr_comp._filter): + out.append(leaf.column) return list(dict.fromkeys(out)) -def _polars_filter_expr(f: FilterInput, df: pl.DataFrame) -> pl.Expr | pl.Series: - """Translate a single AttrComparison/AttrFilter to a polars expression.""" - if isinstance(f, AttrComparison): - return f.op(df[str(f.column)], f.other) - - if f.op == "not": - return ~_polars_filter_expr(f.operands[0], df) - - child_exprs = [_polars_filter_expr(o, df) for o in f.operands] - if f.op == "and": - return functools.reduce(operator.and_, child_exprs) - if f.op == "or": - return functools.reduce(operator.or_, child_exprs) - # xor - return functools.reduce(operator.xor, child_exprs) - - def polars_reduce_attr_comps( df: pl.DataFrame, - attr_comps: Sequence[FilterInput], + attr_comps: Sequence["Attr"], reduce_op: Callable[[Expr, Expr], Expr], ) -> pl.Expr: """ - Reduce a list of attribute comparisons (or compound filters) into a single - polars expression, combined with `reduce_op` at the top level (AND-ed by - default in callers). + Reduce a list of filter-shaped Attrs into a single polars expression, + combined with `reduce_op` at the top level (AND-ed by default in callers). Parameters ---------- df : pl.DataFrame - The dataframe to reduce the attribute comparisons on. - attr_comps : Sequence[AttrComparison | AttrFilter] + Present for API compatibility; unused — each Attr already carries a + fully-formed polars expression in `attr.expr`. + attr_comps : Sequence[Attr] The filters to reduce. reduce_op : Callable[[Expr, Expr], Expr] The operation to reduce the top-level filters with. - - Returns - ------- - pl.Expr - The reduced polars expression. """ if not attr_comps: - # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") - - return pl.reduce(reduce_op, [_polars_filter_expr(f, df) for f in attr_comps]) + del df # unused; kept for backward-compatible signature + return pl.reduce(reduce_op, [a.expr for a in attr_comps]) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index ebed4617..1eb405bb 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -628,12 +628,12 @@ def filter( Creates a filter object that can be used to create a subgraph or query ids and attributes. Multiple positional filters are implicitly AND-ed together. Each filter - can itself be a compound `AttrFilter` built from `AttrComparison`s using - `&`, `|`, `^`, `~` (e.g. `(NodeAttr("t") == 1) | (NodeAttr("t") == 2)`). + is a filter-shaped `Attr` (built from comparisons), optionally combined + via `&`, `|`, `^`, `~` (e.g. `(NodeAttr("t") == 1) | (NodeAttr("t") == 2)`). Parameters ---------- - *attr_filters : AttrComparison | AttrFilter + *attr_filters : Attr The attribute filters to apply. Positional args are AND-ed. node_ids : Sequence[int] | None The IDs of the nodes to include in the filter. diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 9387fe5d..d7ffc179 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -7,7 +7,8 @@ import polars as pl import rustworkx as rx -from tracksdata.attrs import AttrComparison, FilterInput, split_attr_comps +from tracksdata._filter import _FilterLeaf, _FilterNode +from tracksdata.attrs import Attr, FilterInput, split_attr_comps from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.graph._mapped_graph_mixin import MappedGraphMixin @@ -27,31 +28,28 @@ def _pop_time_eq( ) -> tuple[list[FilterInput], int | None]: """ Pop the top-level time equality filter from a list of attribute filters. - Compound (AttrFilter) entries are left untouched even if they reference - the time column. If multiple time equality filters are found at the top - level, an error is raised. + Compound entries are left untouched even if they reference the time + column. If multiple time equality filters are found at the top level, an + error is raised. Parameters ---------- - attrs : Sequence[AttrComparison | AttrFilter] + attrs : Sequence[Attr] The attribute filters to pop the time equality filter from. Returns ------- - tuple[list[AttrComparison | AttrFilter], int | None] + tuple[list[Attr], int | None] The attribute filters without the time equality filter and the time value. """ out_attrs: list[FilterInput] = [] time = None for attr_comp in attrs: - if ( - isinstance(attr_comp, AttrComparison) - and str(attr_comp.column) == DEFAULT_ATTR_KEYS.T - and attr_comp.op == operator.eq - ): + leaf = attr_comp._filter if isinstance(attr_comp, Attr) else None + if isinstance(leaf, _FilterLeaf) and str(leaf.column) == DEFAULT_ATTR_KEYS.T and leaf.op == operator.eq: if time is not None: raise ValueError(f"Multiple '{DEFAULT_ATTR_KEYS.T}' equality filters are not allowed\n {attrs}") - time = int(attr_comp.other) + time = int(leaf.other) else: out_attrs.append(attr_comp) @@ -88,25 +86,36 @@ def _list_to_pl_series(key: str, values: list[Any], schema: AttrSchema) -> pl.Se return s -def _eval_filter( - f: FilterInput, +def _eval_filter_node( + node: _FilterNode, attrs: dict[str, Any], schema: dict[str, AttrSchema], ) -> bool: - """Evaluate a single comparison or compound filter against an attrs dict.""" - if isinstance(f, AttrComparison): - value = attrs.get(f.column, schema[f.column].default_value) - return bool(f.op(value, f.other)) - - if f.op == "and": - return all(_eval_filter(o, attrs, schema) for o in f.operands) - if f.op == "or": - return any(_eval_filter(o, attrs, schema) for o in f.operands) - if f.op == "xor": - truthy_count = sum(1 for o in f.operands if _eval_filter(o, attrs, schema)) + """Evaluate a `_FilterNode` AST against an attrs dict.""" + if isinstance(node, _FilterLeaf): + value = attrs.get(node.column, schema[node.column].default_value) + return bool(node.op(value, node.other)) + + if node.op == "and": + return all(_eval_filter_node(o, attrs, schema) for o in node.operands) + if node.op == "or": + return any(_eval_filter_node(o, attrs, schema) for o in node.operands) + if node.op == "xor": + truthy_count = sum(1 for o in node.operands if _eval_filter_node(o, attrs, schema)) return truthy_count % 2 == 1 # not - return not _eval_filter(f.operands[0], attrs, schema) + return not _eval_filter_node(node.operands[0], attrs, schema) + + +def _eval_filter( + f: FilterInput, + attrs: dict[str, Any], + schema: dict[str, AttrSchema], +) -> bool: + """Evaluate a filter-shaped `Attr` against an attrs dict.""" + if not isinstance(f, Attr) or f._filter is None: + raise ValueError(f"Expected a filter-shaped Attr (built from comparisons), got {type(f).__name__}.") + return _eval_filter_node(f._filter, attrs, schema) def _create_filter_func( @@ -922,7 +931,7 @@ def _filter_nodes_by_attrs( Parameters ---------- - *attrs : AttrComparison | AttrFilter + *attrs : Attr The attributes to filter by, for example: node_ids : list[int] | None The IDs of the nodes to include in the filter. @@ -2052,7 +2061,7 @@ def remove_node(self, node_id: int) -> None: def filter( self, - *attr_filters: AttrComparison, + *attr_filters: FilterInput, node_ids: Sequence[int] | None = None, include_targets: bool = False, include_sources: bool = False, diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 01d748bf..32284419 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -16,7 +16,8 @@ from sqlalchemy.orm.query import Query from sqlalchemy.sql.type_api import TypeEngine -from tracksdata.attrs import AttrComparison, FilterInput, split_attr_comps +from tracksdata._filter import _FilterLeaf, _FilterNode +from tracksdata.attrs import Attr, FilterInput, split_attr_comps from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.graph.filters._base_filter import BaseFilter @@ -59,18 +60,18 @@ def _data_numpy_to_native(data: dict[str, Any]) -> None: data[k] = v.item() -def _to_sql_clause(f: FilterInput, table: type[DeclarativeBase]) -> Any: - """Translate an AttrComparison or AttrFilter into a SQLAlchemy clause.""" - if isinstance(f, AttrComparison): - return f.op(getattr(table, str(f.column)), f.other) +def _to_sql_clause_node(node: _FilterNode, table: type[DeclarativeBase]) -> Any: + """Translate a `_FilterNode` AST into a SQLAlchemy clause.""" + if isinstance(node, _FilterLeaf): + return node.op(getattr(table, str(node.column)), node.other) - if f.op == "not": - return sa.not_(_to_sql_clause(f.operands[0], table)) + if node.op == "not": + return sa.not_(_to_sql_clause_node(node.operands[0], table)) - clauses = [_to_sql_clause(o, table) for o in f.operands] - if f.op == "and": + clauses = [_to_sql_clause_node(o, table) for o in node.operands] + if node.op == "and": return sa.and_(*clauses) - if f.op == "or": + if node.op == "or": return sa.or_(*clauses) # xor: reduce pairwise via (a OR b) AND NOT (a AND b) return functools.reduce( @@ -79,6 +80,13 @@ def _to_sql_clause(f: FilterInput, table: type[DeclarativeBase]) -> Any: ) +def _to_sql_clause(f: FilterInput, table: type[DeclarativeBase]) -> Any: + """Translate a filter-shaped `Attr` into a SQLAlchemy clause.""" + if not isinstance(f, Attr) or f._filter is None: + raise ValueError(f"Expected a filter-shaped Attr (built from comparisons), got {type(f).__name__}.") + return _to_sql_clause_node(f._filter, table) + + def _filter_query( query: sa.Select, table: type[DeclarativeBase], @@ -86,8 +94,8 @@ def _filter_query( ) -> sa.Select: """ Filter a query by a list of attribute filters (AND-ed together at the top - level). Each filter may itself be a compound AttrFilter combining - AttrComparisons with OR / AND / XOR / NOT. + level). Each filter may itself be a compound filter combining leaf + comparisons with OR / AND / XOR / NOT. Parameters ---------- @@ -95,8 +103,8 @@ def _filter_query( The query to filter. table : type[DeclarativeBase] The table to filter. - attr_filters : Sequence[AttrComparison | AttrFilter] - The attribute filters to apply. + attr_filters : Sequence[Attr] + The filter-shaped Attrs to apply. Returns ------- diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index dba8c128..cea78d54 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -825,12 +825,11 @@ def test_filter_subgraph_with_or_attr_filter(graph_backend: BaseGraph) -> None: def test_filter_compound_mixed_node_and_edge_raises(graph_backend: BaseGraph) -> None: - """A single compound filter cannot mix node and edge attributes.""" - graph_with_data = create_test_graph(graph_backend, use_subgraph=False) + """A single compound filter cannot mix node and edge attributes — caught at construction.""" + _ = create_test_graph(graph_backend, use_subgraph=False) - bad_filter = (NodeAttr("t") == 1) | (EdgeAttr("weight") > 0.5) - with pytest.raises(ValueError, match="cannot mix NodeAttr and EdgeAttr"): - graph_with_data.filter(bad_filter).node_ids() + with pytest.raises(ValueError, match="Cannot combine NodeAttr and EdgeAttr"): + _ = (NodeAttr("t") == 1) | (EdgeAttr("weight") > 0.5) @parametrize_subgraph_tests From 1107b6ef24b7def89a1c9fdace24504d7f514ca8 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 31 May 2026 17:45:11 -0700 Subject: [PATCH 5/5] rollback --- src/tracksdata/attrs.py | 55 +++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/tracksdata/attrs.py b/src/tracksdata/attrs.py index 53aafee5..c437c49f 100644 --- a/src/tracksdata/attrs.py +++ b/src/tracksdata/attrs.py @@ -785,6 +785,24 @@ def __rxor__(self, other: FilterInput) -> "AttrFilter": def __invert__(self) -> "AttrFilter": return AttrFilter("not", [self]) + def to_attr(self) -> "Attr": + """Translate the compound filter to an `Attr` holding the polars boolean expression. + + Mirrors [AttrComparison.to_attr][tracksdata.attrs.AttrComparison.to_attr] + and folds children polymorphically — both operand types expose `to_attr`, + so no parallel `(AttrComparison | AttrFilter)` walker is needed for + evaluation or column extraction. + """ + if self.op == "not": + return Attr(~self.operands[0].to_attr().expr) + child_exprs = [o.to_attr().expr for o in self.operands] + if self.op == "and": + return Attr(functools.reduce(operator.and_, child_exprs)) + if self.op == "or": + return Attr(functools.reduce(operator.or_, child_exprs)) + # xor + return Attr(functools.reduce(operator.xor, child_exprs)) + def leaves(self) -> list["AttrComparison"]: """Flatten the filter tree to its leaf comparisons.""" out: list[AttrComparison] = [] @@ -797,7 +815,7 @@ def leaves(self) -> list["AttrComparison"]: @property def columns(self) -> list[str]: - return list(dict.fromkeys(leaf.column for leaf in self.leaves())) + return self.to_attr().expr_columns def __repr__(self) -> str: if self.op == "not": @@ -873,30 +891,9 @@ def attr_comps_to_strs(attr_comps: Sequence[FilterInput]) -> list[str]: The column names referenced by the filters, deduplicated while preserving order. """ - out: list[str] = [] - for attr_comp in attr_comps: - if isinstance(attr_comp, AttrFilter): - out.extend(attr_comp.columns) - else: - out.append(str(attr_comp.column)) - return list(dict.fromkeys(out)) - - -def _polars_filter_expr(f: FilterInput, df: pl.DataFrame) -> pl.Expr | pl.Series: - """Translate a single AttrComparison/AttrFilter to a polars expression.""" - if isinstance(f, AttrComparison): - return f.op(df[str(f.column)], f.other) - - if f.op == "not": - return ~_polars_filter_expr(f.operands[0], df) - - child_exprs = [_polars_filter_expr(o, df) for o in f.operands] - if f.op == "and": - return functools.reduce(operator.and_, child_exprs) - if f.op == "or": - return functools.reduce(operator.or_, child_exprs) - # xor - return functools.reduce(operator.xor, child_exprs) + # Both `AttrComparison` (via `__getattr__` → `to_attr().columns`) and + # `AttrFilter` (via its `columns` property) expose `.columns`. + return list(dict.fromkeys(c for ac in attr_comps for c in ac.columns)) def polars_reduce_attr_comps( @@ -912,7 +909,8 @@ def polars_reduce_attr_comps( Parameters ---------- df : pl.DataFrame - The dataframe to reduce the attribute comparisons on. + Unused; kept for backward-compatible signature. Each filter already + produces a fully-formed polars expression via `to_attr`. attr_comps : Sequence[AttrComparison | AttrFilter] The filters to reduce. reduce_op : Callable[[Expr, Expr], Expr] @@ -924,7 +922,6 @@ def polars_reduce_attr_comps( The reduced polars expression. """ if not attr_comps: - # Return True for all rows by using the first column as a reference raise ValueError("No attribute comparisons provided.") - - return pl.reduce(reduce_op, [_polars_filter_expr(f, df) for f in attr_comps]) + del df # unused; preserved in the signature to avoid a breaking change + return pl.reduce(reduce_op, [f.to_attr().expr for f in attr_comps])