|
120 | 120 |
|
121 | 121 | from collections.abc import Hashable, Sequence |
122 | 122 | from functools import singledispatch |
123 | | -from types import GenericAlias, UnionType |
124 | 123 | from typing import ( |
125 | 124 | TYPE_CHECKING, |
126 | 125 | TypeAlias, |
|
133 | 132 | import numpy as np |
134 | 133 | from typing_extensions import TypeIs |
135 | 134 |
|
136 | | -from pytools.obj_array import ObjectArrayND as ObjectArrayND |
| 135 | +from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND |
137 | 136 |
|
138 | 137 | from arraycontext.typing import ( |
| 138 | + ArithArrayContainer, |
139 | 139 | ArrayContainer, |
140 | 140 | ArrayContainerT, |
141 | 141 | ArrayOrArithContainer, |
142 | 142 | ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar, |
143 | 143 | ArrayOrContainerOrScalar, |
| 144 | + _UserDefinedArithArrayContainer, |
| 145 | + _UserDefinedArrayContainer, |
144 | 146 | ) |
145 | 147 |
|
146 | 148 |
|
147 | 149 | if TYPE_CHECKING: |
| 150 | + from types import GenericAlias, UnionType |
| 151 | + |
148 | 152 | from pymbolic.geometric_algebra import CoeffT, MultiVector |
149 | 153 |
|
150 | 154 | from arraycontext.context import ArrayContext |
@@ -217,17 +221,21 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool: |
217 | 221 | function will say that :class:`numpy.ndarray` is an array container |
218 | 222 | type, only object arrays *actually are* array containers. |
219 | 223 | """ |
220 | | - if cls is ArrayContainer: |
| 224 | + if cls is ArrayContainer or cls is ArithArrayContainer: |
221 | 225 | return True |
222 | 226 |
|
223 | | - while isinstance(cls, GenericAlias): |
224 | | - cls = get_origin(cls) |
| 227 | + origin = get_origin(cls) |
| 228 | + if origin is not None: |
| 229 | + cls = origin # pyright: ignore[reportAny] |
225 | 230 |
|
226 | 231 | assert isinstance(cls, type), ( |
227 | 232 | f"must pass a {type!r}, not a '{cls!r}'") |
228 | 233 |
|
229 | 234 | return ( |
230 | | - cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] |
| 235 | + cls is ObjectArray |
| 236 | + or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] |
| 237 | + or cls is _UserDefinedArrayContainer |
| 238 | + or cls is _UserDefinedArithArrayContainer |
231 | 239 | or (serialize_container.dispatch(cls) |
232 | 240 | is not serialize_container.__wrapped__)) # type:ignore[attr-defined] |
233 | 241 |
|
|
0 commit comments