-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsaving.py
More file actions
24 lines (19 loc) · 740 Bytes
/
saving.py
File metadata and controls
24 lines (19 loc) · 740 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import json
import neural_network as nn
def save(net: nn):
weigh_list = [layer.W.tolist() for layer in net.layers]
bias_list = [layer.bias.tolist() for layer in net.layers]
# Save data to json
data = {'weigh': weigh_list, 'bias': bias_list}
with open('network_data.json', 'w') as json_file:
json.dump(data, json_file)
def load_to(net: nn):
with open('network_data.json', 'r') as json_file:
loaded_data = json.load(json_file)
# Have equal size
weigh = [np.array(matrix) for matrix in loaded_data['weigh']]
bias = [np.array(arr) for arr in loaded_data['bias']]
for i in range(len(weigh)):
net.layers[i].W = weigh[i]
net.layers[i].bias = bias[i]