FAU Monkey Face ID v1.00¶
In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai.vision import *
from fastai.metrics import error_rate
from pathlib import Path
from fastai.widgets import *
In [3]:
bs = 64
Looking at the data¶
In [4]:
path = Path('/home/jupyter/monkey/monkey224');
In [5]:
path.ls()
Out[5]:
In [6]:
path_img = path
Create an ImageDataBunch using the folder structure above.¶
Reserve 20% of the images randomly for validation.
In [7]:
data = ImageDataBunch.from_folder(path, valid_pct=0.2, ds_tfms=get_transforms(), size=224)
In [8]:
data.show_batch(rows=6, figsize=(24,24))
In [9]:
print(data.classes)
len(data.classes),data.c
Out[9]:
Training: resnet50¶
Train the model using resNet50 (more layers)
In [10]:
learn = cnn_learner(data, models.resnet50, metrics=error_rate)
Find and plot the learning rate.
In [18]:
learn.lr_find()
learn.recorder.plot()
Observe the loss rate vs. Learning rate. This will be used when we unfreeze the model. Run it for 4 cycles.
In [19]:
learn.fit_one_cycle(4)
96.5% Accurate! Impressive, but we can do better.
In [4]:
# save or load the model
#learn.load('mf-stage-1-50')
#learn.save('mf-stage-1-50')
Unfreeze the model and then train it again, using a max learning rate determine when we plotted the data above.
In [14]:
learn.unfreeze()
learn.fit_one_cycle(3, max_lr=slice(1e-4,1e-2))
After 3 more epochs, we are at an impressive 99.44% accuracy rate.
In [16]:
# Save stage 2
# learn.save('mf-stage-2-50');
Let's dive deeper into the Classifier to see what little it's confused about.
In [17]:
interp = ClassificationInterpretation.from_learner(learn)
In [18]:
interp.most_confused(min_val=2)
Out[18]:
In [19]:
losses,idxs = interp.top_losses()
len(data.valid_ds)==len(losses)==len(idxs)
Out[19]:
In [20]:
interp.plot_top_losses(9, figsize=(15,11))
In [21]:
interp.plot_confusion_matrix(figsize=(12,12), dpi=120)
In [22]:
# export the learning model for use on another computer.
#learn.export();
In [14]:
# export the data model for use on another computer.
#data.export()