Skip to content

Commit 9542d3e

Browse files
committed
Add provenance to LQP parser
1 parent 646e9e4 commit 9542d3e

31 files changed

Lines changed: 8132 additions & 9552 deletions

meta/src/meta/codegen_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def gen_list_type(self, element_type: str) -> str:
174174
return f"list[{element_type}]"
175175

176176
def gen_option_type(self, element_type: str) -> str:
177-
return f"Optional[{element_type}]"
177+
return f"{element_type} | None"
178178

179179
def gen_dict_type(self, key_type: str, value_type: str) -> str:
180180
return f"dict[{key_type}, {value_type}]"

meta/src/meta/codegen_templates.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ class BuiltinTemplate:
106106
"format_bytes": BuiltinTemplate('"0x" + {0}.hex()'),
107107
"pp_dispatch": BuiltinTemplate("None", ["self.pprint_dispatch({0})"]),
108108
"get_at": BuiltinTemplate("{0}[{1}]"),
109+
# Provenance tracking
110+
"push_path": BuiltinTemplate("None", ["self.push_path({0})"]),
111+
"pop_path": BuiltinTemplate("None", ["self.pop_path()"]),
112+
"span_start": BuiltinTemplate("self.span_start()"),
113+
"record_span": BuiltinTemplate("None", ["self.record_span({0})"]),
109114
}
110115

111116

@@ -191,6 +196,11 @@ class BuiltinTemplate:
191196
"format_bytes": BuiltinTemplate('"0x" * bytes2hex({0})'),
192197
"pp_dispatch": BuiltinTemplate("nothing", ["_pprint_dispatch(pp, {0})"]),
193198
"get_at": BuiltinTemplate("{0}[{1} + 1]"),
199+
# Provenance tracking
200+
"push_path": BuiltinTemplate("nothing", ["push_path!(parser, {0})"]),
201+
"pop_path": BuiltinTemplate("nothing", ["pop_path!(parser)"]),
202+
"span_start": BuiltinTemplate("span_start(parser)"),
203+
"record_span": BuiltinTemplate("nothing", ["record_span!(parser, {0})"]),
194204
}
195205

196206

@@ -280,6 +290,11 @@ class BuiltinTemplate:
280290
"format_bytes": BuiltinTemplate('fmt.Sprintf("0x%x", {0})'),
281291
"pp_dispatch": BuiltinTemplate("nil", ["p.pprintDispatch({0})"]),
282292
"get_at": BuiltinTemplate("{0}[{1}]"),
293+
# Provenance tracking
294+
"push_path": BuiltinTemplate("nil", ["p.pushPath(int({0}))"]),
295+
"pop_path": BuiltinTemplate("nil", ["p.popPath()"]),
296+
"span_start": BuiltinTemplate("int64(p.spanStart())"),
297+
"record_span": BuiltinTemplate("nil", ["p.recordSpan(int({0}))"]),
283298
}
284299

