Skip to content

Commit 1eff19b

Browse files
committed
serialization: version mismatch info, added pytest
1 parent 8f807a1 commit 1eff19b

3 files changed

Lines changed: 74 additions & 27 deletions

File tree

serialization_temp_example.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/xtc/graphs/xtc/builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .graph import XTCGraph
88
from .context import XTCGraphContext
99
from .expr import XTCTensorExpr
10+
from .operators import XTCOperator
1011
from . import op_factory
1112
from yaml import safe_load
1213

@@ -52,8 +53,14 @@ def tuplify(obj: Any) -> Any:
5253
if "name" in node:
5354
args.append(node["name"])
5455
if not hasattr(op_factory, expr["op"]["name"]):
56+
version_mismatch = (
57+
"version mismatch detected, "
58+
if XTCOperator.version_string() != graph_dict["ops_version"]
59+
else ""
60+
)
5561
raise ValueError(
56-
f"serialized op {expr['op']['name']} is not implemented!"
62+
version_mismatch
63+
+ f"serialized op {expr['op']['name']} is not implemented."
5764
)
5865
op_func = getattr(op_factory, expr["op"]["name"])
5966
expr_uid_map[node["uid"]] = op_func(*args, **tuplify(expr["op"]["attrs"]))
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import tempfile
2+
import xtc.graphs.xtc.op as O
3+
4+
def test_matmul_relu_to_from_dict():
5+
I, J, K, dtype = 4, 32, 512, "float32"
6+
a = O.tensor((I, K), dtype, name="A")
7+
b = O.tensor((K, J), dtype, name="B")
8+
9+
with O.graph(name="matmul_relu") as gb:
10+
m = O.matmul(a, b, name="matmul")
11+
O.relu(m, name="relu")
12+
13+
graph_dict = gb.graph.to_dict()
14+
with O.graph() as gb2:
15+
gb2.from_dict(graph_dict)
16+
assert graph_dict != {}
17+
assert graph_dict == gb2.graph.to_dict()
18+
19+
20+
def test_conv2d_pad_sdump_sload():
21+
N, H, W, F, R, S, C, SH, SW, dtype = 1, 8, 8, 16, 5, 5, 3, 2, 2, "float32"
22+
a = O.tensor((N, H, W, C), dtype, name="I")
23+
b = O.tensor((R, S, C, F), dtype, name="W")
24+
25+
with O.graph(name="pad_conv2d_nhwc_mini") as gb:
26+
p = O.pad(a, padding={1: (2), 2: (2, 2)}, name="pad")
27+
O.conv2d(p, b, stride=(SH, SW), name="conv")
28+
29+
graph_str = gb.graph.dumps()
30+
with O.graph(name="matmul_relu") as gb2:
31+
gb2.loads(graph_str)
32+
assert graph_str != ""
33+
assert graph_str == gb2.graph.dumps()
34+
35+
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f:
36+
gb.graph.dump(f.name)
37+
with O.graph() as gb3:
38+
gb3.load(f.name)
39+
assert gb.graph.to_dict() == gb3.graph.to_dict()
40+
41+
def test_mlp_fc_custom_output():
42+
img = O.tensor()
43+
w1 = O.tensor()
44+
w2 = O.tensor()
45+
w3 = O.tensor()
46+
w4 = O.tensor()
47+
fc = lambda i, w, nout: O.matmul(O.reshape(i, shape=(1, -1)), O.reshape(w, shape=(-1, nout)))
48+
# Multi Layer Perceptron with 3 relu(fc) + 1 fc
49+
with O.graph(name="mlp4") as gb:
50+
with O.graph(name="l1"):
51+
l1 = O.relu(fc(img, w1, 512))
52+
with O.graph(name="l2"):
53+
l2 = O.relu(fc(l1, w2, 256))
54+
with O.graph(name="l3"):
55+
l3 = O.relu(fc(l2, w3, 128))
56+
with O.graph(name="l4"):
57+
l4 = fc(l3, w4, 10)
58+
O.reshape(l4, shape=(-1,))
59+
O.outputs(l1)
60+
61+
graph_dict = gb.graph.to_dict()
62+
with O.graph() as gb2:
63+
gb2.from_dict(graph_dict)
64+
assert graph_dict != {}
65+
assert graph_dict == gb2.graph.to_dict()
66+

0 commit comments

Comments
 (0)