|
| 1 | +"""Implementation of the ParticleNet GNN model architecture.""" |
| 2 | +from typing import List, Optional, Callable, Tuple, Union |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import Tensor, LongTensor |
| 6 | +from torch_geometric.data import Data |
| 7 | +from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum |
| 8 | + |
| 9 | +from graphnet.models.components.layers import DynEdgeConv |
| 10 | +from graphnet.models.gnn.gnn import GNN |
| 11 | + |
| 12 | +GLOBAL_POOLINGS = { |
| 13 | + "min": scatter_min, |
| 14 | + "max": scatter_max, |
| 15 | + "sum": scatter_sum, |
| 16 | + "mean": scatter_mean, |
| 17 | +} |
| 18 | + |
| 19 | + |
| 20 | +class ParticleNeT(GNN): |
| 21 | + """ParticleNeT (dynamical edge convolutional) model. |
| 22 | +
|
| 23 | + Inspired by: https://arxiv.org/abs/1902.08570 |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + nb_inputs: int, |
| 29 | + *, |
| 30 | + nb_neighbours: int = 16, |
| 31 | + features_subset: Optional[Union[List[int], slice]] = None, |
| 32 | + dynamic: bool = True, |
| 33 | + dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = [ |
| 34 | + (64, 64, 64), |
| 35 | + (128, 128, 128), |
| 36 | + (256, 256, 256), |
| 37 | + ], |
| 38 | + readout_layer_sizes: Optional[List[int]] = [256], |
| 39 | + global_pooling_schemes: Optional[Union[str, List[str]]] = "mean", |
| 40 | + activation_layer: Optional[str] = "relu", |
| 41 | + add_batchnorm_layer: bool = True, |
| 42 | + dropout_readout: float = 0.1, |
| 43 | + skip_readout: bool = False, |
| 44 | + ): |
| 45 | + """Construct `ParticleNeT`. |
| 46 | +
|
| 47 | + Args: |
| 48 | + nb_inputs: Number of input features on each node. |
| 49 | + nb_neighbours: Number of neighbours to used in the k-nearest |
| 50 | + neighbour clustering which is performed after each (dynamical) |
| 51 | + edge convolution. |
| 52 | + features_subset: The subset of latent features on each node that |
| 53 | + are used as metric dimensions when performing the k-nearest |
| 54 | + neighbours clustering. Defaults to [0,1,2]. |
| 55 | + dynamic: wether or not update the edges after every `DynEdgeConv` |
| 56 | + block. |
| 57 | + dynedge_layer_sizes: The layer sizes, or latent feature dimenions, |
| 58 | + used in the `DynEdgeConv` layer. Each entry in |
| 59 | + `dynedge_layer_sizes` corresponds to a single `DynEdgeConv` |
| 60 | + layer; the integers in the corresponding tuple corresponds to |
| 61 | + the layer sizes in the multi-layer perceptron (MLP) that is |
| 62 | + applied within each `DynEdgeConv` layer. That is, a list of |
| 63 | + size-three tuples means that all `DynEdgeConv` layers contain |
| 64 | + a three-layer MLP. |
| 65 | + Defaults to [(64, 64, 64), (128, 128, 128), (256, 256, 256)]. |
| 66 | + readout_layer_sizes: Hidden layer size in the MLP following the |
| 67 | + post-processing _and_ optional global pooling. As this is the |
| 68 | + last layer in the model, it yields the output of the `DynEdge` |
| 69 | + model. Defaults to [256,]. |
| 70 | + global_pooling_schemes: The list global pooling schemes to use. |
| 71 | + Options are: "min", "max", "mean", and "sum". |
| 72 | + Default to "mean". |
| 73 | + activation_layer: The activation function to use in the model. |
| 74 | + Default to "relu". |
| 75 | + add_batchnorm_layer: Whether to add a batch normalization layer |
| 76 | + after each linear layer. Default to True. |
| 77 | + dropout_readout: Dropout value to use in the readout layer(s). |
| 78 | + Default to 0.1. |
| 79 | + skip_readout: Whether to skip the readout layer(s). If `True`, the |
| 80 | + output of the last DynEdgeConv block is returned directly. |
| 81 | + """ |
| 82 | + # Latent feature subset for computing nearest neighbours in model |
| 83 | + if features_subset is None: |
| 84 | + features_subset = slice(0, 3) |
| 85 | + |
| 86 | + # DynEdge layer sizes |
| 87 | + if dynedge_layer_sizes is None: |
| 88 | + dynedge_layer_sizes = [ |
| 89 | + (64, 64, 64), |
| 90 | + ( |
| 91 | + 128, |
| 92 | + 128, |
| 93 | + 128, |
| 94 | + ), |
| 95 | + ( |
| 96 | + 256, |
| 97 | + 256, |
| 98 | + 256, |
| 99 | + ), |
| 100 | + ] |
| 101 | + |
| 102 | + dynedge_layer_sizes_check = [] |
| 103 | + for sizes in dynedge_layer_sizes: |
| 104 | + if isinstance(sizes, list): |
| 105 | + sizes = tuple(sizes) |
| 106 | + dynedge_layer_sizes_check.append(sizes) |
| 107 | + |
| 108 | + assert isinstance(dynedge_layer_sizes_check, list) |
| 109 | + assert len(dynedge_layer_sizes_check) |
| 110 | + assert all( |
| 111 | + isinstance(sizes, tuple) for sizes in dynedge_layer_sizes_check |
| 112 | + ) |
| 113 | + assert all(len(sizes) > 0 for sizes in dynedge_layer_sizes_check) |
| 114 | + assert all( |
| 115 | + all(size > 0 for size in sizes) |
| 116 | + for sizes in dynedge_layer_sizes_check |
| 117 | + ) |
| 118 | + |
| 119 | + self._dynedge_layer_sizes = dynedge_layer_sizes_check |
| 120 | + |
| 121 | + # Read-out layer sizes |
| 122 | + if readout_layer_sizes is None: |
| 123 | + readout_layer_sizes = [ |
| 124 | + 256, |
| 125 | + ] |
| 126 | + |
| 127 | + assert isinstance(readout_layer_sizes, list) |
| 128 | + assert len(readout_layer_sizes) |
| 129 | + assert all(size > 0 for size in readout_layer_sizes) |
| 130 | + |
| 131 | + self._readout_layer_sizes = readout_layer_sizes |
| 132 | + |
| 133 | + # Global pooling scheme(s) |
| 134 | + if isinstance(global_pooling_schemes, str): |
| 135 | + global_pooling_schemes = [global_pooling_schemes] |
| 136 | + |
| 137 | + if isinstance(global_pooling_schemes, list): |
| 138 | + for pooling_scheme in global_pooling_schemes: |
| 139 | + assert ( |
| 140 | + pooling_scheme in GLOBAL_POOLINGS |
| 141 | + ), f"Global pooling scheme {pooling_scheme} not supported." |
| 142 | + else: |
| 143 | + assert global_pooling_schemes is None |
| 144 | + |
| 145 | + self._global_pooling_schemes = global_pooling_schemes |
| 146 | + |
| 147 | + if activation_layer is None or activation_layer.lower() == "relu": |
| 148 | + activation_layer = torch.nn.ReLU() |
| 149 | + elif activation_layer.lower() == "gelu": |
| 150 | + activation_layer = torch.nn.GELU() |
| 151 | + else: |
| 152 | + raise ValueError( |
| 153 | + f"Activation layer {activation_layer} not supported." |
| 154 | + ) |
| 155 | + |
| 156 | + # Base class constructor |
| 157 | + super().__init__(nb_inputs, self._readout_layer_sizes[-1]) |
| 158 | + |
| 159 | + # Remaining member variables() |
| 160 | + self._activation = activation_layer |
| 161 | + self._nb_inputs = nb_inputs |
| 162 | + self._nb_neighbours = nb_neighbours |
| 163 | + self._features_subset = features_subset |
| 164 | + self._dynamic = dynamic |
| 165 | + self._add_batchnorm_layer = add_batchnorm_layer |
| 166 | + self._dropout_readout = dropout_readout |
| 167 | + self._skip_readout = skip_readout |
| 168 | + |
| 169 | + self._construct_layers() |
| 170 | + |
| 171 | + # Builds the network |
| 172 | + def _construct_layers(self) -> None: |
| 173 | + """Construct layers (torch.nn.Modules).""" |
| 174 | + # Convolutional operations |
| 175 | + nb_input_features = self._nb_inputs |
| 176 | + |
| 177 | + self._conv_layers = torch.nn.ModuleList() |
| 178 | + nb_latent_features = nb_input_features |
| 179 | + for sizes in self._dynedge_layer_sizes: |
| 180 | + layers = [] |
| 181 | + layer_sizes = [nb_latent_features] + list(sizes) |
| 182 | + for ix, (nb_in, nb_out) in enumerate( |
| 183 | + zip(layer_sizes[:-1], layer_sizes[1:]) |
| 184 | + ): |
| 185 | + if ix == 0: |
| 186 | + nb_in *= 2 |
| 187 | + layers.append(torch.nn.Linear(nb_in, nb_out)) |
| 188 | + if self._add_batchnorm_layer: |
| 189 | + layers.append(torch.nn.BatchNorm1d(nb_out)) |
| 190 | + layers.append(self._activation) |
| 191 | + |
| 192 | + conv_layer = DynEdgeConv( |
| 193 | + torch.nn.Sequential(*layers), |
| 194 | + aggr="mean", |
| 195 | + nb_neighbors=self._nb_neighbours, |
| 196 | + features_subset=self._features_subset, |
| 197 | + ) |
| 198 | + self._conv_layers.append(conv_layer) |
| 199 | + |
| 200 | + nb_latent_features = nb_out |
| 201 | + |
| 202 | + # Read-out operations |
| 203 | + nb_poolings = ( |
| 204 | + len(self._global_pooling_schemes) |
| 205 | + if self._global_pooling_schemes |
| 206 | + else 1 |
| 207 | + ) |
| 208 | + nb_latent_features = nb_out * nb_poolings |
| 209 | + |
| 210 | + readout_layers = [] |
| 211 | + layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes) |
| 212 | + for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]): |
| 213 | + readout_layers.append(torch.nn.Linear(nb_in, nb_out)) |
| 214 | + readout_layers.append(self._activation) |
| 215 | + readout_layers.append(torch.nn.Dropout(self._dropout_readout)) |
| 216 | + |
| 217 | + self._readout = torch.nn.Sequential(*readout_layers) |
| 218 | + |
| 219 | + def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor: |
| 220 | + """Perform global pooling.""" |
| 221 | + assert self._global_pooling_schemes |
| 222 | + pooled = [] |
| 223 | + for pooling_scheme in self._global_pooling_schemes: |
| 224 | + pooling_fn = GLOBAL_POOLINGS[pooling_scheme] |
| 225 | + pooled_x = pooling_fn(x, index=batch, dim=0) |
| 226 | + if isinstance(pooled_x, tuple) and len(pooled_x) == 2: |
| 227 | + # `scatter_{min,max}`, which return also an argument, vs. |
| 228 | + # `scatter_{mean,sum}` |
| 229 | + pooled_x, _ = pooled_x |
| 230 | + pooled.append(pooled_x) |
| 231 | + |
| 232 | + return torch.cat(pooled, dim=1) |
| 233 | + |
| 234 | + def forward(self, data: Data) -> Tensor: |
| 235 | + """Apply learnable forward pass.""" |
| 236 | + # Convenience variables |
| 237 | + x, edge_index, batch = data.x, data.edge_index, data.batch |
| 238 | + |
| 239 | + # DynEdge-convolutions |
| 240 | + for conv_layer in self._conv_layers: |
| 241 | + if self._dynamic: |
| 242 | + x, edge_index = conv_layer(x, edge_index, batch) |
| 243 | + else: |
| 244 | + x, _ = conv_layer(x, edge_index, batch) |
| 245 | + |
| 246 | + # Read-out |
| 247 | + if not self._skip_readout: |
| 248 | + # (Optional) Global pooling |
| 249 | + if self._global_pooling_schemes: |
| 250 | + x = self._global_pooling(x, batch=batch) |
| 251 | + |
| 252 | + # Read-out |
| 253 | + x = self._readout(x) |
| 254 | + |
| 255 | + return x |
0 commit comments