#collapse
%load_ext autoreload
%autoreload 2

%matplotlib inline

#collapse

from exp.nb_11 import *

Serializing the model

#collapse

path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)

#collapse

size = 128
bs = 64

tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]
val_tfms = [make_rgb, CenterCrop(size), np_to_float]
il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
ll.valid.x.tfms = val_tfms
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)

#collapse

len(il)
12954

#collapse_show
loss_func = LabelSmoothingCrossEntropy()
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)

#collapse_show

def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
    phases = create_phases(pct_start)
    sched_lr  = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
    return [ParamScheduler('lr', sched_lr),
            ParamScheduler('mom', sched_mom)]

#collapse_show

lr = 3e-3
pct_start = 0.5
cbsched = sched_1cycle(lr, pct_start)

#collapse_show

learn.fit(40, cbsched)
<progress value='0' class='' max='40', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 2.130161 0.248089 2.085567 0.276152 00:15
1 1.993560 0.317562 2.078038 0.288878 00:13
2 1.907229 0.362105 1.908168 0.367269 00:13
3 1.827499 0.405319 1.952905 0.351234 00:13
4 1.774419 0.423823 1.992134 0.336472 00:13
5 1.722007 0.449197 1.837551 0.401883 00:13
6 1.686701 0.470803 1.984103 0.358361 00:13
7 1.677165 0.473684 2.270435 0.356834 00:13
8 1.638770 0.499280 2.088254 0.376686 00:13
9 1.610834 0.505042 2.130411 0.363451 00:13
10 1.567782 0.529086 1.999561 0.408755 00:13
11 1.524520 0.548255 3.583080 0.271061 00:13
12 1.484032 0.571080 1.894512 0.450496 00:13
13 1.446350 0.597341 1.581483 0.526852 00:14
14 1.407801 0.603546 1.724309 0.484856 00:13
15 1.367980 0.626814 2.098016 0.422499 00:13
16 1.339640 0.637784 1.608337 0.543904 00:13
17 1.307783 0.650416 1.544712 0.570629 00:13
18 1.273955 0.669252 1.716685 0.533469 00:13
19 1.243974 0.685429 1.656472 0.565538 00:13
20 1.215433 0.698393 1.512875 0.574701 00:13
21 1.175064 0.715125 1.317181 0.658183 00:13
22 1.149006 0.728310 1.489734 0.611606 00:13
23 1.107240 0.743712 1.476818 0.602952 00:13
24 1.080577 0.756787 1.456552 0.625095 00:13
25 1.058459 0.768089 1.338278 0.660473 00:13
26 1.028792 0.782271 1.194180 0.711886 00:13
27 0.997885 0.795789 1.130299 0.744464 00:13
28 0.960315 0.816067 1.161608 0.731738 00:13
29 0.935516 0.829363 1.069859 0.773479 00:13
30 0.906349 0.841219 1.066289 0.776533 00:13
31 0.878236 0.849307 1.067750 0.777552 00:13
32 0.850478 0.868255 1.027464 0.790277 00:12
33 0.828395 0.877895 1.023876 0.789005 00:12
34 0.800545 0.892078 1.019160 0.796386 00:12
35 0.790883 0.899280 0.999217 0.803512 00:12
36 0.774763 0.904377 1.007500 0.801985 00:12
37 0.768861 0.902936 0.998095 0.804530 00:12
38 0.757434 0.912133 0.995890 0.802749 00:12
39 0.762366 0.909252 0.995066 0.803003 00:12

#collapse_show

st = learn.model.state_dict()

#collapse_show

type(st)
collections.OrderedDict

#collapse_show

