Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -9266,6 +9266,16 @@
}
}
],
"./arraycontext/impl/pytato/parallelize.py": [
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 51,
"endColumn": 79,
"lineCount": 1
}
}
],
"./arraycontext/loopy.py": [
{
"code": "reportUnknownMemberType",
Expand Down Expand Up @@ -11895,4 +11905,4 @@
}
]
}
}
}
7 changes: 6 additions & 1 deletion arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .impl.pytato import (
PytatoJAXArrayContext,
PytatoParallelPyOpenCLArrayContext,
PytatoPyOpenCLArrayContext,
)
from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
Expand Down Expand Up @@ -140,6 +144,7 @@
"NumpyArrayContext",
"PyOpenCLArrayContext",
"PytatoJAXArrayContext",
"PytatoParallelPyOpenCLArrayContext",
"PytatoPyOpenCLArrayContext",
"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
Expand Down
162 changes: 160 additions & 2 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
The following :mod:`pytato`-based array contexts are provided:

.. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoParallelPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext


Expand All @@ -28,7 +29,8 @@
.. automodule:: arraycontext.impl.pytato.utils
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
Copyright (C) 2020-6 University of Illinois Board of Trustees
Copyright (C) 2022-3 Kaushik Kulkarni
"""

__license__ = """
Expand Down Expand Up @@ -139,6 +141,32 @@ class _NotOnlyDataWrappers(Exception): # noqa: N818
pass


class _PaddedAllocator:
"""Wraps a :mod:`pyopencl` allocator to over-allocate every buffer.

This works around a bug in the Intel CPU OpenCL runtime: it executes the
over-provisioned tail work-items of a partial work-group (those masked off
by the kernel's bounds guard) and still commits their global stores, writing
past the end of the output buffer and corrupting the host heap. The extra
padding gives those stray stores valid memory to land in. Buffers are
returned at least as large as requested, so results are unaffected.

The overrun is a fraction of the data extent, so padding by the requested
size covers it; a fixed floor handles buffers small enough that their
overrun exceeds their own size. This is a heuristic shield for a runtime
bug, not a provably tight bound.
"""

def __init__(
self, allocator: cl_array.Allocator, *,
min_pad_bytes: int = 1 << 16) -> None:
self._allocator: cl_array.Allocator = allocator
self._min_pad_bytes: int = min_pad_bytes

def __call__(self, nbytes: int):
return self._allocator(nbytes + max(nbytes, self._min_pad_bytes))


# {{{ _BasePytatoArrayContext

class _BasePytatoArrayContext(ArrayContext, abc.ABC):
Expand Down Expand Up @@ -377,8 +405,25 @@ def __init__(
self.using_svm = None

if allocator is None:
import pyopencl as cl
from pyopencl.characterize import has_coarse_grain_buffer_svm
has_svm = has_coarse_grain_buffer_svm(queue.device)

dev = queue.device
is_intel_cpu_cl = bool(
dev.type & cl.device_type.CPU
and "intel" in dev.platform.name.lower())

if has_svm and is_intel_cpu_cl:
# The Intel CPU OpenCL runtime writes past the end of output
# buffers (see the padding below), so we over-allocate to absorb
# those stray stores. That padding is incompatible with SVM:
# pyopencl's enqueue_svm_memcpy requires the source and
# destination sizes to match, so an over-allocated SVM array
# fails to transfer. Use buffer allocation, which tolerates an
# oversized backing buffer, instead.
has_svm = False

if has_svm:
self.using_svm = True

Expand All @@ -397,6 +442,13 @@ def __init__(
if use_memory_pool:
from pyopencl.tools import MemoryPool
allocator = MemoryPool(allocator)

if is_intel_cpu_cl:
# The Intel CPU OpenCL runtime writes past the end of the output
# buffer when executing the over-provisioned tail of a partial
# work-group, corrupting the host heap. Pad allocations so those
# stray stores land in valid memory.
allocator = _PaddedAllocator(allocator)
else:
# Check whether the passed allocator allocates SVM
try:
Expand Down Expand Up @@ -827,9 +879,15 @@ def compile(self,
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt

dag = pt.transform.deduplicate_data_wrappers(dag)

dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)
return pt.transform.materialize_with_mpms(dag)

dag = pt.transform.materialize_with_mpms(dag)

return dag

@override
def einsum(self, spec, *args, arg_names=None, tagged=()):
Expand Down Expand Up @@ -909,6 +967,106 @@ def clone(self):
# }}}


# {{{ PytatoParallelPyOpenCLArrayContext

class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
"""
Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device.

.. note::

Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
details on the transformation algorithm provided by this array context.

.. automethod:: transform_dag
.. automethod:: transform_loopy_program
"""
# FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext
# should be calling, or should it be left for more-concrete derived array
# contexts? If the latter, where should it live?
def _materialize_einsum_inputs_and_outputs(
self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt

from .utils import (
get_inputs_and_outputs_of_einsum,
get_inputs_and_outputs_of_reduction_nodes,
)

einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
reduction_inputs_outputs = (
einsum_inputs | einsum_outputs | redn_inputs | redn_outputs)

def materialize(
expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames:
if expr in reduction_inputs_outputs:
if isinstance(expr, pt.InputArgumentBase):
return expr
else:
return expr.tagged(pt.tags.ImplStored())
else:
return expr

return pt.transform.map_and_copy(dag, materialize)

@override
def transform_dag(
self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
r"""
Returns a transformed version of *dag*, where the applied transform is:

#. Materialize as per MPMS materialization heuristic.
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
"""
import pytato as pt

dag = pt.transform.deduplicate_data_wrappers(dag)

dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)

dag = pt.transform.materialize_with_mpms(dag)
dag = self._materialize_einsum_inputs_and_outputs(dag)

return dag

def _parallelize_across_device(
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from .parallelize import (
alias_global_temporaries,
parallelize_disjoint_loop_sets,
)

t_unit = parallelize_disjoint_loop_sets(
t_unit, self.queue.device.max_compute_units)

# FIXME: Is this something that this abstract-ish
# PytatoParallelPyOpenCLArrayContext class should be calling, or should it
# be left for more-concrete derived array contexts? If the latter, where
# should it live?
t_unit = alias_global_temporaries(t_unit)

return t_unit

def transform_loopy_program(
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
r"""
Returns a transformed version of *t_unit*, where the applied transform is:

#. An execution grid size :math:`G` is selected based on *self*'s
OpenCL-device.
#. The iteration domain for each statement in the *t_unit* is divided to
equally among the work-items in :math:`G`.
#. Kernel boundaries are drawn between every set of disjoint loops.
#. Once the kernel boundaries are inferred, global temporaries are aliased
to reduce the memory peak memory used by the transformed program.
"""
return self._parallelize_across_device(t_unit)


# {{{ PytatoJAXArrayContext

class PytatoJAXArrayContext(_BasePytatoArrayContext):
Expand Down
Loading
Loading