Big Models Hate This One Weird Trick! (Quantization, T5, & PyTorch 1.4)

Here’s the notebook in Google Colab

As we know, models can be big lumbering beasts, comprised of millions of parameters (both weights and activations) that require lots of matrix multiplications to take an input and arrive at an answer. And for most of our work so far, that’s been fine! We have mighty GPUs that can handle these burdens with ease.

But what if we didn’t? We often package a model up for production inference usage so that it only runs on the CPU. And what if we wanted to run our model on a smaller embedded platform? Suddenly, both the size of the model and all those floating-point operations become a little more problematic. Thankfully, there’s a trick we can perform that makes our model smaller and faster, normally with the trade off with some accuracy. Even better, PyTorch allows us to perform this one weird trick with just one line of code, with some other approaches for squeezing even more performance. Let’s have a quick look at quantization.

Quantization

Every parameter in our model is a 32-bit floating point number, taking up 4 bytes of memory. That’s not a lot, but it can soon add up. Let’s have a look at Google’s recent T5 transformer-based model, which has a t5-small variant that’s available in the transformers library.

import torch
from transformers import pipeline, T5ForConditionalGeneration
		
def count_parameters(model):
  return sum(p.numel() for p in model.parameters())
	    
base_model = T5ForConditionalGeneration.from_pretrained("t5-small")

param_count = count_parameters(base_model)

memory = (param_count * 4) / (1024 *1024)
memory

230.8154296875

Even with the smallest pre-trained T5 weights, our model is roughly 60m parameters and weighs in at a whopping 230Mb!

However, what if we decided that we didn’t need the full precision of our floating-point parameters? If our parameters could be restricted to within a certain range of values, then we could use a smaller type of number representation to store the parameters. This quantization is the key to speeding up our inference time and reducing the memory footprint of our models. What we tend to aim for is to quantize down from a 32-bit floating point to an 8-bit integer. The basic idea is:

$$x{int8} = (\frac{x{float32}}{x{scale}} + x{offset})$$

Which is essentially just fitting the potential values of the parameters of a network to a line of $y = mx + c$, although due to the reduced resolution of the 8-bit integer, there’s only so many values a parameter now may take instead of the huge amount that a float32 value could be. PyTorch does its quantizing in a slightly more complicated affair that ensures that zero is always zero, but the basic idea is the same - we have a range of values that our parameters can take, and then find an appropriate pair $x{scale}$ and $x{offset}$ to provide 256 graduations to represent that range - or 255 if you think about PyTorch always keeping zero around.

At the moment (PyTorch 1.5), quantized layers are best supported with CNN and Linear layers. Thankfully, if we have a look at the model structure of T5, we can see a happy coincidence:

