Skip to content

Commit fdfad32

Browse files
committed
Tweak is_array_container_type logic following typing improvements
1 parent a914f86 commit fdfad32

3 files changed

Lines changed: 16 additions & 71 deletions

File tree

.basedpyright/baseline.json

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4129,30 +4129,6 @@
41294129
"lineCount": 1
41304130
}
41314131
},
4132-
{
4133-
"code": "reportUnknownMemberType",
4134-
"range": {
4135-
"startColumn": 15,
4136-
"endColumn": 30,
4137-
"lineCount": 1
4138-
}
4139-
},
4140-
{
4141-
"code": "reportAny",
4142-
"range": {
4143-
"startColumn": 59,
4144-
"endColumn": 63,
4145-
"lineCount": 1
4146-
}
4147-
},
4148-
{
4149-
"code": "reportAny",
4150-
"range": {
4151-
"startColumn": 67,
4152-
"endColumn": 73,
4153-
"lineCount": 1
4154-
}
4155-
},
41564132
{
41574133
"code": "reportUnnecessaryComparison",
41584134
"range": {
@@ -6189,30 +6165,6 @@
61896165
"lineCount": 1
61906166
}
61916167
},
6192-
{
6193-
"code": "reportUnknownVariableType",
6194-
"range": {
6195-
"startColumn": 12,
6196-
"endColumn": 16,
6197-
"lineCount": 1
6198-
}
6199-
},
6200-
{
6201-
"code": "reportOperatorIssue",
6202-
"range": {
6203-
"startColumn": 19,
6204-
"endColumn": 77,
6205-
"lineCount": 1
6206-
}
6207-
},
6208-
{
6209-
"code": "reportUnknownArgumentType",
6210-
"range": {
6211-
"startColumn": 62,
6212-
"endColumn": 66,
6213-
"lineCount": 1
6214-
}
6215-
},
62166168
{
62176169
"code": "reportPrivateImportUsage",
62186170
"range": {
@@ -9979,14 +9931,6 @@
99799931
"lineCount": 3
99809932
}
99819933
},
9982-
{
9983-
"code": "reportUnknownArgumentType",
9984-
"range": {
9985-
"startColumn": 29,
9986-
"endColumn": 75,
9987-
"lineCount": 1
9988-
}
9989-
},
99909934
{
99919935
"code": "reportUnknownArgumentType",
99929936
"range": {
@@ -10043,14 +9987,6 @@
100439987
"lineCount": 1
100449988
}
100459989
},
10046-
{
10047-
"code": "reportUnknownVariableType",
10048-
"range": {
10049-
"startColumn": 8,
10050-
"endColumn": 14,
10051-
"lineCount": 1
10052-
}
10053-
},
100549990
{
100559991
"code": "reportUnknownArgumentType",
100569992
"range": {

arraycontext/container/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@
120120

121121
from collections.abc import Hashable, Sequence
122122
from functools import singledispatch
123-
from types import GenericAlias, UnionType
124123
from typing import (
125124
TYPE_CHECKING,
126125
TypeAlias,
@@ -133,18 +132,23 @@
133132
import numpy as np
134133
from typing_extensions import TypeIs
135134

136-
from pytools.obj_array import ObjectArrayND as ObjectArrayND
135+
from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND
137136

138137
from arraycontext.typing import (
138+
ArithArrayContainer,
139139
ArrayContainer,
140140
ArrayContainerT,
141141
ArrayOrArithContainer,
142142
ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar,
143143
ArrayOrContainerOrScalar,
144+
_UserDefinedArithArrayContainer,
145+
_UserDefinedArrayContainer,
144146
)
145147

146148

147149
if TYPE_CHECKING:
150+
from types import GenericAlias, UnionType
151+
148152
from pymbolic.geometric_algebra import CoeffT, MultiVector
149153

150154
from arraycontext.context import ArrayContext
@@ -217,17 +221,21 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
217221
function will say that :class:`numpy.ndarray` is an array container
218222
type, only object arrays *actually are* array containers.
219223
"""
220-
if cls is ArrayContainer:
224+
if cls is ArrayContainer or cls is ArithArrayContainer:
221225
return True
222226

223-
while isinstance(cls, GenericAlias):
224-
cls = get_origin(cls)
227+
origin = get_origin(cls)
228+
if origin is not None:
229+
cls = origin # pyright: ignore[reportAny]
225230

226231
assert isinstance(cls, type), (
227232
f"must pass a {type!r}, not a '{cls!r}'")
228233

229234
return (
230-
cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
235+
cls is ObjectArray
236+
or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
237+
or cls is _UserDefinedArrayContainer
238+
or cls is _UserDefinedArithArrayContainer
231239
or (serialize_container.dispatch(cls)
232240
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
233241

arraycontext/container/dataclass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
A type variable. Represents the dataclass being turned into an array container.
1212
"""
1313
from __future__ import annotations
14+
from types import GenericAlias, UnionType
1415

1516

1617
__copyright__ = """
@@ -81,7 +82,7 @@ class _Field(NamedTuple):
8182
type: type
8283

8384

84-
def _is_array_or_container_type(tp: type, /) -> bool:
85+
def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool:
8586
if tp is np.ndarray:
8687
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
8788
"This is deprecated and will stop working in 2026. "

0 commit comments

Comments
 (0)