Skip to content

Commit 14c9810

Browse files
authored
Merge pull request #1377 from calmdown539/dev-postgresql
Add the implementation of the utils functions for the generative model
2 parents 2a850a6 + 398ef47 commit 14c9810

1 file changed

Lines changed: 67 additions & 0 deletions

File tree

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
20+
import gzip
21+
import matplotlib.pyplot as plt
22+
import numpy as np
23+
import os
24+
import pickle
25+
import sys
26+
import time
27+
28+
try:
29+
import urllib.request as ul_request
30+
except ImportError:
31+
import urllib as ul_request
32+
33+
def print_log(s):
34+
t = time.ctime()
35+
print('[{}]{}'.format(t, s))
36+
37+
def load_data(filepath):
38+
with gzip.open(filepath, 'rb') as f:
39+
train_set, valid_set, test_set = pickle.load(f, encoding='bytes')
40+
traindata = train_set[0].astype(np.float32)
41+
validdata = valid_set[0].astype(np.float32)
42+
testdata = test_set[0].astype(np.float32)
43+
trainlabel = train_set[1].astype(np.float32)
44+
validlabel = valid_set[1].astype(np.float32)
45+
testlabel = test_set[1].astype(np.float32)
46+
return traindata, trainlabel, validdata, validlabel, testdata, testlabel
47+
48+
def download_data(gzfile, url):
49+
if os.path.exists(gzfile):
50+
print('Downloaded already!')
51+
sys.exit(0)
52+
print('Downloading data %s' % (url))
53+
ul_request.urlretrieve(url, gzfile)
54+
print('Finished!')
55+
56+
def show_images(filepath):
57+
with open(filepath, 'rb') as f:
58+
imgs = pickle.load(f)
59+
r, c = 5, 5
60+
fig, axs = plt.subplots(5, 5)
61+
cnt = 0
62+
for i in range(r):
63+
for j in range(c):
64+
axs[i, j].imshow(imgs[cnt, :, :, 0], cmap='gray')
65+
axs[i, j].axis('off')
66+
cnt += 1
67+
plt.show()

0 commit comments

Comments
 (0)