Skip to content

Commit ec8c269

Browse files
authored
Merge pull request #5 from lincc-frameworks/awo/update-2026-03-017
Updated model with current best practices, added getting started notebook
2 parents a52c1bf + 783593f commit ec8c269

7 files changed

Lines changed: 406 additions & 40 deletions

File tree

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,9 @@ _html/
148148

149149
# Project initialization script
150150
.initialize_new_project.sh
151+
152+
# Default Hyrax results directory
153+
results/
154+
docs/notebooks/results/
155+
docs/pre_executed/results/
156+
data/

docs/pre_executed/model_usage_example.ipynb

Lines changed: 275 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ classifiers = [
1717
dynamic = ["version"]
1818
requires-python = ">=3.11"
1919
dependencies = [
20+
"hyrax", # The main dependency of this project
21+
"torch", # Used for the example model in this project
2022
]
2123

2224
[project.urls]
@@ -30,6 +32,9 @@ dev = [
3032
"pytest",
3133
"pytest-cov", # Used to report total code coverage
3234
"ruff", # Used for static linting of files
35+
"numpy", # Required by example notebooks
36+
"matplotlib", # Required by example notebooks
37+
"scikit-learn", # Required by example notebooks (imported as sklearn)
3338
]
3439

3540
[build-system]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .example_model import ExampleModel
1+
from .models.vgg11 import VGG11
22

3-
__all__ = ["ExampleModel"]
3+
__all__ = ["VGG11"]
Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1-
[model]
2-
[model.ExampleModel]
3-
layer = 10
1+
[external_hyrax_example]
2+
3+
[external_hyrax_example.VGG11]
4+
dropout = 0.5
5+
num_classes = 10
6+
batch_norm = true
7+
8+
# The libpath that would be used for runtime config
9+
# name = "external_hyrax_example.models.vgg11.VGG11"

src/external_hyrax_example/example_model.py

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from typing import Union, cast
2+
3+
import torch
4+
import torch.nn as nn
5+
from hyrax.models.model_registry import hyrax_model
6+
7+
cfgs = {
8+
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
9+
}
10+
11+
12+
@hyrax_model
13+
class VGG11(nn.Module):
14+
"""Copy of the PyTorch VGG11 model for testing and demonstration
15+
purposes.
16+
https://docs.pytorch.org/vision/main/models/generated/torchvision.models.vgg11.html#torchvision.models.vgg11
17+
"""
18+
19+
def __init__(self, config, data_sample=None):
20+
"""Basic initialization with architecture definition"""
21+
super().__init__()
22+
if data_sample is None:
23+
raise ValueError(
24+
"VGG11 expected 'data_sample' to be provided at construction time "
25+
"so that input channel dimensions can be inferred, but received None."
26+
)
27+
image_sample = data_sample[0]
28+
self.in_channels, width, height = image_sample.shape
29+
self.config = config
30+
31+
dropout = self.config["external_hyrax_example"]["VGG11"]["dropout"]
32+
num_classes = self.config["external_hyrax_example"]["VGG11"]["num_classes"]
33+
batch_norm = self.config["external_hyrax_example"]["VGG11"]["batch_norm"]
34+
35+
self.features = self._make_layers(cfgs["A"], batch_norm=batch_norm)
36+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
37+
self.classifier = nn.Sequential(
38+
nn.Linear(512 * 7 * 7, 4096),
39+
nn.ReLU(True),
40+
nn.Dropout(p=dropout),
41+
nn.Linear(4096, 4096),
42+
nn.ReLU(True),
43+
nn.Dropout(p=dropout),
44+
nn.Linear(4096, num_classes),
45+
)
46+
47+
def _make_layers(self, cfg: list[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
48+
"""Helper function to create the convolutional layers of the VGG11 architecture"""
49+
layers: list[nn.Module] = []
50+
in_channels = self.in_channels
51+
for v in cfg:
52+
if v == "M":
53+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
54+
else:
55+
v = cast(int, v)
56+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
57+
if batch_norm:
58+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
59+
else:
60+
layers += [conv2d, nn.ReLU(inplace=True)]
61+
in_channels = v
62+
return nn.Sequential(*layers)
63+
64+
def forward(self, batch: tuple) -> torch.Tensor:
65+
"""The innermost logic in the forward pass"""
66+
x, _ = batch
67+
x = self.features(x)
68+
x = self.avgpool(x)
69+
x = torch.flatten(x, 1)
70+
x = self.classifier(x)
71+
return x
72+
73+
def infer_batch(self, batch):
74+
"""The innermost logic in the inference loop"""
75+
return self(batch)
76+
77+
def train_batch(self, batch):
78+
"""The innermost logic in the training loop"""
79+
_, labels = batch
80+
self.optimizer.zero_grad()
81+
outputs = self(batch)
82+
loss = self.criterion(outputs, labels)
83+
loss.backward()
84+
self.optimizer.step()
85+
return {"loss": loss.item()}
86+
87+
def validate_batch(self, batch):
88+
"""The innermost logic in the validation loop"""
89+
_, labels = batch
90+
outputs = self(batch)
91+
loss = self.criterion(outputs, labels)
92+
return {"loss": loss.item()}
93+
94+
def test_batch(self, batch):
95+
"""The innermost logic in the testing loop"""
96+
_, labels = batch
97+
outputs = self(batch)
98+
loss = self.criterion(outputs, labels)
99+
return {"loss": loss.item()}
100+
101+
@staticmethod
102+
def prepare_data(data_dict):
103+
"""Method that converts the data in dictionary into the form the model expects"""
104+
image = data_dict["data"]["image"]
105+
106+
label = None
107+
if "label" in data_dict["data"]:
108+
label = data_dict["data"]["label"]
109+
return (image, label)

0 commit comments

Comments
 (0)