Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/ast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LiteralNode,
ElementVariableNode,
SetVariableNode,
VariableLiteralNode,
OperatorNode,
FunctionNode,
SelectNode,
Expand All @@ -37,6 +38,7 @@
'LiteralNode',
'ElementVariableNode',
'SetVariableNode',
'VariableLiteralNode',
'OperatorNode',
'FunctionNode',
'SelectNode',
Expand Down
1 change: 1 addition & 0 deletions core/ast/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class NodeType(Enum):
# VarSQL specific
VAR = "var"
VARSET = "varset"
VAR_LITERAL = "var_literal"

# Operators
OPERATOR = "operator"
Expand Down
33 changes: 29 additions & 4 deletions core/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __hash__(self):

class TableNode(Node):
"""Table reference node"""
def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs):
def __init__(self, _name: str, _alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs):
super().__init__(NodeType.TABLE, **kwargs)
self.name = _name
self.alias = _alias
Expand Down Expand Up @@ -80,7 +80,7 @@ def __hash__(self):

class ColumnNode(Node):
"""Column reference node"""
def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Optional[str] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs):
def __init__(self, _name: str, _alias: Optional[Union[str, 'ElementVariableNode']] = None, _parent_alias: Optional[Union[str, 'ElementVariableNode']] = None, _parent: Optional[TableNode|SubqueryNode] = None, **kwargs):
super().__init__(NodeType.COLUMN, **kwargs)
self.name = _name
self.alias = _alias
Expand Down Expand Up @@ -172,7 +172,7 @@ def __hash__(self):

class ElementVariableNode(Node):
"""Rule element variable ``<name>`` (see ``VarType.ElementVariable`` in rule_parser_v2)."""
def __init__(self, _name: str, parent_alias: Optional[str] = None, alias: Optional[str] = None, **kwargs):
def __init__(self, _name: str, parent_alias: Optional[Union[str, 'ElementVariableNode']] = None, alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs):
super().__init__(NodeType.VAR, **kwargs)
self.name = _name
self.parent_alias = parent_alias
Expand Down Expand Up @@ -202,6 +202,31 @@ def __hash__(self):
return hash((super().__hash__(), self.name))


class VariableLiteralNode(Node):
"""A string literal placeholder, e.g. ``'%<x1>%'`` in a LIKE predicate.

``prefix`` and ``suffix`` capture surrounding wildcard characters so
``LIKE '%foo%'`` → ``VariableLiteralNode('x1', prefix='%', suffix='%')``.
"""
def __init__(self, _name: str, prefix: str = "", suffix: str = "",
_alias: Optional[str] = None, **kwargs):
super().__init__(NodeType.VAR_LITERAL, **kwargs)
self.name = _name
self.prefix = prefix
self.suffix = suffix
self.alias = _alias

def __eq__(self, other):
if not isinstance(other, VariableLiteralNode):
return False
return (super().__eq__(other) and self.name == other.name
and self.prefix == other.prefix and self.suffix == other.suffix
and self.alias == other.alias)

def __hash__(self):
return hash((super().__hash__(), self.name, self.prefix, self.suffix, self.alias))


class OperatorNode(Node):
"""Operator node"""
def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwargs):
Expand Down Expand Up @@ -233,7 +258,7 @@ def __init__(self, _operand: Node, _name: str, **kwargs):

class FunctionNode(Node):
"""Function call node"""
def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs):
def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[Union[str, 'ElementVariableNode']] = None, **kwargs):
if _args is None:
_args = []
super().__init__(NodeType.FUNCTION, children=_args, **kwargs)
Expand Down
22 changes: 0 additions & 22 deletions core/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,10 @@

from __future__ import annotations

import re
from typing import List

from core.ast.node import Node, OperatorNode, UnaryOperatorNode

_PLACEHOLDER_PREFIXES = ("x", "y")


def is_placeholder_name(name: str) -> bool:
"""Return True when name is a generator-internal placeholder identifier.

Matches the parser-friendly tokens (__rv_x?__, __rvs_y?__) and bare x?/y?
external names.
"""
lower = name.lower()
if re.fullmatch(r"__rv_[xy]\d+__", lower):
return True
if re.fullmatch(r"__rvs_[xy]\d+__", lower):
return True
for prefix in _PLACEHOLDER_PREFIXES:
if lower.startswith(prefix):
suffix = lower[len(prefix):]
if suffix.isdigit():
return True
return False


