-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
27 lines (19 loc) · 820 Bytes
/
main.py
File metadata and controls
27 lines (19 loc) · 820 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import argparse
from datetime import datetime
from kde_numpy.kde import get_KDE
import numpy as np
from mnist_data_processing import load_processed_datasets
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--sigma', type=float, default=0.2)
parser.add_argument('--kernel-type', type=str, default='gaussian')
args = parser.parse_args()
if __name__ == '__main__':
train_X, test_X = load_processed_datasets()
print(f'KDE on {np.shape(test_X)} test data')
start = datetime.now()
kde = get_KDE(kernel_type=args.kernel_type, batch_size=args.batch_size, bandwidth=args.sigma ** 2)
kde.fit(train_X)
mlp = kde.mean_log_prob(test_X)
print(f'Mean log prob: {mlp}')
print(f'Custom KDE time elapsed: {datetime.now() - start}')