', '.join(st.keys())
'0.0.weight, 0.1.weight, 0.1.bias, 0.1.running_mean, 0.1.running_var, 0.1.num_batches_tracked, 1.0.weight, 1.1.weight, 1.1.bias, 1.1.running_mean, 1.1.running_var, 1.1.num_batches_tracked, 2.0.weight, 2.1.weight, 2.1.bias, 2.1.running_mean, 2.1.running_var, 2.1.num_batches_tracked, 4.0.convs.0.0.weight, 4.0.convs.0.1.weight, 4.0.convs.0.1.bias, 4.0.convs.0.1.running_mean, 4.0.convs.0.1.running_var, 4.0.convs.0.1.num_batches_tracked, 4.0.convs.1.0.weight, 4.0.convs.1.1.weight, 4.0.convs.1.1.bias, 4.0.convs.1.1.running_mean, 4.0.convs.1.1.running_var, 4.0.convs.1.1.num_batches_tracked, 4.1.convs.0.0.weight, 4.1.convs.0.1.weight, 4.1.convs.0.1.bias, 4.1.convs.0.1.running_mean, 4.1.convs.0.1.running_var, 4.1.convs.0.1.num_batches_tracked, 4.1.convs.1.0.weight, 4.1.convs.1.1.weight, 4.1.convs.1.1.bias, 4.1.convs.1.1.running_mean, 4.1.convs.1.1.running_var, 4.1.convs.1.1.num_batches_tracked, 5.0.convs.0.0.weight, 5.0.convs.0.1.weight, 5.0.convs.0.1.bias, 5.0.convs.0.1.running_mean, 5.0.convs.0.1.running_var, 5.0.convs.0.1.num_batches_tracked, 5.0.convs.1.0.weight, 5.0.convs.1.1.weight, 5.0.convs.1.1.bias, 5.0.convs.1.1.running_mean, 5.0.convs.1.1.running_var, 5.0.convs.1.1.num_batches_tracked, 5.0.idconv.0.weight, 5.0.idconv.1.weight, 5.0.idconv.1.bias, 5.0.idconv.1.running_mean, 5.0.idconv.1.running_var, 5.0.idconv.1.num_batches_tracked, 5.1.convs.0.0.weight, 5.1.convs.0.1.weight, 5.1.convs.0.1.bias, 5.1.convs.0.1.running_mean, 5.1.convs.0.1.running_var, 5.1.convs.0.1.num_batches_tracked, 5.1.convs.1.0.weight, 5.1.convs.1.1.weight, 5.1.convs.1.1.bias, 5.1.convs.1.1.running_mean, 5.1.convs.1.1.running_var, 5.1.convs.1.1.num_batches_tracked, 6.0.convs.0.0.weight, 6.0.convs.0.1.weight, 6.0.convs.0.1.bias, 6.0.convs.0.1.running_mean, 6.0.convs.0.1.running_var, 6.0.convs.0.1.num_batches_tracked, 6.0.convs.1.0.weight, 6.0.convs.1.1.weight, 6.0.convs.1.1.bias, 6.0.convs.1.1.running_mean, 6.0.convs.1.1.running_var, 6.0.convs.1.1.num_batches_tracked, 6.0.idconv.0.weight, 6.0.idconv.1.weight, 6.0.idconv.1.bias, 6.0.idconv.1.running_mean, 6.0.idconv.1.running_var, 6.0.idconv.1.num_batches_tracked, 6.1.convs.0.0.weight, 6.1.convs.0.1.weight, 6.1.convs.0.1.bias, 6.1.convs.0.1.running_mean, 6.1.convs.0.1.running_var, 6.1.convs.0.1.num_batches_tracked, 6.1.convs.1.0.weight, 6.1.convs.1.1.weight, 6.1.convs.1.1.bias, 6.1.convs.1.1.running_mean, 6.1.convs.1.1.running_var, 6.1.convs.1.1.num_batches_tracked, 7.0.convs.0.0.weight, 7.0.convs.0.1.weight, 7.0.convs.0.1.bias, 7.0.convs.0.1.running_mean, 7.0.convs.0.1.running_var, 7.0.convs.0.1.num_batches_tracked, 7.0.convs.1.0.weight, 7.0.convs.1.1.weight, 7.0.convs.1.1.bias, 7.0.convs.1.1.running_mean, 7.0.convs.1.1.running_var, 7.0.convs.1.1.num_batches_tracked, 7.0.idconv.0.weight, 7.0.idconv.1.weight, 7.0.idconv.1.bias, 7.0.idconv.1.running_mean, 7.0.idconv.1.running_var, 7.0.idconv.1.num_batches_tracked, 7.1.convs.0.0.weight, 7.1.convs.0.1.weight, 7.1.convs.0.1.bias, 7.1.convs.0.1.running_mean, 7.1.convs.0.1.running_var, 7.1.convs.0.1.num_batches_tracked, 7.1.convs.1.0.weight, 7.1.convs.1.1.weight, 7.1.convs.1.1.bias, 7.1.convs.1.1.running_mean, 7.1.convs.1.1.running_var, 7.1.convs.1.1.num_batches_tracked, 10.weight, 10.bias'

