Fastai Course DL from the Foundations Transfer Learning
Transfer Learning, use a network pretrained on Imagenette on Imagewoof (Lesson 5 Part 4)
- Serializing the model
- Pets
- Custom head
- adapt_model and gradual unfreezing
- Batch norm transfer
- Discriminative LR and param groups
#collapse
%load_ext autoreload
%autoreload 2
%matplotlib inline
#collapse
from exp.nb_11 import *
#collapse
path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)
#collapse
size = 128
bs = 64
tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]
val_tfms = [make_rgb, CenterCrop(size), np_to_float]
il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
ll.valid.x.tfms = val_tfms
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)
#collapse
len(il)
#collapse_show
loss_func = LabelSmoothingCrossEntropy()
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)
#collapse_show
def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
phases = create_phases(pct_start)
sched_lr = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
return [ParamScheduler('lr', sched_lr),
ParamScheduler('mom', sched_mom)]
#collapse_show
lr = 3e-3
pct_start = 0.5
cbsched = sched_1cycle(lr, pct_start)
#collapse_show
learn.fit(40, cbsched)
#collapse_show
st = learn.model.state_dict()
#collapse_show
type(st)
#collapse_show
', '.join(st.keys())
#collapse_show
st['10.bias']
#collapse_show
mdl_path = path/'models'
mdl_path.mkdir(exist_ok=True)
It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly.
#collapse_show
torch.save(st, mdl_path/'iw5')
#collapse_show
pets = datasets.untar_data(datasets.URLs.PETS)
#collapse_show
pets.ls()
#collapse_show
pets_path = pets/'images'
#collapse_show
il = ImageList.from_files(pets_path, tfms=tfms)
#collapse_show
il
#collapse_show
def random_splitter(fn, p_valid): return random.random() < p_valid
#collapse_show
random.seed(42)
#collapse_show
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))
#collapse_show
sd
#collapse_show
n = il.items[0].name; n
#collapse_show
re.findall(r'^(.*)_\d+.jpg$', n)[0]
#collapse_show
def pet_labeler(fn): return re.findall(r'^(.*)_\d+.jpg$', fn.name)[0]
#collapse_show
proc = CategoryProcessor()
#collapse_show
ll = label_by_func(sd, pet_labeler, proc_y=proc)
#collapse_show
', '.join(proc.vocab)
#collapse_show
ll.valid.x.tfms = val_tfms
#collapse_show
c_out = len(proc.vocab)
#collapse_show
data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)
#collapse_show
learn.fit(5, cbsched)
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
#collapse_show
st = torch.load(mdl_path/'iw5')
#collapse_show
m = learn.model
#collapse_show
m.load_state_dict(st)
#collapse_show
cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = m[:cut]
#collapse_show
xb,yb = get_batch(data.valid_dl, learn)
#collapse_show
pred = m_cut(xb)
#collapse_show
pred.shape
#collapse_show
ni = pred.shape[1]
#collapse_show
class AdaptiveConcatPool2d(nn.Module):
def __init__(self, sz=1):
super().__init__()
self.output_size = sz
self.ap = nn.AdaptiveAvgPool2d(sz)
self.mp = nn.AdaptiveMaxPool2d(sz)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
#collapse_show
nh = 40
m_new = nn.Sequential(
m_cut, AdaptiveConcatPool2d(), Flatten(),
nn.Linear(ni*2, data.c_out))
#collapse_show
learn.model = m_new
#collapse_show
learn.fit(5, cbsched)
#collapse_show
def adapt_model(learn, data):
cut = next(i for i,o in enumerate(learn.model.children())
if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = learn.model[:cut]
xb,yb = get_batch(data.valid_dl, learn)
pred = m_cut(xb)
ni = pred.shape[1]
m_new = nn.Sequential(
m_cut, AdaptiveConcatPool2d(), Flatten(),
nn.Linear(ni*2, data.c_out))
learn.model = m_new
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
#collapse_show
adapt_model(learn, data)
#collapse_show
for p in learn.model[0].parameters(): p.requires_grad_(False)
#collapse_show
learn.fit(3, sched_1cycle(1e-2, 0.5))
#collapse_show
for p in learn.model[0].parameters(): p.requires_grad_(True)
#collapse_show
learn.fit(5, cbsched, reset_opt=True)
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
#collapse_show
def apply_mod(m, f):
f(m)
for l in m.children(): apply_mod(l, f)
def set_grad(m, b):
if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return
if hasattr(m, 'weight'):
for p in m.parameters(): p.requires_grad_(b)
#collapse_show
apply_mod(learn.model, partial(set_grad, b=False))
#collapse_show
learn.fit(3, sched_1cycle(1e-2, 0.5))
#collapse_show
apply_mod(learn.model, partial(set_grad, b=True))
#collapse_show
learn.fit(5, cbsched, reset_opt=True)
Pytorch already has an apply
method we can use:
#collapse_show
learn.model.apply(partial(set_grad, b=False));
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
#collapse_show
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
#collapse_show
def bn_splitter(m):
def _bn_splitter(l, g1, g2):
if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()
elif hasattr(l, 'weight'): g1 += l.parameters()
for ll in l.children(): _bn_splitter(ll, g1, g2)
g1,g2 = [],[]
_bn_splitter(m[0], g1, g2)
g2 += m[1:].parameters()
return g1,g2
#collapse_show
a,b = bn_splitter(learn.model)
#collapse_show
test_eq(len(a)+len(b), len(list(m.parameters())))
#collapse_show
Learner.ALL_CBS
#collapse_show
from types import SimpleNamespace
cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})
#collapse_show
cb_types.after_backward
#collapse_show
class DebugCallback(Callback):
_order = 999
def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f
def __call__(self, cb_name):
if cb_name==self.cb_name:
if self.f: self.f(self.run)
else: set_trace()
#collapse_show
def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
phases = create_phases(pct_start)
sched_lr = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
for lr in lrs]
sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
return [ParamScheduler('lr', sched_lr),
ParamScheduler('mom', sched_mom)]
#collapse_show
disc_lr_sched = sched_1cycle([0,3e-2], 0.5)
#collapse_show
learn = cnn_learner(xresnet18, data, loss_func, opt_func,
c_out=10, norm=norm_imagenette, splitter=bn_splitter)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
#collapse_show
def _print_det(o):
print (len(o.opt.param_groups), o.opt.hypers)
raise CancelTrainException()
learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])
#collapse_show
learn.fit(3, disc_lr_sched)
#collapse_show
disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)
#collapse_show
learn.fit(5, disc_lr_sched)