Fast Whisper inference using dynamic batching

In this example, we demonstrate how to run dynamically batched inference for OpenAI’s speech recognition model, Whisper, on Modal. Batching multiple audio samples together or batching chunks of a single audio sample can help to achieve a 2.8x increase in inference throughput on an A10G!

We will be running the Whisper Large V3 model. To run any of the other HuggingFace Whisper models, simply replace the MODEL_NAME and MODEL_REVISION variables.

Setup

Let’s start by importing the Modal client and defining the model that we want to serve.

import modal

MODEL_DIR = "/model"
MODEL_NAME = "openai/whisper-large-v3"
MODEL_REVISION = "afda370583db9c5359511ed5d989400a6199dfe1"

Define a container image

We’ll start with Modal’s baseline debian_slim image and install the relevant libraries.

image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install(
        "torch==2.5.1",
        "transformers==4.47.1",
        "hf-transfer==0.1.8",
        "huggingface_hub==0.27.0",
        "librosa==0.10.2",
        "soundfile==0.12.1",
        "accelerate==1.2.1",
        "datasets==3.2.0",
    )
    # Use the barebones `hf-transfer` package for maximum download speeds. No progress bar, but expect 700MB/s.
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": MODEL_DIR})
)

model_cache = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)
app = modal.App(
    "example-whisper-batched-inference",
    image=image,
    volumes={MODEL_DIR: model_cache},
)

Caching the model weights

We’ll define a function to download the model and cache it in a volume. You can modal run against this function prior to deploying the App.

@app.function()
def download_model():
    from huggingface_hub import snapshot_download
    from transformers.utils import move_cache

    snapshot_download(
        MODEL_NAME,
        ignore_patterns=["*.pt", "*.bin"],  # Using safetensors
        revision=MODEL_REVISION,
    )
    move_cache()

The model class

The inference function is best represented using Modal’s class syntax.

We define a @modal.enter method to load the model when the container starts, before it picks up any inputs. The weights will be loaded from the Hugging Face cache volume so that we don’t need to download them when we start a new container.

We also define a transcribe method that uses the @modal.batched decorator to enable dynamic batching. This allows us to invoke the function with individual audio samples, and the function will automatically batch them together before running inference. Batching is critical for making good use of the GPU, since GPUs are designed for running parallel operations at high throughput.

The max_batch_size parameter limits the maximum number of audio samples combined into a single batch. We used a max_batch_size of 64, the largest power-of-2 batch size that can be accommodated by the 24 A10G GPU memory. This number will vary depending on the model and the GPU you are using.

The wait_ms parameter sets the maximum time to wait for more inputs before running the batched transcription. To tune this parameter, you can set it to the target latency of your application minus the execution time of an inference batch. This allows the latency of any request to stay within your target latency.

@app.cls(
    gpu="a10g",  # Try using an A100 or H100 if you've got a large model or need big batches!
    concurrency_limit=10,  # default max GPUs for Modal's free tier
)
class Model:
    @modal.enter()
    def load_model(self):
        import torch
        from transformers import (
            AutoModelForSpeechSeq2Seq,
            AutoProcessor,
            pipeline,
        )

        self.processor = AutoProcessor.from_pretrained(MODEL_NAME)
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            use_safetensors=True,
        ).to("cuda")

        self.model.generation_config.language = "<|en|>"

        # Create a pipeline for preprocessing and transcribing speech data
        self.pipeline = pipeline(
            "automatic-speech-recognition",
            model=self.model,
            tokenizer=self.processor.tokenizer,
            feature_extractor=self.processor.feature_extractor,
            torch_dtype=torch.float16,
            device="cuda",
        )

    @modal.batched(max_batch_size=64, wait_ms=1000)
    def transcribe(self, audio_samples):
        import time

        start = time.monotonic_ns()
        print(f"Transcribing {len(audio_samples)} audio samples")
        transcriptions = self.pipeline(
            audio_samples, batch_size=len(audio_samples)
        )
        end = time.monotonic_ns()
        print(
            f"Transcribed {len(audio_samples)} samples in {round((end - start) / 1e9, 2)}s"
        )
        return transcriptions

Transcribe a dataset

In this example, we use the librispeech_asr_dummy dataset from Hugging Face’s Datasets library to test the model.

We use map.aio to asynchronously map over the audio files. This allows us to invoke the batched transcription method on each audio sample in parallel.

@app.function()
async def transcribe_hf_dataset(dataset_name):
    from datasets import load_dataset

    print("📂 Loading dataset", dataset_name)
    ds = load_dataset(dataset_name, "clean", split="validation")
    print("📂 Dataset loaded")
    batched_whisper = Model()
    print("📣 Sending data for transcripton")
    async for transcription in batched_whisper.transcribe.map.aio(ds["audio"]):
        yield transcription

Run the model

We define a local_entrypoint to run the transcription. You can run this locally with modal run batched_whisper.py.

@app.local_entrypoint()
async def main(dataset_name: str = None):
    if dataset_name is None:
        dataset_name = "hf-internal-testing/librispeech_asr_dummy"
    for result in transcribe_hf_dataset.remote_gen(dataset_name):
        print(result["text"])