Skip to content

Commit f347139

Browse files
authored
Merge pull request #24 from wli51/dev-gan-trainer
Dev gan trainer
2 parents 7b1e4ce + 1ef9dad commit f347139

29 files changed

Lines changed: 2752 additions & 291 deletions

examples/3.training_wgan_with_logging_example.ipynb

Lines changed: 787 additions & 0 deletions
Large diffs are not rendered by default.

examples/nbconverted/3.training_wgan_with_logging_example.py

Lines changed: 444 additions & 0 deletions
Large diffs are not rendered by default.

src/virtual_stain_flow/engine/context.py

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
isolated and modular computations.
77
"""
88

9-
from typing import Dict, Iterable, Tuple, Union
9+
from typing import Dict, Iterable, Union, Optional
1010

1111
import torch
12+
from torch import Tensor
1213

13-
from .names import TARGETS, PREDS, RESERVED_KEYS, RESERVED_MODEL_KEYS
14+
from .names import INPUTS, TARGETS, PREDS, RESERVED_KEYS, RESERVED_MODEL_KEYS
1415

1516
ContextValue = Union[torch.Tensor, torch.nn.Module]
1617

@@ -47,17 +48,9 @@ def add(self, **items: ContextValue) -> "Context":
4748
where keys are the names of the tensors.
4849
"""
4950

50-
for k, v in items.items():
51-
if k in RESERVED_KEYS and not isinstance(v, torch.Tensor):
52-
raise ReservedKeyTypeError(
53-
f"Reserved key '{k}' must be a torch.Tensor, got {type(v)}"
54-
)
55-
elif k in RESERVED_MODEL_KEYS and not isinstance(v, torch.nn.Module):
56-
raise ReservedKeyTypeError(
57-
f"Reserved key '{k}' must be a torch.nn.Module, got {type(v)}"
58-
)
59-
60-
self._store.update(items)
51+
for key, value in items.items():
52+
self[key] = value
53+
6154
return self
6255

6356
def require(self, keys: Iterable[str]) -> None:
@@ -81,15 +74,18 @@ def as_kwargs(self) -> Dict[str, ContextValue]:
8174
"""
8275
return self._store
8376

84-
def as_metric_args(self) -> Tuple[ContextValue, ContextValue]:
77+
def as_metric_args(self) -> tuple[Tensor, Tensor]:
8578
"""
8679
Returns the predictions and targets tensors for
8780
Image quality assessment metric computation.
8881
Intended use: metric.update(*ctx.as_metric_args())
82+
83+
:return: A tuple (preds, targets) of tensors.
84+
:raises ValueError: If either preds or targets is missing.
8985
"""
90-
self.require([PREDS, TARGETS])
91-
preds = self._store[PREDS]
92-
targs = self._store[TARGETS]
86+
self.require(keys=[PREDS, TARGETS])
87+
preds: Tensor = self.preds
88+
targs: Tensor = self.targets
9389
return (preds, targs)
9490

9591
def __repr__(self) -> str:
@@ -109,6 +105,28 @@ def __repr__(self) -> str:
109105
# --- Methods for dict like behavior of context class ---
110106

111107
def __setitem__(self, key: str, value: ContextValue) -> None:
108+
"""
109+
Sets a context item, with checks for reserved keys.
110+
111+
:param key: The name of the context item.
112+
:param value: The tensor/module to store.
113+
"""
114+
# Only allow torch.Tensor or torch.nn.Module values
115+
if not isinstance(value, (torch.Tensor, torch.nn.Module)):
116+
raise TypeError(
117+
f"Context values must be torch.Tensor or torch.nn.Module, got {type(value)}"
118+
)
119+
120+
# Further type check matching for reserved keys
121+
if key in RESERVED_KEYS and not isinstance(value, torch.Tensor):
122+
raise ReservedKeyTypeError(
123+
f"Reserved key '{key}' must be a torch.Tensor, got {type(value)}"
124+
)
125+
elif key in RESERVED_MODEL_KEYS and not isinstance(value, torch.nn.Module):
126+
raise ReservedKeyTypeError(
127+
f"Reserved key '{key}' must be a torch.nn.Module, got {type(value)}"
128+
)
129+
112130
self._store[key] = value
113131

114132
def __contains__(self, key: str) -> bool:
@@ -123,7 +141,7 @@ def __iter__(self):
123141
def __len__(self):
124142
return len(self._store)
125143

126-
def get(self, key: str, default: ContextValue = None) -> ContextValue:
144+
def get(self, key: str, default: Optional[ContextValue] = None) -> Optional[ContextValue]:
127145
return self._store.get(key, default)
128146

129147
def values(self):
@@ -134,3 +152,59 @@ def items(self):
134152

135153
def keys(self):
136154
return self._store.keys()
155+
156+
def pop(self, key: str) -> ContextValue:
157+
"""
158+
Remove and return the value for key if key is in the context,
159+
else raises a KeyError.
160+
"""
161+
if key not in self._store:
162+
raise KeyError(f"Key '{key}' not found in Context.")
163+
return self._store.pop(key)
164+
165+
def __or__(self, other: "Context") -> "Context":
166+
"""
167+
Merge two Context objects using the | operator.
168+
Returns a new Context with items from both contexts.
169+
Items from the right operand (other) take precedence in case of key conflicts.
170+
171+
:param other: Another Context object to merge with.
172+
:return: A new Context object containing items from both contexts.
173+
"""
174+
if not isinstance(other, Context):
175+
raise NotImplementedError(
176+
"__or__ operation only supported between Context objects."
177+
)
178+
new_context = Context(**self._store)
179+
new_context.add(**other._store)
180+
return new_context
181+
182+
def __ror__(self, other: "Context") -> "Context":
183+
"""
184+
Reverse merge (right | operator) for Context objects.
185+
Called when the left operand doesn't support __or__ with Context.
186+
187+
:param other: Another Context object to merge with.
188+
:return: A new Context object containing items from both contexts.
189+
"""
190+
if not isinstance(other, Context):
191+
raise NotImplementedError(
192+
"__or__ operation only supported between Context objects."
193+
)
194+
new_context = Context(**other._store)
195+
new_context.add(**self._store)
196+
return new_context
197+
198+
# --- Properties for robust typing for reserved keys ---
199+
# let fail if key is not present
200+
@property
201+
def inputs(self) -> Tensor:
202+
return self._store[INPUTS] # type: ignore
203+
204+
@property
205+
def targets(self) -> Tensor:
206+
return self._store[TARGETS] # type: ignore
207+
208+
@property
209+
def preds(self) -> Tensor:
210+
return self._store[PREDS] # type: ignore

src/virtual_stain_flow/engine/forward_groups.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
"""
3232