base_model
T5Model(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (1): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (2): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (3): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (4): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (5): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (final_layer_norm): T5LayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (1): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (2): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (3): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (4): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (5): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (final_layer_norm): T5LayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

Yes, that’s right, look at all those Linear layers! We should be able to get some benefit out of quantizing this model.

One Weird Trick — Dynamic Quantization

import torch.quantization
	
quantized_model = torch.quantization.quantize_dynamic(base_model, {torch.nn.Linear}, dtype=torch.qint8)

No, really, that’s it. Chapter done. Bye!

Oh, okay, if you really insist. But honestly, there’s not much more to it. Okay, firstly, a caveat in that quantize_dynamic will only quantize the weights, not the activations in our parameters. But all we need to do is pass in the model we wish to quantize and a dict of layers that we wish to replace with our quantized versions, in this case Linear. The function returns a new model, though you could run with the optional parameter inplace=True to mutate the original model rather than make a copy.

Let’s save the model and take a look at the quantized size:

!mkdir t5
quantized_model.save_pretrained("t5")
!du -m t5
mkdir: cannot create directory ‘t5’: File exists
121 t5

Almost a 50% reduction in size! We can’t get down to 4 times smaller due to not being able to store the activations as 8-bit integers, but we’ve done pretty well for one line of code. Let’s do a very simple microbenchmark using both models in the transformers library summarization pipeline.

base_summarizer = pipeline("summarization", model=base_model, tokenizer="t5-small")
quantized_summarizer = pipeline("summarization", model=quantized_model, tokenizer="t5-small")
%timeit base_summarizer("From the very beginning, Regan was seen as having series potential. After the television film scored highly in the ratings, work began on the development of the series proper. Ian Kennedy Martin's idea was for the series to be mainly studio-based, with more dialogue and less action, but producer Ted Childs disagreed, and in consequence Ian Kennedy Martin parted company with the project. Childs produced it on 16mm film, a format that allowed for a much smaller film unit than videotape at that time. This made it possible to shoot almost entirely on location which helped give the series a startling degree of realism and to use film editing techniques which enabled him to give the show a heavy bias toward action sequences. The television play and the subsequent series were commissioned by Thames Television and produced by its film division Euston Films. It was originally broadcast on ITV between 2 January 1975 and 28 December 1978 at 21:00–22:00 on weekdays (usually Mondays), with repeated screenings at the same time until the early 1980s. The writers were given strict guidelines to follow: \"Each show will have an overall screen time (minus titles) of 48 minutes 40 seconds. Each film will open with a teaser of up to 3 minutes, which will be followed by the opening titles. The story will be played across three acts, each being no more than 19 minutes and no less than 8 minutes in length. Regan will appear in every episode, Carter in approximately 10 out of 13 episodes. In addition to these main characters, scripts should be based around three major speaking parts, with up to ten minor speaking parts")
1 loop, best of 3: 29.4 s per loop
%timeit quantized_summarizer("From the very beginning, Regan was seen as having series potential. After the television film scored highly in the ratings, work began on the development of the series proper. Ian Kennedy Martin's idea was for the series to be mainly studio-based, with more dialogue and less action, but producer Ted Childs disagreed, and in consequence Ian Kennedy Martin parted company with the project. Childs produced it on 16mm film, a format that allowed for a much smaller film unit than videotape at that time. This made it possible to shoot almost entirely on location which helped give the series a startling degree of realism and to use film editing techniques which enabled him to give the show a heavy bias toward action sequences. The television play and the subsequent series were commissioned by Thames Television and produced by its film division Euston Films. It was originally broadcast on ITV between 2 January 1975 and 28 December 1978 at 21:00–22:00 on weekdays (usually Mondays), with repeated screenings at the same time until the early 1980s. The writers were given strict guidelines to follow: \"Each show will have an overall screen time (minus titles) of 48 minutes 40 seconds. Each film will open with a teaser of up to 3 minutes, which will be followed by the opening titles. The story will be played across three acts, each being no more than 19 minutes and no less than 8 minutes in length. Regan will appear in every episode, Carter in approximately 10 out of 13 episodes. In addition to these main characters, scripts should be based around three major speaking parts, with up to ten minor speaking parts")
1 loop, best of 3: 16.6 s per loop

In addition to almost being half the size, the quantized model is almost twice as fast! So…why don’t we do this all the time? Are there no downsides? Well…it depends. We are losing information in our inference in a quantized model as our values cannot map to all the possible floating-point values that we find in the original model. So the chain of multiplications will be less accurate in our quantized model than in the original. You’ll need to check the new model against a reference dataset to determine the accuracy loss and whether that loss is an acceptable trade-off compared to the reduced storage demands and faster execution.

Other Quantizing Options Are Available

In addition to dynamic quantizing, PyTorch also offers static quantizing, where a trained model is modified to include observer modules and a selection of data is fed into the model. During the inference on this data, the observers can generate a quantized distribution that fits best to the observed data and the activations that result. This can can produce even further space and time savings, especially with vision models like ResNet.

However, for the best-in-class of accuracy in your smaller model, you’ll want to investigate quantization-aware training (QAT). In this approach, the model fakes quantizing during the training loop of both the forward and backward passes; while all the computations take place with standard floats, everything is rounded down to integer values, so you end up with a quantized model after training is finished, but one with a higher accuracy than you can acheive with the dynamic or static approaches.

Is It Worth It?

You might be wondering if you’re just better off training a smaller model rather than going to all this effort to compress larger models. In the recent paper, Train Large, Then Compress, there’s a good deal of evidence presented that transformer-based models really do benefit from this approach. Because larger models converge faster than smaller ones, you will likely get more accurate results by training a large model and compressing than if you spent the same compute time on a smaller model. So go forth and compress!

(and we’ll see you back here in the future for pruning models)

Further Reading

https://pytorch.org/docs/stable/quantization.html

https://arxiv.org/abs/2002.11794