Skip to content

Commit e3188e3

Browse files
authored
Merge branch 'jammy_flow_integration' into main
2 parents 652f194 + 49576cb commit e3188e3

13 files changed

Lines changed: 545 additions & 62 deletions

File tree

.github/actions/install/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ runs:
3838
run: |
3939
echo requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }}
4040
pip install -r requirements/torch_${{ inputs.hardware }}.txt ${{ env.PIP_FLAGS }} .${{ inputs.extras }}
41+
pip install git+https://github.com/thoglu/jammy_flows.git
4142
shell: bash

.github/workflows/build.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ jobs:
6363
uses: ./.github/actions/install
6464
with:
6565
editable: true
66+
- name: Print packages in pip
67+
run: |
68+
pip show torch
69+
pip show torch-geometric
70+
pip show torch-cluster
71+
pip show torch-sparse
72+
pip show torch-scatter
73+
pip show jammy_flows
6674
- name: Run unit tests and generate coverage report
6775
run: |
6876
coverage run --source=graphnet -m pytest tests/ --ignore=tests/examples/04_training --ignore=tests/utilities
@@ -110,6 +118,8 @@ jobs:
110118
pip show torch-sparse
111119
pip show torch-scatter
112120
pip show numpy
121+
122+
113123
- name: Run unit tests and generate coverage report
114124
run: |
115125
set -o pipefail # To propagate exit code from pytest

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@ repos:
1010
rev: 4.0.1
1111
hooks:
1212
- id: flake8
13+
language_version: python3
1314
- repo: https://github.com/pycqa/docformatter
1415
rev: v1.5.0
1516
hooks:
1617
- id: docformatter
18+
language_version: python3
1719
- repo: https://github.com/pycqa/pydocstyle
1820
rev: 6.1.1
1921
hooks:
2022
- id: pydocstyle
23+
language_version: python3
2124
- repo: https://github.com/pre-commit/mirrors-mypy
2225
rev: v0.982
2326
hooks:
2427
- id: mypy
2528
args: [--follow-imports=silent, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls]
29+
language_version: python3

docs/source/installation/quick-start.html

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,20 @@
107107
}
108108

