-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrieve.py
More file actions
47 lines (38 loc) · 1.48 KB
/
retrieve.py
File metadata and controls
47 lines (38 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import os
import wandb
from pathlib import Path
from retrieval_utils import run_retrieval
os.environ["WANDB_SILENT"] = "true"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-name", type=str, required=True)
parser.add_argument("--wandb-user", required=True, type=str)
parser.add_argument("--wandb-project", required=True, type=str)
parser.add_argument("--data-root", type=str, default="data")
parser.add_argument("--query-arch", type=str, default="mlp", choices=["mlp", "triplane", "hash"])
parser.add_argument("--gallery-arch", type=str, default="triplane", choices=["mlp", "triplane", "hash"])
args = parser.parse_args()
assert args.query_arch != args.gallery_arch
data_root = Path(args.data_root)
emb_root = data_root / "emb" / args.ckpt_name / "shapenet"
query_nerf_root = data_root / "nerf" / "shapenet" / args.query_arch
gallery_nerf_root = data_root / "nerf" / "shapenet" / args.gallery_arch
run_name = f"{args.ckpt_name}_query_{args.query_arch}_gallery_{args.gallery_arch}"
save_root = Path("retrieval") / run_name
run_name = f"retrieval_{run_name}"
split = "test"
wandb.init(
name=run_name,
entity=args.wandb_user,
project=args.wandb_project
)
run_retrieval(
emb_root,
query_nerf_root,
gallery_nerf_root,
save_root,
split,
args.query_arch,
args.gallery_arch
)