11import gc
2+ import os
3+ import shutil
24
35from ....electromagnetics .frequency_domain .simulation import BaseFDEMSimulation as Sim
46from ....utils import Zero
@@ -50,7 +52,9 @@ def receiver_derivs(survey, mesh, fields, blocks):
5052 return field_derivatives
5153
5254
53- def eval_block (simulation , Ainv_deriv_u , deriv_indices , deriv_m , fields , address ):
55+ def compute_rows (
56+ simulation , Ainv_deriv_u , deriv_indices , deriv_m , fields , address , Jmatrix
57+ ):
5458 """
5559 Evaluate the sensitivities for the block or data
5660 """
@@ -92,7 +96,14 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address
9296 if not isinstance (deriv_m , Zero ):
9397 du_dmT += deriv_m
9498
95- return np .array (du_dmT , dtype = complex ).reshape ((du_dmT .shape [0 ], - 1 )).real .T
99+ values = np .array (du_dmT , dtype = complex ).reshape ((du_dmT .shape [0 ], - 1 )).real .T
100+
101+ if isinstance (Jmatrix , zarr .Array ):
102+ Jmatrix .set_orthogonal_selection ((address [1 ][1 ], slice (None )), values )
103+ else :
104+ Jmatrix [address [1 ][1 ], :] = values
105+
106+ return None
96107
97108
98109def getSourceTerm (self , freq , source = None ):
@@ -195,28 +206,39 @@ def compute_J(self, m, f=None):
195206 "Consider creating one misfit per frequency."
196207 )
197208
209+ client , worker = self ._get_client_worker ()
210+
198211 A_i = list (Ainv .values ())[0 ]
199212 m_size = m .size
213+ compute_row_size = np .ceil (self .max_chunk_size / (A_i .A .shape [0 ] * 32.0 * 1e-6 ))
214+ blocks = get_parallel_blocks (
215+ self .survey .source_list , compute_row_size , optimize = True
216+ )
200217
201218 if self .store_sensitivities == "disk" :
219+
220+ chunk_size = np .median (
221+ [np .sum ([len (chunk [1 ][1 ]) for chunk in block ]) for block in blocks ]
222+ ).astype (int )
223+
224+ if os .path .exists (self .sensitivity_path ):
225+ shutil .rmtree (self .sensitivity_path )
226+
202227 Jmatrix = zarr .open (
203228 self .sensitivity_path ,
204229 mode = "w" ,
205230 shape = (self .survey .nD , m_size ),
206- chunks = (self . max_chunk_size , m_size ),
231+ chunks = (chunk_size , m_size ),
207232 )
208233 else :
209234 Jmatrix = np .zeros ((self .survey .nD , m_size ), dtype = np .float32 )
210235
211- compute_row_size = np .ceil (self .max_chunk_size / (A_i .A .shape [0 ] * 32.0 * 1e-6 ))
212- blocks = get_parallel_blocks (
213- self .survey .source_list , compute_row_size , optimize = False
214- )
236+ if client :
237+ Jmatrix = client .scatter (Jmatrix , workers = worker )
238+
215239 fields_array = f [:, self ._solutionType ]
216240 blocks_receiver_derivs = []
217241
218- client , worker = self ._get_client_worker ()
219-
220242 if client :
221243 fields_array = client .scatter (f [:, self ._solutionType ], workers = worker )
222244 fields = client .scatter (f , workers = worker )
@@ -270,7 +292,6 @@ def compute_J(self, m, f=None):
270292 addresses_chunks ,
271293 client ,
272294 worker ,
273- store_sensitivities = self .store_sensitivities ,
274295 )
275296
276297 for A in Ainv .values ():
@@ -295,7 +316,6 @@ def parallel_block_compute(
295316 addresses ,
296317 client ,
297318 worker = None ,
298- store_sensitivities = "disk" ,
299319):
300320 m_size = m .size
301321 block_stack = sp .hstack (blocks_receiver_derivs ).toarray ()
@@ -306,29 +326,29 @@ def parallel_block_compute(
306326 ATinvdf_duT = client .scatter (ATinvdf_duT , workers = worker )
307327 else :
308328 ATinvdf_duT = delayed (ATinvdf_duT )
329+
309330 count = 0
310- rows = []
311331 block_delayed = []
312-
313332 for address , dfduT in zip (addresses , blocks_receiver_derivs ):
314333 n_cols = dfduT .shape [1 ]
315334 n_rows = address [1 ][2 ]
316335
317336 if client :
318337 block_delayed .append (
319338 client .submit (
320- eval_block ,
339+ compute_rows ,
321340 simulation ,
322341 ATinvdf_duT ,
323342 np .arange (count , count + n_cols ),
324343 Zero (),
325344 fields_array ,
326345 address ,
346+ Jmatrix ,
327347 workers = worker ,
328348 )
329349 )
330350 else :
331- delayed_eval = delayed (eval_block )
351+ delayed_eval = delayed (compute_rows )
332352 block_delayed .append (
333353 array .from_delayed (
334354 delayed_eval (
@@ -338,35 +358,22 @@ def parallel_block_compute(
338358 Zero (),
339359 fields_array ,
340360 address ,
361+ Jmatrix ,
341362 ),
342363 dtype = np .float32 ,
343364 shape = (n_rows , m_size ),
344365 )
345366 )
346367 count += n_cols
347- rows += address [1 ][1 ].tolist ()
348-
349- indices = np .hstack (rows )
350368
351369 if client :
352- block_delayed = client .gather (block_delayed )
353- block = np .vstack (block_delayed )
354- else :
355- block = compute (array .vstack (block_delayed ))[0 ]
356-
357- if store_sensitivities == "disk" :
358- Jmatrix .set_orthogonal_selection (
359- (indices , slice (None )),
360- block ,
361- )
370+ client .gather (block_delayed )
362371 else :
363- # Dask process to compute row and store
364- Jmatrix [indices , :] = block
372+ compute (block_delayed )
365373
366374 return Jmatrix
367375
368376
369- Sim .parallel_block_compute = parallel_block_compute
370377Sim .compute_J = compute_J
371378Sim .getJtJdiag = getJtJdiag
372379Sim .Jvec = Jvec
0 commit comments