@@ -424,6 +424,32 @@ def test_idx_lambda_to_hlo():
424424 == BroadcastOp (a ))
425425
426426
427+ def test_stringify ():
428+ x = pt .make_placeholder ("x" , (10 , 4 ), np .int64 )
429+ y = pt .make_placeholder ("y" , (10 , 4 ), np .int64 )
430+
431+ assert (str (3 * x + 4 * y )
432+ == "3*x + 4*y" )
433+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
434+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
435+ assert (str (pt .roll (x .reshape (2 , 20 ).reshape (- 1 ), 3 ))
436+ == "roll(reshape(reshape(x, (2, 20)), 40), 3)" )
437+ assert (str (y * pt .not_equal (x , 3 ))
438+ == "y*(x != 3)" )
439+ assert (str (3 * y @ pt .sum (x , axis = 0 ))
440+ == "3*y @ sum(x, axis=0)" )
441+ assert (str (x [y [:, 2 :3 ], x [2 , :]])
442+ == "x[y[::, 2:3:], x[2]]" )
443+ assert (str (pt .stack ([x [y [:, 2 :3 ], x [2 , :]].T , y [x [:, 2 :3 ], y [2 , :]].T ]))
444+ == ("stack([transpose(x[y[::, 2:3:], x[2]]),"
445+ " transpose(y[x[::, 2:3:], y[2]])])" ))
446+ assert (str (pt .concatenate ([x [y [:, 2 :3 ], x [2 , :]],
447+ y [x [:, 2 :3 ], y [2 , :]]]))
448+ == "concatenate([x[y[::, 2:3:], x[2]], y[x[::, 2:3:], y[2]]])" )
449+ assert (str (pt .einsum ("ij,i->i" , 2 * x , pt .sum (y , axis = 1 )))
450+ == 'einsum("ij, i -> i", 2*x, sum(y, axis=1))' )
451+
452+
427453if __name__ == "__main__" :
428454 if len (sys .argv ) > 1 :
429455 exec (sys .argv [1 ])
0 commit comments