285300
__all__ = [

meta/src/meta/parser_gen.py

Lines changed: 191 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
Terminal,
9191
)
9292
from .grammar_utils import is_epsilon, rhs_elements
93+
from .proto_ast import ProtoMessage
9394
from .target import (
9495
Assign,
9596
BaseType,
@@ -100,6 +101,8 @@
100101
ListExpr,
101102
ListType,
102103
Lit,
104+
NewMessage,
105+
OneOf,
103106
ParseNonterminal,
104107
ParseNonterminalDef,
105108
Seq,
@@ -131,29 +134,117 @@ class AmbiguousGrammarError(Exception):
131134
pass
132135

133136

137+
def _build_param_field_numbers(
138+
action: Lambda,
139+
proto_messages: dict[tuple[str, str], ProtoMessage] | None,
140+
) -> dict[int, int]:
141+
"""Map lambda parameter indices to protobuf field numbers.
142+
143+
Inspects the semantic action's body. When it is a NewMessage, cross-references
144+
field names with proto_messages to find field numbers.
145+
146+
Returns dict mapping parameter index to proto field number.
147+
"""
148+
if proto_messages is None:
149+
return {}
150+
body = action.body
151+
if not isinstance(body, NewMessage):
152+
return {}
153+
key = (body.module, body.name)
154+
proto_msg = proto_messages.get(key)
155+
if proto_msg is None:
156+
return {}
157+
# Build name -> field number map from proto definition
158+
name_to_number: dict[str, int] = {}
159+
for f in proto_msg.fields:
160+
name_to_number[f.name] = f.number
161+
for oneof in proto_msg.oneofs:
162+
for f in oneof.fields:
163+
name_to_number[f.name] = f.number
164+
# Map each message field to the lambda param that provides its value
165+
param_names = [p.name for p in action.params]
166+
result: dict[int, int] = {}
167+
for field_name, field_expr in body.fields:
168+
field_num = name_to_number.get(field_name)
169+
if field_num is None:
170+
continue
171+
# Determine which param provides this field
172+
param_name = _extract_param_name(field_expr, param_names)
173+
if param_name is not None and param_name in param_names:
174+
param_idx = param_names.index(param_name)
175+
result[param_idx] = field_num
176+
return result
177+
178+
179+
def _extract_param_name(expr: TargetExpr, param_names: list[str]) -> str | None:
180+
"""Extract the parameter name from a field expression.
181+
182+
Handles:
183+
- Direct Var reference
184+
- Call(OneOf(...), [Var(...)]) wrapper
185+
- Call(Builtin(...), [Var(...), ...]) e.g. unwrap_option_or
186+
"""
187+
from .target import Builtin
188+
189+
if isinstance(expr, Var) and expr.name in param_names:
190+
return expr.name
191+
if isinstance(expr, Call):
192+
if isinstance(expr.func, OneOf) and expr.args:
193+
return _extract_param_name(expr.args[0], param_names)
194+
if isinstance(expr.func, Builtin) and expr.args:
195+
return _extract_param_name(expr.args[0], param_names)
196+
return None
197+
198+
134199
def generate_parse_functions(
135-
grammar: Grammar, indent: str = ""
200+
grammar: Grammar,
201+
indent: str = "",
202+
proto_messages: dict[tuple[str, str], ProtoMessage] | None = None,
136203
) -> list[ParseNonterminalDef]:
137204
parser_methods = []
138205
reachable, _ = grammar.analysis.partition_nonterminals_by_reachability()
139206
for nt in reachable:
140207
rules = grammar.rules[nt]
141-
method_code = _generate_parse_method(nt, rules, grammar, indent)
208+
method_code = _generate_parse_method(nt, rules, grammar, indent, proto_messages)
142209
parser_methods.append(method_code)
143210
return parser_methods
144211

145212

213+
def _wrap_with_span(body: TargetExpr, return_type) -> TargetExpr:
214+
"""Wrap a nonterminal body with span_start/record_span."""
215+
span_var = Var(gensym("span_start"), BaseType("Int64"))
216+
result_var = Var(gensym("result"), return_type)
217+
return Let(
218+
span_var,
219+
Call(make_builtin("span_start"), []),
220+
Let(
221+
result_var,
222+
body,
223+
Seq(
224+
[
225+
Call(make_builtin("record_span"), [span_var]),
226+
result_var,
227+
]
228+
),
229+
),
230+
)
231+
232+
146233
def _generate_parse_method(
147-
lhs: Nonterminal, rules: list[Rule], grammar: Grammar, indent: str = ""
234+
lhs: Nonterminal,
235+
rules: list[Rule],
236+
grammar: Grammar,
237+
indent: str = "",
238+
proto_messages: dict[tuple[str, str], ProtoMessage] | None = None,
148239
) -> ParseNonterminalDef:
149-
"""Generate parse method code as string (preserving existing logic)."""
240+
"""Generate parse method for a nonterminal with provenance tracking."""
150241
return_type = None
151242
rhs = None
152243
follow_set = FollowSet(grammar, lhs)
153244
if len(rules) == 1:
154245
rule = rules[0]
155246
rhs = _generate_parse_rhs_ir(
156-
rule.rhs, grammar, follow_set, True, rule.constructor
247+
rule.rhs, grammar, follow_set, True, rule.constructor, proto_messages
157248
)
158249
return_type = rule.constructor.return_type
159250
else:
@@ -171,7 +262,6 @@ def _generate_parse_method(
171262
],
172263
)
173264
for i, rule in enumerate(rules):
174-
# Ensure the return type is the same for all actions for this nonterminal.
175265
assert return_type is None or return_type == rule.constructor.return_type, (
176266
f"Return type mismatch at rule {i}: {return_type} != {rule.constructor.return_type}"
177267
)
@@ -183,12 +273,18 @@ def _generate_parse_method(
183273
make_builtin("equal"), [Var(prediction, BaseType("Int64")), Lit(i)]
184274
),
185275
_generate_parse_rhs_ir(
186-
rule.rhs, grammar, follow_set, True, rule.constructor
276+
rule.rhs,
277+
grammar,
278+
follow_set,
279+
True,
280+
rule.constructor,
281+
proto_messages,
187282
),
188283
tail,
189284
)
190285
rhs = Let(Var(prediction, BaseType("Int64")), predictor, tail)
191286
assert return_type is not None
287+
rhs = _wrap_with_span(rhs, return_type)
192288
return ParseNonterminalDef(lhs, [], return_type, rhs, indent)
193289