#collapse_show

st['10.bias']
tensor([-0.0070,  0.0070, -0.0086, -0.0081,  0.0253,  0.0061,  0.0274,  0.0104,
        -0.0421, -0.0088], device='cuda:0')

#collapse_show

mdl_path = path/'models'
mdl_path.mkdir(exist_ok=True)

It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly.

#collapse_show

torch.save(st, mdl_path/'iw5')

Pets

#collapse_show

pets = datasets.untar_data(datasets.URLs.PETS)

#collapse_show

pets.ls()
[PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/annotations'),
 PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images')]

#collapse_show

pets_path = pets/'images'

#collapse_show

il = ImageList.from_files(pets_path, tfms=tfms)

#collapse_show

il
ImageList (7390 items)
[PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/British_Shorthair_45.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/Siamese_128.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_185.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/basset_hound_98.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/basset_hound_136.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/Birman_136.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/scottish_terrier_40.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/saint_bernard_96.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/german_shorthaired_27.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/chihuahua_5.jpg')...]
Path: /home/cedric/.fastai/data/oxford-iiit-pet/images

#collapse_show

def random_splitter(fn, p_valid): return random.random() < p_valid

#collapse_show

random.seed(42)

#collapse_show

sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

#collapse_show

sd
SplitData
Train: ImageList (6667 items)
[PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/British_Shorthair_45.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_185.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/basset_hound_98.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/basset_hound_136.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/Birman_136.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/scottish_terrier_40.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/german_shorthaired_27.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/Egyptian_Mau_62.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/yorkshire_terrier_71.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/Bombay_113.jpg')...]
Path: /home/cedric/.fastai/data/oxford-iiit-pet/images
Valid: ImageList (723 items)
[PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/Siamese_128.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/saint_bernard_96.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/chihuahua_5.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/keeshond_163.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_86.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/yorkshire_terrier_113.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/japanese_chin_94.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_115.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/american_pit_bull_terrier_71.jpg'), PosixPath('/home/cedric/.fastai/data/oxford-iiit-pet/images/great_pyrenees_195.jpg')...]
Path: /home/cedric/.fastai/data/oxford-iiit-pet/images

#collapse_show

n = il.items[0].name; n
'British_Shorthair_45.jpg'

#collapse_show

re.findall(r'^(.*)_\d+.jpg$', n)[0]
'British_Shorthair'

#collapse_show

def pet_labeler(fn): return re.findall(r'^(.*)_\d+.jpg$', fn.name)[0]

#collapse_show

proc = CategoryProcessor()

#collapse_show

ll = label_by_func(sd, pet_labeler, proc_y=proc)

#collapse_show

', '.join(proc.vocab)
'British_Shorthair, staffordshire_bull_terrier, basset_hound, Birman, scottish_terrier, german_shorthaired, Egyptian_Mau, yorkshire_terrier, Bombay, great_pyrenees, english_cocker_spaniel, leonberger, Siamese, american_bulldog, japanese_chin, Maine_Coon, newfoundland, Abyssinian, pug, Russian_Blue, beagle, samoyed, havanese, wheaten_terrier, Bengal, boxer, american_pit_bull_terrier, miniature_pinscher, Sphynx, chihuahua, shiba_inu, english_setter, saint_bernard, pomeranian, Persian, keeshond, Ragdoll'

#collapse_show

ll.valid.x.tfms = val_tfms

#collapse_show

c_out = len(proc.vocab)

#collapse_show

data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)

#collapse_show

learn.fit(5, cbsched)
<progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 3.460010 0.087896 3.508121 0.081604 00:09
1 3.291101 0.138443 4.057820 0.084371 00:09
2 3.074502 0.194390 3.341131 0.146611 00:09
3 2.764267 0.287986 2.808986 0.251729 00:09
4 2.467706 0.386681 2.570934 0.344398 00:09

Custom head

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)

#collapse_show

st = torch.load(mdl_path/'iw5')

#collapse_show

m = learn.model

#collapse_show

m.load_state_dict(st)
<All keys matched successfully>

#collapse_show

cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = m[:cut]

#collapse_show

xb,yb = get_batch(data.valid_dl, learn)
epoch train_loss train_accuracy valid_loss valid_accuracy time

#collapse_show

pred = m_cut(xb)

#collapse_show

pred.shape
torch.Size([128, 512, 4, 4])

#collapse_show

ni = pred.shape[1]

#collapse_show

class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=1):
        super().__init__()
        self.output_size = sz
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

#collapse_show

nh = 40

m_new = nn.Sequential(
    m_cut, AdaptiveConcatPool2d(), Flatten(),
    nn.Linear(ni*2, data.c_out))

#collapse_show

learn.model = m_new

#collapse_show

learn.fit(5, cbsched)
<progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 2.869728 0.286336 2.255571 0.448133 00:09
1 2.109874 0.496625 2.265567 0.448133 00:09
2 1.930034 0.561272 2.033231 0.504841 00:09
3 1.674324 0.657567 1.723327 0.641770 00:09
4 1.474969 0.736463 1.574200 0.699862 00:09

adapt_model and gradual unfreezing

#collapse_show

def adapt_model(learn, data):
    cut = next(i for i,o in enumerate(learn.model.children())
               if isinstance(o,nn.AdaptiveAvgPool2d))
    m_cut = learn.model[:cut]
    xb,yb = get_batch(data.valid_dl, learn)
    pred = m_cut(xb)
    ni = pred.shape[1]
    m_new = nn.Sequential(
        m_cut, AdaptiveConcatPool2d(), Flatten(),
        nn.Linear(ni*2, data.c_out))
    learn.model = m_new

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
<All keys matched successfully>

#collapse_show

adapt_model(learn, data)
epoch train_loss train_accuracy valid_loss valid_accuracy time

#collapse_show

for p in learn.model[0].parameters(): p.requires_grad_(False)

#collapse_show

learn.fit(3, sched_1cycle(1e-2, 0.5))
<progress value='0' class='' max='3', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 2.782441 0.298035 2.739205 0.340249 00:05
1 2.570200 0.401830 2.523606 0.439834 00:05
2 2.133533 0.512674 2.141393 0.496542 00:05

#collapse_show

for p in learn.model[0].parameters(): p.requires_grad_(True)

#collapse_show

learn.fit(5, cbsched, reset_opt=True)
<progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 1.907010 0.585271 1.969608 0.571231 00:09
1 1.861008 0.595770 2.103309 0.510373 00:09
2 1.801710 0.607920 2.016537 0.508990 00:09
3 1.609392 0.691615 1.749493 0.634855 00:09
4 1.424160 0.761212 1.586875 0.695712 00:09

Batch norm transfer

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
epoch train_loss train_accuracy valid_loss valid_accuracy time

#collapse_show

def apply_mod(m, f):
    f(m)
    for l in m.children(): apply_mod(l, f)

def set_grad(m, b):
    if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return
    if hasattr(m, 'weight'):
        for p in m.parameters(): p.requires_grad_(b)

#collapse_show

apply_mod(learn.model, partial(set_grad, b=False))

#collapse_show

learn.fit(3, sched_1cycle(1e-2, 0.5))
<progress value='0' class='' max='3', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 2.690954 0.320084 2.427597 0.403873 00:06
1 2.192524 0.474726 2.123527 0.484094 00:06
2 1.914695 0.569222 1.958007 0.557400 00:06

#collapse_show

apply_mod(learn.model, partial(set_grad, b=True))

#collapse_show

learn.fit(5, cbsched, reset_opt=True)
<progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 1.819283 0.609569 1.912626 0.580913 00:09
1 1.793303 0.617219 2.100043 0.493776 00:09
2 1.751841 0.628919 2.360935 0.394191 00:09
3 1.569761 0.704065 1.744866 0.626556 00:09
4 1.408960 0.767062 1.579932 0.692946 00:08

Pytorch already has an apply method we can use:

#collapse_show

learn.model.apply(partial(set_grad, b=False));

Discriminative LR and param groups

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)

