99import numpy as np
1010
1111_dir = os .path .dirname (__file__ )
12+ _test = os .path .join (os .path .dirname (_dir ), 'result' , 'matlab' )
13+
14+ os .makedirs (_test , exist_ok = True )
1215
1316
1417def test_wavefronts ():
@@ -70,19 +73,19 @@ def test_wavefronts():
7073
7174 plt .figure ()
7275 tike .view .plot_complex (probe0 [5 , 0 , 0 ])
73- plt .savefig ('forward-02.png' )
76+ plt .savefig (os . path . join ( _test , 'forward-02.png' ) )
7477
7578 plt .figure ()
7679 tike .view .plot_complex ((varying_probe [5 , 0 , 0 ]))
77- plt .savefig ('forward-03.png' )
80+ plt .savefig (os . path . join ( _test , 'forward-03.png' ) )
7881
7982 plt .figure ()
8083 tike .view .plot_complex (probe1 [0 , 0 , 0 ])
81- plt .savefig ('forward-04.png' )
84+ plt .savefig (os . path . join ( _test , 'forward-04.png' ) )
8285
8386 plt .figure ()
8487 tike .view .plot_complex ((varying_probe [0 , 0 , 1 ]))
85- plt .savefig ('forward-05.png' )
88+ plt .savefig (os . path . join ( _test , 'forward-05.png' ) )
8689
8790 np .testing .assert_allclose (
8891 probe0 ,
@@ -132,7 +135,7 @@ def test_wavefronts():
132135
133136 plt .figure ()
134137 plt .imshow (cp .asarray (psi , dtype = 'complex64' ).real .get ())
135- plt .savefig ('forward-00.png' )
138+ plt .savefig (os . path . join ( _test , 'forward-00.png' ) )
136139
137140 plt .figure ()
138141 for i in range (5 ):
@@ -143,7 +146,7 @@ def test_wavefronts():
143146 plt .subplot (5 , 3 , 3 * i + 3 )
144147 plt .imshow (new_patches [i ].real - patches [i ].real )
145148 plt .colorbar ()
146- plt .savefig ('forward-01.png' )
149+ plt .savefig (os . path . join ( _test , 'forward-01.png' ) )
147150
148151 np .testing .assert_allclose (
149152 new_patches ,
@@ -152,14 +155,15 @@ def test_wavefronts():
152155 )
153156
154157 plt .figure ()
155- tike .view .plot_complex (wavefronts [0 ,0 ])
156- plt .savefig ('forward-08.png' )
158+ tike .view .plot_complex (wavefronts [0 , 0 ])
159+ plt .savefig (os . path . join ( _test , 'forward-08.png' ) )
157160 plt .figure ()
158- tike .view .plot_complex (psi1 [0 ,0 , 0 ])
159- plt .savefig ('forward-09.png' )
161+ tike .view .plot_complex (psi1 [0 , 0 , 0 ])
162+ plt .savefig (os . path . join ( _test , 'forward-09.png' ) )
160163 plt .figure ()
161- tike .view .plot_complex (wavefronts [0 ,0 ] - psi1 [0 ,0 ,0 ])
162- plt .savefig ('forward-10.png' )
164+ tike .view .plot_complex (wavefronts [0 , 0 ] - psi1 [0 , 0 , 0 ])
165+ plt .savefig (os .path .join (_test , 'forward-10.png' ))
166+ plt .close ('all' )
163167
164168 # NOTE: MATLAB uses naive complex multiplication, but standard CUDA library
165169 # uses optimized (but still correct) complex multiplication. This accounts
0 commit comments