Skip to content

Commit eed269f

Browse files
committed
add embeds to CalcEvaluator
1 parent 20dcfd2 commit eed269f

6 files changed

Lines changed: 125 additions & 104 deletions

File tree

langs/calc/CalcEvaluator.scala

Lines changed: 39 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -30,103 +30,72 @@ object CalcEvaluator extends PassSeq:
3030
def inputWellformed: Wellformed = lang.wf
3131

3232
val passes = List(
33-
SimplifyPass,
34-
RemoveLayerPass,
33+
ConstantsPass,
34+
EvaluatorPass,
35+
StripExpressionPass,
3536
)
3637

37-
object SimplifyPass extends PassSeq.Pass:
38+
object ConstantsPass extends Pass:
3839
val wellformed = prevWellformed.makeDerived:
39-
Node.Top ::=! Expression
40+
Expression.removeCases(Number)
41+
Expression.addCases(EmbedMeta[Int])
42+
val rules = pass(once = true, strategy = pass.bottomUp)
43+
.rules:
44+
on(
45+
tok(Expression) *> onlyChild(tok(Number)),
46+
).rewrite: num =>
47+
splice(Expression(Node.Embed(num.sourceRange.decodeString().toInt)))
48+
49+
object EvaluatorPass extends Pass:
50+
val wellformed = prevWellformed.makeDerived:
51+
Expression ::=! embedded[Int]
4052
val rules = pass(once = false, strategy = pass.bottomUp)
4153
.rules:
4254
on(
43-
field(tok(Expression)) *> onlyChild(
55+
tok(Expression) *> onlyChild(
4456
tok(Add).withChildren:
45-
field(tok(Expression) *> onlyChild(tok(Number)))
46-
~ field(tok(Expression) *> onlyChild(tok(Number)))
57+
field(tok(Expression) *> onlyChild(embed[Int]))
58+
~ field(tok(Expression) *> onlyChild(embed[Int]))
4759
~ eof,
4860
),
4961
).rewrite: (left, right) =>
50-
val leftNum = left.unparent().sourceRange.decodeString().toInt
51-
val rightNum = right.unparent().sourceRange.decodeString().toInt
52-
53-
splice(
54-
Expression(
55-
Number(
56-
(leftNum + rightNum).toString(),
57-
),
58-
),
59-
)
62+
splice(Expression(Node.Embed[Int](left + right)))
6063
| on(
61-
field(tok(Expression)) *> onlyChild(
64+
tok(Expression) *> onlyChild(
6265
tok(Sub).withChildren:
63-
field(tok(Expression) *> onlyChild(tok(Number)))
64-
~ field(tok(Expression) *> onlyChild(tok(Number)))
66+
field(tok(Expression) *> onlyChild(embed[Int]))
67+
~ field(tok(Expression) *> onlyChild(embed[Int]))
6568
~ eof,
6669
),
6770
).rewrite: (left, right) =>
68-
val leftNum = left.unparent().sourceRange.decodeString().toInt
69-
val rightNum = right.unparent().sourceRange.decodeString().toInt
70-
71-
splice(
72-
Expression(
73-
Number(
74-
(leftNum - rightNum).toString(),
75-
),
76-
),
77-
)
71+
splice(Expression(Node.Embed[Int](left - right)))
7872
| on(
79-
field(tok(Expression)) *> onlyChild(
73+
tok(Expression) *> onlyChild(
8074
tok(Mul).withChildren:
81-
field(tok(Expression) *> onlyChild(tok(Number)))
82-
~ field(tok(Expression) *> onlyChild(tok(Number)))
75+
field(tok(Expression) *> onlyChild(embed[Int]))
76+
~ field(tok(Expression) *> onlyChild(embed[Int]))
8377
~ eof,
8478
),
8579
).rewrite: (left, right) =>
86-
val leftNum = left.unparent().sourceRange.decodeString().toInt
87-
val rightNum = right.unparent().sourceRange.decodeString().toInt
88-
89-
splice(
90-
Expression(
91-
Number(
92-
(leftNum * rightNum).toString(),
93-
),
94-
),
95-
)
80+
splice(Expression(Node.Embed[Int](left * right)))
9681
| on(
97-
field(tok(Expression)) *> onlyChild(
82+
tok(Expression) *> onlyChild(
9883
tok(Div).withChildren:
99-
field(tok(Expression) *> onlyChild(tok(Number)))
100-
~ field(tok(Expression) *> onlyChild(tok(Number)))
84+
field(tok(Expression) *> onlyChild(embed[Int]))
85+
~ field(tok(Expression) *> onlyChild(embed[Int]))
10186
~ eof,
10287
),
10388
).rewrite: (left, right) =>
104-
val leftNum = left.unparent().sourceRange.decodeString().toInt
105-
val rightNum = right.unparent().sourceRange.decodeString().toInt
106-
107-
splice(
108-
Expression(
109-
Number(
110-
(leftNum / rightNum).toString(),
111-
),
112-
),
113-
)
89+
splice(Expression(Node.Embed[Int](left / right)))
11490
end rules
115-
end SimplifyPass
91+
end EvaluatorPass
11692

117-
object RemoveLayerPass extends Pass:
93+
object StripExpressionPass extends Pass:
11894
val wellformed = prevWellformed.makeDerived:
119-
Node.Top ::=! Number
120-
val rules = pass(once = true, strategy = pass.topDown)
95+
Node.Top ::=! embedded[Int]
96+
val rules = pass(once = true, strategy = pass.bottomUp)
12197
.rules:
122-
on(
123-
tok(Expression).withChildren:
124-
field(tok(Number))
125-
~ eof,
126-
).rewrite: (number) =>
127-
splice(
128-
number.unparent(),
129-
)
130-
end rules
131-
end RemoveLayerPass
98+
on(tok(Expression) *> onlyChild(embed[Int])).rewrite: i =>
99+
splice(Node.Embed(i))
100+
end StripExpressionPass
132101
end CalcEvaluator

langs/calc/README.md

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ The ```rules``` method uses byte-level pattern matching to create tokens for num
2727

2828
The wellformed definition adds new ```Operation``` token types that have 2 ```Expression``` children. Also, ```Expression``` tokens have a new definition, being able to wrap both ```Number``` tokens as well as ```Operation``` tokens.
2929

30-
```mulDivPass``` and ```addSubPass``` both create nested expressions and splicing the old ```Op``` tokens that were previously defined. Both methods do this by pattern matching on the sequence of ```(Expression, Op, Expression)``` and replacing this sequence with
30+
The passes are a sequence of transformations defined in the `passes` field, in this consisting of 2 binary operation parsing passes.
31+
```MulDivPass``` and ```AddSubPass``` both create nested expressions and splicing the old ```Op``` tokens that were previously defined. Both methods do this by pattern matching on the sequence of ```(Expression, Op, Expression)``` and replacing this sequence with
3132

3233
```
3334
Expression(
@@ -38,38 +39,21 @@ Expression(
3839
)
3940
```
4041

41-
```mulDivPass``` is executed before ```addSubPass``` to create precedence, allowing multiplication and division operations to be nested deeper than addition and subtraction operations in the AST.
42+
```MulDivPass``` is executed before ```AddSubPass``` to create precedence, allowing multiplication and division operations to be nested deeper than addition and subtraction operations in the AST.
4243

4344

4445
### 4. ```CalcEvaluator.scala```
4546
```CalcEvaluator``` simplifies the AST and computes the value of the arithmetic expression.
4647

4748
The wellformed definition is imported from ```package.scala```, picking up with the AST structure of where ```CalcParser``` left off.
48-
49-
```simplifyPass``` splices all expressions repeatedly until there's only a single expression node at the top of the AST structure. The pass uses a bottom-up strategy to begin with simplifying the base-case expressions with no nesting as it goes up the AST.
50-
51-
The pass sequence pattern matches on
52-
53-
```
54-
Expression(
55-
Operation(
56-
Expression(
57-
Number
58-
),
59-
Expression(
60-
Number
61-
)
62-
)
63-
)
64-
```
65-
66-
and replaces the sequence with ```Expression(Number)```. The remaining ```Expression``` token at the end of the pass contains the value of arithmetic expression.
67-
68-
```removeLayerPass``` splices the ```Expression``` token at the top of the AST and replaces it with the ```Number``` token that was wrapped inside. Pattern matching is done on the sequence of ```Expression(Number)``` and replaces it with just ```Number```.
49+
Each pass is annotated with its own input / output grammars, expressed as changes to the previous pass's output grammar.
50+
- `ConstantsPass` converts each number to a native Scala `Int` for easier arithmetic, showing an example of the `Node.Embed` feature.
51+
- `EvaluatorPass` is a ruleset that describes basic arithmetic evaluation, which will fold a wellformed tree into a single node of the form `Expression(Node.Embed[Int](???))`.
52+
- `StripExpressionPass` simply cleans up the previous pass's tree, leaving just `Node.Embed[Int](???)` containing the result of evaluating the expression.
6953

7054

7155
## Usage
72-
To learn how to use the calculator, ```CalcReader.test.scala``` contains methods (```parse```, ```read```, ```evaluate```) that execute the different components of the calculator.
56+
To learn how to use the calculator, ```package.test.scala``` contains methods (```parse```, ```read```, ```evaluate```) that execute the different components of the calculator.
7357

7458

7559
## Example
Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
package forja.langs.calc
1616

1717
import forja.*
18+
import forja.dsl.*
19+
import forja.manip.RewriteDebugTracer
1820
import forja.source.{Source, SourceRange}
1921

2022
import Builtin.{Error, SourceMarker}
2123

22-
class CalcReaderTests extends munit.FunSuite:
24+
class CalcTests extends munit.FunSuite:
2325
extension (str: String)
2426
def read: Node.Top =
2527
forja.langs.calc.read
@@ -45,8 +47,8 @@ class CalcReaderTests extends munit.FunSuite:
4547
val top = parse
4648

4749
// format: off
48-
// instrumentWithTracer(Manip.RewriteDebugTracer(os.pwd / "dbg_calc_evaluator_passes")):
49-
CalcEvaluator(top)
50+
instrumentWithTracer(RewriteDebugTracer(os.pwd / "dbg_calc_evaluator_passes")):
51+
CalcEvaluator(top)
5052
// format: on
5153

5254
// os.write.over(
@@ -265,46 +267,46 @@ class CalcReaderTests extends munit.FunSuite:
265267
assertEquals(
266268
"5 + 11".evaluate,
267269
Node.Top(
268-
lang.Number("16"),
270+
Node.Embed(16),
269271
),
270272
)
271273

272274
test("multiplication calculation"):
273275
assertEquals(
274276
"5 * 11".evaluate,
275277
Node.Top(
276-
lang.Number("55"),
278+
Node.Embed(55),
277279
),
278280
)
279281

280282
test("full calculation"):
281283
assertEquals(
282284
"5 + 11 * 4".evaluate,
283285
Node.Top(
284-
lang.Number("49"),
286+
Node.Embed(49),
285287
),
286288
)
287289

288290
test("full calculation 2"):
289291
assertEquals(
290292
"5 * 4 + 4 / 2".evaluate,
291293
Node.Top(
292-
lang.Number("22"),
294+
Node.Embed(22),
293295
),
294296
)
295297

296298
test("full calculation 3"):
297299
assertEquals(
298300
"5 * 4 + 4 / 2 - 6".evaluate,
299301
Node.Top(
300-
lang.Number("16"),
302+
Node.Embed(16),
301303
),
302304
)
303305

304306
test("full calculation 4"):
305307
assertEquals(
306308
"5 * 4 + 4 / 2 - 6 * 2".evaluate,
307309
Node.Top(
308-
lang.Number("10"),
310+
Node.Embed(10),
309311
),
310312
)

src/EmbedMeta.test.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2024-2025 Forja Team
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package forja
16+
17+
import forja.dsl.*
18+
19+
class EmbedMetaTests extends munit.FunSuite:
20+
test("int == int"):
21+
assert(EmbedMeta[Int] == EmbedMeta[Int])
22+
23+
test("int != long"):
24+
assert(EmbedMeta[Int] != EmbedMeta[Long])
25+
26+
val n42 = Node.Embed(42)
27+
28+
test("node with int"):
29+
assert(n42.meta == EmbedMeta[Int])
30+
31+
test("pattern match"):
32+
// for parent purposes
33+
val top = Node.Top(n42)
34+
35+
val manip =
36+
initNode(n42):
37+
on(embed[Int]).value
38+
39+
assertEquals(manip.perform(), 42)
40+
41+
// try rewriting, to be sure
42+
val manip2 =
43+
initNode(n42):
44+
pass(strategy = pass.bottomUp, once = true)
45+
.rules:
46+
on(embed[Int]).rewrite: i =>
47+
splice(Node.Embed(43))
48+
49+
manip2.perform()
50+
51+
assertEquals(top, Node.Top(Node.Embed(43)))
52+
53+
// test("serialization"):
54+
/* val serialized =
55+
* Source.fromWritable(n42.toCompactWritable(Wellformed.empty)) */
56+
// val tree = sexpr.parse.fromSourceRange(SourceRange.entire(serialized))
57+
// val desern42 = Wellformed.empty.deserializeTree(tree)
58+
// assertEquals(desern42, n42)
59+
60+
// test("embed toString"):
61+
// assertEquals(Node.Embed(42).toString(), "")

src/manip/SeqPatternOps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ trait SeqPatternOps:
184184
def lastChild[T](using DebugInfo)(pattern: SeqPattern[T]): SeqPattern[T] =
185185
refine(atLastChild(on(pattern).value))
186186

187+
def embed[T: EmbedMeta](using DebugInfo): SeqPattern[T] =
188+
anyChild.restrict:
189+
case embed @ Node.Embed(t) if embed.meta == EmbedMeta[T] =>
190+
t.asInstanceOf[T]
191+
187192
extension [P <: Node.Parent](parentPattern: SeqPattern[P])
188193
def withChildren[T](using
189194
DebugInfo,

src/wf/Wellformed.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ object Wellformed:
541541
/* maybe it helps to assert something here, but it is technically
542542
* correct to just do nothing */
543543

544-
def addCases(cases: Token*): Unit =
544+
def addCases(cases: (Token | EmbedMeta[?])*): Unit =
545545
existingShape match
546546
case Shape.Choice(choices) =>
547547
token ::=! Shape.Choice(choices ++ cases)

0 commit comments

Comments
 (0)