194290

@@ -377,30 +473,19 @@ def _generate_parse_rhs_ir(
377473
follow_set: TerminalSequenceSet,
378474
apply_action: bool = False,
379475
action: Lambda | None = None,
476+
proto_messages: dict[tuple[str, str], ProtoMessage] | None = None,
380477
) -> TargetExpr:
381-
"""Generate IR for parsing an RHS.
382-
383-
Args:
384-
rhs: The RHS to parse
385-
grammar: The grammar
386-
follow_set: TerminalSequenceSet for computing follow lazily
387-
apply_action: Whether to apply the semantic action
388-
action: The semantic action to apply (required if apply_action is True)
389-
390-
Returns IR expression for leaf nodes (Literal, Terminal, Nonterminal).
391-
Returns None for complex cases that still use string generation.
392-
"""
478+
"""Generate IR for parsing an RHS with provenance tracking."""
393479
if isinstance(rhs, Sequence):
394480
return _generate_parse_rhs_ir_sequence(
395-
rhs, grammar, follow_set, apply_action, action
481+
rhs, grammar, follow_set, apply_action, action, proto_messages
396482
)
397483
elif isinstance(rhs, LitTerminal):
398484
parse_expr = Call(make_builtin("consume_literal"), [Lit(rhs.name)])
399485
if apply_action and action:
400486
return Seq([parse_expr, apply_lambda(action, [])])
401487
return parse_expr
402488
elif isinstance(rhs, NamedTerminal):
403-
# Use terminal's actual type for consume_terminal instead of generic Token
404489
from .target import FunctionType
405490

