Skip to content

Commit 7564c09

Browse files
authored
Merge pull request #121 from MiraGeoscience/GEOPY-425
GEOPY-425: Crash on Zarr file shape for tiled inversions with disk storage
2 parents 78903c9 + 523df6a commit 7564c09

4 files changed

Lines changed: 200 additions & 179 deletions

File tree

simpeg/dask/electromagnetics/frequency_domain/simulation.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import gc
2+
import os
3+
import shutil
24

35
from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim
46
from ....utils import Zero
@@ -50,7 +52,9 @@ def receiver_derivs(survey, mesh, fields, blocks):
5052
return field_derivatives
5153

5254

53-
def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address):
55+
def compute_rows(
56+
simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address, Jmatrix
57+
):
5458
"""
5559
Evaluate the sensitivities for the block or data
5660
"""
@@ -92,7 +96,14 @@ def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address
9296
if not isinstance(deriv_m, Zero):
9397
du_dmT += deriv_m
9498

95-
return np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T
99+
values = np.array(du_dmT, dtype=complex).reshape((du_dmT.shape[0], -1)).real.T
100+
101+
if isinstance(Jmatrix, zarr.Array):
102+
Jmatrix.set_orthogonal_selection((address[1][1], slice(None)), values)
103+
else:
104+
Jmatrix[address[1][1], :] = values
105+
106+
return None
96107

97108

98109
def getSourceTerm(self, freq, source=None):
@@ -195,28 +206,39 @@ def compute_J(self, m, f=None):
195206
"Consider creating one misfit per frequency."
196207
)
197208

209+
client, worker = self._get_client_worker()
210+
198211
A_i = list(Ainv.values())[0]
199212
m_size = m.size
213+
compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6))
214+
blocks = get_parallel_blocks(
215+
self.survey.source_list, compute_row_size, optimize=True
216+
)
200217

201218
if self.store_sensitivities == "disk":
219+
220+
chunk_size = np.median(
221+
[np.sum([len(chunk[1][1]) for chunk in block]) for block in blocks]
222+
).astype(int)
223+
224+
if os.path.exists(self.sensitivity_path):
225+
shutil.rmtree(self.sensitivity_path)
226+
202227
Jmatrix = zarr.open(
203228
self.sensitivity_path,
204229
mode="w",
205230
shape=(self.survey.nD, m_size),
206-
chunks=(self.max_chunk_size, m_size),
231+
chunks=(chunk_size, m_size),
207232
)
208233
else:
209234
Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32)
210235

211-
compute_row_size = np.ceil(self.max_chunk_size / (A_i.A.shape[0] * 32.0 * 1e-6))
212-
blocks = get_parallel_blocks(
213-
self.survey.source_list, compute_row_size, optimize=False
214-
)
236+
if client:
237+
Jmatrix = client.scatter(Jmatrix, workers=worker)
238+
215239
fields_array = f[:, self._solutionType]
216240
blocks_receiver_derivs = []
217241

218-
client, worker = self._get_client_worker()
219-
220242
if client:
221243
fields_array = client.scatter(f[:, self._solutionType], workers=worker)
222244
fields = client.scatter(f, workers=worker)
@@ -270,7 +292,6 @@ def compute_J(self, m, f=None):
270292
addresses_chunks,
271293
client,
272294
worker,
273-
store_sensitivities=self.store_sensitivities,
274295
)
275296