def flatten_logical_operands(node: Node, op_name: str) -> List[Node]:
"""Flatten a left-associative tree of binary ``op_name`` (e.g. AND/OR).
Expand Down
86 changes: 37 additions & 49 deletions core/query_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,29 @@
SubqueryNode,
ElementVariableNode,
SetVariableNode,
VariableLiteralNode,
)
from core.ast.enums import NodeType, JoinType
from core.ast.node import Node
from core.ast.utils import flatten_logical_operands, is_placeholder_name


def _placeholder_token(name: str) -> str:
if name.lower().startswith("y"):
return f"__rvs_{name}__"
return f"__rv_{name}__"
from core.ast.utils import flatten_logical_operands


def _normalize_placeholder_tokens(sql: str) -> str:
out = sql
out = _replace_wrapped_tokens(out, "__rvs_", "__", "<<", ">>")
out = _replace_wrapped_tokens(out, "__rv_", "__", "<", ">")
out = re.sub(r"__rvs_(\w+)__", r"<<\1>>", sql)
out = re.sub(r"__rv_(\w+)__", r"<\1>", out)
return out


def _replace_wrapped_tokens(text: str, prefix: str, suffix: str, open_marker: str, close_marker: str) -> str:
out = text
start = 0
while True:
i = out.find(prefix, start)
if i < 0:
break
j = out.find(suffix, i + len(prefix))
if j < 0:
break
inner = out[i + len(prefix):j]
if inner and all(ch.isalnum() or ch == "_" for ch in inner):
replacement = f"{open_marker}{inner}{close_marker}"
out = out[:i] + replacement + out[j + len(suffix):]
start = i + len(replacement)
else:
start = i + 1
return out
def _render_alias(alias) -> str:
"""Render an alias/parent_alias field to a string for mosql output.

Concrete string aliases pass through unchanged; ElementVariableNode aliases
emit their placeholder token (``__rv_name__``), which ``_normalize_placeholder_tokens``
later converts to ``<name>``.
"""
if isinstance(alias, ElementVariableNode):
return f"__rv_{alias.name}__"
return alias


class QueryFormatter:
Expand Down Expand Up @@ -121,11 +107,18 @@ def ast_to_json(node: Node) -> dict:
result['orderby'] = format_order_by(child)
elif child.type == NodeType.LIMIT:
lv = child.limit
if isinstance(lv, str) and is_placeholder_name(lv) and not (lv.startswith("__rv_") or lv.startswith("__rvs_")):
lv = _placeholder_token(lv)
if isinstance(lv, ElementVariableNode):
lv = f"__rv_{lv.name}__"
elif isinstance(lv, SetVariableNode):
lv = f"__rvs_{lv.name}__"
result['limit'] = lv
elif child.type == NodeType.OFFSET:
result['offset'] = child.offset
ov = child.offset
if isinstance(ov, ElementVariableNode):
ov = f"__rv_{ov.name}__"
elif isinstance(ov, SetVariableNode):
ov = f"__rvs_{ov.name}__"
result['offset'] = ov

return result

Expand All @@ -147,7 +140,7 @@ def format_select(select_node: SelectNode) -> dict:
for child in children:
item = {'value': format_expression(child)}
if hasattr(child, 'alias') and child.alias:
item['name'] = child.alias
item['name'] = _render_alias(child.alias)
items.append(item)

select_key = 'select_distinct' if select_node.distinct else 'select'
Expand Down Expand Up @@ -252,28 +245,21 @@ def format_source(node: Node) -> dict:
subquery_child = list(node.children)[0]
result = {'value': ast_to_json(subquery_child)}
if node.alias:
result['name'] = node.alias
result['name'] = _render_alias(node.alias)
return result
elif node.type == NodeType.VAR:
result = {'value': f"__rv_{node.name}__"}
if node.alias:
result['name'] = node.alias
result['name'] = _render_alias(node.alias)
return result
raise ValueError(f"Unsupported source type: {node.type}")


def format_table(table_node: TableNode) -> dict:
"""Format a table reference.

TableNode names from the rule generator are always concrete after the refactor;
however RuleParserV2 still produces TableNode(name="x1") for user-written table
variable placeholders like <x1>, so is_placeholder_name is kept as a safety
fallback for those cases.
"""
name = _placeholder_token(table_node.name) if is_placeholder_name(table_node.name) else table_node.name
result = {'value': name}
"""Format a table reference."""
result = {'value': table_node.name}
if table_node.alias:
result['name'] = table_node.alias
result['name'] = _render_alias(table_node.alias)
return result


Expand Down Expand Up @@ -335,17 +321,19 @@ def format_expression(node: Node):
if node.type == NodeType.VAR:
token = f"__rv_{node.name}__"
if node.parent_alias:
pa = _placeholder_token(node.parent_alias) if is_placeholder_name(node.parent_alias) else node.parent_alias
pa = _render_alias(node.parent_alias)
return f"{pa}.{token}"
return token
if node.type == NodeType.VARSET:
return f"__rvs_{node.name}__"
if node.type == NodeType.VAR_LITERAL:
return {'literal': f"{node.prefix}__rv_{node.name}__{node.suffix}"}

if node.type == NodeType.COLUMN:
col_token = _placeholder_token(node.name) if is_placeholder_name(node.name) else node.name
if node.parent_alias:
pa_token = _placeholder_token(node.parent_alias) if is_placeholder_name(node.parent_alias) else node.parent_alias
return f"{pa_token}.{col_token}"
return col_token
pa_token = _render_alias(node.parent_alias)
return f"{pa_token}.{node.name}"
return node.name

elif node.type == NodeType.LITERAL:
if node.value is None:
Expand Down
Loading