9090 Terminal ,
9191)
9292from .grammar_utils import is_epsilon , rhs_elements
93+ from .proto_ast import ProtoMessage
9394from .target import (
9495 Assign ,
9596 BaseType ,
100101 ListExpr ,
101102 ListType ,
102103 Lit ,
104+ NewMessage ,
105+ OneOf ,
103106 ParseNonterminal ,
104107 ParseNonterminalDef ,
105108 Seq ,
@@ -131,29 +134,117 @@ class AmbiguousGrammarError(Exception):
131134 pass
132135
133136
137+ def _build_param_field_numbers (
138+ action : Lambda ,
139+ proto_messages : dict [tuple [str , str ], ProtoMessage ] | None ,
140+ ) -> dict [int , int ]:
141+ """Map lambda parameter indices to protobuf field numbers.
142+
143+ Inspects the semantic action's body. When it is a NewMessage, cross-references
144+ field names with proto_messages to find field numbers.
145+
146+ Returns dict mapping parameter index to proto field number.
147+ """
148+ if proto_messages is None :
149+ return {}
150+ body = action .body
151+ if not isinstance (body , NewMessage ):
152+ return {}
153+ key = (body .module , body .name )
154+ proto_msg = proto_messages .get (key )
155+ if proto_msg is None :
156+ return {}
157+ # Build name -> field number map from proto definition
158+ name_to_number : dict [str , int ] = {}
159+ for f in proto_msg .fields :
160+ name_to_number [f .name ] = f .number
161+ for oneof in proto_msg .oneofs :
162+ for f in oneof .fields :
163+ name_to_number [f .name ] = f .number
164+ # Map each message field to the lambda param that provides its value
165+ param_names = [p .name for p in action .params ]
166+ result : dict [int , int ] = {}
167+ for field_name , field_expr in body .fields :
168+ field_num = name_to_number .get (field_name )
169+ if field_num is None :
170+ continue
171+ # Determine which param provides this field
172+ param_name = _extract_param_name (field_expr , param_names )
173+ if param_name is not None and param_name in param_names :
174+ param_idx = param_names .index (param_name )
175+ result [param_idx ] = field_num
176+ return result
177+
178+
179+ def _extract_param_name (expr : TargetExpr , param_names : list [str ]) -> str | None :
180+ """Extract the parameter name from a field expression.
181+
182+ Handles:
183+ - Direct Var reference
184+ - Call(OneOf(...), [Var(...)]) wrapper
185+ - Call(Builtin(...), [Var(...), ...]) e.g. unwrap_option_or
186+ """
187+ from .target import Builtin
188+
189+ if isinstance (expr , Var ) and expr .name in param_names :
190+ return expr .name
191+ if isinstance (expr , Call ):
192+ if isinstance (expr .func , OneOf ) and expr .args :
193+ return _extract_param_name (expr .args [0 ], param_names )
194+ if isinstance (expr .func , Builtin ) and expr .args :
195+ return _extract_param_name (expr .args [0 ], param_names )
196+ return None
197+
198+
134199def generate_parse_functions (
135- grammar : Grammar , indent : str = ""
200+ grammar : Grammar ,
201+ indent : str = "" ,
202+ proto_messages : dict [tuple [str , str ], ProtoMessage ] | None = None ,
136203) -> list [ParseNonterminalDef ]:
137204 parser_methods = []
138205 reachable , _ = grammar .analysis .partition_nonterminals_by_reachability ()
139206 for nt in reachable :
140207 rules = grammar .rules [nt ]
141- method_code = _generate_parse_method (nt , rules , grammar , indent )
208+ method_code = _generate_parse_method (nt , rules , grammar , indent , proto_messages )
142209 parser_methods .append (method_code )
143210 return parser_methods
144211
145212
213+ def _wrap_with_span (body : TargetExpr , return_type ) -> TargetExpr :
214+ """Wrap a nonterminal body with span_start/record_span."""
215+ span_var = Var (gensym ("span_start" ), BaseType ("Int64" ))
216+ result_var = Var (gensym ("result" ), return_type )
217+ return Let (
218+ span_var ,
219+ Call (make_builtin ("span_start" ), []),
220+ Let (
221+ result_var ,
222+ body ,
223+ Seq (
224+ [
225+ Call (make_builtin ("record_span" ), [span_var ]),
226+ result_var ,
227+ ]
228+ ),
229+ ),
230+ )
231+
232+
146233def _generate_parse_method (
147- lhs : Nonterminal , rules : list [Rule ], grammar : Grammar , indent : str = ""
234+ lhs : Nonterminal ,
235+ rules : list [Rule ],
236+ grammar : Grammar ,
237+ indent : str = "" ,
238+ proto_messages : dict [tuple [str , str ], ProtoMessage ] | None = None ,
148239) -> ParseNonterminalDef :
149- """Generate parse method code as string (preserving existing logic) ."""
240+ """Generate parse method for a nonterminal with provenance tracking ."""
150241 return_type = None
151242 rhs = None
152243 follow_set = FollowSet (grammar , lhs )
153244 if len (rules ) == 1 :
154245 rule = rules [0 ]
155246 rhs = _generate_parse_rhs_ir (
156- rule .rhs , grammar , follow_set , True , rule .constructor
247+ rule .rhs , grammar , follow_set , True , rule .constructor , proto_messages
157248 )
158249 return_type = rule .constructor .return_type
159250 else :
@@ -171,7 +262,6 @@ def _generate_parse_method(
171262 ],
172263 )
173264 for i , rule in enumerate (rules ):
174- # Ensure the return type is the same for all actions for this nonterminal.
175265 assert return_type is None or return_type == rule .constructor .return_type , (
176266 f"Return type mismatch at rule { i } : { return_type } != { rule .constructor .return_type } "
177267 )
@@ -183,12 +273,18 @@ def _generate_parse_method(
183273 make_builtin ("equal" ), [Var (prediction , BaseType ("Int64" )), Lit (i )]
184274 ),
185275 _generate_parse_rhs_ir (
186- rule .rhs , grammar , follow_set , True , rule .constructor
276+ rule .rhs ,
277+ grammar ,
278+ follow_set ,
279+ True ,
280+ rule .constructor ,
281+ proto_messages ,
187282 ),
188283 tail ,
189284 )
190285 rhs = Let (Var (prediction , BaseType ("Int64" )), predictor , tail )
191286 assert return_type is not None
287+ rhs = _wrap_with_span (rhs , return_type )
192288 return ParseNonterminalDef (lhs , [], return_type , rhs , indent )
193289
194290
@@ -377,30 +473,19 @@ def _generate_parse_rhs_ir(
377473 follow_set : TerminalSequenceSet ,
378474 apply_action : bool = False ,
379475 action : Lambda | None = None ,
476+ proto_messages : dict [tuple [str , str ], ProtoMessage ] | None = None ,
380477) -> TargetExpr :
381- """Generate IR for parsing an RHS.
382-
383- Args:
384- rhs: The RHS to parse
385- grammar: The grammar
386- follow_set: TerminalSequenceSet for computing follow lazily
387- apply_action: Whether to apply the semantic action
388- action: The semantic action to apply (required if apply_action is True)
389-
390- Returns IR expression for leaf nodes (Literal, Terminal, Nonterminal).
391- Returns None for complex cases that still use string generation.
392- """
478+ """Generate IR for parsing an RHS with provenance tracking."""
393479 if isinstance (rhs , Sequence ):
394480 return _generate_parse_rhs_ir_sequence (
395- rhs , grammar , follow_set , apply_action , action
481+ rhs , grammar , follow_set , apply_action , action , proto_messages
396482 )
397483 elif isinstance (rhs , LitTerminal ):
398484 parse_expr = Call (make_builtin ("consume_literal" ), [Lit (rhs .name )])
399485 if apply_action and action :
400486 return Seq ([parse_expr , apply_lambda (action , [])])
401487 return parse_expr
402488 elif isinstance (rhs , NamedTerminal ):
403- # Use terminal's actual type for consume_terminal instead of generic Token
404489 from .target import FunctionType
405490
406491 terminal_type = rhs .target_type ()
@@ -427,14 +512,18 @@ def _generate_parse_rhs_ir(
427512 elif isinstance (rhs , Option ):
428513 assert grammar is not None
429514 predictor = _build_option_predictor (grammar , rhs .rhs , follow_set )
430- parse_result = _generate_parse_rhs_ir (rhs .rhs , grammar , follow_set , False , None )
515+ parse_result = _generate_parse_rhs_ir (
516+ rhs .rhs , grammar , follow_set , False , None , proto_messages
517+ )
431518 return IfElse (predictor , Call (make_builtin ("some" ), [parse_result ]), Lit (None ))
432519 elif isinstance (rhs , Star ):
433520 assert grammar is not None
434521 xs = Var (gensym ("xs" ), ListType (rhs .rhs .target_type ()))
435522 cond = Var (gensym ("cond" ), BaseType ("Boolean" ))
436523 predictor = _build_option_predictor (grammar , rhs .rhs , follow_set )
437- parse_item = _generate_parse_rhs_ir (rhs .rhs , grammar , follow_set , False , None )
524+ parse_item = _generate_parse_rhs_ir (
525+ rhs .rhs , grammar , follow_set , False , None , proto_messages
526+ )
438527 item = Var (gensym ("item" ), rhs .rhs .target_type ())
439528 loop_body = Seq (
440529 [
@@ -452,16 +541,31 @@ def _generate_parse_rhs_ir(
452541 raise NotImplementedError (f"Unsupported Rhs type: { type (rhs )} " )
453542
454543
544+ def _wrap_with_path (field_num : int , var : Var , inner : TargetExpr ) -> list [TargetExpr ]:
545+ """Return statements that push path, assign inner to var, then pop path."""
546+ return [
547+ Call (make_builtin ("push_path" ), [Lit (field_num )]),
548+ Assign (var , inner ),
549+ Call (make_builtin ("pop_path" ), []),
550+ ]
551+
552+
455553def _generate_parse_rhs_ir_sequence (
456554 rhs : Sequence ,
457555 grammar : Grammar ,
458556 follow_set : TerminalSequenceSet ,
459557 apply_action : bool = False ,
460558 action : Lambda | None = None ,
559+ proto_messages : dict [tuple [str , str ], ProtoMessage ] | None = None ,
461560) -> TargetExpr :
462561 if is_epsilon (rhs ):
463562 return Lit (None )
464563
564+ # Compute param->field_number mapping for provenance
565+ param_field_numbers : dict [int , int ] = {}
566+ if action is not None and proto_messages is not None :
567+ param_field_numbers = _build_param_field_numbers (action , proto_messages )
568+
465569 exprs = []
466570 arg_vars = []
467571 elems = list (rhs_elements (rhs ))
@@ -473,7 +577,9 @@ def _generate_parse_rhs_ir_sequence(
473577 follow_set_i = ConcatSet (first_following , follow_set )
474578 else :
475579 follow_set_i = follow_set
476- elem_ir = _generate_parse_rhs_ir (elem , grammar , follow_set_i , False , None )
580+ elem_ir = _generate_parse_rhs_ir (
581+ elem , grammar , follow_set_i , False , None , proto_messages
582+ )
477583 if isinstance (elem , LitTerminal ):
478584 exprs .append (elem_ir )
479585 else :
@@ -487,23 +593,80 @@ def _generate_parse_rhs_ir_sequence(
487593 )
488594 var_name = gensym ("arg" )
489595 var = Var (var_name , elem .target_type ())
490- exprs .append (Assign (var , elem_ir ))
596+ field_num = param_field_numbers .get (non_literal_count )
597+ if field_num is not None and isinstance (elem , Star ):
598+ # For repeated fields: push field number around the whole
599+ # Star loop, and push/pop a 0-based index inside the loop
600+ stmts = _wrap_star_with_index_path (
601+ elem , var , grammar , follow_set_i , field_num , proto_messages
602+ )
603+ exprs .extend (stmts )
604+ elif field_num is not None :
605+ exprs .extend (_wrap_with_path (field_num , var , elem_ir ))
606+ else :
607+ exprs .append (Assign (var , elem_ir ))
491608 arg_vars .append (var )
492609 non_literal_count += 1
493610 if apply_action and action :
494611 lambda_call = apply_lambda (action , arg_vars )
495612 exprs .append (lambda_call )
496613 elif len (arg_vars ) > 1 :
497- # Multiple values - wrap in tuple
498614 exprs .append (Call (make_builtin ("tuple" ), arg_vars ))
499615 elif len (arg_vars ) == 1 :
500- # Single value - return the variable
501616 exprs .append (arg_vars [0 ])
502617 else :
503- # no non-literal elements, return None
504618 return Lit (None )
505619
506620 if len (exprs ) == 1 :
507621 return exprs [0 ]
508622 else :
509623 return Seq (exprs )
624+
625+
626+ def _wrap_star_with_index_path (
627+ star : Star ,
628+ result_var : Var ,
629+ grammar : Grammar ,
630+ follow_set : TerminalSequenceSet ,
631+ field_num : int ,
632+ proto_messages : dict [tuple [str , str ], ProtoMessage ] | None = None ,
633+ ) -> list [TargetExpr ]:
634+ """Return statements for a Star loop wrapped with push_path(field_num)
635+ and push_path(index)/pop_path() around each element.
636+ Assigns the resulting list to result_var."""
637+ xs = Var (gensym ("xs" ), ListType (star .rhs .target_type ()))
638+ cond = Var (gensym ("cond" ), BaseType ("Boolean" ))
639+ idx = Var (gensym ("idx" ), BaseType ("Int64" ))
640+ predictor = _build_option_predictor (grammar , star .rhs , follow_set )
641+ parse_item = _generate_parse_rhs_ir (
642+ star .rhs , grammar , follow_set , False , None , proto_messages
643+ )
644+ item = Var (gensym ("item" ), star .rhs .target_type ())
645+ loop_body = Seq (
646+ [
647+ Call (make_builtin ("push_path" ), [idx ]),
648+ Assign (item , parse_item ),
649+ Call (make_builtin ("pop_path" ), []),
650+ Call (make_builtin ("list_push" ), [xs , item ]),
651+ Assign (idx , Call (make_builtin ("add" ), [idx , Lit (1 )])),
652+ Assign (cond , predictor ),
653+ ]
654+ )
655+ inner = Let (
656+ xs ,
657+ ListExpr ([], star .rhs .target_type ()),
658+ Let (
659+ cond ,
660+ predictor ,
661+ Let (
662+ idx ,
663+ Lit (0 ),
664+ Seq ([While (cond , loop_body ), xs ]),
665+ ),
666+ ),
667+ )
668+ return [
669+ Call (make_builtin ("push_path" ), [Lit (field_num )]),
670+ Assign (result_var , inner ),
671+ Call (make_builtin ("pop_path" ), []),
672+ ]
0 commit comments