A Simple and Flexible Pytorch Video Pipeline
September 23, 2020
Intro #
Taking machine learning models into production for video analytics doesn’t have to be hard. A pipeline with reasonable efficiency can be created very quickly just by plugging together the right libraries. In this post we’ll create a video pipeline with a focus on flexibility and simplicity using two main libraries: Gstreamer and Pytorch.
Performance is out of scope for this first step, but we’ll do a deep dive in a later article. Worthwhile throughput improvements are possible with a little effort. We’ll also ignore black-box serving toolkits (Nvidia Triton/TensorRT, Kubeflow, TorchServe etc.) so we can understand what’s happening end to end.
The main library we’ll be using is Gstreamer, a very flexible and efficient media-processing pipeline that comes with a huge ecosystem of components. In principle these components can be seamlessly swapped out to support different codecs, transformations and outputs but in practice constructing a Gstreamer pipeline can be a tricky process with a lot of iteration.
I’ve created a repo with some example code here: https://github.com/pbridger/pytorch-video-pipeline Besides the code, the repo contains a Dockerfile and top-level Makefile to make running the scripts easy.
Step 0: Baseline CLI Gstreamer pipeline #
In order to show the basic Gstreamer pipeline components and to validate the container environment, we can run something like this from the CLI:
$ gst-launch-1.0 filesrc location=media/in.mp4 ! decodebin ! progressreport update-freq=1 ! fakesink sync=true
Running this will show the video file being read (by the filesrc
element), decoded (decodebin
element) and sent to the Gstreamer equivalent of /dev/null (fakesink
element).
If you don’t have Gstreamer installed, the easiest way to do this is to use the makefile from the repo. Grab the repo from github, then use this make target:
$ make cli.pipeline.png
Alternatively, start the Docker container using the makefile and run the above gst-launch-1.0 command from within:
$ make run-container
...
/app# gst-launch-1.0 filesrc location=media/in.mp4 ! decodebin ! progressreport update-freq=1 ! fakesink sync=true
Output should look something like this:
Output: CLI Gstreamer
Gstreamer is able to generate a representation showing the transformations in the pipeline, see below:
Step 1: Get frames into Python #
Since we want to feed these frames into a Pytorch model running in the Python runtime we’ll construct a similar pipeline from a script:
import os, sys
import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst
frame_format = 'RGBA'
Gst.init()
pipeline = Gst.parse_launch(f'''
filesrc location=media/in.mp4 num-buffers=200 !
decodebin !
fakesink name=s
''')
def on_frame_probe(pad, info):
buf = info.get_buffer()
print(f'[{buf.pts / Gst.SECOND:6.2f}]')
return Gst.PadProbeReturn.OK
pipeline.get_by_name('s').get_static_pad('sink').add_probe(
Gst.PadProbeType.BUFFER,
on_frame_probe
)
pipeline.set_state(Gst.State.PLAYING)
try:
while True:
msg = pipeline.get_bus().timed_pop_filtered(
Gst.SECOND,
Gst.MessageType.EOS | Gst.MessageType.ERROR
)
if msg:
text = msg.get_structure().to_string() if msg.get_structure() else ''
msg_type = Gst.message_type_get_name(msg.type)
print(f'{msg.src.name}: [{msg_type}] {text}')
break
finally:
open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write(
Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL)
)
pipeline.set_state(Gst.State.NULL)
The above code runs the same filesrc-decode pipeline, monitoring the pipeline for errors and end of stream (EOS) messages, and installs a probe callback (on_frame_probe
) which will be called for every frame processed. This is about as simple as I could make it. With this code we have video frames/buffers available within the Python callback as Gstreamer buffers.
To run this using the makefile:
$ make frames_into_python.pipeline.png
As you can see from the Gst.parse_launch
call the constructed pipeline is even simpler than the CLI version since we don’t bother with the progressreport
element. Also, because we removed the sync=true
parameter from the fakesink element the pipeline runs as fast as the slowest pipeline element instead of synchronizing with the clock:
Output: frames_into_python.py
Step 2: Get frames into Pytorch #
Now we’ll include logic to add two things to the on_frame_probe
callback:
- Reinterpret and copy the decoded Gstreamer buffer into Pytorch tensor.
- Run some basic object detection on the image tensor using Nvidia’s SSD300.
In the interests of keeping the code short and simple this sample has some deliberate limitations:
- Incomplete image preprocessing.
- No inference post-processing, so we don’t even get bounding boxes to print.
- No emphasis whatsoever on performance except for running on CUDA/GPU if available.
import os, sys
import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst
import numpy as np
import torch, torchvision
frame_format, pixel_bytes = 'RGBA', 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math='fp32').eval().to(device)
preprocess = torchvision.transforms.ToTensor()
Gst.init()
pipeline = Gst.parse_launch(f'''
filesrc location=media/in.mp4 num-buffers=200 !
decodebin !
nvvideoconvert !
video/x-raw,format={frame_format} !
fakesink name=s
''')
def on_frame_probe(pad, info):
buf = info.get_buffer()
print(f'[{buf.pts / Gst.SECOND:6.2f}]')
image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps())
image_batch = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
detections = detector(image_batch)[0]
return Gst.PadProbeReturn.OK
def buffer_to_image_tensor(buf, caps):
caps_structure = caps.get_structure(0)
height, width = caps_structure.get_value('height'), caps_structure.get_value('width')
is_mapped, map_info = buf.map(Gst.MapFlags.READ)
if is_mapped:
try:
image_array = np.ndarray(
(height, width, pixel_bytes),
dtype=np.uint8,
buffer=map_info.data
).copy() # extend array lifetime beyond subsequent unmap
return preprocess(image_array[:,:,:3]) # RGBA -> RGB
finally:
buf.unmap(map_info)
pipeline.get_by_name('s').get_static_pad('sink').add_probe(
Gst.PadProbeType.BUFFER,
on_frame_probe
)
pipeline.set_state(Gst.State.PLAYING)
try:
while True:
msg = pipeline.get_bus().timed_pop_filtered(
Gst.SECOND,
Gst.MessageType.EOS | Gst.MessageType.ERROR
)
if msg:
text = msg.get_structure().to_string() if msg.get_structure() else ''
msg_type = Gst.message_type_get_name(msg.type)
print(f'{msg.src.name}: [{msg_type}] {text}')
break
finally:
open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write(
Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL)
)
pipeline.set_state(Gst.State.NULL)
Some key points:
- In
buffer_to_image_tensor
we create a read-only mapping on the frame buffer memory then create a numpy array that points to the mapped memory. - We need to rearrange the image dimensions since Gstreamer has decoded to (height, width, channel) and this Pytorch model wants (channel, height, width).
- In the pipeline below note the caps change from
video/x-raw(memory:NVVM)
tovideo/x-raw
due to the newly addednvvideoconvert
element. This element is transferring the decoded video buffer memory from GPU to CPU. The container has Gstreamer elements that support hardware video decoding so thedecodebin
element will use any compatible GPU to accelerate this. To make the decoded frames accessible for Gstreamer to map and read into Numpy we explicitly move them to host memory.
Output: frames_into_pytorch.py
Conclusion #
The above samples are nowhere near production ready but they show that the fundamentals of running machine learning inference on video don’t have to be hard. We hooked up a pipeline that can seamlessly process many different video encodings and formats. Further, we are using the regular “research” Pytorch runtime which gives us a lot of flexibility.
To give you a quick feel for performance: this pipeline runs at around 100 FPS on a 2080Ti at <80% utilization. Caveats: we are not doing full pre-processing or post-processing, but on the other hand this pipeline is completely unoptimized. In my more advanced tuning post I add realistic preprocessing and postprocessing and make the performance awesome.