109109
if (os == "linux" && cuda != "cpu" && torch != "no_torch"){
110-
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`);
110+
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
111111
}
112112
else if (os == "linux" && cuda == "cpu" && torch != "no_torch"){
113-
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]`);
113+
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
114114
}
115115
else if (os == "linux" && cuda == "cpu" && torch == "no_torch"){
116-
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]`);
116+
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_${$("#command").attr("cuda")}.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
117117
}
118118

119119
if (os == "macos" && cuda == "cpu" && torch != "no_torch"){
120-
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]`);
120+
$("#command pre").text(`git clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[torch,develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
121121
}
122122
if (os == "macos" && cuda == "cpu" && torch == "no_torch"){
123-
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]`);
123+
$("#command pre").text(`# Installations without PyTorch are intended for file conversion only\ngit clone https://github.com/graphnet-team/graphnet.git\ncd graphnet\n\npip install -r requirements/torch_macos.txt -e .[develop]\n\n#Optionally, install jammy_flows for normalizing flow support:\npip install git+https://github.com/thoglu/jammy_flows.git`);
124124
}
125125
}
126126

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""Example of training a conditional NormalizingFlow."""
2+
3+
import os
4+
from typing import Any, Dict, List, Optional
5+
6+
from pytorch_lightning.loggers import WandbLogger
7+
import torch
8+
from torch.optim.adam import Adam
9+
10+
from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
11+
from graphnet.data.constants import FEATURES, TRUTH
12+
from graphnet.models.detector.prometheus import Prometheus
13+
from graphnet.models.gnn import DynEdge
14+
from graphnet.models.graphs import KNNGraph
15+
from graphnet.training.callbacks import PiecewiseLinearLR
16+
from graphnet.training.utils import make_train_validation_dataloader
17+
from graphnet.utilities.argparse import ArgumentParser
18+
from graphnet.utilities.logging import Logger
19+
from graphnet.utilities.imports import has_jammy_flows_package
20+
21+
# Make sure the jammy flows is installed
22+
try:
23+
assert has_jammy_flows_package()
24+
from graphnet.models import NormalizingFlow
25+
except AssertionError:
26+
raise AssertionError(
27+
"This example requires the package`jammy_flow` "
28+
" to be installed. It appears that the package is "
29+
" not installed. Please install the package."
30+
)
31+
32+
# Constants
33+
features = FEATURES.PROMETHEUS
34+
truth = TRUTH.PROMETHEUS
35+
36+
37+
def main(
38+
path: str,
39+
pulsemap: str,
40+
target: str,
41+
truth_table: str,
42+
gpus: Optional[List[int]],
43+
max_epochs: int,
44+
early_stopping_patience: int,
45+
batch_size: int,
46+
num_workers: int,
47+
wandb: bool = False,
48+
) -> None:
49+
"""Run example."""
50+
# Construct Logger
51+
logger = Logger()
52+
53+
# Initialise Weights & Biases (W&B) run
54+
if wandb:
55+
# Make sure W&B output directory exists
56+
wandb_dir = "./wandb/"
57+
os.makedirs(wandb_dir, exist_ok=True)
58+
wandb_logger = WandbLogger(
59+
project="example-script",
60+
entity="graphnet-team",
61+
save_dir=wandb_dir,
62+
log_model=True,
63+
)
64+
65+
logger.info(f"features: {features}")
66+
logger.info(f"truth: {truth}")
67+
68+
# Configuration
69+
config: Dict[str, Any] = {
70+
"path": path,
71+
"pulsemap": pulsemap,
72+
"batch_size": batch_size,
73+
"num_workers": num_workers,
74+
"target": target,
75+
"early_stopping_patience": early_stopping_patience,
76+
"fit": {
77+
"gpus": gpus,
78+
"max_epochs": max_epochs,
79+
},
80+
}
81+
82+
archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
83+
run_name = "dynedge_{}_example".format(config["target"])
84+
if wandb:
85+
# Log configuration to W&B
86+
wandb_logger.experiment.config.update(config)
87+
88+
# Define graph representation
89+
graph_definition = KNNGraph(detector=Prometheus())
90+
91+
(
92+
training_dataloader,
93+
validation_dataloader,
94+
) = make_train_validation_dataloader(
95+
db=config["path"],
96+
graph_definition=graph_definition,
97+
pulsemaps=config["pulsemap"],
98+
features=features,
99+
truth=truth,
100+
batch_size=config["batch_size"],
101+
num_workers=config["num_workers"],
102+
truth_table=truth_table,
103+
selection=None,
104+
)
105+
106+
# Building model
107+
108+
backbone = DynEdge(
109+
nb_inputs=graph_definition.nb_outputs,
110+
global_pooling_schemes=["min", "max", "mean", "sum"],
111+
)
112+
113+
model = NormalizingFlow(
114+
graph_definition=graph_definition,
115+
backbone=backbone,
116+
optimizer_class=Adam,
117+
target_labels=config["target"],
118+
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
119+
scheduler_class=PiecewiseLinearLR,
120+
scheduler_kwargs={
121+
"milestones": [
122+
0,
123+
len(training_dataloader) / 2,
124+
len(training_dataloader) * config["fit"]["max_epochs"],
125+
],
126+
"factors": [1e-2, 1, 1e-02],
127+
},
128+
scheduler_config={
129+
"interval": "step",
130+
},
131+
)
132+
133+
# Training model
134+
model.fit(
135+
training_dataloader,
136+
validation_dataloader,
137+
early_stopping_patience=config["early_stopping_patience"],
138+
logger=wandb_logger if wandb else None,
139+
**config["fit"],
140+
)
141+
142+
# Get predictions
143+
additional_attributes = model.target_labels
144+
assert isinstance(additional_attributes, list) # mypy
145+
146+
results = model.predict_as_dataframe(
147+
validation_dataloader,
148+
additional_attributes=additional_attributes + ["event_no"],
149+
gpus=config["fit"]["gpus"],
150+
)
151+
152+
# Save predictions and model to file
153+
db_name = path.split("/")[-1].split(".")[0]
154+
path = os.path.join(archive, db_name, run_name)
155+
logger.info(f"Writing results to {path}")
156+
os.makedirs(path, exist_ok=True)
157+
158+
# Save results as .csv
159+
results.to_csv(f"{path}/results.csv")
160+
161+
# Save full model (including weights) to .pth file - not version safe
162+
# Note: Models saved as .pth files in one version of graphnet
163+
# may not be compatible with a different version of graphnet.
164+
model.save(f"{path}/model.pth")
165+
166+
# Save model config and state dict - Version safe save method.
167+
# This method of saving models is the safest way.
168+
model.save_state_dict(f"{path}/state_dict.pth")
169+
model.save_config(f"{path}/model_config.yml")
170+
171+
172+
if __name__ == "__main__":
173+
174+
# Parse command-line arguments
175+
parser = ArgumentParser(
176+
description="""
177+
Train conditional NormalizingFlow without the use of config files.
178+
"""
179+
)
180+
181+
parser.add_argument(
182+
"--path",
183+
help="Path to dataset file (default: %(default)s)",
184+
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
185+
)
186+
187+
parser.add_argument(
188+
"--pulsemap",
189+
help="Name of pulsemap to use (default: %(default)s)",
190+
default="total",
191+
)
192+
193+
parser.add_argument(
194+
"--target",
195+
help=(
196+
"Name of feature to use as regression target (default: "
197+
"%(default)s)"
198+
),
199+
default="total_energy",
200+
)
201+
202+
parser.add_argument(
203+
"--truth-table",
204+
help="Name of truth table to be used (default: %(default)s)",
205+
default="mc_truth",
206+
)
207+
208+
parser.with_standard_arguments(
209+
"gpus",
210+
("max-epochs", 1),
211+
"early-stopping-patience",
212+
("batch-size", 16),
213+
"num-workers",
214+
)
215+
216+
parser.add_argument(
217+
"--wandb",
218+
action="store_true",
219+
help="If True, Weights & Biases are used to track the experiment.",
220+
)
221+
222+
args, unknown = parser.parse_known_args()
223+
224+
main(
225+
args.path,
226+
args.pulsemap,
227+
args.target,
228+
args.truth_table,
229+
args.gpus,
230+
args.max_epochs,
231+
args.early_stopping_patience,
232+
args.batch_size,
233+
args.num_workers,
234+
args.wandb,
235+
)

src/graphnet/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
existing, purpose-built components and chain them together to form a complete
77
GNN
88
"""
9-
10-
9+
from graphnet.utilities.imports import has_jammy_flows_package
1110
from .model import Model
1211
from .standard_model import StandardModel
1312
from .standard_averaged_model import StandardAveragedModel
13+
14+
if has_jammy_flows_package():
15+
from .normalizing_flow import NormalizingFlow

