From 802e1019b2dfa79e3912d735c5ce11316bb93f15 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 21 May 2026 03:45:04 -0700 Subject: [PATCH 1/3] update parser to emit ElementVariableNode for variable columns and tables --- core/query_rewriter_v2.py | 68 +++++++++++++++++++++++++++++++++- core/rule_generator_v2.py | 72 ++++++------------------------------ core/rule_parser_v2.py | 28 +++++++++++--- tests/test_rule_parser_v2.py | 22 ++++++----- 4 files changed, 114 insertions(+), 76 deletions(-) diff --git a/core/query_rewriter_v2.py b/core/query_rewriter_v2.py index 11fad47..620c759 100644 --- a/core/query_rewriter_v2.py +++ b/core/query_rewriter_v2.py @@ -167,6 +167,34 @@ def _match_node( # --- variable nodes in pattern --- if isinstance(p, ElementVariableNode): + # Qualified column variable: ElementVariableNode with parent_alias that is a variable name + if p.parent_alias is not None and _is_var_name(p.parent_alias, mapping): + if not isinstance(q, ColumnNode): + return False + if q.parent_alias is None: + return False + if not _bind(p.name, q.name, memo): + return False + return _bind(p.parent_alias, q.parent_alias, memo) + # Table variable with variable alias: ElementVariableNode with alias that is a variable name + # e.g. ElementVariableNode("tb1", alias="t1") where "tb1" -> TableNode(name), "t1" -> alias string + # Bind a stripped TableNode (no alias) to p.name so that when p.name appears as a bare + # table variable in the rewrite (e.g. inner FROM ), it materializes as TableNode(name) + # without leaking the alias. The alias is separately captured via p.alias. + if p.alias is not None and _is_var_name(p.alias, mapping): + if not isinstance(q, TableNode): + return False + if not _bind(p.name, TableNode(q.name), memo): + return False + if q.alias is None: + return False + return _bind(p.alias, q.alias, memo) + # Default: whole-node binding. + # JoinNode is a compound structural node (not an atomic value) that should never be + # bound to a bare element variable — it would violate the type contract and cause + # spurious second-pass matches after joins have been introduced. + if isinstance(q, JoinNode): + return False return _bind(p.name, q, memo) if isinstance(p, SetVariableNode): @@ -639,11 +667,15 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: """Convert an element-variable binding into a concrete AST node. Rules: - - If bound to a Node, return it (but strip `.alias` to avoid leaking output aliases - unless the rewrite explicitly carries them). + - If bound to a TableNode, return it directly (table aliases must be preserved + so that qualified column references like e1.col remain valid in the rewrite). + - If bound to any other Node, return it (but strip `.alias` to avoid leaking + output aliases unless the rewrite explicitly carries them). - If bound to scalar identifiers, materialize as ColumnNode/LiteralNode so the formatter can emit SQL. """ + if isinstance(val, TableNode): + return val if isinstance(val, Node): if hasattr(val, "alias") and getattr(val, "alias") is not None: cloned = copy.deepcopy(val) @@ -659,6 +691,38 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: if isinstance(node, ElementVariableNode): val = memo.get(node.name, node) + # If variable has a parent_alias that is a variable name, reconstruct qualified column + # e.g. ElementVariableNode("a1", parent_alias="t1") where "a1" -> "id", "t1" -> "e1" + # produces ColumnNode("id", _parent_alias="e1") + # The node's own alias (if any) is a literal string from the rewrite template; preserve it. + if node.parent_alias is not None and node.parent_alias in memo: + pa_val = memo[node.parent_alias] + col_name = val if isinstance(val, str) else (val.name if isinstance(val, Node) and hasattr(val, 'name') else None) + if col_name is not None: + if isinstance(pa_val, TableNode): + pa_str = pa_val.alias if pa_val.alias is not None else pa_val.name + elif isinstance(pa_val, str): + pa_str = pa_val + else: + pa_str = None + if pa_str is not None: + return ColumnNode(col_name, _alias=node.alias, _parent_alias=pa_str) + # If variable has an alias that is a variable name, reconstruct a TableNode + # e.g. ElementVariableNode("tb1", alias="t1") where "tb1" -> TableNode("employee", ...), "t1" -> "e1" + # produces TableNode("employee", "e1") + if node.alias is not None and node.alias in memo: + alias_val = memo[node.alias] + # val may be a whole TableNode (from binding) or a string + if isinstance(val, TableNode): + table_name = val.name + elif isinstance(val, str): + table_name = val + else: + table_name = None + if table_name is not None: + alias_str = alias_val if isinstance(alias_val, str) else None + if alias_str is not None: + return TableNode(table_name, alias_str) materialized = _materialize_element_binding(val) if materialized is not None: return materialized diff --git a/core/rule_generator_v2.py b/core/rule_generator_v2.py index c57edf1..c5c2c4e 100644 --- a/core/rule_generator_v2.py +++ b/core/rule_generator_v2.py @@ -471,7 +471,7 @@ def _alias_token(name: Optional[str]) -> Optional[str]: tuple(RuleGeneratorV2._recommendation_ast_signature(child, state) for child in node.children), ) if isinstance(node, ElementVariableNode): - return ("EVAR", RuleGeneratorV2._fingerPrint(node.name), _alias_token(node.parent_alias)) + return ("EVAR", f"VAR:{RuleGeneratorV2._fingerPrint(node.name)}", _alias_token(node.parent_alias)) if isinstance(node, SetVariableNode): return ("SVAR", RuleGeneratorV2._fingerPrint(node.name)) if isinstance(node, CompoundQueryNode): @@ -840,15 +840,14 @@ def generalize_branches(rule: Dict[str, object]) -> Dict[str, object]: def columns(pattern_ast: Node, rewrite_ast: Node) -> List[str]: """Return the deterministic, sorted set of un-variablized column names in pattern_ast. - Columns variablized by the generator are now ElementVariableNode instances (not ColumnNode), - so those are automatically excluded. User-written variable placeholders (e.g. ) are - still represented as ColumnNode by RuleParserV2, so is_placeholder_name still filters those out. + Variable columns are represented as ElementVariableNode, so isinstance(node, ColumnNode) + naturally excludes them — only concrete column names are returned. rewrite_ast is accepted but ignored. """ del rewrite_ast # accepted for API compatibility found: Set[str] = set() for node in RuleGeneratorV2._walk(pattern_ast): - if isinstance(node, ColumnNode) and node.name and not is_placeholder_name(node.name): + if isinstance(node, ColumnNode) and node.name: found.add(node.name) # Sort deterministically so generalize_columns is hash-seed independent. return sorted(found) @@ -936,10 +935,8 @@ def tables(pattern_ast: Node, rewrite_ast: Node) -> List[Dict[str, str]]: def _tables_of_ast(ast: Node) -> List[Dict[str, str]]: """Return {"value", "name"} descriptors for every concrete TableNode in ast. - Tables variablized by the generator are now ElementVariableNode instances (not TableNode), - so those are automatically excluded by isinstance(node, TableNode). - User-written table variable placeholders (e.g. ) are still represented as TableNode - by RuleParserV2, so is_placeholder_name still filters those out. + Variable tables are represented as ElementVariableNode, so isinstance(node, TableNode) + naturally excludes them — only concrete table references are returned. name is the alias when present, otherwise the table value. """ found: List[Dict[str, str]] = [] @@ -948,11 +945,7 @@ def _tables_of_ast(ast: Node) -> List[Dict[str, str]]: continue if not isinstance(node.name, str): continue - if is_placeholder_name(node.name): - continue alias = node.alias if isinstance(node.alias, str) else node.name - if is_placeholder_name(alias): - continue found.append({"value": node.name, "name": alias}) return found @@ -1019,12 +1012,6 @@ def _visit(node: Node, parent: Optional[Node] = None) -> None: if isinstance(child, ElementVariableNode) and child.parent_alias is None: # Only bare variables (not qualified column vars like .) names.append(child.name) - elif ( - isinstance(child, ColumnNode) - and child.parent_alias is None - and is_placeholder_name(child.name) - ): - names.append(child.name) if names: out.append(names) elif ( @@ -1154,17 +1141,9 @@ def _is_subtree_candidate(node: Node, parent: Optional[Node] = None) -> bool: ): return False - if isinstance(node, ElementVariableNode) and node.parent_alias is not None: - # Qualified column variable (e.g. . → ElementVariableNode("x5", parent_alias="x1")) - # acts as a standalone SELECT, GROUP BY, or ORDER BY item, mirroring the ColumnNode case. - return isinstance(parent, (SelectNode, GroupByNode, OrderByItemNode)) - - if isinstance(node, ColumnNode): - # Column refs that act as standalone SELECT, GROUP BY, or ORDER BY - # items are subtree candidates. Bare column refs inside operators - # or functions, such as JOIN ON, WHERE, and expressions, are not. - if not RuleGeneratorV2._node_is_fully_variablized_column(node): - return False + if isinstance(node, ElementVariableNode): + # Column variables (qualified or bare) are subtree candidates only as + # standalone SELECT, GROUP BY, or ORDER BY items. return isinstance(parent, (SelectNode, GroupByNode, OrderByItemNode)) if isinstance(node, SetVariableNode): @@ -1203,11 +1182,6 @@ def _is_subtree_candidate(node: Node, parent: Optional[Node] = None) -> bool: if isinstance(child, (ElementVariableNode, SetVariableNode)): var_count += 1 continue - if isinstance(child, ColumnNode): - if RuleGeneratorV2._node_is_fully_variablized_column(child): - var_count += 1 - continue - return False if isinstance(child, LiteralNode): value = getattr(child, "value", None) if isinstance(value, str): @@ -1401,10 +1375,8 @@ def _is_branch_node(node: Node) -> bool: # generator-variablized table — ok pass elif isinstance(child, TableNode): - # user-written table variable placeholder (RuleParserV2 produces TableNode for ) - # or a concrete table. Only ok if placeholder name. - if not is_placeholder_name(child.name): - return False + # Concrete table — branch is not fully variablized. + return False elif isinstance(child, JoinNode): if not RuleGeneratorV2._is_branch_node(child): return False @@ -1419,9 +1391,7 @@ def _is_branch_node(node: Node) -> bool: # generator-variablized table — ok pass elif isinstance(child, TableNode): - # user-written table variable placeholder or concrete table - if not is_placeholder_name(child.name): - return False + return False else: if RuleGeneratorV2._tables_of_ast(child): return False @@ -1771,12 +1741,6 @@ def _visit(node: Node, parent: Optional[Node]) -> Node: variable_name: Optional[str] = None if isinstance(child, ElementVariableNode): variable_name = child.name - elif ( - isinstance(child, ColumnNode) - and child.parent_alias is None - and is_placeholder_name(child.name) - ): - variable_name = child.name if variable_name is not None and variable_name in variable_set: if not pending: @@ -2232,16 +2196,4 @@ def _walk(node: Optional[Node]) -> Iterator[Node]: for child in children: yield from RuleGeneratorV2._walk(child) - @staticmethod - def _node_is_fully_variablized_column(node: Node) -> bool: - # Generator-variablized columns are ElementVariableNode. - if isinstance(node, ElementVariableNode): - return True - # User-written variable column placeholders (e.g. .) remain as ColumnNode - # with placeholder names from RuleParserV2. - if isinstance(node, ColumnNode) and is_placeholder_name(node.name): - if node.parent_alias is None: - return True - return is_placeholder_name(node.parent_alias) - return False diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index 66b6c41..d4644f0 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -352,26 +352,44 @@ def _replace_internal_in_string(s: str) -> str: nm = col.name new_alias = _replace_internal_in_string(col.alias) if isinstance(col.alias, str) else col.alias new_pa = _replace_internal_in_string(pa) if isinstance(pa, str) else pa + + # Bare column variable (no qualifier): promote to ElementVariableNode if pa is None and nm in rev: return RuleParserV2._placeholder_varnode(nm, rev[nm]) + + # Both name and parent_alias are variables if pa is not None and pa in rev and nm in rev: - return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=rev[pa]) + return ElementVariableNode(rev[nm], parent_alias=rev[pa], alias=new_alias) + + # Only parent_alias is a variable (concrete column, variable table qualifier) if pa is not None and pa in rev: return ColumnNode(nm, _alias=new_alias, _parent_alias=rev[pa]) + + # Only column name is a variable (concrete table qualifier) if pa is not None and nm in rev: - return ColumnNode(rev[nm], _alias=new_alias, _parent_alias=new_pa) + return ElementVariableNode(rev[nm], parent_alias=new_pa, alias=new_alias) + return ColumnNode(nm, _alias=new_alias, _parent_alias=new_pa) if node.type == NodeType.TABLE: t = node if not isinstance(t, TableNode): return node - # If table name is a SET variable placeholder (<>), promote to SetVariableNode - # so it matches any table or list of tables in the FROM clause. - # Element variable tokens (EV...) stay as TableNode so _match_node handles them. sv_base = VarTypesInfo[VarType.SetVariable]["internalBase"] + ev_base = VarTypesInfo[VarType.ElementVariable]["internalBase"] + + # SET variable table: promote to SetVariableNode if isinstance(t.name, str) and t.name in rev and t.name.startswith(sv_base): return SetVariableNode(rev[t.name]) + + # ELEMENT variable table: promote to ElementVariableNode + if isinstance(t.name, str) and t.name in rev and t.name.startswith(ev_base): + # alias may also be a variable + if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: + return ElementVariableNode(rev[t.name], alias=rev[t.alias]) + return ElementVariableNode(rev[t.name]) + + # Concrete table new_name = rev.get(t.name, t.name) if isinstance(t.name, str) else t.name if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: new_alias = rev[t.alias] diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index a25ce79..3902dd4 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -470,13 +470,16 @@ def test_parse_where_scope_strips_select_and_from(): # ═══════════════════════════════════════════════════════════════════════════════ def test_parse_ast_from_scope(): + # After parser update: EV tokens in TABLE position are promoted to ElementVariableNode. + # The concrete alias "li" is not a variable, so it's dropped from the ElementVariableNode. + # The element variable "t" captures the whole table reference during matching. result = RuleParserV2.parse("FROM li", "FROM li") assert result.mapping == {"t": "EV001"} assert isinstance(result.pattern_ast, QueryNode) frm = next(c for c in result.pattern_ast.children if c.type == NodeType.FROM) assert isinstance(frm, FromNode) tab = list(frm.children)[0] - assert isinstance(tab, TableNode) and tab.name == "t" and tab.alias == "li" + assert isinstance(tab, ElementVariableNode) and tab.name == "t" def test_parse_from_scope_strips_select(): @@ -536,8 +539,8 @@ def test_parse_self_join_rule(): ) _assert_varnodes_declared(result) _assert_no_internal_tokens(result) - assert len(_find_all(result.pattern_ast, TableNode)) >= 2 - assert len(_find_all(result.rewrite_ast, TableNode)) >= 1 + assert len(_find_all(result.pattern_ast, ElementVariableNode)) >= 2 + assert len(_find_all(result.rewrite_ast, ElementVariableNode)) >= 1 pat_svs = [n for n in _walk(result.pattern_ast) if isinstance(n, SetVariableNode)] assert len(pat_svs) >= 2 # s1 and p1 @@ -632,16 +635,17 @@ def test_parse_set_variable_in_select_and_where(): # ═══════════════════════════════════════════════════════════════════════════════ def test_qualified_column_both_parts_substituted(): - """. — both parent_alias and name should become external names.""" + """. — both parent_alias and name should become external names (ElementVariableNode).""" result = RuleParserV2.parse(". = 1", ". = 1") _assert_varnodes_declared(result) _assert_no_internal_tokens(result) - cols = _find_all(result.pattern_ast, ColumnNode) - qualified = [c for c in cols if c.parent_alias is not None] + # When both parts are variables, _substitute_placeholders returns ElementVariableNode + evars = _find_all(result.pattern_ast, ElementVariableNode) + qualified = [e for e in evars if e.parent_alias is not None] assert len(qualified) >= 1 - for c in qualified: - assert c.parent_alias in result.mapping - assert c.name in result.mapping + for e in qualified: + assert e.parent_alias in result.mapping + assert e.name in result.mapping def test_qualified_column_only_parent_alias_is_var(): From d5797dd716556e4cd2d39ce25cbc294fd85f86bc Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 21 May 2026 15:00:06 -0700 Subject: [PATCH 2/3] add VariableLiteralNode --- core/ast/__init__.py | 2 + core/ast/enums.py | 1 + core/ast/node.py | 33 +++++++- core/ast/utils.py | 22 ------ core/query_formatter.py | 59 ++++++++------ core/query_rewriter_v2.py | 132 ++++++++++++++++---------------- core/rule_generator_v2.py | 89 +++++++++------------ core/rule_parser_v2.py | 38 ++++++--- tests/test_rule_generator_v2.py | 10 +-- tests/test_rule_parser_v2.py | 48 ++++++------ 10 files changed, 231 insertions(+), 203 deletions(-) diff --git a/core/ast/__init__.py b/core/ast/__init__.py index 1474504..5f7b0dd 100644 --- a/core/ast/__init__.py +++ b/core/ast/__init__.py @@ -13,6 +13,7 @@ LiteralNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, OperatorNode, FunctionNode, SelectNode, @@ -37,6 +38,7 @@ 'LiteralNode', 'ElementVariableNode', 'SetVariableNode', + 'VariableLiteralNode', 'OperatorNode', 'FunctionNode', 'SelectNode', diff --git a/core/ast/enums.py b/core/ast/enums.py index 63f79dc..fa53cd3 100644 --- a/core/ast/enums.py +++ b/core/ast/enums.py @@ -20,6 +20,7 @@ class NodeType(Enum): # VarSQL specific VAR = "var" VARSET = "varset" + VAR_LITERAL = "var_literal" # Operators OPERATOR = "operator" diff --git a/core/ast/node.py b/core/ast/node.py index b204931..0eddc76 100644 --- a/core/ast/node.py +++ b/core/ast/node.py @@ -46,7 +46,7 @@ def __hash__(self): class TableNode(Node): """Table reference node""" - def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs): + def __init__(self, _name: str, _alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs): super().__init__(NodeType.TABLE, **kwargs) self.name = _name self.alias = _alias @@ -80,7 +80,7 @@ def __hash__(self): class ColumnNode(Node): """Column reference node""" - def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): + def __init__(self, _name: str, _alias: Optional[Union[str, 'ElementVariableNode']] = None, _parent_alias: Optional[Union[str, 'ElementVariableNode']] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs): super().__init__(NodeType.COLUMN, **kwargs) self.name = _name self.alias = _alias @@ -172,7 +172,7 @@ def __hash__(self): class ElementVariableNode(Node): """Rule element variable ```` (see ``VarType.ElementVariable`` in rule_parser_v2).""" - def __init__(self, _name: str, parent_alias: Optional[str] = None, alias: Optional[str] = None, **kwargs): + def __init__(self, _name: str, parent_alias: Optional[Union[str, 'ElementVariableNode']] = None, alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs): super().__init__(NodeType.VAR, **kwargs) self.name = _name self.parent_alias = parent_alias @@ -202,6 +202,31 @@ def __hash__(self): return hash((super().__hash__(), self.name)) +class VariableLiteralNode(Node): + """A string literal placeholder, e.g. ``'%%'`` in a LIKE predicate. + + ``prefix`` and ``suffix`` capture surrounding wildcard characters so + ``LIKE '%foo%'`` → ``VariableLiteralNode('x1', prefix='%', suffix='%')``. + """ + def __init__(self, _name: str, prefix: str = "", suffix: str = "", + _alias: Optional[str] = None, **kwargs): + super().__init__(NodeType.VAR_LITERAL, **kwargs) + self.name = _name + self.prefix = prefix + self.suffix = suffix + self.alias = _alias + + def __eq__(self, other): + if not isinstance(other, VariableLiteralNode): + return False + return (super().__eq__(other) and self.name == other.name + and self.prefix == other.prefix and self.suffix == other.suffix + and self.alias == other.alias) + + def __hash__(self): + return hash((super().__hash__(), self.name, self.prefix, self.suffix, self.alias)) + + class OperatorNode(Node): """Operator node""" def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs): @@ -233,7 +258,7 @@ def __init__(self, _operand: Node, _name: str, **kwargs): class FunctionNode(Node): """Function call node""" - def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs): + def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs): if _args is None: _args = [] super().__init__(NodeType.FUNCTION, children=_args, **kwargs) diff --git a/core/ast/utils.py b/core/ast/utils.py index f516bf7..284deff 100644 --- a/core/ast/utils.py +++ b/core/ast/utils.py @@ -2,32 +2,10 @@ from __future__ import annotations -import re from typing import List from core.ast.node import Node, OperatorNode, UnaryOperatorNode -_PLACEHOLDER_PREFIXES = ("x", "y") - - -def is_placeholder_name(name: str) -> bool: - """Return True when name is a generator-internal placeholder identifier. - - Matches the parser-friendly tokens (__rv_x?__, __rvs_y?__) and bare x?/y? - external names. - """ - lower = name.lower() - if re.fullmatch(r"__rv_[xy]\d+__", lower): - return True - if re.fullmatch(r"__rvs_[xy]\d+__", lower): - return True - for prefix in _PLACEHOLDER_PREFIXES: - if lower.startswith(prefix): - suffix = lower[len(prefix):] - if suffix.isdigit(): - return True - return False - def flatten_logical_operands(node: Node, op_name: str) -> List[Node]: """Flatten a left-associative tree of binary ``op_name`` (e.g. AND/OR). diff --git a/core/query_formatter.py b/core/query_formatter.py index 13afb4c..0b0da7e 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -14,10 +14,11 @@ SubqueryNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, ) from core.ast.enums import NodeType, JoinType from core.ast.node import Node -from core.ast.utils import flatten_logical_operands, is_placeholder_name +from core.ast.utils import flatten_logical_operands def _placeholder_token(name: str) -> str: @@ -33,6 +34,18 @@ def _normalize_placeholder_tokens(sql: str) -> str: return out +def _render_alias(alias) -> str: + """Render an alias/parent_alias field to a string for mosql output. + + Concrete string aliases pass through unchanged; ElementVariableNode aliases + emit their placeholder token (``__rv_name__``), which ``_normalize_placeholder_tokens`` + later converts to ````. + """ + if isinstance(alias, ElementVariableNode): + return f"__rv_{alias.name}__" + return alias + + def _replace_wrapped_tokens(text: str, prefix: str, suffix: str, open_marker: str, close_marker: str) -> str: out = text start = 0 @@ -121,11 +134,18 @@ def ast_to_json(node: Node) -> dict: result['orderby'] = format_order_by(child) elif child.type == NodeType.LIMIT: lv = child.limit - if isinstance(lv, str) and is_placeholder_name(lv) and not (lv.startswith("__rv_") or lv.startswith("__rvs_")): - lv = _placeholder_token(lv) + if isinstance(lv, ElementVariableNode): + lv = _placeholder_token(lv.name) + elif isinstance(lv, SetVariableNode): + lv = f"__rvs_{lv.name}__" result['limit'] = lv elif child.type == NodeType.OFFSET: - result['offset'] = child.offset + ov = child.offset + if isinstance(ov, ElementVariableNode): + ov = _placeholder_token(ov.name) + elif isinstance(ov, SetVariableNode): + ov = f"__rvs_{ov.name}__" + result['offset'] = ov return result @@ -147,7 +167,7 @@ def format_select(select_node: SelectNode) -> dict: for child in children: item = {'value': format_expression(child)} if hasattr(child, 'alias') and child.alias: - item['name'] = child.alias + item['name'] = _render_alias(child.alias) items.append(item) select_key = 'select_distinct' if select_node.distinct else 'select' @@ -252,28 +272,21 @@ def format_source(node: Node) -> dict: subquery_child = list(node.children)[0] result = {'value': ast_to_json(subquery_child)} if node.alias: - result['name'] = node.alias + result['name'] = _render_alias(node.alias) return result elif node.type == NodeType.VAR: result = {'value': f"__rv_{node.name}__"} if node.alias: - result['name'] = node.alias + result['name'] = _render_alias(node.alias) return result raise ValueError(f"Unsupported source type: {node.type}") def format_table(table_node: TableNode) -> dict: - """Format a table reference. - - TableNode names from the rule generator are always concrete after the refactor; - however RuleParserV2 still produces TableNode(name="x1") for user-written table - variable placeholders like , so is_placeholder_name is kept as a safety - fallback for those cases. - """ - name = _placeholder_token(table_node.name) if is_placeholder_name(table_node.name) else table_node.name - result = {'value': name} + """Format a table reference.""" + result = {'value': table_node.name} if table_node.alias: - result['name'] = table_node.alias + result['name'] = _render_alias(table_node.alias) return result @@ -335,17 +348,19 @@ def format_expression(node: Node): if node.type == NodeType.VAR: token = f"__rv_{node.name}__" if node.parent_alias: - pa = _placeholder_token(node.parent_alias) if is_placeholder_name(node.parent_alias) else node.parent_alias + pa = _render_alias(node.parent_alias) return f"{pa}.{token}" return token if node.type == NodeType.VARSET: return f"__rvs_{node.name}__" + if node.type == NodeType.VAR_LITERAL: + return {'literal': f"{node.prefix}__rv_{node.name}__{node.suffix}"} + if node.type == NodeType.COLUMN: - col_token = _placeholder_token(node.name) if is_placeholder_name(node.name) else node.name if node.parent_alias: - pa_token = _placeholder_token(node.parent_alias) if is_placeholder_name(node.parent_alias) else node.parent_alias - return f"{pa_token}.{col_token}" - return col_token + pa_token = _render_alias(node.parent_alias) + return f"{pa_token}.{node.name}" + return node.name elif node.type == NodeType.LITERAL: if node.value is None: diff --git a/core/query_rewriter_v2.py b/core/query_rewriter_v2.py index 620c759..774731f 100644 --- a/core/query_rewriter_v2.py +++ b/core/query_rewriter_v2.py @@ -60,6 +60,7 @@ TableNode, TimeUnitNode, UnaryOperatorNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) @@ -151,10 +152,6 @@ def _bind(var_name: str, value: Any, memo: dict) -> bool: return True -def _is_var_name(s: Any, mapping: dict) -> bool: - """True if s is a string that is an external variable name in the rule mapping.""" - return isinstance(s, str) and s in mapping - # ============================================================================ # Core matching @@ -168,27 +165,27 @@ def _match_node( # --- variable nodes in pattern --- if isinstance(p, ElementVariableNode): # Qualified column variable: ElementVariableNode with parent_alias that is a variable name - if p.parent_alias is not None and _is_var_name(p.parent_alias, mapping): + if isinstance(p.parent_alias, ElementVariableNode): if not isinstance(q, ColumnNode): return False if q.parent_alias is None: return False if not _bind(p.name, q.name, memo): return False - return _bind(p.parent_alias, q.parent_alias, memo) + return _bind(p.parent_alias.name, q.parent_alias, memo) # Table variable with variable alias: ElementVariableNode with alias that is a variable name - # e.g. ElementVariableNode("tb1", alias="t1") where "tb1" -> TableNode(name), "t1" -> alias string + # e.g. ElementVariableNode("tb1", alias=ElementVariableNode("t1")) where "tb1" -> TableNode(name), "t1" -> alias string # Bind a stripped TableNode (no alias) to p.name so that when p.name appears as a bare # table variable in the rewrite (e.g. inner FROM ), it materializes as TableNode(name) # without leaking the alias. The alias is separately captured via p.alias. - if p.alias is not None and _is_var_name(p.alias, mapping): + if isinstance(p.alias, ElementVariableNode): if not isinstance(q, TableNode): return False if not _bind(p.name, TableNode(q.name), memo): return False if q.alias is None: return False - return _bind(p.alias, q.alias, memo) + return _bind(p.alias.name, q.alias, memo) # Default: whole-node binding. # JoinNode is a compound structural node (not an atomic value) that should never be # bound to a bare element variable — it would violate the type contract and cause @@ -215,24 +212,30 @@ def _match_node( # --- type must be compatible --- if not isinstance(q, type(p)) and not isinstance(p, type(q)): - # Allow OperatorNode / UnaryOperatorNode subclass relationship - if not (isinstance(q, OperatorNode) and isinstance(p, OperatorNode)): - return False + # VariableLiteralNode matches against LiteralNode; allow it before the strict type guard + if not isinstance(p, VariableLiteralNode): + # Allow OperatorNode / UnaryOperatorNode subclass relationship + if not (isinstance(q, OperatorNode) and isinstance(p, OperatorNode)): + return False # --- leaf nodes --- + if isinstance(p, VariableLiteralNode): + if not isinstance(q, LiteralNode): + return False + qv = q.value + if not isinstance(qv, str): + return False + if p.prefix and not qv.startswith(p.prefix): + return False + if p.suffix and not qv.endswith(p.suffix): + return False + inner = qv[len(p.prefix): len(qv) - len(p.suffix) if p.suffix else len(qv)] + return _bind(p.name, LiteralNode(inner), memo) + if isinstance(p, LiteralNode): if not isinstance(q, LiteralNode): return False qv, pv = q.value, p.value - # RuleParserV2 may represent placeholders inside string literals like `''` - # as LiteralNode("s") (where "s" is a declared rule variable). In that case, - # treat it as a bindable placeholder rather than a concrete string. - - # TODO: We hope to further flatten variables in the literal, e.g., - # q: {like: [name, '%joe%']} -> Func(like, [Col('name'), LiteralNode('%joe%')]) - # p: {like: [x, '%y%']} -> Func(like, [EV(x),  LitrlComb([LiteralNode('%'), EV(y), LiteralNode('%')])]) - if isinstance(pv, str) and _is_var_name(pv, mapping): - return _bind(pv, q, memo) if isinstance(qv, str) and isinstance(pv, str): return qv.lower() == pv.lower() return qv == pv @@ -243,28 +246,20 @@ def _match_node( if isinstance(p, TimeUnitNode): return isinstance(q, TimeUnitNode) and q.name.upper() == p.name.upper() - # --- TableNode: name and alias may be variable names --- + # --- TableNode: name is always concrete; alias may be a variable name --- if isinstance(p, TableNode): if not isinstance(q, TableNode): return False - if _is_var_name(p.name, mapping) and p.alias is None: - # Variable stands for the entire table reference (name + alias). - # Bind to the whole TableNode so the rewrite can reproduce it faithfully. - return _bind(p.name, q, memo) - if _is_var_name(p.name, mapping): - if not _bind(p.name, q.name, memo): - return False - else: - if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): - return False + if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): + return False if p.alias is not None: # Pattern requires an alias (even if it's a variable). Do not match # unaliased tables, otherwise the alias var would bind to None and # rewrites expecting a real alias/identifier become nonsensical. if q.alias is None: return False - if _is_var_name(p.alias, mapping): - if not _bind(p.alias, q.alias, memo): + if isinstance(p.alias, ElementVariableNode): + if not _bind(p.alias.name, q.alias, memo): return False else: qa = q.alias or "" @@ -272,23 +267,19 @@ def _match_node( return False return True - # --- ColumnNode: name and parent_alias may be variable names --- + # --- ColumnNode: name is always concrete; parent_alias may be a variable name --- if isinstance(p, ColumnNode): if not isinstance(q, ColumnNode): return False - if _is_var_name(p.name, mapping): - if not _bind(p.name, q.name, memo): - return False - else: - if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): - return False + if not isinstance(q.name, str) or q.name.lower() != p.name.lower(): + return False if p.parent_alias is not None: # Pattern requires a qualifier (even if it's a variable). Do not match # unqualified columns, otherwise the qualifier var would bind to None. if q.parent_alias is None: return False - if _is_var_name(p.parent_alias, mapping): - if not _bind(p.parent_alias, q.parent_alias, memo): + if isinstance(p.parent_alias, ElementVariableNode): + if not _bind(p.parent_alias.name, q.parent_alias, memo): return False else: qpa = q.parent_alias or "" @@ -346,10 +337,10 @@ def _match_node( if p.alias is not None: if q.alias is None: return False - if _is_var_name(p.alias, mapping): - if not _bind(p.alias, q.alias, memo): + if isinstance(p.alias, ElementVariableNode): + if not _bind(p.alias.name, q.alias, memo): return False - elif q.alias.lower() != p.alias.lower(): + elif isinstance(q.alias, str) and q.alias.lower() != p.alias.lower(): return False return _match_children_list(list(q.children), list(p.children), memo, mode, mapping) @@ -401,16 +392,16 @@ def _match_node( if isinstance(p, LimitNode): if not isinstance(q, LimitNode): return False - if isinstance(p.limit, str) and _is_var_name(p.limit, mapping): - return _bind(p.limit, q.limit, memo) + if isinstance(p.limit, ElementVariableNode): + return _bind(p.limit.name, q.limit, memo) return q.limit == p.limit # --- OffsetNode --- if isinstance(p, OffsetNode): if not isinstance(q, OffsetNode): return False - if isinstance(p.offset, str) and _is_var_name(p.offset, mapping): - return _bind(p.offset, q.offset, memo) + if isinstance(p.offset, ElementVariableNode): + return _bind(p.offset.name, q.offset, memo) return q.offset == p.offset # --- JoinNode --- @@ -692,11 +683,12 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: if isinstance(node, ElementVariableNode): val = memo.get(node.name, node) # If variable has a parent_alias that is a variable name, reconstruct qualified column - # e.g. ElementVariableNode("a1", parent_alias="t1") where "a1" -> "id", "t1" -> "e1" + # e.g. ElementVariableNode("a1", parent_alias=ElementVariableNode("t1")) where "a1" -> "id", "t1" -> "e1" # produces ColumnNode("id", _parent_alias="e1") # The node's own alias (if any) is a literal string from the rewrite template; preserve it. - if node.parent_alias is not None and node.parent_alias in memo: - pa_val = memo[node.parent_alias] + pa_key = node.parent_alias.name if isinstance(node.parent_alias, ElementVariableNode) else node.parent_alias + if pa_key is not None and pa_key in memo: + pa_val = memo[pa_key] col_name = val if isinstance(val, str) else (val.name if isinstance(val, Node) and hasattr(val, 'name') else None) if col_name is not None: if isinstance(pa_val, TableNode): @@ -708,10 +700,11 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: if pa_str is not None: return ColumnNode(col_name, _alias=node.alias, _parent_alias=pa_str) # If variable has an alias that is a variable name, reconstruct a TableNode - # e.g. ElementVariableNode("tb1", alias="t1") where "tb1" -> TableNode("employee", ...), "t1" -> "e1" + # e.g. ElementVariableNode("tb1", alias=ElementVariableNode("t1")) where "tb1" -> TableNode("employee", ...), "t1" -> "e1" # produces TableNode("employee", "e1") - if node.alias is not None and node.alias in memo: - alias_val = memo[node.alias] + alias_key = node.alias.name if isinstance(node.alias, ElementVariableNode) else node.alias + if alias_key is not None and alias_key in memo: + alias_val = memo[alias_key] # val may be a whole TableNode (from binding) or a string if isinstance(val, TableNode): table_name = val.name @@ -732,6 +725,13 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: # Should not appear at this level; caller handles list expansion return node + if isinstance(node, VariableLiteralNode): + bound = memo.get(node.name) + if bound is None: + return node + val = bound.value if isinstance(bound, LiteralNode) else str(bound) + return LiteralNode(f"{node.prefix}{val}{node.suffix}") + if isinstance(node, (LiteralNode, DataTypeNode, TimeUnitNode)): # For string literals, substitute any variable names embedded in the value # (e.g. LiteralNode('%y%') with memo['y']=LiteralNode('iphone') to LiteralNode('%iphone%')) @@ -813,16 +813,16 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: if isinstance(node, LimitNode): val = node.limit - if isinstance(val, str): - val = memo.get(val, val) + if isinstance(val, ElementVariableNode): + val = memo.get(val.name, val) if isinstance(val, LiteralNode): val = val.value return LimitNode(val) if isinstance(node, OffsetNode): val = node.offset - if isinstance(val, str): - val = memo.get(val, val) + if isinstance(val, ElementVariableNode): + val = memo.get(val.name, val) if isinstance(val, LiteralNode): val = val.value return OffsetNode(val) @@ -867,7 +867,7 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: def _subst_str(s: Any, memo: dict) -> Any: - """Substitute a string field if it matches a variable name in memo. + """Substitute a string or ElementVariableNode alias/parent_alias field from memo. Extracts a canonical string from the bound value: - str to return directly @@ -876,9 +876,11 @@ def _subst_str(s: Any, memo: dict) -> Any: - FunctionNode (rare): if an element variable bound to e.g. COUNT(col) in the SELECT list is substituted into a string-only context, unwrap COUNT(col) -> ``col`` name. Normally bindings are ColumnNode/TableNode/str here. + ElementVariableNode in alias/parent_alias fields (from widened field type) looks up .name in memo. """ - if isinstance(s, str) and s in memo: - val = memo[s] + key = s.name if isinstance(s, ElementVariableNode) else s + if isinstance(key, str) and key in memo: + val = memo[key] if isinstance(val, str): return val if isinstance(val, TableNode): @@ -923,7 +925,7 @@ def _replace_in_tree(tree: Node, target_id: int, replacement: Node) -> Node: return replacement if isinstance(tree, (LiteralNode, DataTypeNode, TimeUnitNode, TableNode, ColumnNode, - ElementVariableNode, SetVariableNode)): + ElementVariableNode, SetVariableNode, VariableLiteralNode)): return tree if isinstance(tree, FunctionNode): @@ -1046,7 +1048,7 @@ def _node_subst(tree: Any, src: Any, tgt: Any) -> Any: if isinstance(tree, TableNode): return TableNode(_subst_val(tree.name, src, tgt), _subst_val(tree.alias, src, tgt)) - if isinstance(tree, (LiteralNode, DataTypeNode, TimeUnitNode)): + if isinstance(tree, (LiteralNode, DataTypeNode, TimeUnitNode, VariableLiteralNode)): return tree if isinstance(tree, FunctionNode): diff --git a/core/rule_generator_v2.py b/core/rule_generator_v2.py index c5c2c4e..09dad5c 100644 --- a/core/rule_generator_v2.py +++ b/core/rule_generator_v2.py @@ -56,13 +56,13 @@ SubqueryNode, TableNode, UnaryOperatorNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) from core.query_parser import QueryParser -from core.query_formatter import QueryFormatter, _placeholder_token +from core.query_formatter import QueryFormatter from core.rule_parser_v2 import RuleParserV2, Scope, VarType, VarTypesInfo -from core.ast.utils import is_placeholder_name @functools.lru_cache(maxsize=None) @@ -392,13 +392,14 @@ def _table_token(name: Optional[str]) -> Optional[str]: state["tables"][name] = mapped return mapped - def _alias_token(name: Optional[str]) -> Optional[str]: + def _alias_token(name) -> Optional[str]: if name is None: return None - mapped = state["aliases"].get(name) + key = name.name if isinstance(name, ElementVariableNode) else name + mapped = state["aliases"].get(key) if mapped is None: mapped = f"A{len(state['aliases']) + 1}" - state["aliases"][name] = mapped + state["aliases"][key] = mapped return mapped if isinstance(node, QueryNode): @@ -426,16 +427,13 @@ def _alias_token(name: Optional[str]) -> Optional[str]: return ("ORDERBY_ITEM", node.sort.value if node.sort else None, RuleGeneratorV2._recommendation_ast_signature(inner, state)) if isinstance(node, LimitNode): value = node.limit - # TODO: LimitNode/OffsetNode use string placeholders (from _replace_literal_in_ast) - # rather than ElementVariableNode, for the same reason as string literals. - # is_placeholder_name check here is the one remaining generator-level token reference. - if isinstance(value, str) and is_placeholder_name(value): - value = f"VAR:{RuleGeneratorV2._fingerPrint(value)}" + if isinstance(value, ElementVariableNode): + value = f"VAR:{RuleGeneratorV2._fingerPrint(value.name)}" return ("LIMIT", value) if isinstance(node, OffsetNode): value = node.offset - if isinstance(value, str) and is_placeholder_name(value): - value = f"VAR:{RuleGeneratorV2._fingerPrint(value)}" + if isinstance(value, ElementVariableNode): + value = f"VAR:{RuleGeneratorV2._fingerPrint(value.name)}" return ("OFFSET", value) if isinstance(node, TableNode): return ("TABLE", _table_token(node.name), _alias_token(node.alias)) @@ -447,6 +445,8 @@ def _alias_token(name: Optional[str]) -> Optional[str]: return ("COLUMN", node.name, _alias_token(node.alias), _alias_token(node.parent_alias)) if isinstance(node, LiteralNode): return ("LITERAL", node.value, _alias_token(getattr(node, "alias", None))) + if isinstance(node, VariableLiteralNode): + return ("VAR_LITERAL", node.prefix, node.suffix) if isinstance(node, FunctionNode): return ( "FUNCTION", @@ -624,14 +624,12 @@ def variablize_literal(rule: Dict[str, object], literal: Union[str, numbers.Numb mapping, external_name = RuleGeneratorV2._find_next_element_variable(mapping) new_rule["mapping"] = mapping - # TODO: remove placeholder_token once VariableLiteralNode is added for string literals - placeholder_token = _placeholder_token(external_name) for key in ("pattern_ast", "rewrite_ast"): ast = new_rule.get(key) if not isinstance(ast, Node): raise TypeError(f"rule['{key}'] must be an AST Node") - new_rule[key] = RuleGeneratorV2._replace_literal_in_ast(ast, literal, external_name, placeholder_token) + new_rule[key] = RuleGeneratorV2._replace_literal_in_ast(ast, literal, external_name) new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] @@ -880,13 +878,7 @@ def _literal_counts(ast: Node) -> Dict[Union[str, numbers.Number], int]: continue value = getattr(node, "value", None) if isinstance(value, str): - normalized = value.replace("%", "") - # TODO: string literals with embedded __rv_ tokens (from _replace_literal_in_ast) - # cannot be replaced with ElementVariableNode due to LIKE wildcard preservation. - # is_placeholder_name check here is the one remaining generator-level token reference. - if is_placeholder_name(normalized): - continue - counts[normalized] = counts.get(normalized, 0) + 1 + counts[value.replace("%", "")] = counts.get(value.replace("%", ""), 0) + 1 elif isinstance(value, numbers.Number): counts[value] = counts.get(value, 0) + 1 return counts @@ -1025,8 +1017,8 @@ def _visit(node: Node, parent: Optional[Node] = None) -> None: seen_and_ids.add(id(node)) elif isinstance(node, WhereNode) and len(node.children) == 1 and isinstance(node.children[0], ElementVariableNode): out.append([node.children[0].name]) - elif isinstance(node, LimitNode) and isinstance(node.limit, str) and is_placeholder_name(node.limit): - out.append([node.limit]) + elif isinstance(node, LimitNode) and isinstance(node.limit, ElementVariableNode): + out.append([node.limit.name]) elif isinstance(node, JoinNode) and node.on_condition is not None: oc = node.on_condition if isinstance(oc, ElementVariableNode): @@ -1164,12 +1156,7 @@ def _is_subtree_candidate(node: Node, parent: Optional[Node] = None) -> bool: return True return False - if isinstance(node, LiteralNode): - if isinstance(parent, ListNode): - return False - value = getattr(node, "value", None) - if isinstance(value, str) and is_placeholder_name(value): - return True + if isinstance(node, (LiteralNode, VariableLiteralNode)): return False var_count = 0 @@ -1179,15 +1166,10 @@ def _is_subtree_candidate(node: Node, parent: Optional[Node] = None) -> bool: if isinstance(child, list): return False if isinstance(child, Node): - if isinstance(child, (ElementVariableNode, SetVariableNode)): + if isinstance(child, (ElementVariableNode, SetVariableNode, VariableLiteralNode)): var_count += 1 continue if isinstance(child, LiteralNode): - value = getattr(child, "value", None) - if isinstance(value, str): - normalized = value.replace("%", "") - if is_placeholder_name(normalized): - var_count += 1 continue return False return var_count >= 1 @@ -1462,16 +1444,17 @@ def _replace_literal_in_ast( ast: Node, literal: Union[str, numbers.Number], external_name: str, - placeholder_token: str, ) -> Node: """Substitute every occurrence of literal in ast with the new variable. - String literals are rewritten in place (preserving any surrounding % LIKE wildcards) using placeholder_token; numeric literal nodes are swapped wholesale for an ElementVariableNode(external_name). Mutates ast in place and returns it. + String literals become VariableLiteralNode (preserving surrounding % wildcards). + Numeric literals and LIMIT/OFFSET values become ElementVariableNode. """ + to_replace = [] for node in RuleGeneratorV2._walk(ast): if isinstance(node, LimitNode): if isinstance(literal, numbers.Number) and node.limit == literal: - node.limit = placeholder_token + node.limit = ElementVariableNode(external_name) continue if node.type != NodeType.LITERAL: continue @@ -1479,17 +1462,21 @@ def _replace_literal_in_ast( if isinstance(literal, str) and isinstance(value, str): if value == literal: - node.value = placeholder_token # type: ignore[attr-defined] + to_replace.append((node, VariableLiteralNode(external_name))) elif value.replace("%", "") == literal: - node.value = value.replace(literal, placeholder_token) # type: ignore[attr-defined] + prefix = "%" if value.startswith("%") else "" + suffix = "%" if value.endswith("%") else "" + to_replace.append((node, VariableLiteralNode(external_name, prefix=prefix, suffix=suffix))) continue if isinstance(literal, numbers.Number) and isinstance(value, numbers.Number) and value == literal: - replacement = ElementVariableNode(external_name) - if node is ast: - ast = replacement - else: - RuleGeneratorV2._replace_node_reference(ast, node, replacement) + to_replace.append((node, ElementVariableNode(external_name))) + + for old_node, new_node in to_replace: + if old_node is ast: + ast = new_node + else: + RuleGeneratorV2._replace_node_reference(ast, old_node, new_node) return ast @staticmethod @@ -1503,7 +1490,7 @@ def _replace_table_in_ast( A bare-named reference to target_value is also matched even when its alias disagrees with target_name, so a single variable can cover both an aliased outer reference and a bare-named reference inside a subquery. placeholder_token here is actually the external_name (e.g. "x1") passed from variablize_table. - ColumnNode.parent_alias is set to this bare string; the formatter's is_placeholder_name check handles it. + ColumnNode.parent_alias is set to ElementVariableNode(placeholder_token) so the formatter and rewriter handle it via isinstance checks. """ # A bare-table reference, with no explicit alias, is also matched when # its value equals the target's value even if target_name differs. This @@ -1543,7 +1530,7 @@ def _replace_table_in_ast( or node.parent_alias == target_name ) ): - node.parent_alias = placeholder_token + node.parent_alias = ElementVariableNode(placeholder_token) return ast @staticmethod @@ -1770,8 +1757,8 @@ def _visit(node: Node, parent: Optional[Node]) -> Node: node.children[2] = replacement return node - if isinstance(node, LimitNode) and isinstance(node.limit, str) and node.limit in variable_set: - node.limit = set_name + if isinstance(node, LimitNode) and isinstance(node.limit, ElementVariableNode) and node.limit.name in variable_set: + node.limit = SetVariableNode(set_name) return node if ( @@ -2195,5 +2182,3 @@ def _walk(node: Optional[Node]) -> Iterator[Node]: return for child in children: yield from RuleGeneratorV2._walk(child) - - diff --git a/core/rule_parser_v2.py b/core/rule_parser_v2.py index d4644f0..3824746 100644 --- a/core/rule_parser_v2.py +++ b/core/rule_parser_v2.py @@ -35,6 +35,7 @@ UnaryOperatorNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) @@ -350,7 +351,12 @@ def _replace_internal_in_string(s: str) -> str: return node pa = col.parent_alias nm = col.name - new_alias = _replace_internal_in_string(col.alias) if isinstance(col.alias, str) else col.alias + if isinstance(col.alias, str) and col.alias in rev: + new_alias: Optional[Union[str, ElementVariableNode]] = ElementVariableNode(rev[col.alias]) + elif isinstance(col.alias, str): + new_alias = _replace_internal_in_string(col.alias) + else: + new_alias = col.alias new_pa = _replace_internal_in_string(pa) if isinstance(pa, str) else pa # Bare column variable (no qualifier): promote to ElementVariableNode @@ -359,11 +365,11 @@ def _replace_internal_in_string(s: str) -> str: # Both name and parent_alias are variables if pa is not None and pa in rev and nm in rev: - return ElementVariableNode(rev[nm], parent_alias=rev[pa], alias=new_alias) + return ElementVariableNode(rev[nm], parent_alias=ElementVariableNode(rev[pa]), alias=new_alias) # Only parent_alias is a variable (concrete column, variable table qualifier) if pa is not None and pa in rev: - return ColumnNode(nm, _alias=new_alias, _parent_alias=rev[pa]) + return ColumnNode(nm, _alias=new_alias, _parent_alias=ElementVariableNode(rev[pa])) # Only column name is a variable (concrete table qualifier) if pa is not None and nm in rev: @@ -386,13 +392,13 @@ def _replace_internal_in_string(s: str) -> str: if isinstance(t.name, str) and t.name in rev and t.name.startswith(ev_base): # alias may also be a variable if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: - return ElementVariableNode(rev[t.name], alias=rev[t.alias]) + return ElementVariableNode(rev[t.name], alias=ElementVariableNode(rev[t.alias])) return ElementVariableNode(rev[t.name]) # Concrete table new_name = rev.get(t.name, t.name) if isinstance(t.name, str) else t.name if t.alias is not None and isinstance(t.alias, str) and t.alias in rev: - new_alias = rev[t.alias] + new_alias = ElementVariableNode(rev[t.alias]) else: new_alias = t.alias return TableNode(new_name, new_alias) @@ -403,10 +409,15 @@ def _replace_internal_in_string(s: str) -> str: return node alias = _replace_internal_in_string(lit.alias) if isinstance(getattr(lit, "alias", None), str) else getattr(lit, "alias", None) if isinstance(lit.value, str): - # If the entire literal value is an internal placeholder token, promote to var node + # Exact match: the entire literal is a placeholder token → variable literal if lit.value in rev: - return LiteralNode(rev[lit.value], _alias=alias) - # Otherwise substitute any embedded tokens (e.g. '%EV001%' to '%x%') + return VariableLiteralNode(rev[lit.value], _alias=alias) + # Embedded token: e.g. '%EV001%' → VariableLiteralNode with surrounding wildcards + stripped = lit.value.replace("%", "") + if stripped in rev: + prefix = "%" if lit.value.startswith("%") else "" + suffix = "%" if lit.value.endswith("%") else "" + return VariableLiteralNode(rev[stripped], prefix=prefix, suffix=suffix, _alias=alias) return LiteralNode(_replace_internal_in_string(lit.value), _alias=alias) return LiteralNode(lit.value, _alias=alias) @@ -483,6 +494,8 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(lim, LimitNode): return node if isinstance(lim.limit, str): + if lim.limit in rev: + return LimitNode(ElementVariableNode(rev[lim.limit])) return LimitNode(_replace_internal_in_string(lim.limit)) return LimitNode(lim.limit) @@ -491,6 +504,8 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(off, OffsetNode): return node if isinstance(off.offset, str): + if off.offset in rev: + return OffsetNode(ElementVariableNode(rev[off.offset])) return OffsetNode(_replace_internal_in_string(off.offset)) return OffsetNode(off.offset) @@ -532,7 +547,12 @@ def _replace_internal_in_string(s: str) -> str: if not isinstance(f, FunctionNode): return node new_args = [RuleParserV2._substitute_placeholders(a, rev) for a in f.children] - alias = _replace_internal_in_string(f.alias) if isinstance(f.alias, str) else f.alias + if isinstance(f.alias, str) and f.alias in rev: + alias = ElementVariableNode(rev[f.alias]) + elif isinstance(f.alias, str): + alias = _replace_internal_in_string(f.alias) + else: + alias = f.alias return FunctionNode(f.name, _args=new_args, _alias=alias) if node.type == NodeType.LIST: diff --git a/tests/test_rule_generator_v2.py b/tests/test_rule_generator_v2.py index 4d2e3fa..a971827 100644 --- a/tests/test_rule_generator_v2.py +++ b/tests/test_rule_generator_v2.py @@ -1052,7 +1052,7 @@ def test_branches_4(): ) branches = RuleGeneratorV2.branches(result.pattern_ast, result.rewrite_ast) actual = {(b["key"], RuleGeneratorV2.deparse(b["value"])) for b in branches} - assert actual == {("eq_rhs", "TIMESTAMP('x')")} + assert actual == {("eq_rhs", "TIMESTAMP('')")} def test_branches_5(): @@ -1428,13 +1428,13 @@ def test_generate_general_rule_10(): FROM WHERE IN (SELECT FROM - WHERE = ) + WHERE = '') """ expected_rewrite = """ SELECT DISTINCT FROM , WHERE . = . - AND . = + AND . = '' """ _assert_matches_expected(q0, q1, expected_pattern, expected_rewrite) @@ -2082,8 +2082,8 @@ def test_generate_spreadsheet_id_18(): _assert_matches_expected( q0, q1, - "SELECT DISTINCT ON (.) <>, COALESCE(., ), FROM LEFT JOIN ON <> LEFT JOIN ON <> WHERE <> AND <> AND <> ORDER BY DESC", - "SELECT <>, COALESCE((SELECT . FROM WHERE <> AND <> LIMIT ), ), (SELECT FROM WHERE <> AND <> LIMIT ) FROM WHERE <>", + "SELECT DISTINCT ON (.) <>, COALESCE(., ''), FROM LEFT JOIN ON <> LEFT JOIN ON <> WHERE <> AND <> AND <> ORDER BY DESC", + "SELECT <>, COALESCE((SELECT . FROM WHERE <> AND <> LIMIT ), ''), (SELECT FROM WHERE <> AND <> LIMIT ) FROM WHERE <>", ) diff --git a/tests/test_rule_parser_v2.py b/tests/test_rule_parser_v2.py index 3902dd4..4f4ccfd 100644 --- a/tests/test_rule_parser_v2.py +++ b/tests/test_rule_parser_v2.py @@ -30,6 +30,7 @@ UnaryOperatorNode, ElementVariableNode, SetVariableNode, + VariableLiteralNode, WhenThenNode, WhereNode, ) @@ -87,7 +88,15 @@ def _assert_varnodes_declared(result: RuleParseResult) -> None: def _assert_no_internal_tokens(result: RuleParseResult) -> None: """No EV00x / SV00x tokens should survive in identifier-bearing AST fields.""" - internal_tokens = set(result.mapping.values()) + def _check_alias(label: str, field_name: str, value) -> None: + if isinstance(value, str): + assert not _TOKEN_RE.match(value), ( + f"{label} AST has raw internal token {value!r} as {field_name}" + ) + elif isinstance(value, ElementVariableNode): + assert not _TOKEN_RE.match(value.name), ( + f"{label} AST has raw internal token {value.name!r} inside ElementVariableNode at {field_name}" + ) for tree_label, tree in [("pattern", result.pattern_ast), ("rewrite", result.rewrite_ast)]: for n in _walk(tree): @@ -95,34 +104,22 @@ def _assert_no_internal_tokens(result: RuleParseResult) -> None: assert not _TOKEN_RE.match(n.name), ( f"{tree_label} AST has raw internal token {n.name!r} as ColumnNode.name" ) - if isinstance(n.alias, str): - assert not _TOKEN_RE.match(n.alias), ( - f"{tree_label} AST has raw internal token {n.alias!r} as ColumnNode.alias" - ) - if n.parent_alias in internal_tokens: - assert not _TOKEN_RE.match(n.parent_alias), ( - f"{tree_label} AST has raw internal token {n.parent_alias!r} " - f"as ColumnNode.parent_alias" - ) + _check_alias(tree_label, "ColumnNode.alias", n.alias) + _check_alias(tree_label, "ColumnNode.parent_alias", n.parent_alias) if isinstance(n, TableNode) and isinstance(n.name, str): assert not _TOKEN_RE.match(n.name), ( f"{tree_label} AST has raw internal token {n.name!r} as TableNode.name" ) - if isinstance(n.alias, str): - assert not _TOKEN_RE.match(n.alias), ( - f"{tree_label} AST has raw internal token {n.alias!r} as TableNode.alias" - ) + _check_alias(tree_label, "TableNode.alias", n.alias) if isinstance(n, SubqueryNode) and isinstance(n.alias, str): assert not _TOKEN_RE.match(n.alias), ( f"{tree_label} AST has raw internal token {n.alias!r} as SubqueryNode.alias" ) - if isinstance(n, FunctionNode) and isinstance(n.alias, str): - assert not _TOKEN_RE.match(n.alias), ( - f"{tree_label} AST has raw internal token {n.alias!r} as FunctionNode.alias" - ) + if isinstance(n, FunctionNode): + _check_alias(tree_label, "FunctionNode.alias", n.alias) # ═══════════════════════════════════════════════════════════════════════════════ @@ -364,14 +361,15 @@ def test_parse_ast_strpos_ilike_rule(): assert isinstance(lower, FunctionNode) and lower.name.lower() == "lower" assert isinstance(list(lower.children)[0], ElementVariableNode) assert list(lower.children)[0].name == "x" - assert isinstance(strpos_args[1], LiteralNode) + assert isinstance(strpos_args[1], VariableLiteralNode) + assert strpos_args[1].name == "s" # Rewrite: ILIKE rew = result.rewrite_ast assert isinstance(rew, FunctionNode) and rew.name.lower() == "ilike" ilike_args = list(rew.children) assert isinstance(ilike_args[0], ElementVariableNode) and ilike_args[0].name == "x" - assert isinstance(ilike_args[1], LiteralNode) - assert ilike_args[1].value == "%s%" + assert isinstance(ilike_args[1], VariableLiteralNode) + assert ilike_args[1].name == "s" and ilike_args[1].prefix == "%" and ilike_args[1].suffix == "%" def test_substitute_placeholders_limit_offset_string_tokens(): @@ -382,8 +380,8 @@ def test_substitute_placeholders_limit_offset_string_tokens(): off = RuleParserV2._substitute_placeholders( # type: ignore[arg-type] OffsetNode("EV002"), {"EV002": "y"} ) - assert isinstance(lim, LimitNode) and lim.limit == "x" - assert isinstance(off, OffsetNode) and off.offset == "y" + assert isinstance(lim, LimitNode) and isinstance(lim.limit, ElementVariableNode) and lim.limit.name == "x" + assert isinstance(off, OffsetNode) and isinstance(off.offset, ElementVariableNode) and off.offset.name == "y" def test_parse_substitutes_alias_fields(): @@ -640,11 +638,13 @@ def test_qualified_column_both_parts_substituted(): _assert_varnodes_declared(result) _assert_no_internal_tokens(result) # When both parts are variables, _substitute_placeholders returns ElementVariableNode + # with parent_alias stored as ElementVariableNode (widened field type). evars = _find_all(result.pattern_ast, ElementVariableNode) qualified = [e for e in evars if e.parent_alias is not None] assert len(qualified) >= 1 for e in qualified: - assert e.parent_alias in result.mapping + assert isinstance(e.parent_alias, ElementVariableNode) + assert e.parent_alias.name in result.mapping assert e.name in result.mapping From 55a3a66650d6b746293315fd6b0b05a0ededda2f Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 21 May 2026 16:24:43 -0700 Subject: [PATCH 3/3] clean up dead code and fix type annotations --- core/query_formatter.py | 35 +--- core/query_rewriter_v2.py | 36 +--- core/rule.py | 50 ++++++ core/rule_generator_v2.py | 290 ++++++++++++-------------------- data/rules.py | 27 ++- tests/test_rule_generator_v2.py | 21 ++- 6 files changed, 186 insertions(+), 273 deletions(-) create mode 100644 core/rule.py diff --git a/core/query_formatter.py b/core/query_formatter.py index 0b0da7e..1af1bb4 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -21,16 +21,9 @@ from core.ast.utils import flatten_logical_operands -def _placeholder_token(name: str) -> str: - if name.lower().startswith("y"): - return f"__rvs_{name}__" - return f"__rv_{name}__" - - def _normalize_placeholder_tokens(sql: str) -> str: - out = sql - out = _replace_wrapped_tokens(out, "__rvs_", "__", "<<", ">>") - out = _replace_wrapped_tokens(out, "__rv_", "__", "<", ">") + out = re.sub(r"__rvs_(\w+)__", r"<<\1>>", sql) + out = re.sub(r"__rv_(\w+)__", r"<\1>", out) return out @@ -46,26 +39,6 @@ def _render_alias(alias) -> str: return alias -def _replace_wrapped_tokens(text: str, prefix: str, suffix: str, open_marker: str, close_marker: str) -> str: - out = text - start = 0 - while True: - i = out.find(prefix, start) - if i < 0: - break - j = out.find(suffix, i + len(prefix)) - if j < 0: - break - inner = out[i + len(prefix):j] - if inner and all(ch.isalnum() or ch == "_" for ch in inner): - replacement = f"{open_marker}{inner}{close_marker}" - out = out[:i] + replacement + out[j + len(suffix):] - start = i + len(replacement) - else: - start = i + 1 - return out - - class QueryFormatter: def format(self, query: Node) -> str: # [1] AST -> JSON @@ -135,14 +108,14 @@ def ast_to_json(node: Node) -> dict: elif child.type == NodeType.LIMIT: lv = child.limit if isinstance(lv, ElementVariableNode): - lv = _placeholder_token(lv.name) + lv = f"__rv_{lv.name}__" elif isinstance(lv, SetVariableNode): lv = f"__rvs_{lv.name}__" result['limit'] = lv elif child.type == NodeType.OFFSET: ov = child.offset if isinstance(ov, ElementVariableNode): - ov = _placeholder_token(ov.name) + ov = f"__rv_{ov.name}__" elif isinstance(ov, SetVariableNode): ov = f"__rvs_{ov.name}__" result['offset'] = ov diff --git a/core/query_rewriter_v2.py b/core/query_rewriter_v2.py index 774731f..7a2cf63 100644 --- a/core/query_rewriter_v2.py +++ b/core/query_rewriter_v2.py @@ -23,7 +23,6 @@ import copy import logging -import re from contextlib import contextmanager from collections import deque from enum import Enum @@ -143,8 +142,6 @@ def _bind(var_name: str, value: Any, memo: dict) -> bool: # (table vars bind to whole TableNode; ColumnNode.parent_alias binds to string) if isinstance(existing, TableNode) and isinstance(value, str): return (existing.alias or existing.name) == value - if isinstance(value, TableNode) and isinstance(existing, str): - return (value.alias or value.name) == existing if isinstance(existing, Node) and isinstance(value, Node): return existing == value return existing == value @@ -675,9 +672,6 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: return val if isinstance(val, str): return ColumnNode(val) - if isinstance(val, (int, float, bool)) or val is None: - return LiteralNode(val) - # Fallback: caller will keep the variable node unchanged. return None if isinstance(node, ElementVariableNode): @@ -706,12 +700,7 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: if alias_key is not None and alias_key in memo: alias_val = memo[alias_key] # val may be a whole TableNode (from binding) or a string - if isinstance(val, TableNode): - table_name = val.name - elif isinstance(val, str): - table_name = val - else: - table_name = None + table_name = val.name if isinstance(val, TableNode) else None if table_name is not None: alias_str = alias_val if isinstance(alias_val, str) else None if alias_str is not None: @@ -733,32 +722,11 @@ def _materialize_element_binding(val: Any) -> Optional[Node]: return LiteralNode(f"{node.prefix}{val}{node.suffix}") if isinstance(node, (LiteralNode, DataTypeNode, TimeUnitNode)): - # For string literals, substitute any variable names embedded in the value - # (e.g. LiteralNode('%y%') with memo['y']=LiteralNode('iphone') to LiteralNode('%iphone%')) - if isinstance(node, LiteralNode) and isinstance(node.value, str): - new_val = node.value - for var_name, bound in memo.items(): - if not isinstance(var_name, str) or var_name.startswith("_"): - continue - if var_name not in new_val: - continue - if isinstance(bound, LiteralNode) and isinstance(bound.value, (str, int, float)): - new_val = re.sub(r"\b" + re.escape(var_name) + r"\b", str(bound.value), new_val) - elif isinstance(bound, str): - new_val = re.sub(r"\b" + re.escape(var_name) + r"\b", bound, new_val) - if new_val != node.value: - return LiteralNode(new_val) return node if isinstance(node, TableNode): - # If the name variable is bound to a whole TableNode, return it directly - if isinstance(node.name, str) and node.name in memo: - val = memo[node.name] - if isinstance(val, TableNode): - return val - new_name = _subst_str(node.name, memo) new_alias = _subst_str(node.alias, memo) if node.alias is not None else None - return TableNode(new_name, new_alias) + return TableNode(node.name, new_alias) if isinstance(node, ColumnNode): new_name = _subst_str(node.name, memo) diff --git a/core/rule.py b/core/rule.py new file mode 100644 index 0000000..77d14c6 --- /dev/null +++ b/core/rule.py @@ -0,0 +1,50 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional +from core.ast.node import Node + + +@dataclass +class RuleV2: + pattern: str + rewrite: str + pattern_ast: Node + rewrite_ast: Node + mapping: Dict[str, str] + source_pattern_ast: Optional[Node] = None + source_rewrite_ast: Optional[Node] = None + source_pattern_sql: str = "" + source_rewrite_sql: str = "" + constraints: str = "" + actions: str = "" + id: Optional[Any] = None + key: Optional[str] = None + children: Optional[List[RuleV2]] = field(default=None) + + @classmethod + def from_dict(cls, d: dict) -> RuleV2: + return cls( + pattern=d.get("pattern", ""), + rewrite=d.get("rewrite", ""), + pattern_ast=d["pattern_ast"], + rewrite_ast=d["rewrite_ast"], + mapping=d.get("mapping", {}), + source_pattern_ast=d.get("source_pattern_ast"), + source_rewrite_ast=d.get("source_rewrite_ast"), + source_pattern_sql=d.get("source_pattern_sql", ""), + source_rewrite_sql=d.get("source_rewrite_sql", ""), + constraints=d.get("constraints", ""), + actions=d.get("actions", ""), + id=d.get("id"), + key=d.get("key"), + children=d.get("children"), + ) + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def __setitem__(self, key: str, value: Any) -> None: + setattr(self, key, value) + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) diff --git a/core/rule_generator_v2.py b/core/rule_generator_v2.py index 09dad5c..e38765c 100644 --- a/core/rule_generator_v2.py +++ b/core/rule_generator_v2.py @@ -62,6 +62,7 @@ ) from core.query_parser import QueryParser from core.query_formatter import QueryFormatter +from core.rule import RuleV2 from core.rule_parser_v2 import RuleParserV2, Scope, VarType, VarTypesInfo @@ -243,27 +244,27 @@ def _internal_variable_token_length_delta(internal_name: str) -> int: return 0 @staticmethod - def initialize_seed_rule(q0: str, q1: str) -> Dict[str, object]: - """Build the initial (un-generalized) rule dict for the rewrite pair q0 -> q1. + def initialize_seed_rule(q0: str, q1: str) -> RuleV2: + """Build the initial (un-generalized) rule for the rewrite pair q0 -> q1. - Parses both sides via RuleParserV2, snapshots the source ASTs/SQL, and returns a fresh rule dict carrying pattern, rewrite, pattern_ast, rewrite_ast, mapping, and empty constraints/actions. + Parses both sides via RuleParserV2, snapshots the source ASTs/SQL, and returns a fresh RuleV2 carrying pattern, rewrite, pattern_ast, rewrite_ast, mapping, and empty constraints/actions. """ parsed = RuleParserV2.parse(q0, q1) pattern = RuleGeneratorV2.deparse(copy.deepcopy(parsed.pattern_ast)) rewrite = RuleGeneratorV2.deparse(copy.deepcopy(parsed.rewrite_ast)) - return { - "pattern": pattern, - "rewrite": rewrite, - "pattern_ast": parsed.pattern_ast, - "rewrite_ast": parsed.rewrite_ast, - "source_pattern_ast": copy.deepcopy(parsed.pattern_ast), - "source_rewrite_ast": copy.deepcopy(parsed.rewrite_ast), - "source_pattern_sql": q0, - "source_rewrite_sql": q1, - "mapping": parsed.mapping, - "constraints": "", - "actions": "", - } + return RuleV2( + pattern=pattern, + rewrite=rewrite, + pattern_ast=parsed.pattern_ast, + rewrite_ast=parsed.rewrite_ast, + source_pattern_ast=copy.deepcopy(parsed.pattern_ast), + source_rewrite_ast=copy.deepcopy(parsed.rewrite_ast), + source_pattern_sql=q0, + source_rewrite_sql=q1, + mapping=parsed.mapping, + constraints="", + actions="", + ) RuleGeneralizations = ( "generalize_tables", @@ -275,7 +276,7 @@ def initialize_seed_rule(q0: str, q1: str) -> Dict[str, object]: ) @staticmethod - def generate_general_rule(q0: str, q1: str) -> Dict[str, object]: + def generate_general_rule(q0: str, q1: str) -> RuleV2: """Repeatedly apply every generalize_* step until the rule's fingerprint stops changing. Returns the most general rule reachable from the seed by exhaustively variablizing tables/columns/literals/subtrees, merging variable lists, and dropping branches. @@ -292,7 +293,7 @@ def generate_general_rule(q0: str, q1: str) -> Dict[str, object]: return general_rule @staticmethod - def generate_rule_graph(q0: str, q1: str) -> Dict[str, object]: + def generate_rule_graph(q0: str, q1: str) -> RuleV2: """Build the full BFS graph of generalizations rooted at the seed rule for q0 -> q1. Each node's children list is populated with the rules reachable in one variabilization/merge/drop step; nodes with the same fingerprint are deduplicated, so the graph is a DAG, not a tree. @@ -300,7 +301,7 @@ def generate_rule_graph(q0: str, q1: str) -> Dict[str, object]: seed_rule = RuleGeneratorV2.initialize_seed_rule(q0, q1) seed_fp = RuleGeneratorV2.fingerPrint(seed_rule) visited = {seed_fp: seed_rule} - queue: deque[Dict[str, object]] = deque([seed_rule]) + queue: deque[RuleV2] = deque([seed_rule]) while queue: base_rule = queue.popleft() base_rule["children"] = [] @@ -323,18 +324,18 @@ def generate_rule_graph(q0: str, q1: str) -> Dict[str, object]: return seed_rule @staticmethod - def recommend_simple_rules(examples: List[Dict[str, str]]) -> List[Dict[str, object]]: + def recommend_simple_rules(examples: List[Dict[str, str]]) -> List[RuleV2]: """Pick a small set of generalized rules that together cover every (q0, q1) example. Generates candidate rules per example, fingerprints them, and greedy set-covers the still-uncovered examples, breaking ties toward fewer variables. """ fingerprint_to_examples: Dict[str, Set[int]] = defaultdict(set) - fingerprint_to_rule: Dict[str, Dict[str, object]] = {} - example_candidates: List[List[Tuple[str, Dict[str, object]]]] = [] + fingerprint_to_rule: Dict[str, RuleV2] = {} + example_candidates: List[List[Tuple[str, RuleV2]]] = [] for index, example in enumerate(examples): seed = RuleGeneratorV2.initialize_seed_rule(example["q0"], example["q1"]) - candidates_with_fingerprints: List[Tuple[str, Dict[str, object]]] = [] + candidates_with_fingerprints: List[Tuple[str, RuleV2]] = [] for rule in RuleGeneratorV2._recommendation_candidates(seed): fp = RuleGeneratorV2.fingerPrint(rule) candidates_with_fingerprints.append((fp, rule)) @@ -345,11 +346,11 @@ def recommend_simple_rules(examples: List[Dict[str, str]]) -> List[Dict[str, obj example_candidates.append(candidates_with_fingerprints) uncovered = set(range(len(examples))) - ans: List[Dict[str, object]] = [] + ans: List[RuleV2] = [] for index, _example in enumerate(examples): if index not in uncovered: continue - chosen: Optional[Dict[str, object]] = None + chosen: Optional[RuleV2] = None remaining = set(uncovered) for fp, rule in example_candidates[index]: covered = fingerprint_to_examples.get(fp, set()).intersection(remaining) @@ -365,7 +366,7 @@ def recommend_simple_rules(examples: List[Dict[str, str]]) -> List[Dict[str, obj return ans @staticmethod - def _recommendation_signature(rule: Dict[str, object]) -> str: + def _recommendation_signature(rule: RuleV2) -> str: pattern_ast = rule.get("pattern_ast") rewrite_ast = rule.get("rewrite_ast") if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): @@ -487,11 +488,11 @@ def _alias_token(name) -> Optional[str]: ) @staticmethod - def _recommendation_candidates(seed: Dict[str, object]) -> List[Dict[str, object]]: - candidates: List[Dict[str, object]] = [] + def _recommendation_candidates(seed: RuleV2) -> List[RuleV2]: + candidates: List[RuleV2] = [] seed_sig = RuleGeneratorV2._recommendation_signature(seed) seen: Set[str] = {seed_sig} - queue: deque[Dict[str, object]] = deque([seed]) + queue: deque[RuleV2] = deque([seed]) max_candidates = RuleGeneratorV2._MAX_RECOMMENDATION_CANDIDATES while queue and len(candidates) < max_candidates: @@ -518,27 +519,34 @@ def _recommendation_candidates(seed: Dict[str, object]) -> List[Dict[str, object return candidates @staticmethod - def variablize_tables(rule: Dict[str, object]) -> List[Dict[str, object]]: + def variablize_tables(rule: RuleV2) -> List[RuleV2]: """Return one child rule per table that can still be replaced with a fresh element variable. Each child is the result of substituting a single table reference with on both pattern and rewrite sides. """ - pattern_ast = rule.get("pattern_ast") - rewrite_ast = rule.get("rewrite_ast") + pattern_ast = rule.pattern_ast + rewrite_ast = rule.rewrite_ast if not isinstance(pattern_ast, Node) or not isinstance(rewrite_ast, Node): raise TypeError("rule ASTs must be Node instances") return [RuleGeneratorV2.variablize_table(rule, table) for table in RuleGeneratorV2.tables(pattern_ast, rewrite_ast)] @staticmethod - def variablize_table(rule: Dict[str, object], table: Dict[str, str]) -> Dict[str, object]: + def _sync_rule_strings(rule: RuleV2) -> None: + rule.pattern = RuleGeneratorV2.deparse(rule.pattern_ast) + rule.rewrite = RuleGeneratorV2.deparse(rule.rewrite_ast) + + @staticmethod + def variablize_table(rule: Union[RuleV2, dict], table: Dict[str, str]) -> RuleV2: """Return a new rule where the named table (and its qualified column refs) is replaced by a fresh element variable. table is a {"value": , "name": } descriptor as produced by tables. Both ASTs are rewritten and re-deparsed; the input rule is not mutated. """ + if isinstance(rule, dict): + rule = RuleV2.from_dict(rule) new_rule = copy.deepcopy(rule) - mapping = copy.deepcopy(new_rule["mapping"]) + mapping = copy.deepcopy(new_rule.mapping) if not isinstance(mapping, dict): - raise TypeError("rule['mapping'] must be a dict[str, str]") + raise TypeError("rule.mapping must be a dict[str, str]") target_value = table.get("value") target_name = table.get("name") @@ -546,25 +554,24 @@ def variablize_table(rule: Dict[str, object], table: Dict[str, str]) -> Dict[str raise TypeError("table must have string keys 'value' and 'name'") mapping, external_name = RuleGeneratorV2._find_next_element_variable(mapping) - new_rule["mapping"] = mapping + new_rule.mapping = mapping - for key in ("pattern_ast", "rewrite_ast"): - ast = new_rule.get(key) + for attr in ("pattern_ast", "rewrite_ast"): + ast = getattr(new_rule, attr) if not isinstance(ast, Node): - raise TypeError(f"rule['{key}'] must be an AST Node") - new_rule[key] = RuleGeneratorV2._replace_table_in_ast( + raise TypeError(f"rule.{attr} must be an AST Node") + setattr(new_rule, attr, RuleGeneratorV2._replace_table_in_ast( ast, target_value=target_value, target_name=target_name, placeholder_token=external_name, - ) + )) - new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] - new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] + RuleGeneratorV2._sync_rule_strings(new_rule) return new_rule @staticmethod - def variablize_columns(rule: Dict[str, object]) -> List[Dict[str, object]]: + def variablize_columns(rule: RuleV2) -> List[RuleV2]: """Return one child rule per column that can still be replaced with a fresh element variable. Each child substitutes one un-variablized column name with on both sides. @@ -576,7 +583,7 @@ def variablize_columns(rule: Dict[str, object]) -> List[Dict[str, object]]: return [RuleGeneratorV2.variablize_column(rule, column) for column in RuleGeneratorV2.columns(pattern_ast, rewrite_ast)] @staticmethod - def variablize_column(rule: Dict[str, object], column: str) -> Dict[str, object]: + def variablize_column(rule: RuleV2, column: str) -> RuleV2: """Return a new rule where every occurrence of column (in both ASTs) is replaced by a fresh element variable. Allocates the next available and re-deparses both sides. The input rule is not mutated. @@ -595,12 +602,11 @@ def variablize_column(rule: Dict[str, object], column: str) -> Dict[str, object] raise TypeError(f"rule['{key}'] must be an AST Node") new_rule[key] = RuleGeneratorV2._replace_column_in_ast(ast, column, external_name) - new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] - new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] + RuleGeneratorV2._sync_rule_strings(new_rule) return new_rule @staticmethod - def variablize_literals(rule: Dict[str, object]) -> List[Dict[str, object]]: + def variablize_literals(rule: RuleV2) -> List[Dict[str, object]]: """Return one child rule per literal that can still be replaced with a fresh element variable. Considers literals that recur within one side or are shared across both sides. @@ -612,7 +618,7 @@ def variablize_literals(rule: Dict[str, object]) -> List[Dict[str, object]]: return [RuleGeneratorV2.variablize_literal(rule, literal) for literal in RuleGeneratorV2.literals(pattern_ast, rewrite_ast)] @staticmethod - def variablize_literal(rule: Dict[str, object], literal: Union[str, numbers.Number]) -> Dict[str, object]: + def variablize_literal(rule: RuleV2, literal: Union[str, numbers.Number]) -> RuleV2: """Return a new rule where every occurrence of literal (in both ASTs) is replaced by a fresh element variable. Allocates the next available and re-deparses both sides. The input rule is not mutated. @@ -631,18 +637,17 @@ def variablize_literal(rule: Dict[str, object], literal: Union[str, numbers.Numb raise TypeError(f"rule['{key}'] must be an AST Node") new_rule[key] = RuleGeneratorV2._replace_literal_in_ast(ast, literal, external_name) - new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] - new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] + RuleGeneratorV2._sync_rule_strings(new_rule) return new_rule @staticmethod - def variablize_subtrees(rule: Dict[str, object]) -> List[Dict[str, object]]: + def variablize_subtrees(rule: RuleV2) -> List[Dict[str, object]]: """Return one child rule per subtree shared by pattern and rewrite that can be collapsed into an element variable. """ - return [RuleGeneratorV2.variablize_subtree(rule, subtree) for subtree in RuleGeneratorV2.subtrees(rule["pattern_ast"], rule["rewrite_ast"])] # type: ignore[arg-type,index] + return [RuleGeneratorV2.variablize_subtree(rule, subtree) for subtree in RuleGeneratorV2.subtrees(rule.pattern_ast, rule.rewrite_ast)] @staticmethod - def variablize_subtree(rule: Dict[str, object], subtree: Node) -> Dict[str, object]: + def variablize_subtree(rule: RuleV2, subtree: Node) -> RuleV2: """Return a new rule where every occurrence of subtree (in both ASTs) is replaced by a fresh element variable. Allocates the next available in the mapping and re-deparses both sides. The input rule is not mutated. @@ -661,12 +666,11 @@ def variablize_subtree(rule: Dict[str, object], subtree: Node) -> Dict[str, obje raise TypeError(f"rule['{key}'] must be an AST Node") new_rule[key] = RuleGeneratorV2._replace_subtree_in_ast(ast, subtree, ElementVariableNode(external_name)) - new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] - new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] + RuleGeneratorV2._sync_rule_strings(new_rule) return new_rule @staticmethod - def merge_variables(rule: Dict[str, object]) -> List[Dict[str, object]]: + def merge_variables(rule: RuleV2) -> List[Dict[str, object]]: """Return one child rule per element-variable list collapsible into a single set variable <>. Each candidate list is the intersection of an AND-chain or SELECT-list on both sides. @@ -678,7 +682,7 @@ def merge_variables(rule: Dict[str, object]) -> List[Dict[str, object]]: return [RuleGeneratorV2.merge_variable_list(rule, variable_list) for variable_list in RuleGeneratorV2.variable_lists(pattern_ast, rewrite_ast)] @staticmethod - def merge_variable_list(rule: Dict[str, object], variable_list: List[str]) -> Dict[str, object]: + def merge_variable_list(rule: RuleV2, variable_list: List[str]) -> RuleV2: """Return a new rule where the given element variables are collapsed into a single set variable <>. Allocates the next available set variable and rewrites both ASTs (and their deparsed forms) so consecutive members of variable_list share that one set variable. The input rule is not mutated. @@ -698,12 +702,11 @@ def merge_variable_list(rule: Dict[str, object], variable_list: List[str]) -> Di raise TypeError(f"rule['{key}'] must be an AST Node") new_rule[key] = RuleGeneratorV2._merge_variable_list_in_ast(ast, var_set, set_name) - new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] - new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] + RuleGeneratorV2._sync_rule_strings(new_rule) return new_rule @staticmethod - def drop_branches(rule: Dict[str, object]) -> List[Dict[str, object]]: + def drop_branches(rule: RuleV2) -> List[Dict[str, object]]: """Return one child rule per droppable branch (a clause or AND/OR conjunct that is fully variablized on both sides). Each child removes one branch from both pattern and rewrite, producing a strictly more general rule. @@ -715,7 +718,7 @@ def drop_branches(rule: Dict[str, object]) -> List[Dict[str, object]]: return [RuleGeneratorV2.drop_branch(rule, branch) for branch in RuleGeneratorV2.branches(pattern_ast, rewrite_ast)] @staticmethod - def drop_branch(rule: Dict[str, object], branch: Dict[str, object]) -> Dict[str, object]: + def drop_branch(rule: RuleV2, branch: Dict[str, object]) -> RuleV2: """Return a new rule with branch removed from both pattern and rewrite ASTs. branch is a descriptor produced by branches (e.g. {"key": "where", "value": ...}). The input rule is not mutated. @@ -726,12 +729,11 @@ def drop_branch(rule: Dict[str, object], branch: Dict[str, object]) -> Dict[str, if not isinstance(ast, Node): raise TypeError(f"rule['{key}'] must be an AST Node") new_rule[key] = RuleGeneratorV2._drop_branch_in_ast(ast, branch) - new_rule["pattern"] = RuleGeneratorV2.deparse(new_rule["pattern_ast"]) # type: ignore[index] - new_rule["rewrite"] = RuleGeneratorV2.deparse(new_rule["rewrite_ast"]) # type: ignore[index] + RuleGeneratorV2._sync_rule_strings(new_rule) return new_rule @staticmethod - def generalize_tables(rule: Dict[str, object]) -> Dict[str, object]: + def generalize_tables(rule: RuleV2) -> RuleV2: """Return a new rule with every replaceable table variabilized in one pass. Walks the candidate tables and applies variablize_table repeatedly. Returns a fresh dict; the input rule is not mutated. @@ -743,12 +745,12 @@ def generalize_tables(rule: Dict[str, object]) -> Dict[str, object]: raise TypeError("rule ASTs must be Node instances") for table in RuleGeneratorV2.tables(pattern_ast, rewrite_ast): new_rule = RuleGeneratorV2.variablize_table(new_rule, table) - pattern_ast = new_rule["pattern_ast"] # type: ignore[assignment] - rewrite_ast = new_rule["rewrite_ast"] # type: ignore[assignment] + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast return new_rule @staticmethod - def generalize_columns(rule: Dict[str, object]) -> Dict[str, object]: + def generalize_columns(rule: RuleV2) -> RuleV2: """Return a new rule with every replaceable column variabilized in one pass. Returns a fresh dict; the input is not mutated. @@ -760,12 +762,12 @@ def generalize_columns(rule: Dict[str, object]) -> Dict[str, object]: raise TypeError("rule ASTs must be Node instances") for column in RuleGeneratorV2.columns(pattern_ast, rewrite_ast): new_rule = RuleGeneratorV2.variablize_column(new_rule, column) - pattern_ast = new_rule["pattern_ast"] # type: ignore[assignment] - rewrite_ast = new_rule["rewrite_ast"] # type: ignore[assignment] + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast return new_rule @staticmethod - def generalize_literals(rule: Dict[str, object]) -> Dict[str, object]: + def generalize_literals(rule: RuleV2) -> RuleV2: """Return a new rule with every replaceable literal variabilized in one pass. Returns a fresh dict; the input is not mutated. @@ -777,12 +779,12 @@ def generalize_literals(rule: Dict[str, object]) -> Dict[str, object]: raise TypeError("rule ASTs must be Node instances") for literal in RuleGeneratorV2.literals(pattern_ast, rewrite_ast): new_rule = RuleGeneratorV2.variablize_literal(new_rule, literal) - pattern_ast = new_rule["pattern_ast"] # type: ignore[assignment] - rewrite_ast = new_rule["rewrite_ast"] # type: ignore[assignment] + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast return new_rule @staticmethod - def generalize_subtrees(rule: Dict[str, object]) -> Dict[str, object]: + def generalize_subtrees(rule: RuleV2) -> RuleV2: """Return a new rule with every shared, fully-variablized subtree collapsed into a single element variable. Returns a fresh dict; the input is not mutated. @@ -794,12 +796,12 @@ def generalize_subtrees(rule: Dict[str, object]) -> Dict[str, object]: raise TypeError("rule ASTs must be Node instances") for subtree in RuleGeneratorV2.subtrees(pattern_ast, rewrite_ast): new_rule = RuleGeneratorV2.variablize_subtree(new_rule, subtree) - pattern_ast = new_rule["pattern_ast"] # type: ignore[assignment] - rewrite_ast = new_rule["rewrite_ast"] # type: ignore[assignment] + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast return new_rule @staticmethod - def generalize_variables(rule: Dict[str, object]) -> Dict[str, object]: + def generalize_variables(rule: RuleV2) -> RuleV2: """Return a new rule with every mergeable element-variable list collapsed into a set variable. Returns a fresh dict; the input is not mutated. @@ -812,12 +814,12 @@ def generalize_variables(rule: Dict[str, object]) -> Dict[str, object]: for variable_list in RuleGeneratorV2.variable_lists(pattern_ast, rewrite_ast): if variable_list: new_rule = RuleGeneratorV2.merge_variable_list(new_rule, variable_list) - pattern_ast = new_rule["pattern_ast"] # type: ignore[assignment] - rewrite_ast = new_rule["rewrite_ast"] # type: ignore[assignment] + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast return new_rule @staticmethod - def generalize_branches(rule: Dict[str, object]) -> Dict[str, object]: + def generalize_branches(rule: RuleV2) -> RuleV2: """Return a new rule with every droppable branch removed in one pass. Returns a fresh dict; the input is not mutated. @@ -829,8 +831,8 @@ def generalize_branches(rule: Dict[str, object]) -> Dict[str, object]: raise TypeError("rule ASTs must be Node instances") for branch in RuleGeneratorV2.branches(pattern_ast, rewrite_ast): new_rule = RuleGeneratorV2.drop_branch(new_rule, branch) - pattern_ast = new_rule["pattern_ast"] # type: ignore[assignment] - rewrite_ast = new_rule["rewrite_ast"] # type: ignore[assignment] + pattern_ast = new_rule.pattern_ast + rewrite_ast = new_rule.rewrite_ast return new_rule @@ -1342,11 +1344,8 @@ def _is_branch_clause(key: str, clause: Node) -> bool: return RuleGeneratorV2._is_branch_node(clause) return False if key == "where": - if isinstance(clause, WhereNode): - if len(clause.children) == 1: - return RuleGeneratorV2._is_branch_node(clause.children[0]) - return RuleGeneratorV2._is_branch_node(clause) - return RuleGeneratorV2._is_branch_node(clause) + if isinstance(clause, WhereNode) and len(clause.children) == 1: + return RuleGeneratorV2._is_branch_node(clause.children[0]) return RuleGeneratorV2._is_branch_node(clause) @staticmethod @@ -1560,13 +1559,10 @@ def _replace_subtree_in_ast(ast: Node, subtree: Node, replacement: Node, parent: replacements: List[Tuple[Node, Node]] = [] new_children: Set[Node] = set() for child in children: - if isinstance(child, Node): - new_child = RuleGeneratorV2._replace_subtree_in_ast(child, subtree, replacement, ast) - new_children.add(new_child) - if new_child is not child: - replacements.append((child, new_child)) - else: - new_children.add(child) # type: ignore[arg-type] + new_child = RuleGeneratorV2._replace_subtree_in_ast(child, subtree, replacement, ast) + new_children.add(new_child) + if new_child is not child: + replacements.append((child, new_child)) ast.children = new_children for old, new in replacements: RuleGeneratorV2._resync_parallel_attrs(ast, old, new) @@ -1785,13 +1781,10 @@ def _visit(node: Node, parent: Optional[Node]) -> Node: new_set: Set[Node] = set() replacements: List[Tuple[Node, Node]] = [] for child in children: - if isinstance(child, Node): - new_child = _visit(child, node) - new_set.add(new_child) - if new_child is not child: - replacements.append((child, new_child)) - else: - new_set.add(child) # type: ignore[arg-type] + new_child = _visit(child, node) + new_set.add(new_child) + if new_child is not child: + replacements.append((child, new_child)) node.children = new_set for old, new in replacements: RuleGeneratorV2._resync_parallel_attrs(node, old, new) @@ -1995,27 +1988,10 @@ def _extract_partial_sql(full_sql: str, scope: Scope) -> str: return full_sql.replace("SELECT * FROM t ", "", 1) return full_sql.replace("SELECT * FROM t WHERE ", "", 1) - @staticmethod - def _normalize_placeholder_numbers(text: str, start_token: str, end_token: str) -> str: - out = text - start = 0 - while True: - i = out.find(start_token, start) - if i < 0: - break - j = out.find(end_token, i + len(start_token)) - if j < 0: - break - inner = out[i + len(start_token):j] - if inner.isdigit(): - out = out[: i + len(start_token)] + out[j:] - start = i + len(start_token) - else: - start = j + len(end_token) - return out + @staticmethod - def fingerPrint(rule: Dict[str, object]) -> str: + def fingerPrint(rule: RuleV2) -> str: """Return a stable fingerprint string for rule based on its deparsed pattern. Variable indices are normalized so that two rules that differ only in variable numbering share a fingerprint. Used to deduplicate rules in the generalization graph. @@ -2033,12 +2009,12 @@ def _fingerPrint(fingerprint: str) -> str: out = re.sub(r"", "", out) out = re.sub(r"<>", "<>", out) out = re.sub(r"''", "''", out) - out = RuleGeneratorV2._normalize_placeholder_numbers(out, "") - out = RuleGeneratorV2._normalize_placeholder_numbers(out, "<>") + out = re.sub(r"", "", out) + out = re.sub(r"<>", "<>", out) return out @staticmethod - def numberOfVariables(rule: Dict[str, object]) -> int: + def numberOfVariables(rule: RuleV2) -> int: """Return the count of declared variables in rule['mapping']. Used as a tie-breaker when picking the simplest rule among equivalents. @@ -2057,65 +2033,15 @@ def unify_variable_names(q0: str, q1: str) -> Tuple[str, str]: mapping: Dict[str, str] = {} counter = 1 - def _scan_tokens(text: str) -> List[str]: - tokens: List[str] = [] - i = 0 - while i < len(text): - if text.startswith("<<", i): - j = text.find(">>", i + 2) - if j != -1: - token = text[i : j + 2] - inner = token[2:-2] - if inner and all(ch.isalnum() or ch == "_" for ch in inner): - tokens.append(token) - i = j + 2 - continue - if text[i] == "<": - j = text.find(">", i + 1) - if j != -1: - token = text[i : j + 1] - inner = token[1:-1] - if inner and all(ch.isalnum() or ch == "_" for ch in inner): - tokens.append(token) - i = j + 1 - continue - i += 1 - return tokens - - for token in _scan_tokens(q0) + _scan_tokens(q1): - if token in mapping: - continue - if token.startswith("<<") and token.endswith(">>"): - mapping[token] = f"<>" - else: - mapping[token] = f"" - counter += 1 - - def _replace_all(text: str) -> str: - out: List[str] = [] - i = 0 - while i < len(text): - if text.startswith("<<", i): - j = text.find(">>", i + 2) - if j != -1: - token = text[i : j + 2] - if token in mapping: - out.append(mapping[token]) - i = j + 2 - continue - if text[i] == "<": - j = text.find(">", i + 1) - if j != -1: - token = text[i : j + 1] - if token in mapping: - out.append(mapping[token]) - i = j + 1 - continue - out.append(text[i]) - i += 1 - return "".join(out) - - return _replace_all(q0), _replace_all(q1) + for token in re.findall(r"<<\w+>>|<\w+>", q0 + " " + q1): + if token not in mapping: + mapping[token] = f"<>" if token.startswith("<<") else f"" + counter += 1 + + def _replace(text: str) -> str: + return re.sub(r"<<\w+>>|<\w+>", lambda m: mapping.get(m.group(), m.group()), text) + + return _replace(q0), _replace(q1) @staticmethod def _find_next_element_variable(mapping: Dict[str, str]) -> Tuple[Dict[str, str], str]: diff --git a/data/rules.py b/data/rules.py index c99a6ff..3981f0a 100644 --- a/data/rules.py +++ b/data/rules.py @@ -2,6 +2,7 @@ from core.rule_parser import RuleParser from core.rule_parser_v2 import RuleParserV2 +from core.rule import RuleV2 rules = [ # PostgresSQL Rules @@ -764,7 +765,7 @@ def get_rule(key: str) -> dict: # fetch one rule by key using the v2 AST-based parser # -def get_rule_v2(key: str) -> dict: +def get_rule_v2(key: str) -> RuleV2: rule = next(filter(lambda x: x['key'] == key, rules), None) if rule is None: raise ValueError(f"Rule {key} not found") @@ -772,20 +773,16 @@ def get_rule_v2(key: str) -> dict: # TODO: reuse v1 parse_actions? identity_mapping = json.dumps({k: k for k in result.mapping}) actions_json = RuleParser.parse_actions(rule['actions'], identity_mapping) - return { - 'id': rule['id'], - 'key': rule['key'], - 'name': rule['name'], - 'pattern': rule['pattern'], - 'pattern_ast': result.pattern_ast, - 'rewrite': rule['rewrite'], - 'rewrite_ast': result.rewrite_ast, - 'mapping': result.mapping, - 'actions': rule['actions'], - 'actions_json': json.loads(actions_json), - 'database': rule['database'], - 'examples': rule['examples'], - } + return RuleV2( + id=rule['id'], + key=rule['key'], + pattern=rule['pattern'], + pattern_ast=result.pattern_ast, + rewrite=rule['rewrite'], + rewrite_ast=result.rewrite_ast, + mapping=result.mapping, + actions=rule['actions'], + ) # return a list of rules (json attributes are in str) diff --git a/tests/test_rule_generator_v2.py b/tests/test_rule_generator_v2.py index a971827..0479a58 100644 --- a/tests/test_rule_generator_v2.py +++ b/tests/test_rule_generator_v2.py @@ -7,21 +7,20 @@ from core.query_formatter import QueryFormatter from core.query_parser import QueryParser from core.rule_generator_v2 import RuleGeneratorV2 +from core.rule import RuleV2 from core.rule_parser_v2 import RuleParserV2, VarType from data.rules import get_rule_v2 as get_rule -def _build_rule(pattern: str, rewrite: str): +def _build_rule(pattern: str, rewrite: str) -> RuleV2: parsed = RuleParserV2.parse(pattern, rewrite) - return { - "pattern": pattern, - "rewrite": rewrite, - "pattern_ast": parsed.pattern_ast, - "rewrite_ast": parsed.rewrite_ast, - "mapping": parsed.mapping, - "constraints": "", - "actions": "", - } + return RuleV2( + pattern=pattern, + rewrite=rewrite, + pattern_ast=parsed.pattern_ast, + rewrite_ast=parsed.rewrite_ast, + mapping=parsed.mapping, + ) def _has_clause(query: QueryNode, clause_type: NodeType) -> bool: @@ -1884,7 +1883,7 @@ def test_generate_rule_graph_0(): q0 = "CAST(created_at AS DATE)" q1 = "created_at" root_rule = RuleGeneratorV2.generate_rule_graph(q0, q1) - assert isinstance(root_rule, dict) + assert isinstance(root_rule, RuleV2) children = root_rule["children"] assert len(children) == 1 child_rule = children[0]