@@ -1895,7 +1895,8 @@ def build_expression(tracer):
18951895 np .testing .assert_allclose (outputs [key ], expected [key ])
18961896
18971897
1898- def test_nested_function_calls (ctx_factory ):
1898+ @pytest .mark .parametrize ("should_concatenate_bar" , (False , True ))
1899+ def test_nested_function_calls (ctx_factory , should_concatenate_bar ):
18991900 from functools import partial
19001901
19011902 ctx = ctx_factory ()
@@ -1929,6 +1930,14 @@ def call_bar(tracer, x, y):
19291930 "out2" : call_bar (pt .trace_call , x2 , y2 )}
19301931 )
19311932 result = pt .tag_all_calls_to_be_inlined (result )
1933+ if should_concatenate_bar :
1934+ from pytato .transform .calls import CallsiteCollector
1935+ assert len (CallsiteCollector (())(result )) == 4
1936+ result = pt .concatenate_calls (
1937+ result ,
1938+ lambda x : pt .tags .FunctionIdentifier ("bar" ) in x .call .function .tags )
1939+ assert len (CallsiteCollector (())(result )) == 2
1940+
19321941 expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
19331942 "out2" : call_bar (ref_tracer , x2 , y2 )}
19341943 )
@@ -1941,6 +1950,109 @@ def call_bar(tracer, x, y):
19411950 np .testing .assert_allclose (result_out [k ], expect_out [k ])
19421951
19431952
1953+ def test_concatenate_calls_no_nested (ctx_factory ):
1954+ rng = np .random .default_rng (0 )
1955+
1956+ ctx = ctx_factory ()
1957+ cq = cl .CommandQueue (ctx )
1958+
1959+ def foo (x , y ):
1960+ return 3 * x + 4 * y + 42 * pt .sin (x ) + 1729 * pt .tan (y )* pt .maximum (x , y )
1961+
1962+ x1 = pt .make_placeholder ("x1" , (10 , 4 ), np .float64 )
1963+ x2 = pt .make_placeholder ("x2" , (10 , 4 ), np .float64 )
1964+
1965+ y1 = pt .make_placeholder ("y1" , (10 , 4 ), np .float64 )
1966+ y2 = pt .make_placeholder ("y2" , (10 , 4 ), np .float64 )
1967+
1968+ z1 = pt .make_placeholder ("z1" , (10 , 4 ), np .float64 )
1969+ z2 = pt .make_placeholder ("z2" , (10 , 4 ), np .float64 )
1970+
1971+ result = pt .make_dict_of_named_arrays ({"out1" : 2 * pt .trace_call (foo , 2 * x1 , 3 * x2 ),
1972+ "out2" : 4 * pt .trace_call (foo , 4 * y1 , 9 * y2 ),
1973+ "out3" : 6 * pt .trace_call (foo , 7 * z1 , 8 * z2 )
1974+ })
1975+
1976+ concatenated_result = pt .concatenate_calls (
1977+ result , lambda x : pt .tags .FunctionIdentifier ("foo" ) in x .call .function .tags )
1978+
1979+ result = pt .tag_all_calls_to_be_inlined (result )
1980+ concatenated_result = pt .tag_all_calls_to_be_inlined (concatenated_result )
1981+
1982+ assert (pt .analysis .get_num_nodes (pt .inline_calls (result ))
1983+ > pt .analysis .get_num_nodes (pt .inline_calls (concatenated_result )))
1984+
1985+ x1_np , x2_np , y1_np , y2_np , z1_np , z2_np = rng .random ((6 , 10 , 4 ))
1986+
1987+ _ , out_dict1 = pt .generate_loopy (result )(cq ,
1988+ x1 = x1_np , x2 = x2_np ,
1989+ y1 = y1_np , y2 = y2_np ,
1990+ z1 = z1_np , z2 = z2_np )
1991+
1992+ _ , out_dict2 = pt .generate_loopy (concatenated_result )(cq ,
1993+ x1 = x1_np , x2 = x2_np ,
1994+ y1 = y1_np , y2 = y2_np ,
1995+ z1 = z1_np , z2 = z2_np )
1996+ assert out_dict1 .keys () == out_dict2 .keys ()
1997+
1998+ for key in out_dict1 :
1999+ np .testing .assert_allclose (out_dict1 [key ], out_dict2 [key ])
2000+
2001+
2002+ def test_concatenation_via_constant_expressions (ctx_factory ):
2003+
2004+ from pytato .transform .calls import CallsiteCollector
2005+
2006+ rng = np .random .default_rng (0 )
2007+
2008+ ctx = ctx_factory ()
2009+ cq = cl .CommandQueue (ctx )
2010+
2011+ def resampling (coords , iels ):
2012+ return coords [iels ]
2013+
2014+ n_el = 1000
2015+ n_dof = 20
2016+ n_dim = 3
2017+
2018+ n_left_els = 17
2019+ n_right_els = 29
2020+
2021+ coords_dofs_np = rng .random ((n_el , n_dim , n_dof ), np .float64 )
2022+ left_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_left_els )
2023+ right_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_right_els )
2024+
2025+ coords_dofs = pt .make_data_wrapper (coords_dofs_np )
2026+ left_bnd_iels = pt .make_data_wrapper (left_bnd_iels_np )
2027+ right_bnd_iels = pt .make_data_wrapper (right_bnd_iels_np )
2028+
2029+ lcoords = pt .trace_call (resampling , coords_dofs , left_bnd_iels )
2030+ rcoords = pt .trace_call (resampling , coords_dofs , right_bnd_iels )
2031+
2032+ result = pt .make_dict_of_named_arrays ({"lcoords" : lcoords ,
2033+ "rcoords" : rcoords })
2034+ result = pt .tag_all_calls_to_be_inlined (result )
2035+
2036+ assert len (CallsiteCollector (())(result )) == 2
2037+ concated_result = pt .concatenate_calls (
2038+ result ,
2039+ lambda cs : pt .tags .FunctionIdentifier ("resampling" ) in cs .call .function .tags
2040+ )
2041+ assert len (CallsiteCollector (())(concated_result )) == 1
2042+
2043+ _ , out_result = pt .generate_loopy (result )(cq )
2044+ np .testing .assert_allclose (out_result ["lcoords" ],
2045+ coords_dofs_np [left_bnd_iels_np ])
2046+ np .testing .assert_allclose (out_result ["rcoords" ],
2047+ coords_dofs_np [right_bnd_iels_np ])
2048+
2049+ _ , out_concated_result = pt .generate_loopy (result )(cq )
2050+ np .testing .assert_allclose (out_concated_result ["lcoords" ],
2051+ coords_dofs_np [left_bnd_iels_np ])
2052+ np .testing .assert_allclose (out_concated_result ["rcoords" ],
2053+ coords_dofs_np [right_bnd_iels_np ])
2054+
2055+
19442056if __name__ == "__main__" :
19452057 if len (sys .argv ) > 1 :
19462058 exec (sys .argv [1 ])
0 commit comments