-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
76 lines (60 loc) · 2.85 KB
/
dataset.py
File metadata and controls
76 lines (60 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
from scipy.sparse import load_npz
import scipy.sparse as sp
from torch.utils.data import DataLoader, Dataset
class CSRDataset(Dataset):
def __init__(self, x_mat, y_mat, z_mat=None, is_train=False):
self.x_mat = x_mat
self.y_mat = y_mat
self.z_mat = z_mat
self.shuffle = is_train
assert self.x_mat.shape[0] == self.y_mat.shape[0]
if self.z_mat is not None:
assert self.x_mat.shape[0] == self.z_mat.shape[0]
def __len__(self):
return self.x_mat.shape[0]
def __getitem__(self, idx):
outputs = {
'idx': idx,
'rating_input': np.array(self.x_mat[idx].todense())[0],
'label': np.array(self.y_mat[idx].todense())[0]
}
if self.z_mat is not None:
outputs['text_input'] = np.array(self.z_mat[idx].todense())[0]
return outputs
def get_npz_file(file_path):
matrix = load_npz(file=file_path)
return matrix
def create_data_loader(x_mat, y_mat, z_mat, is_train, batch_size, num_workers):
loader = DataLoader(CSRDataset(x_mat, y_mat, z_mat, is_train), batch_size=batch_size,
shuffle=is_train, num_workers=num_workers, drop_last=False,
pin_memory=True)
return loader
def get_data_loader(config_dict, num_workers, shuffle=False):
data_stats = {}
dataset = config_dict['dataset']
data_dir = f'./data/{dataset}'
train_x = load_npz(f'{data_dir}/train_X.npz')
train_y = load_npz(f'{data_dir}/train_Y.npz')
train_z = load_npz(f'{data_dir}/user_train_X.npz')
val_z = load_npz(f'{data_dir}/user_val_X.npz')
test_z = load_npz(f'{data_dir}/user_test_X.npz')
data_stats['num_items'] = train_x.shape[1]
data_stats['num_words'] = train_z.shape[1]
train_loader = create_data_loader(train_x, train_y, z_mat=train_z, is_train=shuffle,
batch_size=config_dict['batch_size'],
num_workers=num_workers)
val_x = load_npz(f'{data_dir}/val_X.npz')
val_y = load_npz(f'{data_dir}/val_Y.npz')
val_loader = create_data_loader(val_x, val_y, z_mat=val_z, is_train=False,
batch_size=config_dict['batch_size'],
num_workers=num_workers)
test_x = load_npz(f'{data_dir}/test_X.npz')
test_y = load_npz(f'{data_dir}/test_Y.npz')
test_loader = create_data_loader(test_x, test_y, z_mat=test_z, is_train=False,
batch_size=config_dict['batch_size'],
num_workers=num_workers)
return (train_loader, val_loader, test_loader), data_stats
def get_data(config_dict, num_workers, train_shuffle=False):
data_loader, data_stats = get_data_loader(config_dict, num_workers, shuffle=train_shuffle)
return data_loader, data_stats