This provides both a standalone class and a callback for registering and automatically deregistering PyTorch hooks, along with some pre-defined hooks. Hooks can be attached to any nn.Module, for either the forward or the backward pass.
We'll start by looking at the pre-defined hook ActivationStats, then we'll see how to create our own.
from fastai.gen_doc.nbdoc import *
from fastai.callbacks.hooks import *
from fastai import *
from fastai.train import *
from fastai.vision import *
show_doc(ActivationStats)
class ActivationStats[source]
ActivationStats(learn:Learner,modules:Sequence[Module]=None,do_remove:bool=True) ::HookCallback
Callback that record the activations.
ActivationStats saves the layer activations in self.stats for all modules passed to it. By default it will save activations for all modules. For instance:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, callback_fns=ActivationStats)
learn.fit(1)
VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…
Total time: 00:13 epoch train loss valid loss 0 0.077055 0.049985 (00:13)
The saved stats is a FloatTensor of shape (2,num_modules,num_batches). The first axis is (mean,stdev).
len(learn.data.train_dl),len(learn.activation_stats.modules)
(194, 44)
learn.activation_stats.stats.shape
torch.Size([2, 44, 194])
So this shows the standard deviation (axis0==1) of 5th last layer (axis1==-5) for each batch (axis2):
plt.plot(learn.activation_stats.stats[1][-5].numpy());
show_doc(Hook)
Registers and manually deregisters a PyTorch hook. Your hook_func will be called automatically when forward/backward (depending on is_forward) for your module m is run, and the result of that function is placed in self.stored.
show_doc(Hook.remove)
remove[source]
remove()
Deregister the hook, if not called already.
show_doc(Hooks)
class Hooks[source]
Hooks(ms:ModuleList,hook_func:HookFunc,is_forward:bool=True)
Create several hooks.
Acts as a Collection (i.e. len(hooks) and hooks[i]) and an Iterator (i.e. for hook in hooks) of a group of hooks, one for each module in ms, with the ability to remove all as a group. Use stored to get all hook results. hook_func and is_forward behavior is the same as Hook. See the source code for HookCallback for a simple example.
show_doc(Hooks.remove)
remove[source]
remove()
Deregister all hooks created by this class, if not previously called.
Function that creates a Hook for module that simply stores the output of the layer.
Function that creates a Hook for all passed modules that simply stores the output of the layers. For example, the (slightly simplified) source code of model_sizes is:
def model_sizes(m, size):
x = m(torch.zeros(1, in_channels(m), *size))
return [o.stored.shape for o in hook_outputs(m)]
show_doc(model_sizes)
It can be useful to get the size of each layer of a model (e.g. for printing a summary, or for generating cross-connections for a DynamicUnet), however they depend on the size of the input. This function calculates the layer sizes by passing in a minimal tensor of size.
show_doc(HookCallback)
class HookCallback[source]
HookCallback(learn:Learner,modules:Sequence[Module]=None,do_remove:bool=True) ::LearnerCallback
Callback that registers given hooks.
For all modules, uses a callback to automatically register a method self.hook (that you must define in an inherited class) as a hook. This method must have the signature:
def hook(self, m:Model, input:Tensors, output:Tensors)
If do_remove then the hook is automatically deregistered at the end of training. See ActivationStats for a simple example of inheriting from this class.
show_doc(HookCallback.remove)
remove[source]
remove()
show_doc(HookCallback.on_train_begin)
show_doc(HookCallback.on_train_end)
show_doc(ActivationStats.hook)
show_doc(ActivationStats.on_batch_end)
show_doc(ActivationStats.on_train_begin)
show_doc(ActivationStats.on_train_end)
show_doc(Hook.hook_fn)