Skip to content

Commit 3b02249

Browse files
committed
add more params to helper
1 parent ade4f0e commit 3b02249

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

notebooks/lang_model_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,12 @@ def train_lang_model(model_path: int,
122122
trn_indexed: List[int],
123123
val_indexed: List[int],
124124
vocab_size: int,
125+
lr: float,
125126
n_cycle: int = 2,
126-
em_sz: int = 1200,
127-
nh: int = 1200,
127+
cycle_len: int =3,
128+
cycle_mult : int =1,
129+
em_sz: int = 400,
130+
nh: int = 400,
128131
nl: int = 3,
129132
bptt: int = 20,
130133
wd: int = 1e-7,
@@ -185,11 +188,14 @@ def train_lang_model(model_path: int,
185188
dropoute=drops[3],
186189
dropouth=drops[4])
187190

188-
# learning rate is hardcoded, I already ran learning rate finder on this problem.
189-
lrs = 1e-3 / 2
190-
191191
# borrowed these parameters from fastai
192-
learner.fit(lrs, 2, wds=wd, cycle_len=3, use_clr=(32, 10), best_save_name='langmodel_best')
192+
learner.fit(lr,
193+
n_cycle=n_cycle,
194+
wds=wd,
195+
cycle_len=cycle_len,
196+
use_clr=(32, 10),
197+
cycle_mult=cycle_mult,
198+
best_save_name='langmodel_best')
193199

194200
# eval sets model to inference mode (turns off dropout, etc.)
195201
model = learner.model.eval()
@@ -270,10 +276,14 @@ def get_emb_batch(lang_model, np_array, bs, dest_dir):
270276
y_mean = get_mean_emb(raw_emb=y, idx_arr=x.data.cpu().numpy())
271277
# get the last hidden state in the sequence. Returns arr of size (bs, encoder_dim)
272278
y_last = y[:, -1, :]
279+
# get the maximum across timesteps
280+
y_max = y.max(1)
273281

274282
# collect predictions
275283
np.save(destPath/f'lang_model_mean_emb_{i}.npy', y_mean)
276284
np.save(destPath/f'lang_model_last_emb_{i}.npy', y_last)
285+
np.save(destPath/f'lang_model_last_emb_{i}.npy', y_max)
286+
np.save(destPath/f'lang_model_pool_emb_{i}.npy', np.concatenate([y_mean, y_max, y_last], axis=1))
277287

278288
logging.warning(f'Saved {2*len(data_chunked)} files to {str(destPath.absolute())}')
279289

0 commit comments

Comments
 (0)