Skip to content

Commit 6b0d338

Browse files
committed
fix some type-checking issues and make generic Array compatible with pytato arrays
1 parent 6a5d798 commit 6b0d338

File tree

8 files changed

+125
-197
lines changed

8 files changed

+125
-197
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -22633,14 +22633,6 @@
2263322633
"lineCount": 1
2263422634
}
2263522635
},
22636-
{
22637-
"code": "reportUnusedFunction",
22638-
"range": {
22639-
"startColumn": 4,
22640-
"endColumn": 8,
22641-
"lineCount": 1
22642-
}
22643-
},
2264422636
{
2264522637
"code": "reportMissingParameterType",
2264622638
"range": {
@@ -24307,30 +24299,6 @@
2430724299
"lineCount": 1
2430824300
}
2430924301
},
24310-
{
24311-
"code": "reportAttributeAccessIssue",
24312-
"range": {
24313-
"startColumn": 25,
24314-
"endColumn": 29,
24315-
"lineCount": 1
24316-
}
24317-
},
24318-
{
24319-
"code": "reportAttributeAccessIssue",
24320-
"range": {
24321-
"startColumn": 25,
24322-
"endColumn": 29,
24323-
"lineCount": 1
24324-
}
24325-
},
24326-
{
24327-
"code": "reportAttributeAccessIssue",
24328-
"range": {
24329-
"startColumn": 14,
24330-
"endColumn": 31,
24331-
"lineCount": 1
24332-
}
24333-
},
2433424302
{
2433524303
"code": "reportAttributeAccessIssue",
2433624304
"range": {
@@ -24339,70 +24307,6 @@
2433924307
"lineCount": 1
2434024308
}
2434124309
},
24342-
{
24343-
"code": "reportAttributeAccessIssue",
24344-
"range": {
24345-
"startColumn": 43,
24346-
"endColumn": 47,
24347-
"lineCount": 1
24348-
}
24349-
},
24350-
{
24351-
"code": "reportAttributeAccessIssue",
24352-
"range": {
24353-
"startColumn": 43,
24354-
"endColumn": 47,
24355-
"lineCount": 1
24356-
}
24357-
},
24358-
{
24359-
"code": "reportArgumentType",
24360-
"range": {
24361-
"startColumn": 22,
24362-
"endColumn": 36,
24363-
"lineCount": 1
24364-
}
24365-
},
24366-
{
24367-
"code": "reportAttributeAccessIssue",
24368-
"range": {
24369-
"startColumn": 37,
24370-
"endColumn": 41,
24371-
"lineCount": 1
24372-
}
24373-
},
24374-
{
24375-
"code": "reportAttributeAccessIssue",
24376-
"range": {
24377-
"startColumn": 42,
24378-
"endColumn": 50,
24379-
"lineCount": 1
24380-
}
24381-
},
24382-
{
24383-
"code": "reportArgumentType",
24384-
"range": {
24385-
"startColumn": 22,
24386-
"endColumn": 37,
24387-
"lineCount": 1
24388-
}
24389-
},
24390-
{
24391-
"code": "reportAttributeAccessIssue",
24392-
"range": {
24393-
"startColumn": 38,
24394-
"endColumn": 42,
24395-
"lineCount": 1
24396-
}
24397-
},
24398-
{
24399-
"code": "reportAttributeAccessIssue",
24400-
"range": {
24401-
"startColumn": 43,
24402-
"endColumn": 51,
24403-
"lineCount": 1
24404-
}
24405-
},
2440624310
{
2440724311
"code": "reportMissingParameterType",
2440824312
"range": {
@@ -24615,22 +24519,6 @@
2461524519
"lineCount": 1
2461624520
}
2461724521
},
24618-
{
24619-
"code": "reportUnusedFunction",
24620-
"range": {
24621-
"startColumn": 4,
24622-
"endColumn": 28,
24623-
"lineCount": 1
24624-
}
24625-
},
24626-
{
24627-
"code": "reportUnusedFunction",
24628-
"range": {
24629-
"startColumn": 4,
24630-
"endColumn": 30,
24631-
"lineCount": 1
24632-
}
24633-
},
2463424522
{
2463524523
"code": "reportMissingParameterType",
2463624524
"range": {
@@ -24663,14 +24551,6 @@
2466324551
"lineCount": 1
2466424552
}
2466524553
},
24666-
{
24667-
"code": "reportUnusedFunction",
24668-
"range": {
24669-
"startColumn": 4,
24670-
"endColumn": 23,
24671-
"lineCount": 1
24672-
}
24673-
},
2467424554
{
2467524555
"code": "reportGeneralTypeIssues",
2467624556
"range": {
@@ -24735,14 +24615,6 @@
2473524615
"lineCount": 1
2473624616
}
2473724617
},
24738-
{
24739-
"code": "reportUnusedFunction",
24740-
"range": {
24741-
"startColumn": 4,
24742-
"endColumn": 26,
24743-
"lineCount": 1
24744-
}
24745-
},
2474624618
{
2474724619
"code": "reportMissingParameterType",
2474824620
"range": {

arraycontext/container/traversal.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777

7878
import numpy as np
7979

80+
from pymbolic.typing import Integer
81+
8082
from arraycontext.container import (
8183
ArrayContainer,
8284
NotAnArrayContainerError,
@@ -91,7 +93,6 @@
9193
ArrayOrContainer,
9294
ArrayOrContainerOrScalar,
9395
ArrayOrContainerT,
94-
ArrayT,
9596
ScalarLike,
9697
)
9798

@@ -399,33 +400,27 @@ def keyed_map_array_container(
399400
])
400401

401402

402-
def _rec_keyed_map_array_container_rec(
403-
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
404-
keys: tuple[SerializationKey, ...],
405-
ary_: ArrayOrContainerT
406-
) -> ArrayOrContainerT:
407-
try:
408-
iterable = serialize_container(ary_)
409-
except NotAnArrayContainerError:
410-
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
411-
else:
412-
return deserialize_container(ary_, [
413-
(key, _rec_keyed_map_array_container_rec(
414-
f, (*keys, key), subary)) for key, subary in iterable
415-
])
416-
417-
418403
def rec_keyed_map_array_container(
419-
f: Callable[[tuple[SerializationKey, ...], ArrayT], ArrayT],
420-
ary: ArrayOrContainerT) -> ArrayOrContainerT:
404+
f: Callable[[tuple[SerializationKey, ...], Array], Array],
405+
ary: ArrayOrContainer) -> ArrayOrContainer:
421406
"""
422407
Works similarly to :func:`rec_map_array_container`, except that *f* also
423408
takes in a traversal path to the leaf array. The traversal path argument is
424409
passed in as a tuple of identifiers of the arrays traversed before reaching
425410
the current array.
426411
"""
412+
def rec(keys: tuple[SerializationKey, ...],
413+
ary_: ArrayOrContainer) -> ArrayOrContainer:
414+
try:
415+
iterable = serialize_container(ary_)
416+
except NotAnArrayContainerError:
417+
return cast(ArrayOrContainer, f(keys, cast(Array, ary_)))
418+
else:
419+
return deserialize_container(ary_, [
420+
(key, rec((*keys, key), subary)) for key, subary in iterable
421+
])
427422

