Skip to content

Commit eb69531

Browse files
kaushikcfdinducer
authored andcommitted
Add regressions for pt.concatenate_calls
1 parent 3655fb4 commit eb69531

1 file changed

Lines changed: 113 additions & 1 deletion

File tree

test/test_codegen.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1895,7 +1895,8 @@ def build_expression(tracer):
18951895
np.testing.assert_allclose(outputs[key], expected[key])
18961896

18971897

1898-
def test_nested_function_calls(ctx_factory):
1898+
@pytest.mark.parametrize("should_concatenate_bar", (False, True))
1899+
def test_nested_function_calls(ctx_factory, should_concatenate_bar):
18991900
from functools import partial
19001901

19011902
ctx = ctx_factory()
@@ -1929,6 +1930,14 @@ def call_bar(tracer, x, y):
19291930
"out2": call_bar(pt.trace_call, x2, y2)}
19301931
)
19311932
result = pt.tag_all_calls_to_be_inlined(result)
1933+
if should_concatenate_bar:
1934+
from pytato.transform.calls import CallsiteCollector
1935+
assert len(CallsiteCollector(())(result)) == 4
1936+
result = pt.concatenate_calls(
1937+
result,
1938+
lambda x: pt.tags.FunctionIdentifier("bar") in x.call.function.tags)
1939+
assert len(CallsiteCollector(())(result)) == 2
1940+
19321941
expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1),
19331942
"out2": call_bar(ref_tracer, x2, y2)}
19341943
)
@@ -1941,6 +1950,109 @@ def call_bar(tracer, x, y):
19411950
np.testing.assert_allclose(result_out[k], expect_out[k])
19421951

19431952

1953+
def test_concatenate_calls_no_nested(ctx_factory):
1954+
rng = np.random.default_rng(0)
1955+
1956+
ctx = ctx_factory()
1957+
cq = cl.CommandQueue(ctx)
1958+
1959+
def foo(x, y):
1960+
return 3*x + 4*y + 42*pt.sin(x) + 1729*pt.tan(y)*pt.maximum(x, y)
1961+
1962+
x1 = pt.make_placeholder("x1", (10, 4), np.float64)
1963+
x2 = pt.make_placeholder("x2", (10, 4), np.float64)
1964+
1965+
y1 = pt.make_placeholder("y1", (10, 4), np.float64)
1966+
y2 = pt.make_placeholder("y2", (10, 4), np.float64)
1967+
1968+
z1 = pt.make_placeholder("z1", (10, 4), np.float64)
1969+
z2 = pt.make_placeholder("z2", (10, 4), np.float64)
1970+
1971+
result = pt.make_dict_of_named_arrays({"out1": 2*pt.trace_call(foo, 2*x1, 3*x2),
1972+
"out2": 4*pt.trace_call(foo, 4*y1, 9*y2),
1973+
"out3": 6*pt.trace_call(foo, 7*z1, 8*z2)
1974+
})
1975+
1976+
concatenated_result = pt.concatenate_calls(
1977+
result, lambda x: pt.tags.FunctionIdentifier("foo") in x.call.function.tags)
1978+
1979+
result = pt.tag_all_calls_to_be_inlined(result)
1980+
concatenated_result = pt.tag_all_calls_to_be_inlined(concatenated_result)
1981+
1982+
assert (pt.analysis.get_num_nodes(pt.inline_calls(result))
1983+
> pt.analysis.get_num_nodes(pt.inline_calls(concatenated_result)))
1984+
1985+
x1_np, x2_np, y1_np, y2_np, z1_np, z2_np = rng.random((6, 10, 4))
1986+
1987+
_, out_dict1 = pt.generate_loopy(result)(cq,
1988+
x1=x1_np, x2=x2_np,
1989+
y1=y1_np, y2=y2_np,
1990+
z1=z1_np, z2=z2_np)
1991+
1992+
_, out_dict2 = pt.generate_loopy(concatenated_result)(cq,
1993+
x1=x1_np, x2=x2_np,
1994+
y1=y1_np, y2=y2_np,
1995+
z1=z1_np, z2=z2_np)
1996+
assert out_dict1.keys() == out_dict2.keys()
1997+
1998+
for key in out_dict1:
1999+
np.testing.assert_allclose(out_dict1[key], out_dict2[key])
2000+
2001+
2002+
def test_concatenation_via_constant_expressions(ctx_factory):
2003+
2004+
from pytato.transform.calls import CallsiteCollector
2005+
2006+
rng = np.random.default_rng(0)
2007+
2008+
ctx = ctx_factory()
2009+
cq = cl.CommandQueue(ctx)
2010+
2011+
def resampling(coords, iels):
2012+
return coords[iels]
2013+
2014+
n_el = 1000
2015+
n_dof = 20
2016+
n_dim = 3
2017+
2018+
n_left_els = 17
2019+
n_right_els = 29
2020+
2021+
coords_dofs_np = rng.random((n_el, n_dim, n_dof), np.float64)
2022+
left_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_left_els)
2023+
right_bnd_iels_np = rng.integers(low=0, high=n_el, size=n_right_els)
2024+
2025+
coords_dofs = pt.make_data_wrapper(coords_dofs_np)
2026+
left_bnd_iels = pt.make_data_wrapper(left_bnd_iels_np)
2027+
right_bnd_iels = pt.make_data_wrapper(right_bnd_iels_np)
2028+
2029+
lcoords = pt.trace_call(resampling, coords_dofs, left_bnd_iels)
2030+
rcoords = pt.trace_call(resampling, coords_dofs, right_bnd_iels)
2031+
2032+
result = pt.make_dict_of_named_arrays({"lcoords": lcoords,
2033+
"rcoords": rcoords})
2034+
result = pt.tag_all_calls_to_be_inlined(result)
2035+
2036+
assert len(CallsiteCollector(())(result)) == 2
2037+
concated_result = pt.concatenate_calls(
2038+
result,
2039+
lambda cs: pt.tags.FunctionIdentifier("resampling") in cs.call.function.tags
2040+
)
2041+
assert len(CallsiteCollector(())(concated_result)) == 1
2042+
2043+
_, out_result = pt.generate_loopy(result)(cq)
2044+
np.testing.assert_allclose(out_result["lcoords"],
2045+
coords_dofs_np[left_bnd_iels_np])
2046+
np.testing.assert_allclose(out_result["rcoords"],
2047+
coords_dofs_np[right_bnd_iels_np])
2048+
2049+
_, out_concated_result = pt.generate_loopy(result)(cq)
2050+
np.testing.assert_allclose(out_concated_result["lcoords"],
2051+
coords_dofs_np[left_bnd_iels_np])
2052+
np.testing.assert_allclose(out_concated_result["rcoords"],
2053+
coords_dofs_np[right_bnd_iels_np])
2054+
2055+
19442056
if __name__ == "__main__":
19452057
if len(sys.argv) > 1:
19462058
exec(sys.argv[1])

0 commit comments

Comments
 (0)