66 isolated and modular computations.
77"""
88
9- from typing import Dict , Iterable , Tuple , Union
9+ from typing import Dict , Iterable , Union , Optional
1010
1111import torch
12+ from torch import Tensor
1213
13- from .names import TARGETS , PREDS , RESERVED_KEYS , RESERVED_MODEL_KEYS
14+ from .names import INPUTS , TARGETS , PREDS , RESERVED_KEYS , RESERVED_MODEL_KEYS
1415
1516ContextValue = Union [torch .Tensor , torch .nn .Module ]
1617
@@ -47,17 +48,9 @@ def add(self, **items: ContextValue) -> "Context":
4748 where keys are the names of the tensors.
4849 """
4950
50- for k , v in items .items ():
51- if k in RESERVED_KEYS and not isinstance (v , torch .Tensor ):
52- raise ReservedKeyTypeError (
53- f"Reserved key '{ k } ' must be a torch.Tensor, got { type (v )} "
54- )
55- elif k in RESERVED_MODEL_KEYS and not isinstance (v , torch .nn .Module ):
56- raise ReservedKeyTypeError (
57- f"Reserved key '{ k } ' must be a torch.nn.Module, got { type (v )} "
58- )
59-
60- self ._store .update (items )
51+ for key , value in items .items ():
52+ self [key ] = value
53+
6154 return self
6255
6356 def require (self , keys : Iterable [str ]) -> None :
@@ -81,15 +74,18 @@ def as_kwargs(self) -> Dict[str, ContextValue]:
8174 """
8275 return self ._store
8376
84- def as_metric_args (self ) -> Tuple [ ContextValue , ContextValue ]:
77+ def as_metric_args (self ) -> tuple [ Tensor , Tensor ]:
8578 """
8679 Returns the predictions and targets tensors for
8780 Image quality assessment metric computation.
8881 Intended use: metric.update(*ctx.as_metric_args())
82+
83+ :return: A tuple (preds, targets) of tensors.
84+ :raises ValueError: If either preds or targets is missing.
8985 """
90- self .require ([PREDS , TARGETS ])
91- preds = self ._store [ PREDS ]
92- targs = self ._store [ TARGETS ]
86+ self .require (keys = [PREDS , TARGETS ])
87+ preds : Tensor = self .preds
88+ targs : Tensor = self .targets
9389 return (preds , targs )
9490
9591 def __repr__ (self ) -> str :
@@ -109,6 +105,28 @@ def __repr__(self) -> str:
109105 # --- Methods for dict like behavior of context class ---
110106
111107 def __setitem__ (self , key : str , value : ContextValue ) -> None :
108+ """
109+ Sets a context item, with checks for reserved keys.
110+
111+ :param key: The name of the context item.
112+ :param value: The tensor/module to store.
113+ """
114+ # Only allow torch.Tensor or torch.nn.Module values
115+ if not isinstance (value , (torch .Tensor , torch .nn .Module )):
116+ raise TypeError (
117+ f"Context values must be torch.Tensor or torch.nn.Module, got { type (value )} "
118+ )
119+
120+ # Further type check matching for reserved keys
121+ if key in RESERVED_KEYS and not isinstance (value , torch .Tensor ):
122+ raise ReservedKeyTypeError (
123+ f"Reserved key '{ key } ' must be a torch.Tensor, got { type (value )} "
124+ )
125+ elif key in RESERVED_MODEL_KEYS and not isinstance (value , torch .nn .Module ):
126+ raise ReservedKeyTypeError (
127+ f"Reserved key '{ key } ' must be a torch.nn.Module, got { type (value )} "
128+ )
129+
112130 self ._store [key ] = value
113131
114132 def __contains__ (self , key : str ) -> bool :
@@ -123,7 +141,7 @@ def __iter__(self):
123141 def __len__ (self ):
124142 return len (self ._store )
125143
126- def get (self , key : str , default : ContextValue = None ) -> ContextValue :
144+ def get (self , key : str , default : Optional [ ContextValue ] = None ) -> Optional [ ContextValue ] :
127145 return self ._store .get (key , default )
128146
129147 def values (self ):
@@ -134,3 +152,59 @@ def items(self):
134152
135153 def keys (self ):
136154 return self ._store .keys ()
155+
156+ def pop (self , key : str ) -> ContextValue :
157+ """
158+ Remove and return the value for key if key is in the context,
159+ else raises a KeyError.
160+ """
161+ if key not in self ._store :
162+ raise KeyError (f"Key '{ key } ' not found in Context." )
163+ return self ._store .pop (key )
164+
165+ def __or__ (self , other : "Context" ) -> "Context" :
166+ """
167+ Merge two Context objects using the | operator.
168+ Returns a new Context with items from both contexts.
169+ Items from the right operand (other) take precedence in case of key conflicts.
170+
171+ :param other: Another Context object to merge with.
172+ :return: A new Context object containing items from both contexts.
173+ """
174+ if not isinstance (other , Context ):
175+ raise NotImplementedError (
176+ "__or__ operation only supported between Context objects."
177+ )
178+ new_context = Context (** self ._store )
179+ new_context .add (** other ._store )
180+ return new_context
181+
182+ def __ror__ (self , other : "Context" ) -> "Context" :
183+ """
184+ Reverse merge (right | operator) for Context objects.
185+ Called when the left operand doesn't support __or__ with Context.
186+
187+ :param other: Another Context object to merge with.
188+ :return: A new Context object containing items from both contexts.
189+ """
190+ if not isinstance (other , Context ):
191+ raise NotImplementedError (
192+ "__or__ operation only supported between Context objects."
193+ )
194+ new_context = Context (** other ._store )
195+ new_context .add (** self ._store )
196+ return new_context
197+
198+ # --- Properties for robust typing for reserved keys ---
199+ # let fail if key is not present
200+ @property
201+ def inputs (self ) -> Tensor :
202+ return self ._store [INPUTS ] # type: ignore
203+
204+ @property
205+ def targets (self ) -> Tensor :
206+ return self ._store [TARGETS ] # type: ignore
207+
208+ @property
209+ def preds (self ) -> Tensor :
210+ return self ._store [PREDS ] # type: ignore
0 commit comments