3333
from abc import ABC, abstractmethod
34-
from typing import Optional, Tuple, Dict
34+
from typing import Optional, Dict
3535

3636
import torch
3737
import torch.optim as optim
3838
import torch.nn as nn
3939

40-
from .names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL
40+
from .names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL, DISCRIMINATOR_MODEL
4141
from .context import Context
4242

4343

@@ -52,9 +52,9 @@ class AbstractForwardGroup(ABC):
5252
"""
5353

5454
# Subclasses should override these with ordered tuples.
55-
input_keys: Tuple[str, ...]
56-
target_keys: Tuple[str, ...]
57-
output_keys: Tuple[str, ...]
55+
input_keys: tuple[str, ...]
56+
target_keys: tuple[str, ...]
57+
output_keys: tuple[str, ...]
5858

5959
def __init__(
6060
self,
@@ -75,7 +75,7 @@ def _move_tensors(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso
7575
}
7676

7777
@staticmethod
78-
def _normalize_outputs(raw) -> Tuple[torch.Tensor, ...]:
78+
def _normalize_outputs(raw) -> tuple[torch.Tensor, ...]:
7979
"""
8080
Normalize model outputs to a tuple of tensors while preserving order.
8181
@@ -140,9 +140,9 @@ class GeneratorForwardGroup(AbstractForwardGroup):
140140
metric_value = metric_fn(preds, targets)
141141
"""
142142

143-
input_keys: Tuple[str, ...] = (INPUTS,)
144-
target_keys: Tuple[str, ...] = (TARGETS,)
145-
output_keys: Tuple[str, ...] = (PREDS,)
143+
input_keys: tuple[str, ...] = (INPUTS,)
144+
target_keys: tuple[str, ...] = (TARGETS,)
145+
output_keys: tuple[str, ...] = (PREDS,)
146146

147147
def __init__(
148148
self,
@@ -207,3 +207,88 @@ def optimizer(self) -> Optional[optim.Optimizer]:
207207
Convenience property to access the generator optimizer directly.
208208
"""
209209
return self._optimizers[GENERATOR_MODEL]
210+
211+
212+
class DiscriminatorForwardGroup(AbstractForwardGroup):
213+
"""
214+
Forward group for a simple single (GAN/wGAN) discriminator workflow.
215+
The discriminator is assumed to take in a "stack" of input and target
216+
images concatenated along the channel dimension, and output a score/probability.
217+
Relevant context values are input_keys, target_keys, output_keys for a
218+
single-discriminator model, where:
219+
- the forward is called as:
220+
p = discriminator(stack)
221+
- the evaluation is less straightforward, but typically involves
222+
computing losses/metrics based on p and real/fake labels:
223+
metric_value = metric_fn(p, real_or_fake_labels)
224+
or perhaps involving the discrminator model itself for wasserstein distance:
225+
metric_value = metric_fn(discriminator, stack, real_or_fake_labels)
226+
"""
227+
228+
input_keys: tuple[str, ...] = ("stack",)
229+
target_keys: tuple[str, ...] = ()
230+
output_keys: tuple[str, ...] = ("p",)
231+
232+
def __init__(
233+
self,
234+
discriminator: nn.Module,
235+
optimizer: Optional[optim.Optimizer] = None,
236+
device: torch.device = torch.device("cpu"),
237+
):
238+
super().__init__(device=device)
239+
240+
self._models[DISCRIMINATOR_MODEL] = discriminator
241+
self._models[DISCRIMINATOR_MODEL].to(self.device)
242+
self._optimizers[DISCRIMINATOR_MODEL] = optimizer
243+
244+
def __call__(self, train: bool, **inputs: torch.Tensor) -> Context:
245+
"""
246+
Executes the forward pass, managing training/eval modes and optimizer steps.
247+
Subclasses may override this method if needed.
248+
249+
:param train: Whether to run in training mode. Meant to be specified
250+
by the trainer to switch between train/eval modes and determine
251+
whether gradients should be computed.
252+
:param inputs: Keyword arguments of input tensors.
253+
"""
254+
255+
fp_model = self.model
256+
fp_optimizer = self.optimizer
257+
258+
# 1) Stage and validate inputs/targets
259+
ctx = Context(**self._move_tensors(inputs), **{DISCRIMINATOR_MODEL: fp_model })
260+
ctx.require(self.input_keys)
261+
ctx.require(self.target_keys)
262+
263+
# 2) Forward, with grad only when training
264+
fp_model.train(mode=train)
265+
train and fp_optimizer is not None and fp_optimizer.zero_grad(set_to_none=True)
266+
with torch.set_grad_enabled(train):
267+
model_inputs = [ctx[k] for k in self.input_keys] # ordered
268+
raw = fp_model(*model_inputs)
269+
y_tuple = self._normalize_outputs(raw)
270+
271+
# 3) Arity check + map outputs to names
272+
if len(y_tuple) != len(self.output_keys):
273+
raise ValueError(
274+
f"Model returned {len(y_tuple)} outputs, "
275+
f"but output_keys expects {len(self.output_keys)}"
276+
)
277+
outputs = {k: v for k, v in zip(self.output_keys, y_tuple)}
278+
279+
# 5) Return enriched context (preds available for losses/metrics)
280+
return ctx.add(**outputs)
281+
282+
@property
283+
def model(self) -> nn.Module:
284+
"""
285+
Convenience property to access the discriminator model directly.
286+
"""
287+
return self._models[DISCRIMINATOR_MODEL]
288+
289+
@property
290+
def optimizer(self) -> Optional[optim.Optimizer]:
291+
"""
292+
Convenience property to access the discriminator optimizer directly.
293+
"""
294+
return self._optimizers[DISCRIMINATOR_MODEL]

