Skip to content

Commit 338649a

Browse files
tmp
1 parent 6afb528 commit 338649a

3 files changed

Lines changed: 49 additions & 15 deletions

File tree

python/egglog/builtins.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def replace(self, old: StringLike, new: StringLike) -> String: ...
115115
def __add__(self, other: StringLike) -> String:
116116
return join(self, other)
117117

118+
@method(egg_fn="log")
119+
def log(self) -> Unit: ...
120+
118121

119122
StringLike: TypeAlias = String | str
120123

python/egglog/exp/array_api.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,15 +1534,7 @@ def eval(self) -> PyTupleValuesRecursive:
15341534

15351535
@array_api_ruleset.register
15361536
def _recursive_value(
1537-
i: Int,
1538-
i2: Int,
1539-
v: Value,
1540-
vs: Vec[RecursiveValue],
1541-
k: i64,
1542-
lt: Callable[[], RecursiveValue],
1543-
lf: Callable[[], RecursiveValue],
1544-
vi: Vec[Int],
1545-
rv: RecursiveValue,
1537+
v: Value, vs: Vec[RecursiveValue], k: i64, vi: Vec[Int], vi1: Vec[Int], rv: RecursiveValue, rv1: RecursiveValue
15461538
):
15471539
yield rewrite(RecursiveValue(v).shape).to(TupleInt(()))
15481540
yield rewrite(RecursiveValue.vec(vs).shape).to(TupleInt((vs.length(),)) + vs[0].shape, vs.length() > 0)
@@ -1551,11 +1543,15 @@ def _recursive_value(
15511543
yield rewrite(RecursiveValue(v)[vi], subsume=True).to(v) # Assume ti is empty
15521544

15531545
yield rule(
1554-
eq(v).to(RecursiveValue.vec(vs)[vi]),
1546+
eq(rv).to(RecursiveValue.vec(vs)),
1547+
eq(v).to(rv[vi]),
15551548
vi.length() > 0,
15561549
eq(vi[0]).to(Int(k)),
1550+
eq(rv1).to(vs[k]),
1551+
eq(vi1).to(vi.remove(0)),
15571552
).then(
1558-
union(v).with_(vs[k][vi.remove(0)]),
1553+
union(v).with_(rv1[vi1]),
1554+
subsume(rv[vi]),
15591555
)
15601556

15611557

python/egglog/exp/vecdot_example.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,52 @@
11
from egglog.exp.array_api import *
22

3-
v = NDArray([[1, 2], [3, 4]])
4-
n = NDArray([3, 4])
3+
v = NDArray([[1], [3]])
4+
n = NDArray([3])
55
res = vecdot(v, n)
6-
egraph = EGraph()
6+
egraph = EGraph(save_egglog_string=True)
77
egraph.register(res.to_recursive_value())
88
egraph.run(array_api_schedule)
9+
print(egraph.run(array_api_ruleset).updated)
10+
print(egraph.run(array_api_ruleset).updated)
11+
12+
print(egraph.extract(res.to_recursive_value()))
913

10-
new_res = egraph.extract(res.to_recursive_value())
14+
15+
@egraph.run
16+
@ruleset
17+
def _recursive_value(
18+
i: Int,
19+
i2: Int,
20+
v: Value,
21+
v1: Value,
22+
vs: Vec[RecursiveValue],
23+
k: i64,
24+
lt: Callable[[], RecursiveValue],
25+
lf: Callable[[], RecursiveValue],
26+
vi: Vec[Int],
27+
vi1: Vec[Int],
28+
rv: RecursiveValue,
29+
rv1: RecursiveValue,
30+
):
31+
yield rule(
32+
eq(rv).to(RecursiveValue.vec(vs)),
33+
eq(v).to(rv[vi]),
34+
vi.length() > 0,
35+
eq(vi[0]).to(Int(k)),
36+
eq(rv1).to(vs[k]),
37+
eq(vi1).to(vi.remove(0)),
38+
).then(
39+
union(v).with_(rv1[vi1]),
40+
subsume(rv[vi]),
41+
)
1142

1243

1344
egraph.debug_print()
45+
new_res = egraph.extract(res.to_recursive_value())
1446
print(new_res)
47+
# print(egraph.as_egglog_string)
48+
# egraph.debug_print()
49+
# print(new_res)
1550

1651
# new_egraph = EGraph()
1752
# new_egraph.register(new_res)

0 commit comments

Comments
 (0)