PyTorch Memory Tuning
July 20, 2023
Intro #
In this article we will focus on minimizing GPU memory footprint — for both optimization and inference workloads — and we can largely forget about throughput for once. Usually companies ask me to help reduce inference response time or cost, but reducing memory consumption without making architecture sacrifices or blowing out throughput is often just as valuable.
For any kind of optimization, the systematic approach is to break down the relevant measurement to a level where we have a clear signal as to what part of the system to improve next.
- In the case of throughput optimization we would break down the passage of GPU time, understanding which parts of the model and which operations on the GPU are wasting the most time.
- However, for reducing memory consumption we break down the process-level GPU memory consumption into smaller parts.
Let’s get started with an understanding of the components of memory consumption and how to measure it. Most optimization/training projects can split their GPU memory usage into:
- Data tensors: batched data for input, and eventual output — image tensors, token tensors, embeddings, etc.
- Model parameters: these are the weights (and biases) of the model layers which are used to compute the forward pass, and are passed to the optimizer to be updated based on computed loss.
- Layer activations: during optimization/training, layerwise intermediate output tensors are retained during the forward pass in order to compute gradients in the backward pass.
- Miscellaneous: CUDA kernel-related allocations, optimizer housekeeping data, normalization layer statistics, etc.
Tensor memory consumption is easily calculated from dimensions and datatype, but a simple peak memory consumption measurement (torch.cuda.max_memory_allocated()
) at various points in the execution provides a good on-device ground truth. For example:
A process level measurement like torch.cuda.max_memory_allocated()
is often good enough to understand your memory breakdown and to track changes, but sometimes we want to understand usage at an individual tensor level. For a long time I have been using a modified version of gpu_mem_track.py to trace allocation and deallocation of GPU-resident tensors as a delta between two points in the code. Check it out if you need to dig deeper.
Enough preamble — let’s get into some ways to reduce memory usage:
Don’t Forget Inference-Mode #
This is a simple gotcha we should get out of the way quickly — if you’re doing inference rather than training/optimization, don’t forget to enable torch.inference_mode()
.
If you’re not in inference mode during the forward pass PyTorch will record layer activations to enable gradient calculation during a possible backward pass — this means allocating tensors that won’t be used and wasting GPU memory.
Using it is simple:
with torch.inference_mode():
output = model(input)
You should see significant peak memory reductions — here’s an example using a classic image model, but this will apply to all models:
Bothinference_mode
andno_grad
remove the activation allocations, for the difference see Autograd mechanics.
Use Mixed Precision (Quantization) #
By far the easiest way to make substantial improvements to your memory footprint is with mixed precision. This works just as well for training as for inference. PyTorch has excellent native support for this via the torch.autocast
context manager, turning default float32 computation and tensor storage into float16 or bfloat16.
for input in dataloader:
with torch.autocast("cuda"):
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
One subtlety is that torch.autocast
does not halve memory consumption as you might expect when converting float32 to float16 — this is because network parameters are kept in full precision, and only computation and the resulting activations happen at half precision. One consequence of this is that larger models with small input/output batches benefit much less from the memory-saving benefits of mixed-precision mode (though they still benefit from higher half precision throughput).
It is possible to exactly halve memory consumption by brute-forcing the cast to half precision rather than relying ontorch.autocast
. This will bypass the protections engineered into thetorch.autocast
/GradScaler system, so gradient underflow or overflow may become a problem during optimization.
model = model.to(dtype=torch.float16) # cast network parameters to float16
for input in dataloader:
output = model(input.to(dtype=torch.float16))
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
The chart shows an optimization done at float32, float16 with autocast, and brute-forced float16:
Use Activation Checkpointing #
In order to compute gradients (which happens when we call loss.backward()
) PyTorch retains node activations computed in the forward pass, and these stored activations can consume a lot of memory. During inference we use torch.inference_mode()
to entirely disable this activation storage, but even during training/optimization we can selectively turn off activation storage to reduce peak memory requirements: this is activation checkpointing.
Activation checkpointing is a straight-forward compute vs memory trade-off — in the forward pass we execute some parts of the model without saving activations, and then when we call loss.backward()
the missing activations will be automatically recalculated by repeating the required parts of the forward pass.
A More Convenient API 🐒 #
The PyTorch checkpoint API gives us the basic functionality but doesn’t include a code pattern to use checkpointing effectively with existing models, and also doesn’t advise how best to use it — on which layers or blocks.
The basic PyTorch API looks like this:
from torch.utils import checkpoint as ckpt
# where module is part of a model:
result = module(*args, **kwargs) # regular invocation with default activation caching
result = ckpt.checkpoint(module, *args, **kwargs) # checkpointed invocation
In order to use this API in your model you would have to modify the model source code and therefore hard-code the checkpointing behavior. Much more useful would be a non-invasive API whereby you could seamlessly monkey-patch the model, and checkpoint subgraphs easily.
This is how I do it — with a tiny CheckpointModule
class and a recursive function:
model = torch.hub.load(...)
model = ckpt_monkey(model, re.compile('layer1$')) # checkpoint all modules named "layer1"
import torch
class CheckpointModule(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return torch.utils.checkpoint.checkpoint(self.module, *args, **kwargs)
def ckpt_monkey(module, pattern_re, ancestor_vn=''):
for vn, v in module._modules.items():
full_vn = f'{ancestor_vn}.{vn}' if ancestor_vn else vn
v = ckpt_monkey(v, pattern_re, full_vn)
if pattern_re.match(full_vn):
print('monkey-patching', full_vn)
setattr(module, vn, CheckpointModule(v))
else:
setattr(module, vn, v)
return module
Withckpt_monkey
🐒 andCheckpointModule
there’s no need to rewrite model code, because the selected modules or layers are wrapped with checkpointing logic non-invasively.
When and Where to Use Activation Checkpointing #
You could checkpoint your entire model or you could checkpoint all individual layers of a specific type, but the ideal strategy is usually somewhere in between and will be model-specific. Because checkpointing incurs an additional compute cost you’ll want to do as little as possible while fitting into your memory budget. However, not all checkpoints have the same ratio of compute to memory trade-off — it’s often possible to save a lot of memory with a small compute cost, just as some checkpoints will incur a large compute cost for small memory savings.
The most practical way to decide which modules to checkpoint is to simply measure the effect at various levels and locations, and then choose a set of checkpointed modules so that you fit within your memory budget but incur the smallest throughput impact. One key guideline is that separate checkpoints are independent and cumulative with respect to memory savings and additional compute incurred — so you can keep adding checkpoints with a good memory-to-compute trade-off until you get within your memory budget.
Let’s see a couple of examples:
ResNet-152 #
- ResNet-152 as implemented in PyTorch hub has four top-level layers consisting of varying numbers of convolution, batchnorm and ReLU operations — the points in the scatter plot below correspond to subsets of these top-level layers being checkpointed.
- The
mem-ckpt
(point color) is the regexp pattern applied when selecting checkpoints withckpt_monkey
, i.e.layer(1|2)
means layer1 and layer2 were checkpointed,layer(1|2|3|4)
is essentially the whole model checkpointed.
- The best checkpoints in this visualization are down and to the right —
layer1
,layer2
both have low throughput cost and give significant memory savings. Note that in checkpointing both these layers we can get a ~34% memory saving at the cost of only 10% lower throughput — not bad. - Not all checkpoints are equal:
layer3
andlayer4
don’t give a good trade-off, and aren’t worthwhile.
BERT #
- This version of BERT (
bert-base-uncased
hosted on Hugging Face) has 12 sequential attention + intermediate + output layers — each consisting of linear, layernorm and GELU operations.
- All 12 layers give the same memory/throughput trade-off — which makes sense as they have the same number of parameters and operations. (I’ve only included a data point for the first and last layer for simplicity, as they all superimpose near perfectly).
- Checkpointing two, three, or more layers show predictable, cumulative effects — the chart contains points for zero, one, two, three and twelve layers checkpointed.
Optimizer Choice — Use bitsandbytes
#
In practice we cannot swap out optimizers based on simple measures of memory consumption or batch throughput — different optimization algorithms have wildly different rates of convergence which will dwarf the effect of low-level performance differences in reaching a desired network state. This matches one of the long-known truths of optimization work, that high-level algorithm improvements are far more impactful than low-level tweaks.
However trade-offs do exist — a given optimization algorithm can be implemented in different ways and at different precisions. The bitsandbytes
library by Tim Dettmers is a great example of this. The library provides 8-bit versions of many algorithms available in torch.optim
, often with support for paging state tensors between host and GPU memory as required.
Let’s look at the effect of different optimization algorithms and the bitsandbytes
8-bit implementations, using PyTorch Adam as the reference point against which percentage-changes are calculated:
ResNet-152 #
- Some of the more retro algorithms (SGD, RMSprop, Adagrad) implemented in
torch.optim
consume less memory and give higher throughput than Adam. This is hard to care about since Adam converges much faster, and is included as a historical curiosity. - The
bitsandbytes
8-bit implementations give substantial memory savings at the cost of non-trivial throughput costs, at least for this model. - Paged optimizer implementations save even more memory at a minor throughput cost.
bitsandbytes
claims ~75% memory savings instead of the 7-10% shown here. This is what you’d expect going from 32-bit to 8-bit precision, and I’ve empirically verified their 75% is correct. They consider memory consumed by the optimizer itself, whereas these charts show peak training-loop memory in the context of these specific models.
BERT #
- For BERT, the
bitsandbytes
8-bit optimizers improve both memory footprint and throughput. We would need to trace the workload to understand exactly why, but I’ll take it.
Use set_to_none=True
#
If you’re still using PyTorch 1.X you can reduce your peak optimization memory by using the set_to_none=True
option when resetting gradients. PyTorch 2.0 changes the default to True
so if you’ve already upgraded you’re probably already getting this benefit.
In detail — in your training loop you will reset gradients after doing loss.backward()
and optimizer.step()
. Either optimizer.zero_grad()
or model.zero_grad()
works, resetting all parameters sent to the optimizer or all parameters in the model respectively.
In order to save memory by resetting gradient tensors to None instead of updating them to dense 0.0-valued tensors, simply call optimizer.zero_grad(set_to_none=True)
or model.zero_grad(set_to_none=True)
.
Final thoughts #
As usual when writing a higher-level overview article backed by empirical data, the model-specific results (and differences) are interesting but also raise deeper questions as to why. In particular, the differing throughput vs memory trade-offs incurred when checkpointing different subgraphs in different models is something I’ll dig into in a future article.
We can distill all the charts and explanations above into some fairly simple recommendations for reducing PyTorch GPU memory usage without making architectural changes:
- Always use mixed precision always unless you find a special case where it causes problems, and even then you can usually determine which subgraph is causing problems and turn off autocast in that specific section.
- During optimization/training you generally shouldn’t brute-force half precision on your models unless you are sure you don’t need the over/underflow protection of the PyTorch autocast/grad-scaler functionality.
- During inference brute-forcing half precision is a perfectly acceptable alternative to autocast, but make sure you test the output.
- Unless doing optimization/training, make sure to use inference mode to save a huge amount of activation memory.
- Prefer the
bitsandbytes
8-bit optimizers over thetorch.optim
implementations, but test convergence. - Use set_to_none=True when resetting gradients at the end of your training loop. If using PyTorch >= 2.0 this is already the default.
- If you need to save even more memory, use activation checkpointing after empirically finding the most cost-effective subgraphs on a memory-saved per throughput-lost basis.
Thanks for reading, and I hope you found this useful!