Skip to content

Commit 7da0cb4

Browse files
authored
Merge pull request #116 from MiraGeoscience/GEOPY-2439
GEOPY-2439: Dask.distributed crash on compute source term for large ATEM survey
2 parents 5561fb5 + c9c2188 commit 7da0cb4

1 file changed

Lines changed: 8 additions & 46 deletions

File tree

simpeg/dask/electromagnetics/time_domain/simulation.py

Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -62,47 +62,9 @@ def getSourceTerm(self, tInd):
6262
elif getattr(self, "_stashed_sources", None) is None:
6363
self._stashed_sources = {}
6464

65-
try:
66-
client = get_client()
67-
sim = client.scatter(self, workers=self.worker)
68-
except ValueError:
69-
client = None
70-
sim = self
71-
72-
source_list = self.survey.source_list
73-
source_block = np.array_split(
74-
np.arange(len(source_list)), self.n_threads(client=client)
75-
)
76-
77-
if client:
78-
sim = client.scatter(self, workers=self.worker)
79-
source_list = client.scatter(source_list, workers=self.worker)
80-
else:
81-
delayed_source_eval = delayed(source_evaluation)
82-
sim = self
83-
84-
block_compute = []
85-
for block in source_block:
86-
if client:
87-
block_compute.append(
88-
client.submit(
89-
source_evaluation,
90-
sim,
91-
block,
92-
self.times[tInd],
93-
source_list,
94-
workers=self.worker,
95-
)
96-
)
97-
else:
98-
block_compute.append(
99-
delayed_source_eval(self, block, self.times[tInd], source_list)
100-
)
101-
102-
if client:
103-
blocks = client.gather(block_compute)
104-
else:
105-
blocks = dask.compute(block_compute)[0]
65+
blocks = []
66+
for source in self.survey.source_list:
67+
blocks.append(source_evaluation(self, self.times[tInd], source))
10668

10769
s_m, s_e = [], []
10870
for block in blocks:
@@ -283,12 +245,12 @@ def field_projection(field_array, src_list, array_ind, time_ind, func):
283245
return new_array
284246

285247

286-
def source_evaluation(simulation, indices, time_channel, sources):
248+
def source_evaluation(simulation, time_channel, source):
287249
s_m, s_e = [], []
288-
for ind in indices:
289-
sm, se = sources[ind].eval(simulation, time_channel)
290-
s_m.append(sm)
291-
s_e.append(se)
250+
251+
sm, se = source.eval(simulation, time_channel)
252+
s_m.append(sm)
253+
s_e.append(se)
292254

293255
return s_m, s_e
294256

0 commit comments

Comments
 (0)