#collapse_show

learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
epoch train_loss train_accuracy valid_loss valid_accuracy time

#collapse_show

def bn_splitter(m):
    def _bn_splitter(l, g1, g2):
        if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()
        elif hasattr(l, 'weight'): g1 += l.parameters()
        for ll in l.children(): _bn_splitter(ll, g1, g2)
        
    g1,g2 = [],[]
    _bn_splitter(m[0], g1, g2)
    
    g2 += m[1:].parameters()
    return g1,g2

#collapse_show

a,b = bn_splitter(learn.model)

#collapse_show

test_eq(len(a)+len(b), len(list(m.parameters())))

#collapse_show

Learner.ALL_CBS
{'after_backward',
 'after_batch',
 'after_cancel_batch',
 'after_cancel_epoch',
 'after_cancel_train',
 'after_epoch',
 'after_fit',
 'after_loss',
 'after_pred',
 'after_step',
 'begin_batch',
 'begin_epoch',
 'begin_fit',
 'begin_validate'}

#collapse_show

from types import SimpleNamespace
cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})

#collapse_show

cb_types.after_backward
'after_backward'

#collapse_show

class DebugCallback(Callback):
    _order = 999
    def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f
    def __call__(self, cb_name):
        if cb_name==self.cb_name:
            if self.f: self.f(self.run)
            else:      set_trace()