276297
for A in Ainv.values():
@@ -295,7 +316,6 @@ def parallel_block_compute(
295316
addresses,
296317
client,
297318
worker=None,
298-
store_sensitivities="disk",
299319
):
300320
m_size = m.size
301321
block_stack = sp.hstack(blocks_receiver_derivs).toarray()
@@ -306,29 +326,29 @@ def parallel_block_compute(
306326
ATinvdf_duT = client.scatter(ATinvdf_duT, workers=worker)
307327
else:
308328
ATinvdf_duT = delayed(ATinvdf_duT)
329+
309330
count = 0
310-
rows = []
311331
block_delayed = []
312-
313332
for address, dfduT in zip(addresses, blocks_receiver_derivs):
314333
n_cols = dfduT.shape[1]
315334
n_rows = address[1][2]
316335

317336
if client:
318337
block_delayed.append(
319338
client.submit(
320-
eval_block,
339+
compute_rows,
321340
simulation,
322341
ATinvdf_duT,
323342
np.arange(count, count + n_cols),
324343
Zero(),
325344
fields_array,
326345
address,
346+
Jmatrix,
327347
workers=worker,
328348
)
329349
)
330350
else:
331-
delayed_eval = delayed(eval_block)
351+
delayed_eval = delayed(compute_rows)
332352
block_delayed.append(
333353
array.from_delayed(
334354
delayed_eval(
@@ -338,35 +358,22 @@ def parallel_block_compute(
338358
Zero(),
339359
fields_array,
340360
address,
361+
Jmatrix,
341362
),
342363
dtype=np.float32,
343364
shape=(n_rows, m_size),
344365
)
345366
)
346367
count += n_cols
347-
rows += address[1][1].tolist()
348-
349-
indices = np.hstack(rows)
350368

351369
if client:
352-
block_delayed = client.gather(block_delayed)
353-
block = np.vstack(block_delayed)
354-
else:
355-
block = compute(array.vstack(block_delayed))[0]
356-
357-
if store_sensitivities == "disk":
358-
Jmatrix.set_orthogonal_selection(
359-
(indices, slice(None)),
360-
block,
361-
)
370+
client.gather(block_delayed)
362371
else:
363-
# Dask process to compute row and store
364-
Jmatrix[indices, :] = block
372+
compute(block_delayed)
365373

366374
return Jmatrix
367375

368376

369-
Sim.parallel_block_compute = parallel_block_compute
370377
Sim.compute_J = compute_J
371378
Sim.getJtJdiag = getJtJdiag
372379
Sim.Jvec = Jvec

simpeg/dask/electromagnetics/static/resistivity/simulation.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from ....simulation import getJtJdiag, Jvec, Jtvec, Jmatrix
44

55
from .....utils import Zero
6-
6+
import shutil
7+
import os
78
import dask.array as da
89
import numpy as np
910
from scipy import sparse as sp
@@ -42,23 +43,29 @@ def compute_J(self, m, f=None):
4243

4344
f, Ainv = self.fields(m=m, return_Ainv=True)
4445

45-
m_size = m.size
46+
n_cells = m.size
4647
row_chunks = int(
4748
np.ceil(
4849
float(self.survey.nD)
49-
/ np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size)
50+
/ np.ceil(
51+
float(n_cells) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size
52+
)
5053
)
5154
)
5255

5356
if self.store_sensitivities == "disk":
57+
58+
if os.path.exists(self.sensitivity_path):
59+
shutil.rmtree(self.sensitivity_path)
60+
5461
Jmatrix = zarr.open(
55-
self.sensitivity_path + "J.zarr",
62+
self.sensitivity_path,
5663
mode="w",
57-
shape=(self.survey.nD, m_size),
58-
chunks=(row_chunks, m_size),
64+
shape=(self.survey.nD, n_cells),
65+
chunks=(row_chunks, n_cells),
5966
)
6067
else:
61-
Jmatrix = np.zeros((self.survey.nD, m_size), dtype=np.float32)
68+
Jmatrix = np.zeros((self.survey.nD, n_cells), dtype=np.float32)
6269

6370
blocks = []
6471
count = 0
@@ -92,7 +99,7 @@ def compute_J(self, m, f=None):
9299
du_dmT += df_dmT
93100

94101
#
95-
du_dmT = du_dmT.T.reshape((-1, m_size))
102+
du_dmT = du_dmT.T.reshape((-1, n_cells))
96103

97104
if len(blocks) == 0:
98105
blocks = du_dmT
@@ -130,7 +137,7 @@ def compute_J(self, m, f=None):
130137

131138
if self.store_sensitivities == "disk":
132139
del Jmatrix
133-
self._Jmatrix = da.from_zarr(self.sensitivity_path + "J.zarr")
140+
self._Jmatrix = da.from_zarr(self.sensitivity_path)
134141
else:
135142
self._Jmatrix = Jmatrix
136143

0 commit comments

Comments
 (0)