Fastai Course DL from the Foundations LSTM Training
Fastai training the network on Wikitext Dataset (Lesson 5 Part 7)
Fastai Pretraining on Wikitext 103
- This Post is based on the Notebok by the Fastai Course Part2
#collapse
%load_ext autoreload
%autoreload 2
%matplotlib inline
#collapse
from exp.nb_12a import *
One time download
#path = datasets.Config().data_path()
#version = '103' #2
#! wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-{version}-v1.zip -P {path}
#! unzip -q -n {path}/wikitext-{version}-v1.zip -d {path}
#! mv {path}/wikitext-{version}/wiki.train.tokens {path}/wikitext-{version}/train.txt
#! mv {path}/wikitext-{version}/wiki.valid.tokens {path}/wikitext-{version}/valid.txt
#! mv {path}/wikitext-{version}/wiki.test.tokens {path}/wikitext-{version}/test.txt
Split the articles: WT103 is given as one big text file and we need to chunk it in different articles if we want to be able to shuffle them at the beginning of each epoch.
#collapse
path = datasets.Config().data_path()/'wikitext-103'
#collapse_show
def istitle(line):
return len(re.findall(r'^ = [^=]* = $', line)) != 0
#collapse_show
def read_wiki(filename):
articles = []
with open(filename, encoding='utf8') as f:
lines = f.readlines()
current_article = ''
for i,line in enumerate(lines):
current_article += line
if i < len(lines)-2 and lines[i+1] == ' \n' and istitle(lines[i+2]):
current_article = current_article.replace('<unk>', UNK)
articles.append(current_article)
current_article = ''
current_article = current_article.replace('<unk>', UNK)
articles.append(current_article)
return articles
#collapse_show
train = TextList(read_wiki(path/'train.txt'), path=path) #+read_file(path/'test.txt')
valid = TextList(read_wiki(path/'valid.txt'), path=path)
#collapse_show
len(train), len(valid)
#collapse_show
sd = SplitData(train, valid)
#collapse_show
proc_tok,proc_num = TokenizeProcessor(),NumericalizeProcessor()
#collapse_show
ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok,proc_num])
#collapse_show
pickle.dump(ll, open(path/'ld.pkl', 'wb'))
#collapse_show
ll = pickle.load( open(path/'ld.pkl', 'rb'))
#collapse_show
bs,bptt = 128,70
data = lm_databunchify(ll, bs, bptt)
#collapse_show
vocab = ll.train.proc_x[-1].vocab
len(vocab)
#collapse_show
dps = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2
tok_pad = vocab.index(PAD)
#collapse_show
emb_sz, nh, nl = 300, 300, 2
model = get_language_model(len(vocab), emb_sz, nh, nl, tok_pad, *dps)
#collapse_show
cbs = [partial(AvgStatsCallback,accuracy_flat),
CudaCallback, Recorder,
partial(GradientClipping, clip=0.1),
partial(RNNTrainer, α=2., β=1.),
ProgressCallback]
#collapse_show
learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt())
#collapse_show
lr = 5e-3
sched_lr = combine_scheds([0.3,0.7], cos_1cycle_anneal(lr/10., lr, lr/1e5))
sched_mom = combine_scheds([0.3,0.7], cos_1cycle_anneal(0.8, 0.7, 0.8))
cbsched = [ParamScheduler('lr', sched_lr), ParamScheduler('mom', sched_mom)]
#collapse_show
learn.fit(10, cbs=cbsched)
#collapse_show
torch.save(learn.model.state_dict(), path/'pretrained.pth')
pickle.dump(vocab, open(path/'vocab.pkl', 'wb'))