@@ -86,4 +86,57 @@ def get_states(self):
8686 def set_states (self , states ):
8787 self .hx .copy_from (states [self .hx .name ])
8888 self .hx .copy_from (states [self .hx .name ])
89- super ().set_states (states )
89+ super ().set_states (states )
90+
91+ class Data (object ):
92+
93+ def __init__ (self , fpath , batch_size = 32 , seq_length = 100 , train_ratio = 0.8 ):
94+ '''Data object for loading a plain text file.
95+
96+ Args:
97+ fpath, path to the text file.
98+ train_ratio, split the text file into train and test sets, where
99+ train_ratio of the characters are in the train set.
100+ '''
101+ self .raw_data = open (fpath , 'r' ,
102+ encoding = 'iso-8859-1' ).read () # read text file
103+ chars = list (set (self .raw_data ))
104+ self .vocab_size = len (chars )
105+ self .char_to_idx = {ch : i for i , ch in enumerate (chars )}
106+ self .idx_to_char = {i : ch for i , ch in enumerate (chars )}
107+ data = [self .char_to_idx [c ] for c in self .raw_data ]
108+ # seq_length + 1 for the data + label
109+ nsamples = len (data ) // (1 + seq_length )
110+ data = data [0 :nsamples * (1 + seq_length )]
111+ data = np .asarray (data , dtype = np .int32 )
112+ data = np .reshape (data , (- 1 , seq_length + 1 ))
113+ # shuffle all sequences
114+ np .random .shuffle (data )
115+ self .train_dat = data [0 :int (data .shape [0 ] * train_ratio )]
116+ self .num_train_batch = self .train_dat .shape [0 ] // batch_size
117+ self .val_dat = data [self .train_dat .shape [0 ]:]
118+ self .num_test_batch = self .val_dat .shape [0 ] // batch_size
119+ print ('train dat' , self .train_dat .shape )
120+ print ('val dat' , self .val_dat .shape )
121+
122+
123+ def numpy2tensors (npx , npy , dev , inputs = None , labels = None ):
124+ '''batch, seq, dim -- > seq, batch, dim'''
125+ tmpy = np .swapaxes (npy , 0 , 1 ).reshape ((- 1 , 1 ))
126+ if labels :
127+ labels .copy_from_numpy (tmpy )
128+ else :
129+ labels = tensor .from_numpy (tmpy )
130+ labels .to_device (dev )
131+ tmpx = np .swapaxes (npx , 0 , 1 )
132+ inputs_ = []
133+ for t in range (tmpx .shape [0 ]):
134+ if inputs :
135+ inputs [t ].copy_from_numpy (tmpx [t ])
136+ else :
137+ x = tensor .from_numpy (tmpx [t ])
138+ x .to_device (dev )
139+ inputs_ .append (x )
140+ if not inputs :
141+ inputs = inputs_
142+ return inputs , labels
0 commit comments