1- import pytest
2- from parser .ast import (
3- Program , Identifier , Literal , Assignment , BinaryOperation ,
4- FunctionDefinition , FunctionCall
5- )
6-
7- # Test Identifier Node
8- def test_identifier_node ():
9- ident = Identifier ("my_var" )
10- assert ident .name == "my_var"
11- assert str (ident ) == "<Identifier:my_var>"
12-
13- # Test Literal Node
14- @pytest .mark .parametrize (
15- "value, expected_str" ,
16- [
17- (123 , "<Literal:123>" ),
18- ("hello" , "<Literal:hello>" ),
19- (True , "<Literal:True>" ),
20- (None , "<Literal:None>" ),
21- ]
22- )
23- def test_literal_node (value , expected_str ):
24- literal = Literal (value )
25- assert literal .value == value
26- assert str (literal ) == expected_str
27-
28- # Test Assignment Node
29- def test_assignment_node ():
30- var = Identifier ("x" )
31- val = Literal (10 )
32- assign = Assignment (var , val )
33- assert assign .variable == var
34- assert assign .value == val
35- assert str (assign ) == "<Assignment:<Identifier:x> = <Literal:10>>"
36-
37- # Test BinaryOperation Node
38- def test_binary_operation_node ():
39- left = Identifier ("a" )
40- right = Literal (5 )
41- op = BinaryOperation (left , "+" , right )
42- assert op .left == left
43- assert op .operator == "+"
44- assert op .right == right
45- assert str (op ) == "<BinaryOp:<Identifier:a> + <Literal:5>>"
46-
47- # Test FunctionDefinition Node
48- def test_function_definition_node ():
49- name = Identifier ("my_func" )
50- params = [Identifier ("p1" ), Identifier ("p2" )]
51- body = [
52- Assignment (Identifier ("local_var" ), Literal (1 )),
53- BinaryOperation (Identifier ("p1" ), "+" , Identifier ("p2" ))
54- ]
55- func_def = FunctionDefinition (name , params , body )
56- assert func_def .name == name
57- assert func_def .parameters == params
58- assert func_def .body == body
59- expected_str = (
60- "<FunctionDef:<Identifier:my_func>(<Identifier:p1>, <Identifier:p2>)>\n "
61- " <Assignment:<Identifier:local_var> = <Literal:1>>\n "
62- " <BinaryOp:<Identifier:p1> + <Identifier:p2>>"
63- )
64- assert str (func_def ) == expected_str
65-
66- def test_function_definition_no_params_no_body ():
67- name = Identifier ("empty_func" )
68- func_def = FunctionDefinition (name , None , None )
69- assert func_def .name == name
70- assert func_def .parameters == []
71- assert func_def .body == []
72- assert str (func_def ) == "<FunctionDef:<Identifier:empty_func>()>\n "
73-
74- # Test FunctionCall Node
75- def test_function_call_node ():
76- func = Identifier ("call_me" )
77- args = [Literal (10 ), Identifier ("arg2" )]
78- func_call = FunctionCall (func , args )
79- assert func_call .function == func
80- assert func_call .arguments == args
81- assert str (func_call ) == "<FunctionCall:<Identifier:call_me>(<Literal:10>, <Identifier:arg2>)>"
82-
83- def test_function_call_no_args ():
84- func = Identifier ("no_args_call" )
85- func_call = FunctionCall (func , None )
86- assert func_call .function == func
87- assert func_call .arguments == []
88- assert str (func_call ) == "<FunctionCall:<Identifier:no_args_call>()>"
89-
90- # Test Program Node
91- def test_program_node ():
92- statements = [
93- Assignment (Identifier ("a" ), Literal (1 )),
94- FunctionCall (Identifier ("print" ), [Identifier ("a" )])
95- ]
96- program = Program (statements )
97- assert program .statements == statements
98- expected_str = (
99- "<Program>\n "
100- " <Assignment:<Identifier:a> = <Literal:1>>\n "
101- " <FunctionCall:<Identifier:print>(<Identifier:a>)>"
102- )
103- assert str (program ) == expected_str
104-
105- def test_program_empty ():
106- program = Program ([])
107- assert program .statements == []
108- assert str (program ) == "<Program>\n "
1+ """
2+ Abstract Syntax Tree (AST) node definitions for a simple programming language parser.
3+ """
1094
1105
6+ class ASTNode :
7+ """Base class for all AST nodes."""
8+ pass
9+
10+
11+ class Identifier (ASTNode ):
12+ """Represents an identifier/variable name."""
13+
14+ def __init__ (self , name ):
15+ self .name = name
16+
17+ def __str__ (self ):
18+ return f"<Identifier:{ self .name } >"
19+
20+
21+ class Literal (ASTNode ):
22+ """Represents a literal value (number, string, boolean, etc.)."""
23+
24+ def __init__ (self , value ):
25+ self .value = value
26+
27+ def __str__ (self ):
28+ return f"<Literal:{ self .value } >"
29+
30+
31+ class Assignment (ASTNode ):
32+ """Represents an assignment statement (variable = value)."""
33+
34+ def __init__ (self , variable , value ):
35+ self .variable = variable
36+ self .value = value
37+
38+ def __str__ (self ):
39+ return f"<Assignment:{ self .variable } = { self .value } >"
40+
41+
42+ class BinaryOperation (ASTNode ):
43+ """Represents a binary operation (left operator right)."""
44+
45+ def __init__ (self , left , operator , right ):
46+ self .left = left
47+ self .operator = operator
48+ self .right = right
49+
50+ def __str__ (self ):
51+ return f"<BinaryOp:{ self .left } { self .operator } { self .right } >"
52+
53+
54+ class FunctionDefinition (ASTNode ):
55+ """Represents a function definition."""
56+
57+ def __init__ (self , name , parameters = None , body = None ):
58+ self .name = name
59+ self .parameters = parameters if parameters is not None else []
60+ self .body = body if body is not None else []
61+
62+ def __str__ (self ):
63+ # Format parameters
64+ if self .parameters :
65+ params_str = ", " .join (str (param ) for param in self .parameters )
66+ else :
67+ params_str = ""
68+
69+ # Start with function signature
70+ result = f"<FunctionDef:{ self .name } ({ params_str } )>\n "
71+
72+ # Add body statements with indentation
73+ for statement in self .body :
74+ result += f" { statement } \n "
75+
76+ # Remove trailing newline if there are body statements
77+ if self .body :
78+ result = result .rstrip ('\n ' )
79+
80+ return result
81+
82+
83+ class FunctionCall (ASTNode ):
84+ """Represents a function call."""
85+
86+ def __init__ (self , function , arguments = None ):
87+ self .function = function
88+ self .arguments = arguments if arguments is not None else []
89+
90+ def __str__ (self ):
91+ # Format arguments
92+ if self .arguments :
93+ args_str = ", " .join (str (arg ) for arg in self .arguments )
94+ else :
95+ args_str = ""
96+
97+ return f"<FunctionCall:{ self .function } ({ args_str } )>"
98+
99+
100+ class Program (ASTNode ):
101+ """Represents the root of the AST - a program containing statements."""
102+
103+ def __init__ (self , statements ):
104+ self .statements = statements
105+
106+ def __str__ (self ):
107+ result = "<Program>\n "
108+
109+ # Add each statement with indentation
110+ for statement in self .statements :
111+ result += f" { statement } \n "
112+
113+ # Remove trailing newline if there are statements
114+ if self .statements :
115+ result = result .rstrip ('\n ' )
116+
117+ return result
118+
119+
120+ # Additional utility functions for working with AST nodes
121+
122+ def pretty_print_ast (node , indent = 0 ):
123+ """
124+ Pretty print an AST node with proper indentation.
125+ This is an alternative to the __str__ methods for more detailed output.
126+ """
127+ indent_str = " " * indent
128+
129+ if isinstance (node , Program ):
130+ print (f"{ indent_str } Program:" )
131+ for stmt in node .statements :
132+ pretty_print_ast (stmt , indent + 1 )
133+
134+ elif isinstance (node , FunctionDefinition ):
135+ params = ", " .join (param .name for param in node .parameters )
136+ print (f"{ indent_str } FunctionDef: { node .name .name } ({ params } )" )
137+ for stmt in node .body :
138+ pretty_print_ast (stmt , indent + 1 )
139+
140+ elif isinstance (node , Assignment ):
141+ print (f"{ indent_str } Assignment:" )
142+ print (f"{ indent_str } Variable:" )
143+ pretty_print_ast (node .variable , indent + 2 )
144+ print (f"{ indent_str } Value:" )
145+ pretty_print_ast (node .value , indent + 2 )
146+
147+ elif isinstance (node , BinaryOperation ):
148+ print (f"{ indent_str } BinaryOp: { node .operator } " )
149+ print (f"{ indent_str } Left:" )
150+ pretty_print_ast (node .left , indent + 2 )
151+ print (f"{ indent_str } Right:" )
152+ pretty_print_ast (node .right , indent + 2 )
153+
154+ elif isinstance (node , FunctionCall ):
155+ print (f"{ indent_str } FunctionCall:" )
156+ print (f"{ indent_str } Function:" )
157+ pretty_print_ast (node .function , indent + 2 )
158+ if node .arguments :
159+ print (f"{ indent_str } Arguments:" )
160+ for arg in node .arguments :
161+ pretty_print_ast (arg , indent + 2 )
162+
163+ elif isinstance (node , Identifier ):
164+ print (f"{ indent_str } Identifier: { node .name } " )
165+
166+ elif isinstance (node , Literal ):
167+ print (f"{ indent_str } Literal: { node .value } " )
168+
169+ else :
170+ print (f"{ indent_str } Unknown node type: { type (node )} " )
171+
172+
173+ def traverse_ast (node , visitor_func ):
174+ """
175+ Traverse an AST and apply a visitor function to each node.
176+ The visitor function should accept a single node parameter.
177+ """
178+ visitor_func (node )
179+
180+ if isinstance (node , Program ):
181+ for stmt in node .statements :
182+ traverse_ast (stmt , visitor_func )
183+
184+ elif isinstance (node , FunctionDefinition ):
185+ traverse_ast (node .name , visitor_func )
186+ for param in node .parameters :
187+ traverse_ast (param , visitor_func )
188+ for stmt in node .body :
189+ traverse_ast (stmt , visitor_func )
190+
191+ elif isinstance (node , Assignment ):
192+ traverse_ast (node .variable , visitor_func )
193+ traverse_ast (node .value , visitor_func )
194+
195+ elif isinstance (node , BinaryOperation ):
196+ traverse_ast (node .left , visitor_func )
197+ traverse_ast (node .right , visitor_func )
198+
199+ elif isinstance (node , FunctionCall ):
200+ traverse_ast (node .function , visitor_func )
201+ for arg in node .arguments :
202+ traverse_ast (arg , visitor_func )
203+
204+
205+ def find_identifiers (node ):
206+ """
207+ Find all identifier names used in an AST.
208+ Returns a set of identifier names.
209+ """
210+ identifiers = set ()
211+
212+ def collect_identifier (n ):
213+ if isinstance (n , Identifier ):
214+ identifiers .add (n .name )
215+
216+ traverse_ast (node , collect_identifier )
217+ return identifiers
218+
219+
220+ def count_nodes_by_type (node ):
221+ """
222+ Count the number of nodes of each type in an AST.
223+ Returns a dictionary with node type names as keys and counts as values.
224+ """
225+ counts = {}
226+
227+ def count_node (n ):
228+ node_type = type (n ).__name__
229+ counts [node_type ] = counts .get (node_type , 0 ) + 1
230+
231+ traverse_ast (node , count_node )
232+ return counts
0 commit comments