-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathevaluate_model.py
More file actions
65 lines (53 loc) · 1.74 KB
/
evaluate_model.py
File metadata and controls
65 lines (53 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys
from pathlib import Path
from modules.lab_utils import load_dataset
def evaluate_pytorch_model(dataset_path, model_path):
from modules.pytorch_mlp import PytorchMLPReg
mlp = PytorchMLPReg(model_file=model_path)
_, _, x_test, y_test = load_dataset(dataset_path)
score = mlp.score(x_test, y_test)
print(f"Score on dataset: {score:.4e}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Evaluate a trained model on a dataset"
)
parser.add_argument(
"--model-type",
type=str,
choices=["pytorch", "calibrated"],
default="pytorch",
help="Model type: pytorch or calibrated",
)
parser.add_argument(
"--dataset-path",
type=Path,
required=False,
default=Path("data/results/blueleg_beam_sphere.csv"),
help="Path to dataset CSV",
)
parser.add_argument(
"--model-path",
type=Path,
required=False,
default=Path("data/results/blueleg_beam_sphere.pth"),
help="Path to trained model file",
)
args = parser.parse_args()
model_type = args.model_type
dataset_path = args.dataset_path
model_path = args.model_path
if not dataset_path.exists():
print(f"Dataset file not found: {dataset_path}")
sys.exit(1)
if not model_path.exists():
print(f"Model file not found: {model_path}")
sys.exit(1)
if model_type == "pytorch":
evaluate_pytorch_model(str(dataset_path), str(model_path))
elif model_type == "calibrated":
print("Calibrated model evaluation is not implemented yet.")
sys.exit(2)
else:
print(f"Unknown model type: {model_type}")
sys.exit(1)