Mastering TorchScript: Tracing vs Scripting, Device Pinning, Direct Graph Modification
October 29, 2020
Intro #
TorchScript is one of the most important parts of the Pytorch ecosystem, allowing portable, efficient and nearly seamless deployment. With just a few lines of torch.jit
code and some simple model changes you can export an asset that runs anywhere libtorch
does. It’s an important toolset to master if you want to run your models outside the lab at high efficiency.
Good introductory material is already available for starting to work with TorchScript including execution in the C++ libtorch
runtime, and reference material is also provided. This article is a collection of topics going beyond the basics of your first export.
Tracing vs Scripting #
Pytorch provides two methods for generating TorchScript from your model code — tracing and scripting — but which should you use? Let’s recap how they work:
-
Tracing. When using
torch.jit.trace
you’ll provide your model and sample input as arguments. The input will be fed through the model as in regular inference and the executed operations will be traced and recorded into TorchScript. Logical structure will be frozen into the path taken during this sample execution. -
Scripting. When using
torch.jit.script
you’ll simply provide your model as an argument. TorchScript will be generated from the static inspection of thenn.Module
contents (recursively).
It’s not obvious from the tutorial documentation, but choosing which method to use is a fairly simple and fluid choice:
Use Scripting by Default #
Because torch.jit.script
captures both the operations and full conditional logic of your model, it’s a great place to start. If your model doesn’t need any unsupported Pytorch functionality and has logic restricted to the supported subset of Python functions and syntax, then torch.jit.script
should be all you need.
One major advantage of scripting over tracing is that an export is likely to either fail for a well-defined reason — implying a clear code modification — or succeed without warnings.
Unlike Python, TorchScript is Statically Typed
You will need to be consistent about container element datatypes, and be wary of implicit function signatures. A useful practice is to use type hints in method signatures.
Despite TorchScript’s ability to capture conditional logic it does not allow you to run arbitrary Python within libtorch
— a popular misconception.
Use Tracing if You Must #
There are a few special cases in which torch.jit.trace
may be useful:
- If you are unable to modify the model code — because you do not have access or ownership — you may find scripting the model simply will not work because it uses unsupported Pytorch/Python functionality.
- In pursuit of performance or to bake in architectural decisions the logic freezing behavior of tracing might be preferable — similar to inlining C/C++ code.
Pay Close Attention to Tracer Warnings
Due to how tracing can simplify model behavior, each warning should be fully understood and only then ignored (or fixed). Also, be sure to trace in eval mode if you are exporting a model for production inference!
Use Both Together #
Scripted and traced code can be freely mixed, and this is often a great choice. See the existing pytorch.org documentation for details and examples.
Device Pinning #
If you find yourself using torch.jit.trace
on some code, you’ll have to actively deal with some of the gotchas or face performance and portability consequences. Besides addressing any warnings Pytorch emits, you’ll also need to keep an eye out for device pinning. Just like torch.jit.trace
records and freezes conditional logic, it will also trace and make constant the values resulting from this logic — this can include device constants.
Using this sample code:
def forward(X):
return torch.arange(X.size(0))
If we trace while executing on CPU or GPU we get this TorchScript (scroll to the right on mobile):
def forward(X: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(X, 0))
_1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _1
You can see that torch.device("cpu")
has been inserted as a constant into the generated TorchScript. If we try to get clever with this code:
def forward(X):
return torch.arange(X.size(0), device=X.device)
Tracing will now result in TorchScript that is pinned to the tracing device. When traced on GPU, we see this:
def forward(self,
X: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(X, 0))
_1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cuda:0"), pin_memory=False)
return _1
Tensors Created During Tracing Will Have Their Device Pinned
This can be a significant performance and portability problem.
Performance and Portability #
If we later deserialize and run this TorchScript in libtorch
the arange
tensor will always be created on the device that is pinned — torch.device("cpu")
or torch.device("cuda:0")
in the examples above. If the rest of the model is running on a different device this can result in costly memory transfers and synchronization.
This device pinning issue extends to multi-GPU scenarios as well. If you have traced and exported a model on cuda:0
and then run it on cuda:1
you’ll see transfers and synchronization between the devices. Not good. Perhaps even worse, if such a model is run in an environment without any CUDA-capable device it will fail since cuda:0
doesn’t exist.
Replace Tensors Created During Execution With Parameters
Tensors created in the execution path while tracing will have their device pinned. Depending on model logic, these can often be turned into Parameters created during construction.
An example of the problem looks like this in Nsight Systems:
Tensor Subscript Mask and Indexing Will Pin Devices #
Unlike their more explicit counterparts (masked_select
and index_select
), using tensor subscripting will pin the mask or indexes to the tracing device:
def forward(X):
return X[X > 1]
Generates this TorchScript:
def forward(X: Tensor) -> Tensor:
_0 = torch.to(torch.gt(X, 1), dtype=11, layout=0, device=torch.device("cpu"), pin_memory=False, non_blocking=False, copy=False, memory_format=None)
_1 = annotate(List[Optional[Tensor]], [_0])
return torch.index(X, _1)
Whereas:
def forward(X):
return X.masked_select(X > 1)
Generates this TorchScript:
def forward(X: Tensor) -> Tensor:
_0 = torch.masked_select(X, torch.gt(X, 1))
return _0
The same pattern holds for tensor[indexes]
and tensor.index_select(0, indexes)
. This device pinning carries the same performance and portability risks as noted above.
Replace Tensor Subscripting With
masked_select
andindex_select
Subscript-based masking and indexing will always pin the tracing device into generated TorchScript. :(
Direct Graph Modification #
Once we’ve used torch.jit.script
or torch.jit.trace
to generate a ScriptModule or ScriptFunction we can use .graph
, .inlined_graph
or .code
to understand exactly what TorchScript has been generated. Though it has an entirely undocumented interface it is possible (and fun) to access and modify the generated TorchScript AST directly via the .graph
method.
The most useful parts of the API are defined in torch/csrc/jit/python/python_ir.cpp. As you can see, all the basic functionality is present for finding and changing the graph nodes you want. If you change nodes or arguments and then persist the module your subsequent TorchScript load and inference will reflect your changes, though modules cannot be changed recursively in this way (torch.jit.freeze
can be useful here).
An example of the kind of graph modification that is possible:
def undevice(tsc):
# use ::to variant which does not hardcode device
for to_node in tsc.graph.findAllNodes('aten::to'):
i, dtype, layout, device, pin_mem, non_blocking, copy, mem_format = list(to_node.inputs())
to_node.removeAllInputs()
for a in [i, dtype, non_blocking, copy, mem_format]:
to_node.addInput(a)
for constant in tsc.graph.findAllNodes('prim::Constant'):
if not constant.hasUses():
constant.destroy()
The above code will modify a traced graph, changing aten::to
to use an overload which doesn’t change memory location.
But what is this really useful for? As an undocumented API you’d be unwise to use this capability in a production pipeline unless you like maintenance coding. I would only recommend it for research, as in the above example which I used to understand and profile the transfer/synchronization behavior of tensor subscripting.
Don’t Bother With Direct Graph Modification
For legitimate production use-cases you can almost always find a way to modify your model code to generate the TorchScript you want.
Rewrite for ONNX/TensorRT Export #
You can get some awesome results with TensorRT but exporting a model from Pytorch to TensorRT is far from a sure thing. The export path to ONNX and then to TensorRT can fail due to missing or incompatible operations at either step and this can be frustrating.
After the obligatory Google search, I’ve found a reasonable hail-mary approach is to rewrite your tensor processing code to avoid unsupported operators. I can’t give general advice for this but let me show you an example of how this can be possible: repeat_interleave
.
class RI(torch.nn.Module):
def forward(self, X, repeat):
return X.repeat_interleave(repeat, dim=0)
inputs = (torch.arange(5), torch.tensor(3))
torch.onnx.export(RI(), inputs, 'please_work.onnx', opset_version=11)
Doesn’t work:
RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 11 is not supported. Please open a bug to request ONNX export support for the missing operator.
However, the behavior of repeat_interleave
with a fixed dim
argument can be replicated in a form that will export to ONNX (but good luck passing code review):
class RW(torch.nn.Module):
def forward(self, X, repeat):
X = X.reshape(1, *X.size()).expand(repeat, *X.size())
return torch.cat(torch.unbind(X, dim=1))
N.B. The above code is only equivalent to repeat_interleave(X, dim=0)
though it can be adapted for any fixed dim.
The same approach can be taken to work around incomplete support in TensorRT, which is far more prevalent in my experience.
Conclusion #
Efficient and portable Pytorch production deployment used to be almost impossible, but the introduction and continued evolution of TorchScript has been great for the ecosystem. There are still a few rough edges and tricks, and I hope you’ve found something new or useful in the topics above.
If there is some major concern or problem you’re having with TorchScript or Pytorch production deployment please get in touch — I’m always looking for new areas to research.