-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathGNN_Main.py
More file actions
executable file
·112 lines (91 loc) · 3.5 KB
/
GNN_Main.py
File metadata and controls
executable file
·112 lines (91 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import time
from shutil import copyfile
import argparse
import networkx as nx
import os
import scipy.io
import umap
import torch
import torch.nn as nn
import torch_geometric.data as data
from sklearn import metrics
from tifffile import imread
from torch_geometric.loader import DataLoader
from torch_geometric.utils.convert import to_networkx
from scipy.optimize import curve_fit
from scipy.spatial import Delaunay
from torchvision.transforms import GaussianBlur
import matplotlib
matplotlib.use("Agg") # Non-interactive backend
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib.ticker import FuncFormatter
from prettytable import PrettyTable
from ParticleGraph.config import ParticleGraphConfig
from ParticleGraph.generators.graph_data_generator import *
from ParticleGraph.models.graph_trainer import *
from ParticleGraph.models.Siren_Network import *
from ParticleGraph.models.Ghost_Particles import Ghost_Particles
from ParticleGraph.models.utils import *
import warnings
warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API")
if __name__ == "__main__":
warnings.filterwarnings("ignore", category=FutureWarning)
parser = argparse.ArgumentParser(description="ParticleGraph")
parser.add_argument("-o", "--option", nargs="+", help="Option that takes multiple values")
args = parser.parse_args()
if args.option:
print(f"Options: {args.option}")
if args.option != None:
task = args.option[0]
config_list = [args.option[1]]
if len(args.option) > 2:
best_model = args.option[2]
else:
best_model = None
else:
best_model = None
task = 'generate' # 'generate', 'train', 'test'
config_list = ['diffusiophoresis_test']
for config_file_ in config_list:
print(" ")
config_root = os.path.dirname(os.path.abspath(__file__)) + "/config"
config_file, pre_folder = add_pre_folder(config_file_)
config = ParticleGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
config.dataset = pre_folder + config.dataset
config.config_file = pre_folder + config_file_
device = set_device(config.training.device)
print(f"config_file {config.config_file}")
print(f"\033[92mdevice {device}\033[0m")
if "generate" in task:
data_generate(
config,
device=device,
visualize=True,
run_vizualized=0,
style="color",
alpha=1,
erase=True,
bSave=True,
step=100,
timer=False
)
# data and figures are generated into "ParticleGraph/graphs_data/signal/signal_N2_a37"
if "train" in task:
data_train(config=config, erase=False, best_model=best_model, device=device)
# temporary results are saved into "ParticleGraph/log/signal/signal_N2_a37_1/tmp_training"
if "test" in task:
data_test(
config=config,
visualize=True,
style="black color name",
verbose=False,
best_model='best',
run=0,
test_mode="",
sample_embedding=False,
step=5,
device=device,
particle_of_interest=0,
)
# rollout inference results are save into "ParticleGraph/log/signal/signal_N2_a37_1/tmp_recons"