@@ -88,6 +88,19 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
8888 to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
8989 any, for now.)
9090 """
91+ def __init__ (self , queue : "pyopencl.CommandQueue" ,
92+ allocator : Optional ["pyopencl.tools.AllocatorInterface" ] = None ,
93+ wait_event_queue_length : Optional [int ] = None ,
94+ force_device_scalars : bool = False ) -> None :
95+
96+ if allocator is None :
97+ from warnings import warn
98+ warn ("No memory allocator specified, please pass one. "
99+ "(Preferably a pyopencl.tools.MemoryPool in order "
100+ "to reduce device allocations)" )
101+
102+ super ().__init__ (queue , allocator ,
103+ wait_event_queue_length , force_device_scalars )
91104
92105# }}}
93106
@@ -99,6 +112,13 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase):
99112 Extends it to understand :mod:`grudge`-specific transform metadata. (Of
100113 which there isn't any, for now.)
101114 """
115+ def __init__ (self , queue , allocator = None ):
116+ if allocator is None :
117+ from warnings import warn
118+ warn ("No memory allocator specified, please pass one. "
119+ "(Preferably a pyopencl.tools.MemoryPool in order "
120+ "to reduce device allocations)" )
121+ super ().__init__ (queue , allocator )
102122
103123# }}}
104124
@@ -210,6 +230,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
210230 out_dict = execute_distributed_partition (
211231 self .distributed_partition , self .part_id_to_prg ,
212232 self .actx .queue , self .actx .mpi_communicator ,
233+ allocator = self .actx .allocator ,
213234 input_args = input_args_for_prg )
214235
215236 def to_output_template (keys , _ ):
@@ -224,6 +245,12 @@ class MPIPytatoArrayContextBase(MPIBasedArrayContext):
224245 def __init__ (
225246 self , mpi_communicator , queue , * , mpi_base_tag , allocator = None
226247 ) -> None :
248+ if allocator is None :
249+ from warnings import warn
250+ warn ("No memory allocator specified, please pass one. "
251+ "(Preferably a pyopencl.tools.MemoryPool in order "
252+ "to reduce device allocations)" )
253+
227254 super ().__init__ (queue , allocator )
228255
229256 self .mpi_communicator = mpi_communicator
0 commit comments