@@ -1902,7 +1902,8 @@ def build_expression(tracer):
19021902 np .testing .assert_allclose (outputs [key ], expected [key ])
19031903
19041904
1905- def test_nested_function_calls (ctx_factory ):
1905+ @pytest .mark .parametrize ("should_concatenate_bar" , (False , True ))
1906+ def test_nested_function_calls (ctx_factory , should_concatenate_bar ):
19061907 from functools import partial
19071908
19081909 ctx = ctx_factory ()
@@ -1936,6 +1937,14 @@ def call_bar(tracer, x, y):
19361937 "out2" : call_bar (pt .trace_call , x2 , y2 )}
19371938 )
19381939 result = pt .tag_all_calls_to_be_inlined (result )
1940+ if should_concatenate_bar :
1941+ from pytato .transform .calls import CallsiteCollector
1942+ assert len (CallsiteCollector (())(result )) == 4
1943+ result = pt .concatenate_calls (
1944+ result ,
1945+ lambda x : pt .tags .FunctionIdentifier ("bar" ) in x .call .function .tags )
1946+ assert len (CallsiteCollector (())(result )) == 2
1947+
19391948 expect = pt .make_dict_of_named_arrays ({"out1" : call_bar (ref_tracer , x1 , y1 ),
19401949 "out2" : call_bar (ref_tracer , x2 , y2 )}
19411950 )
@@ -1948,6 +1957,109 @@ def call_bar(tracer, x, y):
19481957 np .testing .assert_allclose (result_out [k ], expect_out [k ])
19491958
19501959
1960+ def test_concatenate_calls_no_nested (ctx_factory ):
1961+ rng = np .random .default_rng (0 )
1962+
1963+ ctx = ctx_factory ()
1964+ cq = cl .CommandQueue (ctx )
1965+
1966+ def foo (x , y ):
1967+ return 3 * x + 4 * y + 42 * pt .sin (x ) + 1729 * pt .tan (y )* pt .maximum (x , y )
1968+
1969+ x1 = pt .make_placeholder ("x1" , (10 , 4 ), np .float64 )
1970+ x2 = pt .make_placeholder ("x2" , (10 , 4 ), np .float64 )
1971+
1972+ y1 = pt .make_placeholder ("y1" , (10 , 4 ), np .float64 )
1973+ y2 = pt .make_placeholder ("y2" , (10 , 4 ), np .float64 )
1974+
1975+ z1 = pt .make_placeholder ("z1" , (10 , 4 ), np .float64 )
1976+ z2 = pt .make_placeholder ("z2" , (10 , 4 ), np .float64 )
1977+
1978+ result = pt .make_dict_of_named_arrays ({"out1" : 2 * pt .trace_call (foo , 2 * x1 , 3 * x2 ),
1979+ "out2" : 4 * pt .trace_call (foo , 4 * y1 , 9 * y2 ),
1980+ "out3" : 6 * pt .trace_call (foo , 7 * z1 , 8 * z2 )
1981+ })
1982+
1983+ concatenated_result = pt .concatenate_calls (
1984+ result , lambda x : pt .tags .FunctionIdentifier ("foo" ) in x .call .function .tags )
1985+
1986+ result = pt .tag_all_calls_to_be_inlined (result )
1987+ concatenated_result = pt .tag_all_calls_to_be_inlined (concatenated_result )
1988+
1989+ assert (pt .analysis .get_num_nodes (pt .inline_calls (result ))
1990+ > pt .analysis .get_num_nodes (pt .inline_calls (concatenated_result )))
1991+
1992+ x1_np , x2_np , y1_np , y2_np , z1_np , z2_np = rng .random ((6 , 10 , 4 ))
1993+
1994+ _ , out_dict1 = pt .generate_loopy (result )(cq ,
1995+ x1 = x1_np , x2 = x2_np ,
1996+ y1 = y1_np , y2 = y2_np ,
1997+ z1 = z1_np , z2 = z2_np )
1998+
1999+ _ , out_dict2 = pt .generate_loopy (concatenated_result )(cq ,
2000+ x1 = x1_np , x2 = x2_np ,
2001+ y1 = y1_np , y2 = y2_np ,
2002+ z1 = z1_np , z2 = z2_np )
2003+ assert out_dict1 .keys () == out_dict2 .keys ()
2004+
2005+ for key in out_dict1 :
2006+ np .testing .assert_allclose (out_dict1 [key ], out_dict2 [key ])
2007+
2008+
2009+ def test_concatenation_via_constant_expressions (ctx_factory ):
2010+
2011+ from pytato .transform .calls import CallsiteCollector
2012+
2013+ rng = np .random .default_rng (0 )
2014+
2015+ ctx = ctx_factory ()
2016+ cq = cl .CommandQueue (ctx )
2017+
2018+ def resampling (coords , iels ):
2019+ return coords [iels ]
2020+
2021+ n_el = 1000
2022+ n_dof = 20
2023+ n_dim = 3
2024+
2025+ n_left_els = 17
2026+ n_right_els = 29
2027+
2028+ coords_dofs_np = rng .random ((n_el , n_dim , n_dof ), np .float64 )
2029+ left_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_left_els )
2030+ right_bnd_iels_np = rng .integers (low = 0 , high = n_el , size = n_right_els )
2031+
2032+ coords_dofs = pt .make_data_wrapper (coords_dofs_np )
2033+ left_bnd_iels = pt .make_data_wrapper (left_bnd_iels_np )
2034+ right_bnd_iels = pt .make_data_wrapper (right_bnd_iels_np )
2035+
2036+ lcoords = pt .trace_call (resampling , coords_dofs , left_bnd_iels )
2037+ rcoords = pt .trace_call (resampling , coords_dofs , right_bnd_iels )
2038+
2039+ result = pt .make_dict_of_named_arrays ({"lcoords" : lcoords ,
2040+ "rcoords" : rcoords })
2041+ result = pt .tag_all_calls_to_be_inlined (result )
2042+
2043+ assert len (CallsiteCollector (())(result )) == 2
2044+ concated_result = pt .concatenate_calls (
2045+ result ,
2046+ lambda cs : pt .tags .FunctionIdentifier ("resampling" ) in cs .call .function .tags
2047+ )
2048+ assert len (CallsiteCollector (())(concated_result )) == 1
2049+
2050+ _ , out_result = pt .generate_loopy (result )(cq )
2051+ np .testing .assert_allclose (out_result ["lcoords" ],
2052+ coords_dofs_np [left_bnd_iels_np ])
2053+ np .testing .assert_allclose (out_result ["rcoords" ],
2054+ coords_dofs_np [right_bnd_iels_np ])
2055+
2056+ _ , out_concated_result = pt .generate_loopy (result )(cq )
2057+ np .testing .assert_allclose (out_concated_result ["lcoords" ],
2058+ coords_dofs_np [left_bnd_iels_np ])
2059+ np .testing .assert_allclose (out_concated_result ["rcoords" ],
2060+ coords_dofs_np [right_bnd_iels_np ])
2061+
2062+
19512063if __name__ == "__main__" :
19522064 if len (sys .argv ) > 1 :
19532065 exec (sys .argv [1 ])
0 commit comments