src/graphnet/models/easy_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pytorch_lightning.loggers import Logger as LightningLogger
1717

1818
from graphnet.training.callbacks import ProgressBar
19-
from graphnet.models.graphs import GraphDefinition
2019
from graphnet.models.model import Model
2120
from graphnet.models.task import StandardLearnedTask
2221

@@ -292,6 +291,7 @@ def predict(
292291
dataloader: DataLoader,
293292
gpus: Optional[Union[List[int], int]] = None,
294293
distribution_strategy: Optional[str] = "auto",
294+
**trainer_kwargs: Any,
295295
) -> List[Tensor]:
296296
"""Return predictions for `dataloader`."""
297297
self.inference()
@@ -305,6 +305,7 @@ def predict(
305305
gpus=gpus,
306306
distribution_strategy=distribution_strategy,
307307
callbacks=callbacks,
308+
**trainer_kwargs,
308309
)
309310

310311
predictions_list = inference_trainer.predict(self, dataloader)
@@ -325,6 +326,7 @@ def predict_as_dataframe(
325326
additional_attributes: Optional[List[str]] = None,
326327
gpus: Optional[Union[List[int], int]] = None,
327328
distribution_strategy: Optional[str] = "auto",
329+
**trainer_kwargs: Any,
328330
) -> pd.DataFrame:
329331
"""Return predictions for `dataloader` as a DataFrame.
330332
@@ -357,6 +359,7 @@ def predict_as_dataframe(
357359
dataloader=dataloader,
358360
gpus=gpus,
359361
distribution_strategy=distribution_strategy,
362+
**trainer_kwargs,
360363
)
361364
predictions = (
362365
torch.cat(predictions_torch, dim=1).detach().cpu().numpy()

0 commit comments

Comments
 (0)