-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
120 lines (103 loc) · 4.17 KB
/
train.py
File metadata and controls
120 lines (103 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
from pathlib import Path
from torch import nn
from gmn.graph_construct.hash_grid import MultiResHashGrid
from gmn.graph_construct.layers import TriplanarGridWithInputEncoding
from nf2vec import config as nerf_cfg
from trainers.l_rec import RecontructionTrainer
from trainers.l_rec_con import RecontructionContrastiveTrainer
from trainers.l_con import ContrastiveTrainer
os.environ["WANDB_SILENT"] = "true"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--loss", required=True, type=str, choices=["l_rec", "l_rec_con", "l_con"])
parser.add_argument("--wandb-user", required=True, type=str)
parser.add_argument("--wandb-project", required=True, type=str)
parser.add_argument("--wandb-run-name", type=str, default=None)
parser.add_argument("--data-root", type=str, default="data")
parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--lr", type=int, default=1e-4)
parser.add_argument("--weight-decay", type=int, default=1e-2)
parser.add_argument("--gnn-hidden-dim", type=int, default=128)
parser.add_argument("--num-gnn-layers", type=int, default=4)
args = parser.parse_args()
data_root = Path(args.data_root)
mlp_nerf_root = data_root / "nerf" / "shapenet" / "mlp"
mlp_graph_root = str(data_root / "graph" / "shapenet" / "mlp")
enc_dim = nerf_cfg.MLP_INPUT_SIZE_AFTER_ENCODING
hid_dim = nerf_cfg.MLP_UNITS
out_dim = nerf_cfg.MLP_PADDED_OUTPUT_SIZE
mlp_nerf = nn.Sequential(
nn.Linear(enc_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, out_dim, bias=False)
)
triplane_nerf_root = data_root / "nerf" / "shapenet" / "triplane"
triplane_graph_root = str(data_root / "graph" / "shapenet" / "triplane")
res = nerf_cfg.TRIPLANE_RES
feat_dim = nerf_cfg.TRIPLANE_FEAT_SIZE
enc_dim = nerf_cfg.TRIPLANE_IN_SIZE_AFTER_ENC
hid_dim = nerf_cfg.TRIPLANE_HID_UNITS
out_dim = nerf_cfg.TRIPLANE_PADDED_OUT_SIZE
triplane_nerf = nn.Sequential(
TriplanarGridWithInputEncoding(res, feat_dim, enc_dim),
nn.Linear(enc_dim + feat_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, out_dim, bias=False)
)
hash_nerf_root = data_root / "nerf" / "shapenet" / "hash"
hash_graph_root = str(data_root / "graph" / "shapenet" / "hash")
dim = nerf_cfg.HASH_IN_SIZE
n_levels = nerf_cfg.HASH_LEVELS
n_features_per_level = nerf_cfg.HASH_FEATURES_PER_ENTRY
log2_hashmap_size = nerf_cfg.HASH_LOG2_TAB_SIZE
base_resolution = nerf_cfg.HASH_MIN_RES
finest_resolution = nerf_cfg.HASH_MAX_RES
pad_in_dim = nerf_cfg.HASH_PADDED_IN_SIZE
hid_dim = nerf_cfg.HASH_HID_UNITS
pad_out_dim = nerf_cfg.HASH_PADDED_OUT_SIZE
hash_nerf = nn.Sequential(
MultiResHashGrid(
dim,
n_levels,
n_features_per_level,
log2_hashmap_size,
base_resolution,
finest_resolution
),
nn.Linear(pad_in_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, hid_dim, bias=False), nn.ReLU(),
nn.Linear(hid_dim, pad_out_dim, bias=False)
)
if args.loss == "l_rec":
Trainer = RecontructionTrainer
elif args.loss == "l_rec_con":
Trainer = RecontructionContrastiveTrainer
else:
Trainer = ContrastiveTrainer
trainer = Trainer(
mlp_nerf_root,
mlp_graph_root,
triplane_nerf_root,
triplane_graph_root,
hash_nerf_root,
hash_graph_root,
mlp_nerf,
triplane_nerf,
hash_nerf,
args.num_epochs,
args.batch_size,
args.lr,
args.weight_decay,
args.gnn_hidden_dim,
args.num_gnn_layers,
args.wandb_user,
args.wandb_project,
args.wandb_run_name,
)
trainer.train()