Skip to content

Commit 02cd50e

Browse files
authored
Merge pull request #1346 from gzrp/dev-postgresql
Update the Data class for the peft example
2 parents b8e867b + 9b7d370 commit 02cd50e

1 file changed

Lines changed: 54 additions & 1 deletion

File tree

examples/singa_peft/examples/model/char_rnn.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,57 @@ def get_states(self):
8686
def set_states(self, states):
8787
self.hx.copy_from(states[self.hx.name])
8888
self.hx.copy_from(states[self.hx.name])
89-
super().set_states(states)
89+
super().set_states(states)
90+
91+
class Data(object):
92+
93+
def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8):
94+
'''Data object for loading a plain text file.
95+
96+
Args:
97+
fpath, path to the text file.
98+
train_ratio, split the text file into train and test sets, where
99+
train_ratio of the characters are in the train set.
100+
'''
101+
self.raw_data = open(fpath, 'r',
102+
encoding='iso-8859-1').read() # read text file
103+
chars = list(set(self.raw_data))
104+
self.vocab_size = len(chars)
105+
self.char_to_idx = {ch: i for i, ch in enumerate(chars)}
106+
self.idx_to_char = {i: ch for i, ch in enumerate(chars)}
107+
data = [self.char_to_idx[c] for c in self.raw_data]
108+
# seq_length + 1 for the data + label
109+
nsamples = len(data) // (1 + seq_length)
110+
data = data[0:nsamples * (1 + seq_length)]
111+
data = np.asarray(data, dtype=np.int32)
112+
data = np.reshape(data, (-1, seq_length + 1))
113+
# shuffle all sequences
114+
np.random.shuffle(data)
115+
self.train_dat = data[0:int(data.shape[0] * train_ratio)]
116+
self.num_train_batch = self.train_dat.shape[0] // batch_size
117+
self.val_dat = data[self.train_dat.shape[0]:]
118+
self.num_test_batch = self.val_dat.shape[0] // batch_size
119+
print('train dat', self.train_dat.shape)
120+
print('val dat', self.val_dat.shape)
121+
122+
123+
def numpy2tensors(npx, npy, dev, inputs=None, labels=None):
124+
'''batch, seq, dim -- > seq, batch, dim'''
125+
tmpy = np.swapaxes(npy, 0, 1).reshape((-1, 1))
126+
if labels:
127+
labels.copy_from_numpy(tmpy)
128+
else:
129+
labels = tensor.from_numpy(tmpy)
130+
labels.to_device(dev)
131+
tmpx = np.swapaxes(npx, 0, 1)
132+
inputs_ = []
133+
for t in range(tmpx.shape[0]):
134+
if inputs:
135+
inputs[t].copy_from_numpy(tmpx[t])
136+
else:
137+
x = tensor.from_numpy(tmpx[t])
138+
x.to_device(dev)
139+
inputs_.append(x)
140+
if not inputs:
141+
inputs = inputs_
142+
return inputs, labels

0 commit comments

Comments
 (0)