Skip to content

Commit cbc4228

Browse files
committed
2 parents fbafb46 + dc7fa4f commit cbc4228

3 files changed

Lines changed: 259 additions & 2 deletions

File tree

src/graphnet/models/gnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .dynedge_kaggle_tito import DynEdgeTITO
77
from .RNN_tito import RNN_TITO
88
from .icemix import DeepIce
9+
from .particlenet import ParticleNeT
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""Implementation of the ParticleNet GNN model architecture."""
2+
from typing import List, Optional, Callable, Tuple, Union
3+
4+
import torch
5+
from torch import Tensor, LongTensor
6+
from torch_geometric.data import Data
7+
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
8+
9+
from graphnet.models.components.layers import DynEdgeConv
10+
from graphnet.models.gnn.gnn import GNN
11+
12+
GLOBAL_POOLINGS = {
13+
"min": scatter_min,
14+
"max": scatter_max,
15+
"sum": scatter_sum,
16+
"mean": scatter_mean,
17+
}
18+
19+
20+
class ParticleNeT(GNN):
21+
"""ParticleNeT (dynamical edge convolutional) model.
22+
23+
Inspired by: https://arxiv.org/abs/1902.08570
24+
"""
25+
26+
def __init__(
27+
self,
28+
nb_inputs: int,
29+
*,
30+
nb_neighbours: int = 16,
31+
features_subset: Optional[Union[List[int], slice]] = None,
32+
dynamic: bool = True,
33+
dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = [
34+
(64, 64, 64),
35+
(128, 128, 128),
36+
(256, 256, 256),
37+
],
38+
readout_layer_sizes: Optional[List[int]] = [256],
39+
global_pooling_schemes: Optional[Union[str, List[str]]] = "mean",
40+
activation_layer: Optional[str] = "relu",
41+
add_batchnorm_layer: bool = True,
42+
dropout_readout: float = 0.1,
43+
skip_readout: bool = False,
44+
):
45+
"""Construct `ParticleNeT`.
46+
47+
Args:
48+
nb_inputs: Number of input features on each node.
49+
nb_neighbours: Number of neighbours to used in the k-nearest
50+
neighbour clustering which is performed after each (dynamical)
51+
edge convolution.
52+
features_subset: The subset of latent features on each node that
53+
are used as metric dimensions when performing the k-nearest
54+
neighbours clustering. Defaults to [0,1,2].
55+
dynamic: wether or not update the edges after every `DynEdgeConv`
56+
block.
57+
dynedge_layer_sizes: The layer sizes, or latent feature dimenions,
58+
used in the `DynEdgeConv` layer. Each entry in
59+
`dynedge_layer_sizes` corresponds to a single `DynEdgeConv`
60+
layer; the integers in the corresponding tuple corresponds to
61+
the layer sizes in the multi-layer perceptron (MLP) that is
62+
applied within each `DynEdgeConv` layer. That is, a list of
63+
size-three tuples means that all `DynEdgeConv` layers contain
64+
a three-layer MLP.
65+
Defaults to [(64, 64, 64), (128, 128, 128), (256, 256, 256)].
66+
readout_layer_sizes: Hidden layer size in the MLP following the
67+
post-processing _and_ optional global pooling. As this is the
68+
last layer in the model, it yields the output of the `DynEdge`
69+
model. Defaults to [256,].
70+
global_pooling_schemes: The list global pooling schemes to use.
71+
Options are: "min", "max", "mean", and "sum".
72+
Default to "mean".
73+
activation_layer: The activation function to use in the model.
74+
Default to "relu".
75+
add_batchnorm_layer: Whether to add a batch normalization layer
76+
after each linear layer. Default to True.
77+
dropout_readout: Dropout value to use in the readout layer(s).
78+
Default to 0.1.
79+
skip_readout: Whether to skip the readout layer(s). If `True`, the
80+
output of the last DynEdgeConv block is returned directly.
81+
"""
82+
# Latent feature subset for computing nearest neighbours in model
83+
if features_subset is None:
84+
features_subset = slice(0, 3)
85+
86+
# DynEdge layer sizes
87+
if dynedge_layer_sizes is None:
88+
dynedge_layer_sizes = [
89+
(64, 64, 64),
90+
(
91+
128,
92+
128,
93+
128,
94+
),
95+
(
96+
256,
97+
256,
98+
256,
99+
),
100+
]
101+
102+
dynedge_layer_sizes_check = []
103+
for sizes in dynedge_layer_sizes:
104+
if isinstance(sizes, list):
105+
sizes = tuple(sizes)
106+
dynedge_layer_sizes_check.append(sizes)
107+
108+
assert isinstance(dynedge_layer_sizes_check, list)
109+
assert len(dynedge_layer_sizes_check)
110+
assert all(
111+
isinstance(sizes, tuple) for sizes in dynedge_layer_sizes_check
112+
)
113+
assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes_check)
114+
assert all(
115+
all(size > 0 for size in sizes)
116+
for sizes in dynedge_layer_sizes_check
117+
)
118+
119+
self._dynedge_layer_sizes = dynedge_layer_sizes_check
120+
121+
# Read-out layer sizes
122+
if readout_layer_sizes is None:
123+
readout_layer_sizes = [
124+
256,
125+
]
126+
127+
assert isinstance(readout_layer_sizes, list)
128+
assert len(readout_layer_sizes)
129+
assert all(size > 0 for size in readout_layer_sizes)
130+
131+
self._readout_layer_sizes = readout_layer_sizes
132+
133+
# Global pooling scheme(s)
134+
if isinstance(global_pooling_schemes, str):
135+
global_pooling_schemes = [global_pooling_schemes]
136+
137+
if isinstance(global_pooling_schemes, list):
138+
for pooling_scheme in global_pooling_schemes:
139+
assert (
140+
pooling_scheme in GLOBAL_POOLINGS
141+
), f"Global pooling scheme {pooling_scheme} not supported."
142+
else:
143+
assert global_pooling_schemes is None
144+
145+
self._global_pooling_schemes = global_pooling_schemes
146+
147+
if activation_layer is None or activation_layer.lower() == "relu":
148+
activation_layer = torch.nn.ReLU()
149+
elif activation_layer.lower() == "gelu":
150+
activation_layer = torch.nn.GELU()
151+
else:
152+
raise ValueError(
153+
f"Activation layer {activation_layer} not supported."
154+
)
155+
156+
# Base class constructor
157+
super().__init__(nb_inputs, self._readout_layer_sizes[-1])
158+
159+
# Remaining member variables()
160+
self._activation = activation_layer
161+
self._nb_inputs = nb_inputs
162+
self._nb_neighbours = nb_neighbours
163+
self._features_subset = features_subset
164+
self._dynamic = dynamic
165+
self._add_batchnorm_layer = add_batchnorm_layer
166+
self._dropout_readout = dropout_readout
167+
self._skip_readout = skip_readout
168+
169+
self._construct_layers()
170+
171+
# Builds the network
172+
def _construct_layers(self) -> None:
173+
"""Construct layers (torch.nn.Modules)."""
174+
# Convolutional operations
175+
nb_input_features = self._nb_inputs
176+
177+
self._conv_layers = torch.nn.ModuleList()
178+
nb_latent_features = nb_input_features
179+
for sizes in self._dynedge_layer_sizes:
180+
layers = []
181+
layer_sizes = [nb_latent_features] + list(sizes)
182+
for ix, (nb_in, nb_out) in enumerate(
183+
zip(layer_sizes[:-1], layer_sizes[1:])
184+
):
185+
if ix == 0:
186+
nb_in *= 2
187+
layers.append(torch.nn.Linear(nb_in, nb_out))
188+
if self._add_batchnorm_layer:
189+
layers.append(torch.nn.BatchNorm1d(nb_out))
190+
layers.append(self._activation)
191+
192+
conv_layer = DynEdgeConv(
193+
torch.nn.Sequential(*layers),
194+
aggr="mean",
195+
nb_neighbors=self._nb_neighbours,
196+
features_subset=self._features_subset,
197+
)
198+
self._conv_layers.append(conv_layer)
199+
200+
nb_latent_features = nb_out
201+
202+
# Read-out operations
203+
nb_poolings = (
204+
len(self._global_pooling_schemes)
205+
if self._global_pooling_schemes
206+
else 1
207+
)
208+
nb_latent_features = nb_out * nb_poolings
209+
210+
readout_layers = []
211+
layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes)
212+
for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
213+
readout_layers.append(torch.nn.Linear(nb_in, nb_out))
214+
readout_layers.append(self._activation)
215+
readout_layers.append(torch.nn.Dropout(self._dropout_readout))
216+
217+
self._readout = torch.nn.Sequential(*readout_layers)
218+
219+
def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor:
220+
"""Perform global pooling."""
221+
assert self._global_pooling_schemes
222+
pooled = []
223+
for pooling_scheme in self._global_pooling_schemes:
224+
pooling_fn = GLOBAL_POOLINGS[pooling_scheme]
225+
pooled_x = pooling_fn(x, index=batch, dim=0)
226+
if isinstance(pooled_x, tuple) and len(pooled_x) == 2:
227+
# `scatter_{min,max}`, which return also an argument, vs.
228+
# `scatter_{mean,sum}`
229+
pooled_x, _ = pooled_x
230+
pooled.append(pooled_x)
231+
232+
return torch.cat(pooled, dim=1)
233+
234+
def forward(self, data: Data) -> Tensor:
235+
"""Apply learnable forward pass."""
236+
# Convenience variables
237+
x, edge_index, batch = data.x, data.edge_index, data.batch
238+
239+
# DynEdge-convolutions
240+
for conv_layer in self._conv_layers:
241+
if self._dynamic:
242+
x, edge_index = conv_layer(x, edge_index, batch)
243+
else:
244+
x, _ = conv_layer(x, edge_index, batch)
245+
246+
# Read-out
247+
if not self._skip_readout:
248+
# (Optional) Global pooling
249+
if self._global_pooling_schemes:
250+
x = self._global_pooling(x, batch=batch)
251+
252+
# Read-out
253+
x = self._readout(x)
254+
255+
return x

src/graphnet/models/standard_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.optim import Adam
88

99
from graphnet.models.gnn.gnn import GNN
10+
from graphnet.models import Model
1011
from .easy_model import EasySyntax
1112
from graphnet.models.task import StandardLearnedTask
1213
from graphnet.models.graphs import GraphDefinition
@@ -25,7 +26,7 @@ def __init__(
2526
self,
2627
graph_definition: GraphDefinition,
2728
tasks: Union[StandardLearnedTask, List[StandardLearnedTask]],
28-
backbone: GNN = None,
29+
backbone: Model = None,
2930
gnn: Optional[GNN] = None,
3031
optimizer_class: Type[torch.optim.Optimizer] = Adam,
3132
optimizer_kwargs: Optional[Dict] = None,
@@ -60,7 +61,7 @@ def __init__(
6061
)
6162

6263
# Checks
63-
assert isinstance(backbone, GNN)
64+
assert isinstance(backbone, Model)
6465
assert isinstance(graph_definition, GraphDefinition)
6566

6667
# Member variable(s)

0 commit comments

Comments
 (0)