#collapse_show

def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
    phases = create_phases(pct_start)
    sched_lr  = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
                 for lr in lrs]
    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
    return [ParamScheduler('lr', sched_lr),
            ParamScheduler('mom', sched_mom)]

#collapse_show

disc_lr_sched = sched_1cycle([0,3e-2], 0.5)

#collapse_show

learn = cnn_learner(xresnet18, data, loss_func, opt_func,
                    c_out=10, norm=norm_imagenette, splitter=bn_splitter)

learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)
epoch train_loss train_accuracy valid_loss valid_accuracy time

#collapse_show

def _print_det(o): 
    print (len(o.opt.param_groups), o.opt.hypers)
    raise CancelTrainException()

learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])
<progress value='0' class='' max='1', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
2 [{'mom': 0.9499999999999997, 'mom_sqr': 0.99, 'eps': 1e-06, 'wd': 0.01, 'lr': 0.0, 'sqr_mom': 0.99}, {'mom': 0.9499999999999997, 'mom_sqr': 0.99, 'eps': 1e-06, 'wd': 0.01, 'lr': 0.0030000000000000512, 'sqr_mom': 0.99}]

#collapse_show
learn.fit(3, disc_lr_sched)
<progress value='0' class='' max='3', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 2.585809 0.358032 2.342050 0.409405 00:08
1 2.312127 0.433328 2.434426 0.412172 00:08
2 2.032101 0.523324 1.978259 0.539419 00:08

#collapse_show

disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)

#collapse_show

learn.fit(5, disc_lr_sched)
<progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>
epoch train_loss train_accuracy valid_loss valid_accuracy time
0 1.862455 0.596520 2.062223 0.511757 00:08
1 1.927012 0.558422 2.068517 0.508990 00:08
2 1.780360 0.621869 1.958512 0.542185 00:08
3 1.633498 0.674516 1.755347 0.605809 00:08
4 1.534966 0.709615 1.697863 0.641770 00:08