CNN Interpretation with CAM

#hide
! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
#hide
from fastbook import *

CAM and Hooks

path = untar_data(URLs.PETS)/'images'
def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2, seed=21,
    label_func=is_cat, item_tfms=Resize(224))
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
img = PILImage.create(image_cat())
x, = first(dls.test_dl([img]))
class Hook():
    def hook_func(self, m, i, o): self.stored = o.detach().clone()
hook_output = Hook()
hook = learn.model[0].register_forward_hook(hook_output.hook_func)
with torch.no_grad(): output = learn.model.eval()(x)
act = hook_output.stored[0]
F.softmax(output, dim=-1)
dls.vocab
x.shape
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)
cam_map.shape
x_dec = TensorImage(dls.train.decode((x,))[0][0])
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,224,224,0),
              interpolation='bilinear', cmap='magma');
hook.remove()
class Hook():
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_func)   
    def hook_func(self, m, i, o): self.stored = o.detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()
with Hook(learn.model[0]) as hook:
    with torch.no_grad(): output = learn.model.eval()(x.cuda())
    act = hook.stored

Gradient CAM

class HookBwd():
    def __init__(self, m):
        self.hook = m.register_backward_hook(self.hook_func)   
    def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.hook.remove()
cls = 1
with HookBwd(learn.model[0]) as hookg:
    with Hook(learn.model[0]) as hook:
        output = learn.model.eval()(x.cuda())
        act = hook.stored
    output[0,cls].backward()
    grad = hookg.stored
w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
              interpolation='bilinear', cmap='magma');
with HookBwd(learn.model[0][-2]) as hookg:
    with Hook(learn.model[0][-2]) as hook:
        output = learn.model.eval()(x.cuda())
        act = hook.stored
    output[0,cls].backward()
    grad = hookg.stored
w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)
_,ax = plt.subplots()
x_dec.show(ctx=ax)
ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,224,224,0),
              interpolation='bilinear', cmap='magma');

Conclusion

Questionnaire

  1. What is a “hook” in PyTorch?
  2. Which layer does CAM use the outputs of?
  3. Why does CAM require a hook?
  4. Look at the source code of the ActivationStats class and see how it uses hooks.
  5. Write a hook that stores the activations of a given layer in a model (without peeking, if possible).
  6. Why do we call eval before getting the activations? Why do we use no_grad?
  7. Use torch.einsum to compute the “dog” or “cat” score of each of the locations in the last activation of the body of the model.
  8. How do you check which order the categories are in (i.e., the correspondence of index->category)?
  9. Why are we using decode when displaying the input image?
  10. What is a “context manager”? What special methods need to be defined to create one?
  11. Why can’t we use plain CAM for the inner layers of a network?
  12. Why do we need to register a hook on the backward pass in order to do Grad-CAM?
  13. Why can’t we call output.backward() when output is a rank-2 tensor of output activations per image per class?

Further Research

  1. Try removing keepdim and see what happens. Look up this parameter in the PyTorch docs. Why do we need it in this notebook?
  2. Create a notebook like this one, but for NLP, and use it to find which words in a movie review are most significant in assessing the sentiment of a particular movie review.