Skip to content

Commit e6e190b

Browse files
authored
Merge branch 'main' into multi-volume
2 parents 85d9812 + 7a8dd5d commit e6e190b

3 files changed

Lines changed: 44 additions & 15 deletions

File tree

examples/old_symbolics/dagrt-fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def isolate_function_calls_in_phase(phase, stmt_id_gen, var_name_gen):
243243
stmt_id_gen=stmt_id_gen,
244244
var_name_gen=var_name_gen)
245245

246-
for stmt in sorted(phase.statements, key=lambda stmt: stmt.id):
246+
for stmt in sorted(phase.statements, key=lambda stmt_: stmt_.id):
247247
new_deps = []
248248

249249
from dagrt.language import Assign

grudge/array_context.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
8888
to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
8989
any, for now.)
9090
"""
91+
def __init__(self, queue: "pyopencl.CommandQueue",
92+
allocator: Optional["pyopencl.tools.AllocatorInterface"] = None,
93+
wait_event_queue_length: Optional[int] = None,
94+
force_device_scalars: bool = False) -> None:
95+
96+
if allocator is None:
97+
from warnings import warn
98+
warn("No memory allocator specified, please pass one. "
99+
"(Preferably a pyopencl.tools.MemoryPool in order "
100+
"to reduce device allocations)")
101+
102+
super().__init__(queue, allocator,
103+
wait_event_queue_length, force_device_scalars)
91104

92105
# }}}
93106

@@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
99112
Extends it to understand :mod:`grudge`-specific transform metadata. (Of
100113
which there isn't any, for now.)
101114
"""
115+
def __init__(self, queue, allocator=None):
116+
if allocator is None:
117+
from warnings import warn
118+
warn("No memory allocator specified, please pass one. "
119+
"(Preferably a pyopencl.tools.MemoryPool in order "
120+
"to reduce device allocations)")
121+
super().__init__(queue, allocator)
102122

103123
# }}}
104124

@@ -210,6 +230,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
210230
out_dict = execute_distributed_partition(
211231
self.distributed_partition, self.part_id_to_prg,
212232
self.actx.queue, self.actx.mpi_communicator,
233+
allocator=self.actx.allocator,
213234
input_args=input_args_for_prg)
214235

215236
def to_output_template(keys, _):
@@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext):
224245
def __init__(
225246
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None
226247
) -> None:
248+
if allocator is None:
249+
from warnings import warn
250+
warn("No memory allocator specified, please pass one. "
251+
"(Preferably a pyopencl.tools.MemoryPool in order "
252+
"to reduce device allocations)")
253+
227254
super().__init__(queue, allocator)
228255

229256
self.mpi_communicator = mpi_communicator

test/test_mpi_communication.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,22 +151,24 @@ def _test_func_comparison_mpi_communication_entrypoint(actx):
151151
bdry_faces_func = op.project(dcoll, BTAG_ALL, dd_af,
152152
op.project(dcoll, dd_vol, BTAG_ALL, myfunc))
153153

154-
hopefully_zero = (
155-
op.project(
156-
dcoll, "int_faces", "all_faces",
157-
dcoll.opposite_face_connection(
158-
dof_desc.BoundaryDomainTag(
159-
dof_desc.FACE_RESTR_INTERIOR, dof_desc.VTAG_ALL)
160-
)(int_faces_func)
161-
)
162-
+ sum(op.project(dcoll, tpair.dd, "all_faces", tpair.int)
163-
for tpair in op.cross_rank_trace_pairs(dcoll, myfunc,
164-
comm_tag=SimpleTag))
165-
) - (all_faces_func - bdry_faces_func)
154+
def hopefully_zero():
155+
return (
156+
op.project(
157+
dcoll, "int_faces", "all_faces",
158+
dcoll.opposite_face_connection(
159+
dof_desc.BoundaryDomainTag(
160+
dof_desc.FACE_RESTR_INTERIOR, dof_desc.VTAG_ALL)
161+
)(int_faces_func)
162+
)
163+
+ sum(op.project(dcoll, tpair.dd, "all_faces", tpair.ext)
164+
for tpair in op.cross_rank_trace_pairs(dcoll, myfunc,
165+
comm_tag=SimpleTag))
166+
) - (all_faces_func - bdry_faces_func)
167+
168+
hopefully_zero_result = actx.compile(hopefully_zero)()
166169

167-
error = actx.to_numpy(flat_norm(hopefully_zero, ord=np.inf))
170+
error = actx.to_numpy(flat_norm(hopefully_zero_result, ord=np.inf))
168171

169-
print(__file__)
170172
with np.printoptions(threshold=100000000, suppress=True):
171173
logger.debug(hopefully_zero)
172174
logger.info("error: %.5e", error)

0 commit comments

Comments
 (0)