Fast Whisper inference using dynamic batching
In this example, we demonstrate how to run dynamically batched inference for OpenAI’s speech recognition model, Whisper, on Modal. Batching multiple audio samples together or batching chunks of a single audio sample can help to achieve a 2.8x increase in inference throughput on an A10G!
We will be running the Whisper Large V3 model.
To run any of the other HuggingFace Whisper models,
simply replace the MODEL_NAME and MODEL_REVISION variables.
Setup
Let’s start by importing the Modal client and defining the model that we want to serve.
Define a container image
We’ll start with Modal’s baseline debian_slim image and install the relevant libraries.
Caching the model weights
We’ll define a function to download the model and cache it in a volume.
You can modal run against this function prior to deploying the App.
The model class
The inference function is best represented using Modal’s class syntax.
We define a @modal.enter method to load the model when the container starts, before it picks up any inputs.
The weights will be loaded from the Hugging Face cache volume so that we don’t need to download them when
we start a new container. For more on storing model weights on Modal, see this guide.
We also define a transcribe method that uses the @modal.batched decorator to enable dynamic batching.
This allows us to invoke the function with individual audio samples, and the function will automatically batch them
together before running inference. Batching is critical for making good use of the GPU, since GPUs are designed
for running parallel operations at high throughput.
The max_batch_size parameter limits the maximum number of audio samples combined into a single batch.
We used a max_batch_size of 64, the largest power-of-2 batch size that can be accommodated by the 24 A10G GPU memory.
This number will vary depending on the model and the GPU you are using.
The wait_ms parameter sets the maximum time to wait for more inputs before running the batched transcription.
To tune this parameter, you can set it to the target latency of your application minus the execution time of an inference batch.
This allows the latency of any request to stay within your target latency.
Transcribe a dataset
In this example, we use the librispeech_asr_dummy dataset from Hugging Face’s Datasets library to test the model.
We use map.aio to asynchronously map over the audio files.
This allows us to invoke the batched transcription method on each audio sample in parallel.
Run the model
We define a local_entrypoint to run the transcription. You can run this locally with modal run batched_whisper.py.