406491
terminal_type = rhs.target_type()
@@ -427,14 +512,18 @@ def _generate_parse_rhs_ir(
427512
elif isinstance(rhs, Option):
428513
assert grammar is not None
429514
predictor = _build_option_predictor(grammar, rhs.rhs, follow_set)
430-
parse_result = _generate_parse_rhs_ir(rhs.rhs, grammar, follow_set, False, None)
515+
parse_result = _generate_parse_rhs_ir(
516+
rhs.rhs, grammar, follow_set, False, None, proto_messages
517+
)
431518
return IfElse(predictor, Call(make_builtin("some"), [parse_result]), Lit(None))
432519
elif isinstance(rhs, Star):
433520
assert grammar is not None
434521
xs = Var(gensym("xs"), ListType(rhs.rhs.target_type()))
435522
cond = Var(gensym("cond"), BaseType("Boolean"))
436523
predictor = _build_option_predictor(grammar, rhs.rhs, follow_set)
437-
parse_item = _generate_parse_rhs_ir(rhs.rhs, grammar, follow_set, False, None)
524+
parse_item = _generate_parse_rhs_ir(
525+
rhs.rhs, grammar, follow_set, False, None, proto_messages
526+
)
438527
item = Var(gensym("item"), rhs.rhs.target_type())
439528
loop_body = Seq(
440529
[
@@ -452,16 +541,31 @@ def _generate_parse_rhs_ir(
452541
raise NotImplementedError(f"Unsupported Rhs type: {type(rhs)}")
453542

454543

544+
def _wrap_with_path(field_num: int, var: Var, inner: TargetExpr) -> list[TargetExpr]:
545+
"""Return statements that push path, assign inner to var, then pop path."""
546+
return [
547+
Call(make_builtin("push_path"), [Lit(field_num)]),
548+
Assign(var, inner),
549+
Call(make_builtin("pop_path"), []),
550+
]
551+
552+
455553
def _generate_parse_rhs_ir_sequence(
456554
rhs: Sequence,
457555
grammar: Grammar,
458556
follow_set: TerminalSequenceSet,
459557
apply_action: bool = False,
460558
action: Lambda | None = None,
559+
proto_messages: dict[tuple[str, str], ProtoMessage] | None = None,
461560
) -> TargetExpr:
462561
if is_epsilon(rhs):
463562
return Lit(None)
464563

564+
# Compute param->field_number mapping for provenance
565+
param_field_numbers: dict[int, int] = {}
566+
if action is not None and proto_messages is not None:
567+
param_field_numbers = _build_param_field_numbers(action, proto_messages)
568+
465569
exprs = []
466570
arg_vars = []
467571
elems = list(rhs_elements(rhs))
@@ -473,7 +577,9 @@ def _generate_parse_rhs_ir_sequence(
473577
follow_set_i = ConcatSet(first_following, follow_set)
474578
else:
475579
follow_set_i = follow_set
476-
elem_ir = _generate_parse_rhs_ir(elem, grammar, follow_set_i, False, None)
580+
elem_ir = _generate_parse_rhs_ir(
581+
elem, grammar, follow_set_i, False, None, proto_messages
582+
)
477583
if isinstance(elem, LitTerminal):
478584
exprs.append(elem_ir)
479585
else:
@@ -487,23 +593,80 @@ def _generate_parse_rhs_ir_sequence(
487593
)
488594
var_name = gensym("arg")
489595
var = Var(var_name, elem.target_type())
490-
exprs.append(Assign(var, elem_ir))
596+
field_num = param_field_numbers.get(non_literal_count)
597+
if field_num is not None and isinstance(elem, Star):
598+
# For repeated fields: push field number around the whole
599+
# Star loop, and push/pop a 0-based index inside the loop
600+
stmts = _wrap_star_with_index_path(
601+
elem, var, grammar, follow_set_i, field_num, proto_messages
602+
)
603+
exprs.extend(stmts)
604+
elif field_num is not None:
605+
exprs.extend(_wrap_with_path(field_num, var, elem_ir))
606+
else:
607+
exprs.append(Assign(var, elem_ir))
491608
arg_vars.append(var)
492609
non_literal_count += 1
493610
if apply_action and action:
494611
lambda_call = apply_lambda(action, arg_vars)
495612
exprs.append(lambda_call)
496613
elif len(arg_vars) > 1:
497-
# Multiple values - wrap in tuple
498614
exprs.append(Call(make_builtin("tuple"), arg_vars))
499615
elif len(arg_vars) == 1:
500-
# Single value - return the variable
501616
exprs.append(arg_vars[0])
502617
else:
503-
# no non-literal elements, return None
504618
return Lit(None)
505619

506620
if len(exprs) == 1:
507621
return exprs[0]
508622
else:
509623
return Seq(exprs)
624+
625+
626+
def _wrap_star_with_index_path(
627+
star: Star,
628+
result_var: Var,
629+
grammar: Grammar,
630+
follow_set: TerminalSequenceSet,
631+
field_num: int,
632+
proto_messages: dict[tuple[str, str], ProtoMessage] | None = None,
633+
) -> list[TargetExpr]:
634+
"""Return statements for a Star loop wrapped with push_path(field_num)
635+
and push_path(index)/pop_path() around each element.
636+
Assigns the resulting list to result_var."""
637+
xs = Var(gensym("xs"), ListType(star.rhs.target_type()))
638+
cond = Var(gensym("cond"), BaseType("Boolean"))
639+
idx = Var(gensym("idx"), BaseType("Int64"))
640+
predictor = _build_option_predictor(grammar, star.rhs, follow_set)
641+
parse_item = _generate_parse_rhs_ir(
642+
star.rhs, grammar, follow_set, False, None, proto_messages
643+
)
644+
item = Var(gensym("item"), star.rhs.target_type())
645+
loop_body = Seq(
646+
[
647+
Call(make_builtin("push_path"), [idx]),
648+
Assign(item, parse_item),
649+
Call(make_builtin("pop_path"), []),
650+
Call(make_builtin("list_push"), [xs, item]),
651+
Assign(idx, Call(make_builtin("add"), [idx, Lit(1)])),
652+
Assign(cond, predictor),
653+
]
654+
)
655+
inner = Let(
656+
xs,
657+
ListExpr([], star.rhs.target_type()),
658+
Let(
659+
cond,
660+
predictor,
661+
Let(
662+
idx,
663+
Lit(0),
664+
Seq([While(cond, loop_body), xs]),
665+
),
666+
),
667+
)
668+
return [
669+
Call(make_builtin("push_path"), [Lit(field_num)]),
670+
Assign(result_var, inner),
671+
Call(make_builtin("pop_path"), []),
672+
]

0 commit comments

Comments
 (0)