Preparing a Model for Quantization
Note: If you just want a run-down of the required modifications to make sure a model is properly quantized in Distiller, you can skip this part and head right to the next section.
Distiller provides an automatic mechanism to convert a "vanilla" FP32 PyTorch model to a quantized counterpart (for quantization-aware training and post-training quantization). This mechanism works at the PyTorch "Module" level. By "Module" we refer to any sub-class of the
torch.nn.Module class. The Distiller Quantizer can detect modules, and replace them with other modules.
However, it is not a requirement in PyTorch that all operations be defined as modules. Operations are often executed via direct overloaded tensor operator (
-, etc.) and functions under the
torch namespace (e.g.
torch.cat()). There is also the
torch.nn.functional namespace, which provides functional equivalents to modules provided in
torch.nn. When an operation does not maintain any state, even if it has a dedicated
nn.Module, it'll often be invoked via its functional counterpart. For example - calling
nn.functional.relu() instead of creating an instance of
nn.ReLU and invoking that. Such non-module operations are called directly from the module's
forward function. There are ways to discover these operations up-front, which are used in Distiller for different purposes. Even so, we cannot replace these operations without resorting to rather "dirty" Python tricks, which we would rather not do for numerous reasons.
In addition, there might be cases where the same module instance is re-used multiple times in the
forward function. This is also a problem for Distiller. There are several flows that will not work as expected if each call to an operation is not "tied" to a dedicated module instance. For example:
- When collecting statistics, each invocation of a re-used it will overwrite the statistics collected for the previous invocation. We end up with statistics missing for all invocations except the last one.
- "Net-aware" quantization relies on a 1:1 mapping from each operation executed in the model to a module which invoked it. With re-used modules, this mapping is not 1:1 anymore.
Hence, to make sure all supported operations in a model are properly quantized by Distiller, it might be necessary to modify the model code before passing it to the quantizer. Note that the exact set of supported operations might vary between the different available quantizers.
Model Preparation To-Do List
The steps required to prepare a model for quantization can be summarized as follows:
- Replace direct tensor operations with modules
- Replace re-used modules with dedicated instances
torch.nn.functionalcalls with equivalent modules
- Special cases - replace modules that aren't quantize-able with quantize-able variants
In the next section we'll see an example of the items 1-3 in this list.
As for "special cases", at the moment the only such case is LSTM. See the section after the example for details.
Model Preparation Example
We'll using the following simple module as an example. This module is loosely based on the ResNet implementation in torchvision, with some changes that don't make much sense and are meant to demonstrate the different modifications that might be required.
import torch.nn as nn import torch.nn.functional as F class BasicModule(nn.Module): def __init__(self, in_ch, out_ch, kernel_size): super(BasicModule, self).__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size) self.bn1 = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size) self.bn2 = nn.BatchNorm2d(out_ch) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) # (1) Overloaded tensor addition operation # Alternatively, could be called via a tensor function: skip_1.add_(identity) out += identity # (2) Relu module re-used out = self.relu(out) # (3) Using operation from 'torch' namespace out = torch.cat([identity, out], dim=1) # (4) Using function from torch.nn.functional out = F.sigmoid(out) return out
Replace direct tensor operations with modules
The addition (1) and concatenation (3) operations in the
forward function are examples of direct tensor operations. These operations do not have equivalent modules defined in
torch.nn.Module. Hence, if we want to quantize these operations, we must implement modules that will call them. In Distiller we've implemented a few simple wrapper modules for common operations. These are defined in the
distiller.modules namespace. Specifically, the addition operation should be replaced with the
EltWiseAdd module, and the concatenation operation with the
Concat module. Check out the code here to see the available modules.
Replace re-used modules with dedicated instances
The relu operation above is called via a module, but the same instance is used for both calls (2). We need to create a second instance of
__init__ and use that for the second call during
torch.nn.functional calls with equivalent modules
The sigmoid (4) operation is invoked using the functional interface. Luckily, operations in
torch.nn.functional have equivalent modules, so se can just use those. In this case we need to create an instance of
Putting it all together
After making all of the changes detailed above, we end up with:
import torch.nn as nn import torch.nn.functional as F import distiller.modules class BasicModule(nn.Module): def __init__(self, in_ch, out_ch, kernel_size): super(BasicModule, self).__init__() self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size) self.bn1 = nn.BatchNorm2d(out_ch) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size) self.bn2 = nn.BatchNorm2d(out_ch) # Fixes start here # (1) Replace '+=' with an inplace module self.add = distiller.modules.EltWiseAdd(inplace=True) # (2) Separate instance for each relu call self.relu2 = nn.ReLU() # (3) Dedicated module instead of tensor op self.concat = distiller.modules.Concat(dim=1) # (4) Dedicated module instead of functional call self.sigmoid = nn.Sigmoid() def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.conv2(out) out = self.bn2(out) out = self.add(out, identity) out = self.relu(out) out = self.concat(identity, out) out = self.sigmoid(out) return out
Special Case: LSTM (a "compound" module)
LSTMs present a special case. An LSTM block is comprised of building blocks, such as fully-connected layers and sigmoid/tanh non-linearities, all of which have dedicated modules in
torch.nn. However, the LSTM implementation provided in PyTorch does not use these building blocks. For optimization purposes, all of the internal operations are implemented at the C++ level. The only part of the model exposed at the Python level are the parameters of the fully-connected layers. Hence, all we can do with the PyTorch LSTM module is to quantize the inputs/outputs of the entire block, and to quantize the FC layers parameters. We cannot quantize the internal stages of the block at all. In addition to just quantizing the internal stages, we'd also like the option to control the quantization parameters of each of the internal stage separately.
What to do
Distiller provides a "modular" implementation of LSTM, comprised entirely of operations defined at the Python level. We provide an implementation of
LSTMCell provided by PyTorch. See the implementation here.
A function to convert all LSTM instances in the model to the Distiller variant is also provided:
model = distiller.modules.convert_model_to_distiller_lstm(model)
To see an example of this conversion, and of mixed-precision quantization within an LSTM block, check out our tutorial on word-language model quantization here.