1+ from __future__ import annotations
2+
13import copy
4+ from typing import TYPE_CHECKING , Any , Generic , TypeVar , cast , overload
5+
6+ T = TypeVar ("T" )
7+
8+ if TYPE_CHECKING :
9+ from collections .abc import Iterable , Iterator , Sequence
10+
11+ # The type aliases defined here are evaluated when the django-stubs mypy plugin
12+ # loads this module, so they must be able to execute under the lowest supported
13+ # Python VM:
14+ # - typing.List, typing.Tuple become obsolete in Pyton 3.9
15+ # - typing.Union becomes obsolete in Pyton 3.10
16+ from typing import List , Tuple , Union
17+
18+ from django_stubs_ext import StrOrPromise
19+
20+ # The type argument 'T' to 'Choices' is the database representation type.
21+ _Double = Tuple [T , StrOrPromise ]
22+ _Triple = Tuple [T , str , StrOrPromise ]
23+ _Group = Tuple [StrOrPromise , Sequence ["_Choice[T]" ]]
24+ _Choice = Union [_Double [T ], _Triple [T ], _Group [T ]]
25+ # Choices can only be given as a single string if 'T' is 'str'.
26+ _GroupStr = Tuple [StrOrPromise , Sequence ["_ChoiceStr" ]]
27+ _ChoiceStr = Union [str , _Double [str ], _Triple [str ], _GroupStr ]
28+ # Note that we only accept lists and tuples in groups, not arbitrary sequences.
29+ # However, annotating it as such causes many problems.
30+
31+ _DoubleRead = Union [_Double [T ], Tuple [StrOrPromise , Iterable ["_DoubleRead[T]" ]]]
32+ _DoubleCollector = List [Union [_Double [T ], Tuple [StrOrPromise , "_DoubleCollector[T]" ]]]
33+ _TripleCollector = List [Union [_Triple [T ], Tuple [StrOrPromise , "_TripleCollector[T]" ]]]
234
335
4- class Choices :
36+ class Choices ( Generic [ T ]) :
537 """
638 A class to encapsulate handy functionality for lists of choices
739 for a Django model field.
@@ -41,36 +73,60 @@ class Choices:
4173
4274 """
4375
44- def __init__ (self , * choices ):
76+ @overload
77+ def __init__ (self : Choices [str ], * choices : _ChoiceStr ):
78+ ...
79+
80+ @overload
81+ def __init__ (self , * choices : _Choice [T ]):
82+ ...
83+
84+ def __init__ (self , * choices : _ChoiceStr | _Choice [T ]):
4585 # list of choices expanded to triples - can include optgroups
46- self ._triples = []
86+ self ._triples : _TripleCollector [ T ] = []
4787 # list of choices as (db, human-readable) - can include optgroups
48- self ._doubles = []
88+ self ._doubles : _DoubleCollector [ T ] = []
4989 # dictionary mapping db representation to human-readable
50- self ._display_map = {}
90+ self ._display_map : dict [ T , StrOrPromise | list [ _Triple [ T ]]] = {}
5191 # dictionary mapping Python identifier to db representation
52- self ._identifier_map = {}
92+ self ._identifier_map : dict [ str , T ] = {}
5393 # set of db representations
54- self ._db_values = set ()
94+ self ._db_values : set [ T ] = set ()
5595
5696 self ._process (choices )
5797
58- def _store (self , triple , triple_collector , double_collector ):
98+ def _store (
99+ self ,
100+ triple : tuple [T , str , StrOrPromise ],
101+ triple_collector : _TripleCollector [T ],
102+ double_collector : _DoubleCollector [T ]
103+ ) -> None :
59104 self ._identifier_map [triple [1 ]] = triple [0 ]
60105 self ._display_map [triple [0 ]] = triple [2 ]
61106 self ._db_values .add (triple [0 ])
62107 triple_collector .append (triple )
63108 double_collector .append ((triple [0 ], triple [2 ]))
64109
65- def _process (self , choices , triple_collector = None , double_collector = None ):
110+ def _process (
111+ self ,
112+ choices : Iterable [_ChoiceStr | _Choice [T ]],
113+ triple_collector : _TripleCollector [T ] | None = None ,
114+ double_collector : _DoubleCollector [T ] | None = None
115+ ) -> None :
66116 if triple_collector is None :
67117 triple_collector = self ._triples
68118 if double_collector is None :
69119 double_collector = self ._doubles
70120
71- store = lambda c : self ._store (c , triple_collector , double_collector )
121+ def store (c : tuple [Any , str , StrOrPromise ]) -> None :
122+ self ._store (c , triple_collector , double_collector )
72123
73124 for choice in choices :
125+ # The type inference is not very accurate here:
126+ # - we lied in the type aliases, stating groups contain an arbitrary Sequence
127+ # rather than only list or tuple
128+ # - there is no way to express that _ChoiceStr is only used when T=str
129+ # - mypy 1.9.0 doesn't narrow types based on the value of len()
74130 if isinstance (choice , (list , tuple )):
75131 if len (choice ) == 3 :
76132 store (choice )
@@ -79,13 +135,13 @@ def _process(self, choices, triple_collector=None, double_collector=None):
79135 # option group
80136 group_name = choice [0 ]
81137 subchoices = choice [1 ]
82- tc = []
138+ tc : _TripleCollector [ T ] = []
83139 triple_collector .append ((group_name , tc ))
84- dc = []
140+ dc : _DoubleCollector [ T ] = []
85141 double_collector .append ((group_name , dc ))
86142 self ._process (subchoices , tc , dc )
87143 else :
88- store ((choice [0 ], choice [0 ], choice [1 ]))
144+ store ((choice [0 ], cast ( str , choice [0 ]), cast ( 'StrOrPromise' , choice [1 ]) ))
89145 else :
90146 raise ValueError (
91147 "Choices can't take a list of length %s, only 2 or 3"
@@ -94,54 +150,74 @@ def _process(self, choices, triple_collector=None, double_collector=None):
94150 else :
95151 store ((choice , choice , choice ))
96152
97- def __len__ (self ):
153+ def __len__ (self ) -> int :
98154 return len (self ._doubles )
99155
100- def __iter__ (self ):
156+ def __iter__ (self ) -> Iterator [ _DoubleRead [ T ]] :
101157 return iter (self ._doubles )
102158
103- def __reversed__ (self ):
159+ def __reversed__ (self ) -> Iterator [ _DoubleRead [ T ]] :
104160 return reversed (self ._doubles )
105161
106- def __getattr__ (self , attname ) :
162+ def __getattr__ (self , attname : str ) -> T :
107163 try :
108164 return self ._identifier_map [attname ]
109165 except KeyError :
110166 raise AttributeError (attname )
111167
112- def __getitem__ (self , key ) :
168+ def __getitem__ (self , key : T ) -> StrOrPromise | Sequence [ _Triple [ T ]] :
113169 return self ._display_map [key ]
114170
115- def __add__ (self , other ):
171+ @overload
172+ def __add__ (self : Choices [str ], other : Choices [str ] | Iterable [_ChoiceStr ]) -> Choices [str ]:
173+ ...
174+
175+ @overload
176+ def __add__ (self , other : Choices [T ] | Iterable [_Choice [T ]]) -> Choices [T ]:
177+ ...
178+
179+ def __add__ (self , other : Choices [Any ] | Iterable [_ChoiceStr | _Choice [Any ]]) -> Choices [Any ]:
180+ other_args : list [Any ]
116181 if isinstance (other , self .__class__ ):
117- other = other ._triples
182+ other_args = other ._triples
118183 else :
119- other = list (other )
120- return Choices (* (self ._triples + other ))
184+ other_args = list (other )
185+ return Choices (* (self ._triples + other_args ))
186+
187+ @overload
188+ def __radd__ (self : Choices [str ], other : Iterable [_ChoiceStr ]) -> Choices [str ]:
189+ ...
190+
191+ @overload
192+ def __radd__ (self , other : Iterable [_Choice [T ]]) -> Choices [T ]:
193+ ...
121194
122- def __radd__ (self , other ) :
195+ def __radd__ (self , other : Iterable [ _ChoiceStr ] | Iterable [ _Choice [ T ]]) -> Choices [ Any ] :
123196 # radd is never called for matching types, so we don't check here
124- other = list (other )
125- return Choices (* (other + self ._triples ))
197+ other_args = list (other )
198+ # The exact type of 'other' depends on our type argument 'T', which
199+ # is expressed in the overloading, but lost within this method body.
200+ return Choices (* (other_args + self ._triples )) # type: ignore[arg-type]
126201
127- def __eq__ (self , other ) :
202+ def __eq__ (self , other : object ) -> bool :
128203 if isinstance (other , self .__class__ ):
129204 return self ._triples == other ._triples
130205 return False
131206
132- def __repr__ (self ):
207+ def __repr__ (self ) -> str :
133208 return '{}({})' .format (
134209 self .__class__ .__name__ ,
135210 ', ' .join ("%s" % repr (i ) for i in self ._triples )
136211 )
137212
138- def __contains__ (self , item ) :
213+ def __contains__ (self , item : T ) -> bool :
139214 return item in self ._db_values
140215
141- def __deepcopy__ (self , memo ):
142- return self .__class__ (* copy .deepcopy (self ._triples , memo ))
216+ def __deepcopy__ (self , memo : dict [int , Any ] | None ) -> Choices [T ]:
217+ args : list [Any ] = copy .deepcopy (self ._triples , memo )
218+ return self .__class__ (* args )
143219
144- def subset (self , * new_identifiers ) :
220+ def subset (self , * new_identifiers : str ) -> Choices [ T ] :
145221 identifiers = set (self ._identifier_map .keys ())
146222
147223 if not identifiers .issuperset (new_identifiers ):
@@ -150,7 +226,8 @@ def subset(self, *new_identifiers):
150226 identifiers .symmetric_difference (new_identifiers ),
151227 )
152228
153- return self . __class__ ( * [
229+ args : list [ Any ] = [
154230 choice for choice in self ._triples
155231 if choice [1 ] in new_identifiers
156- ])
232+ ]
233+ return self .__class__ (* args )
0 commit comments