428-
return _rec_keyed_map_array_container_rec(f, (), ary)
423+
return rec((), ary)
429424

430425
# }}}
431426

@@ -782,7 +777,7 @@ def unflatten(
782777
checks are skipped.
783778
"""
784779
# NOTE: https://github.com/python/mypy/issues/7057
785-
offset = 0
780+
offset: int = 0
786781
common_dtype = None
787782

788783
def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
@@ -795,7 +790,11 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
795790

796791
# {{{ validate subary
797792

798-
if (offset + template_subary_c.size) > ary.size:
793+
if (
794+
isinstance(offset, Integer)
795+
and isinstance(template_subary_c.size, Integer)
796+
and isinstance(ary.size, Integer)
797+
and (offset + template_subary_c.size) > ary.size):
799798
raise ValueError("'template' and 'ary' sizes do not match: "
800799
"'template' is too large") from None
801800

@@ -818,6 +817,12 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
818817

819818
# {{{ reshape
820819

820+
if not isinstance(template_subary_c.size, Integer):
821+
raise NotImplementedError(
822+
"unflatten is not implemented for arrays with array-valued "
823+
"size.") from None
824+
825+
# FIXME: Not sure how to make the slicing part work for Array-valued sizes
821826
flat_subary = ary[offset:offset + template_subary_c.size]
822827
try:
823828
subary = actx.np.reshape(flat_subary,
@@ -876,15 +881,15 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
876881

877882

878883
def flat_size_and_dtype(
879-
ary: ArrayOrContainer) -> tuple[int, np.dtype[Any] | None]:
884+
ary: ArrayOrContainer) -> tuple[Array | Integer, np.dtype[Any] | None]:
880885
"""
881886
:returns: a tuple ``(size, dtype)`` that would be the length and
882887
:class:`numpy.dtype` of the one-dimensional array returned by
883888
:func:`flatten`.
884889
"""
885890
common_dtype = None
886891

887-
def _flat_size(subary: ArrayOrContainer) -> int:
892+
def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
888893
nonlocal common_dtype
889894

890895
try:

arraycontext/context.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
import numpy as np
175175
from typing_extensions import Self
176176

177+
from pymbolic.typing import Integer, Scalar as _Scalar
177178
from pytools import memoize_method
178179
from pytools.tag import ToTagSetConvertible
179180

@@ -202,11 +203,11 @@ class Array(Protocol):
202203
"""
203204

204205
@property
205-
def shape(self) -> tuple[int, ...]:
206+
def shape(self) -> tuple[Array | Integer, ...]:
206207
...
207208

208209
@property
209-
def size(self) -> int:
210+
def size(self) -> Array | Integer:
210211
...
211212

212213
@property
@@ -221,21 +222,21 @@ def __getitem__(self, index: Any) -> Array:
221222
...
222223

223224
# some basic arithmetic that's supposed to work
224-
def __neg__(self) -> Self: ...
225-
def __abs__(self) -> Self: ...
226-
def __add__(self, other: Self | ScalarLike) -> Self: ...
227-
def __radd__(self, other: Self | ScalarLike) -> Self: ...
228-
def __sub__(self, other: Self | ScalarLike) -> Self: ...
229-
def __rsub__(self, other: Self | ScalarLike) -> Self: ...
230-
def __mul__(self, other: Self | ScalarLike) -> Self: ...
231-
def __rmul__(self, other: Self | ScalarLike) -> Self: ...
232-
def __truediv__(self, other: Self | ScalarLike) -> Self: ...
233-
def __rtruediv__(self, other: Self | ScalarLike) -> Self: ...
225+
def __neg__(self) -> Array: ...
226+
def __abs__(self) -> Array: ...
227+
def __add__(self, other: Self | ScalarLike) -> Array: ...
228+
def __radd__(self, other: Self | ScalarLike) -> Array: ...
229+
def __sub__(self, other: Self | ScalarLike) -> Array: ...
230+
def __rsub__(self, other: Self | ScalarLike) -> Array: ...
231+
def __mul__(self, other: Self | ScalarLike) -> Array: ...
232+
def __rmul__(self, other: Self | ScalarLike) -> Array: ...
233+
def __truediv__(self, other: Self | ScalarLike) -> Array: ...
234+
def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
234235

235236

236237
# deprecated, use ScalarLike instead
237-
ScalarLike: TypeAlias = int | float | complex | np.generic
238-
Scalar = ScalarLike
238+
Scalar = _Scalar
239+
ScalarLike = Scalar
239240
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)
240241

241242
# NOTE: I'm kind of not sure about the *Tc versions of these type variables.

arraycontext/impl/pyopencl/taggable_cl_array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class TaggableCLArray(cla.Array, Taggable):
7474
record application-specific metadata to drive the optimizations in
7575
:meth:`arraycontext.PyOpenCLArrayContext.transform_loopy_program`.
7676
"""
77+
tags: frozenset[Tag]
78+
axes: tuple[Axis, ...]
79+
7780
def __init__(self, cq, shape, dtype, order="C", allocator=None,
7881
data=None, offset=0, strides=None, events=None, _flags=None,
7982
_fast=False, _size=None, _context=None, _queue=None,

0 commit comments

Comments
 (0)