--- title: Binary Segmentation keywords: fastai sidebar: home_sidebar nb_path: "nbs/course2020/vision/07_Binary_Segmentation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, and wwf currently running at the time of writing this:

  • fastai: 2.1.10
  • fastcore: 1.3.13
  • wwf: 0.0.7

{% endraw %} {% raw %}
file = "https://drive.google.com/uc?id=18xM3jU2dSp1DiDqEM6PVXattNMZvsX4z"
{% endraw %} {% raw %}
!gdown {file}
Downloading...
From: https://drive.google.com/uc?id=18xM3jU2dSp1DiDqEM6PVXattNMZvsX4z
To: /content/Portrait.zip
107MB [00:02, 44.4MB/s]
{% endraw %}

We'll unzip the data

{% raw %}
from zipfile import ZipFile
{% endraw %} {% raw %}
with ZipFile('Portrait.zip', 'r') as zip_ref:
  zip_ref.extractall('')
{% endraw %} {% raw %}
from fastai.vision.all import *
{% endraw %}

And grab our ground truth labels and files

{% raw %}
path = Path('')
{% endraw %} {% raw %}
lbl_names = get_image_files(path/'GT_png')
fnames = get_image_files(path/'images_data_crop')
{% endraw %} {% raw %}
img_fn = fnames[10]; img_fn
Path('images_data_crop/00970.jpg')
{% endraw %} {% raw %}
lbl_names[10]
Path('GT_png/02621_mask.png')
{% endraw %} {% raw %}
fn = '00013.jpg'
{% endraw %} {% raw %}
im = PILImage.create(f'images_data_crop/{fn}')
{% endraw %} {% raw %}
msk = PILMask.create(f'GT_png/00013_mask.png')
{% endraw %}

Now, our mask isn't set up how fastai expects, in which the mask points are not all in a row. We need to change this:

{% raw %}
len(np.unique(msk))
2
{% endraw %} {% raw %}
np.unique(msk)
array([  0, 255], dtype=uint8)
{% endraw %}

We'll do this through an n_codes function. What this will do is run through our masks and build a set based on the unique values present in our masks. From there we will build a dictionary that will replace our points once we load in the image

{% raw %}
def n_codes(fnames, is_partial=True):
  "Gather the codes from a list of `fnames`"
  vals = set()
  if is_partial:
    random.shuffle(fnames)
    fnames = fnames[:10]
  for fname in fnames:
    msk = np.array(PILMask.create(fname))
    for val in np.unique(msk):
      if val not in vals:
        vals.add(val)
  vals = list(vals)
  p2c = dict()
  for i,val in enumerate(vals):
    p2c[i] = vals[i]
  return p2c
{% endraw %} {% raw %}
vals = n_codes(lbl_names)
{% endraw %}

So vals in this case is anywhere that is 255 in our mask should be replaced to one

{% raw %}
vals
{0: 0, 1: 255}
{% endraw %}

So now let's build a get_msk function that will modify our mask we get based on this dictionary and override those values

{% raw %}
def get_msk(fn, pix2class):
  "Grab a mask from a `filename` and adjust the pixels based on `pix2class`"
  fn = path/'GT_png'/f'{fn.stem}_mask.png'
  msk = np.array(PILMask.create(fn))
  mx = np.max(msk)
  for i, val in enumerate(p2c):
    msk[msk==p2c[i]] = val
  return PILMask.create(msk)
{% endraw %} {% raw %}
codes = ['Background', 'Face']
{% endraw %}

Now we can build a get_y and a DataBlock!

{% raw %}
get_y = lambda o: get_msk(o, p2c)
{% endraw %} {% raw %}
binary = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                   get_items=get_image_files,
                   splitter=RandomSplitter(),
                   get_y=get_y,
                   item_tfms=Resize(224),
                   batch_tfms=[Normalize.from_stats(*imagenet_stats)])
{% endraw %} {% raw %}
dls = binary.dataloaders(path/'images_data_crop', bs=8)
{% endraw %}

We can look at how our masks look by adjusting the colormap and the vmin and max

{% raw %}
dls.show_batch(cmap='Blues', vmin=0, vmax=1)
{% endraw %}

And now we can train!

{% raw %}
learn = unet_learner(dls, resnet34)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/checkpoints/resnet34-333f7ec4.pth

{% endraw %} {% raw %}
learn.fit(1)
epoch train_loss valid_loss time
0 0.287433 0.168146 00:43
{% endraw %}

And we're good :)

{% raw %}
learn.show_results(cmap='Blues', vmin=0, vmax=1)
{% endraw %}

If we want to examine it further we can do:

{% raw %}
preds = learn.get_preds()
{% endraw %} {% raw %}
preds[0][0].shape
torch.Size([2, 224, 224])
{% endraw %} {% raw %}
p = preds[0][0]
{% endraw %} {% raw %}
plt.imshow(p[1])
<matplotlib.image.AxesImage at 0x7f942a6599e8>
{% endraw %} {% raw %}
plt.imshow(p[0])
<matplotlib.image.AxesImage at 0x7f942a5d8240>
{% endraw %}