src/virtual_stain_flow/engine/loss_group.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626

2727
import torch
2828

29-
from .loss_utils import AbstractLoss, _get_loss_name, _scalar_from_ctx
30-
from .context import Context
29+
from .loss_utils import BaseLoss, _get_loss_name, _scalar_from_ctx
30+
from .context import Context, ContextValue
3131
from .names import PREDS, TARGETS
3232

33+
Scalar = Union[int, float, bool]
34+
3335

3436
@dataclass
3537
class LossItem:
@@ -53,7 +55,7 @@ class LossItem:
5355
losses and centralizes device management.
5456
)
5557
"""
56-
module: Union[torch.nn.Module, AbstractLoss]
58+
module: Union[torch.nn.Module, BaseLoss]
5759
args: Union[str, Tuple[str, ...]] = (PREDS, TARGETS)
5860
key: Optional[str] = None
5961
weight: float = 1.0
@@ -63,7 +65,7 @@ class LossItem:
6365

6466
def __post_init__(self):
6567

66-
self.key = self.key or _get_loss_name(self.module)
68+
self.key = str(self.key or _get_loss_name(self.module))
6769
self.args = (self.args,) if isinstance(self.args, str) else self.args
6870

6971
try:
@@ -94,7 +96,9 @@ def __call__(
9496

9597
if context is not None:
9698
context.require(self.args)
97-
inputs = {arg: context[arg] for arg in self.args}
99+
inputs: Dict[str, ContextValue] = {
100+
arg: context[arg] for arg in self.args
101+
}
98102

99103
if not self.enabled or (not train and not self.compute_at_val):
100104
zero = _scalar_from_ctx(0.0, inputs)
@@ -127,15 +131,15 @@ class LossGroup:
127131
items: Sequence[LossItem]
128132

129133
@property
130-
def item_names(self) -> List[str]:
134+
def item_names(self) -> List[Optional[str]]:
131135
return [item.key for item in self.items]
132136

133137
def __call__(
134138
self,
135139
train: bool,
136140
context: Optional[Context] = None,
137141
**inputs: torch.Tensor
138-
) -> Tuple[torch.Tensor, Dict[str, float]]:
142+
) -> Tuple[torch.Tensor, Dict[str, Scalar]]:
139143
"""
140144
Compute the total loss and individual loss values.
141145
@@ -153,7 +157,7 @@ def __call__(
153157

154158
for item in self.items:
155159
raw, weighted = item(train, context=context, **inputs)
156-
logs[item.key] = raw.item()
160+
logs[item.key] = raw.item() # type: ignore
157161
total += weighted
158162

159163
return total, logs

0 commit comments

Comments
 (0)