Parakeet Multi-talker Speech-to-Text
This example shows how to run a streaming multi-talker speech-to-text service using NVIDIA’s Parakeet Multi-talker model and Sortformer diarization model. The application transcribes audio in real-time while identifying different speakers without the need to register unique speakers in advance.
Try it yourself! Click the “View on GitHub” button to see the code. And sign up for a Modal account if you haven’t already.
Setup
We start by importing the necessary dependencies and defining the Modal App and Image. We use a persistent Volume to cache the models.
import asyncio
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import modal
app = modal.App("parakeet-multitalker")
model_cache = modal.Volume.from_name("parakeet-model-cache", create_if_missing=True)
CACHE_PATH = "/cache"
hf_secret = modal.Secret.from_name("huggingface-secret")
image = (
modal.Image.from_registry(
"nvidia/cuda:13.0.1-cudnn-devel-ubuntu22.04", add_python="3.12"
)
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1",
"HF_HOME": CACHE_PATH, # cache directory for Hugging Face models
"CXX": "g++",
"CC": "g++",
"TORCH_HOME": CACHE_PATH,
}
)
.apt_install("git", "libsndfile1", "ffmpeg")
.uv_pip_install(
"hf_transfer==0.1.9",
"huggingface_hub[hf-xet]==0.31.2",
"cuda-python==13.0.1",
"numpy<2",
"fastapi",
"nemo_toolkit[asr]@git+https://github.com/NVIDIA/NeMo.git@main",
)
)
SAMPLE_RATE = 16000
NUM_REQUIRED_BUFFER_FRAMES = 13
BYTES_PER_SAMPLE = 2
FRAME_LEN_SEC = 0.080
PARAKEET_RT_STREAMING_CHUNK_SIZE = (
int(FRAME_LEN_SEC * SAMPLE_RATE) * BYTES_PER_SAMPLE * NUM_REQUIRED_BUFFER_FRAMES
)
def chunk_audio(data: bytes, chunk_size: int):
for i in range(0, len(data), chunk_size):
yield data[i : i + chunk_size]Configuration
This dataclass holds all the configuration parameters for the transcription and diarization models.
@dataclass
class MultitalkerTranscriptionConfig:
"""
Configuration for Multi-talker transcription with an ASR model and a diarization model.
"""
# Required configs
diar_model: Optional[str] = None # Path to a .nemo file
diar_pretrained_name: Optional[str] = None # Name of a pretrained model
max_num_of_spks: Optional[int] = 4 # maximum number of speakers
parallel_speaker_strategy: bool = True # whether to use parallel speaker strategy
masked_asr: bool = True # whether to use masked ASR
mask_preencode: bool = False # whether to mask preencode or mask features
cache_gating: bool = True # whether to use cache gating
cache_gating_buffer_size: int = 2 # buffer size for cache gating
single_speaker_mode: bool = False # whether to use single speaker mode
# General configs
session_len_sec: float = -1 # End-to-end diarization session length in seconds
num_workers: int = 8
random_seed: Optional[int] = (
None # seed number going to be used in seed_everything()
)
log: bool = True # If True,log will be printed
# Streaming diarization configs
streaming_mode: bool = True # If True, streaming diarization will be used.
spkcache_len: int = 188
spkcache_refresh_rate: int = 0
fifo_len: int = 188
chunk_len: int = 0
chunk_left_context: int = 1
chunk_right_context: int = 0
# If `cuda` is a negative number, inference will be on CPU only.
cuda: Optional[int] = None
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
matmul_precision: str = "highest" # Literal["highest", "high", "medium"]
# ASR Configs
asr_model: Optional[str] = None
device: str = "cuda"
audio_file: Optional[str] = None
manifest_file: Optional[str] = None
use_amp: bool = True
debug_mode: bool = True
batch_size: int = 32
chunk_size: int = -1
shift_size: int = -1
left_chunks: int = 2
online_normalization: bool = True
output_path: Optional[str] = None
pad_and_drop_preencoded: bool = False
set_decoder: Optional[str] = None # ["ctc", "rnnt"]
att_context_size: Optional[List[int]] = field(default_factory=lambda: [70, 13])
generate_realtime_scripts: bool = True
word_window: int = 50
sent_break_sec: float = 30.0
fix_prev_words_count: int = 5
update_prev_words_sentence: int = 5
left_frame_shift: int = -1
right_frame_shift: int = 0
min_sigmoid_val: float = 1e-2
discarded_frames: int = 8
print_time: bool = True
print_sample_indices: List[int] = field(default_factory=lambda: [0])
colored_text: bool = True
real_time_mode: bool = True
print_path: Optional[str] = "./"
ignored_initial_frame_steps: int = 5
verbose: bool = True
feat_len_sec: float = 0.01
finetune_realtime_ratio: float = 0.01
spk_supervision: str = "diar" # ["diar", "rttm"]
binary_diar_preds: bool = False
deploy_mode: bool = True
@staticmethod
def init_diar_model(cfg, diar_model):
# Set streaming mode diar_model params (matching the diarization setup from lines 263-271 of reference file)
diar_model.streaming_mode = cfg.streaming_mode
diar_model.sortformer_modules.chunk_len = (
cfg.chunk_len if cfg.chunk_len > 0 else 6
)
diar_model.sortformer_modules.spkcache_len = cfg.spkcache_len
diar_model.sortformer_modules.chunk_left_context = cfg.chunk_left_context
diar_model.sortformer_modules.chunk_right_context = (
cfg.chunk_right_context if cfg.chunk_right_context > 0 else 7
)
diar_model.sortformer_modules.fifo_len = cfg.fifo_len
diar_model.sortformer_modules.log = cfg.log
diar_model.sortformer_modules.spkcache_refresh_rate = cfg.spkcache_refresh_rate
return diar_model
with image.imports():
import logging
from urllib.request import urlopen
import numpy as np
import torch
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from nemo.collections.asr.models import ASRModel, SortformerEncLabelModel
from nemo.collections.asr.parts.utils.multispk_transcribe_utils import (
SpeakerTaggedASR,
)
from omegaconf import OmegaConf
from starlette.websockets import WebSocketState
from .asr_utils import int2float, preprocess_audio
from .cache_aware_buffer import CacheAwareStreamingAudioBufferTranscriber Service
We define the main Transcriber class as a Modal Cls.
This class loads the models into GPU memory and handles the streaming inference.
For more on lifecycle management with Cls and cold start penalty reduction on Modal, see this guide. In particular, this model
is amenable to GPU snapshots which can significantly reduce cold start times.
We use a CacheAwareStreamingAudioBuffer to manage the audio stream.
This buffer handles the streaming input and output, ensuring that the model receives
the correct amount of audio data for each inference step.
WebSocket Handling
We use FastAPI’s WebSocket support to handle the audio stream. Incoming audio bytes are buffered and processed in chunks, and transcriptions are sent back to the client as they become available.
@app.cls(
volumes={CACHE_PATH: model_cache},
gpu=["A100"],
image=image,
secrets=[hf_secret] if hf_secret is not None else [],
)
class Transcriber:
@modal.enter()
# @modal.enter()
async def load(self):
# silence chatty logs from nemo
logging.getLogger("nemo_logger").setLevel(logging.CRITICAL)
self.diar_model = (
SortformerEncLabelModel.from_pretrained(
"nvidia/diar_streaming_sortformer_4spk-v2.1"
)
.eval()
.to(torch.device("cuda"))
)
self.asr_model = (
ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1")
.eval()
.to(torch.device("cuda"))
)
self.cfg = OmegaConf.structured(MultitalkerTranscriptionConfig())
self.diar_model = MultitalkerTranscriptionConfig.init_diar_model(
self.cfg, self.diar_model
)
self.multispk_asr_streamer = SpeakerTaggedASR(
self.cfg, self.asr_model, self.diar_model
)
self._chunk_size = PARAKEET_RT_STREAMING_CHUNK_SIZE
# warm up gpu
AUDIO_URL = "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/mono_44100/156550__acclivity__a-dream-within-a-dream.wav"
audio_bytes = urlopen(AUDIO_URL).read()
audio_bytes = preprocess_audio(AUDIO_URL, target_sample_rate=16000)
self.streaming_buffer = CacheAwareStreamingAudioBuffer(
model=self.asr_model,
online_normalization=self.cfg.online_normalization,
pad_and_drop_preencoded=self.cfg.pad_and_drop_preencoded,
)
self.streaming_buffer.reset_buffer()
step_num = 0
stream_id = -1
for audio_data in chunk_audio(audio_bytes, PARAKEET_RT_STREAMING_CHUNK_SIZE):
transcript, stream_id = await self.transcribe(
audio_data, step_num, stream_id
)
step_num += 1
stream_id = 0
print(f"transcript: {transcript}")
print(f"stream_id: {stream_id}")
self.streaming_buffer.reset_buffer()
self.web_app = FastAPI()
@self.web_app.websocket("/ws")
async def run_with_websocket(ws: WebSocket):
audio_queue = asyncio.Queue()
transcription_queue = asyncio.Queue()
self.streaming_buffer.reset_buffer()
async def recv_loop(ws, audio_queue):
audio_buffer = bytearray()
while True:
data = await ws.receive_bytes()
audio_buffer.extend(data)
if len(audio_buffer) > self._chunk_size:
print("sending audio data")
await audio_queue.put(audio_buffer)
audio_buffer = bytearray()
async def inference_loop(audio_queue, transcription_queue):
step_num = 0
stream_id = -1
while True:
audio_data = await audio_queue.get()
start_time = time.perf_counter()
print("transcribing audio data")
transcript, stream_id = await self.transcribe(
audio_data, step_num, stream_id
)
step_num += 1
stream_id = 0
print(f"transcript: {transcript}")
if transcript:
await transcription_queue.put(transcript)
end_time = time.perf_counter()
print(
f"time taken to transcribe audio segment: {end_time - start_time} seconds"
)
async def send_loop(transcription_queue, ws):
while True:
transcript = await transcription_queue.get()
print(f"sending transcription data: {transcript}")
await ws.send_text(transcript)
await ws.accept()
try:
tasks = [
asyncio.create_task(recv_loop(ws, audio_queue)),
asyncio.create_task(
inference_loop(audio_queue, transcription_queue)
),
asyncio.create_task(send_loop(transcription_queue, ws)),
]
await asyncio.gather(*tasks)
except WebSocketDisconnect:
print("WebSocket disconnected")
ws = None
except Exception as e:
print("Exception:", e)
finally:
if ws and ws.application_state is WebSocketState.CONNECTED:
await ws.close(code=1011) # internal error
ws = None
for task in tasks:
if not task.done():
try:
task.cancel()
await task
except asyncio.CancelledError:
pass
async def transcribe(self, audio_data, step_num, stream_id=-1) -> str:
print(f"transcribing audio data: {len(audio_data)} bytes")
drop_extra_pre_encoded = (
0
if step_num == 0 and not self.cfg.pad_and_drop_preencoded
else self.asr_model.encoder.streaming_cfg.drop_extra_pre_encoded
)
# convert to numpy
audio_data = int2float(np.frombuffer(audio_data, dtype=np.int16))
processed_signal, processed_signal_length, stream_id = (
self.streaming_buffer.append_audio(audio_data, stream_id=stream_id)
)
result = self.streaming_buffer.get_next_chunk()
if result is not None:
audio_chunk, chunk_lengths = result
else:
return None, stream_id
with torch.inference_mode():
with torch.amp.autocast(self.diar_model.device.type, enabled=True):
with torch.no_grad():
result = (
self.multispk_asr_streamer.perform_parallel_streaming_stt_spk(
step_num=step_num,
chunk_audio=audio_chunk,
chunk_lengths=chunk_lengths,
is_buffer_empty=False,
drop_extra_pre_encoded=drop_extra_pre_encoded,
)
)
if result:
return result[0], stream_id
return None, stream_id
@modal.asgi_app()
def webapp(self):
return self.web_app
@modal.method()
def ping(self):
return "pong"Frontend Service
We serve a simple HTML/JS frontend to interact with the transcriber. The frontend captures microphone input and streams it to the WebSocket endpoint.
web_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install("fastapi")
.add_local_dir(Path(__file__).parent / "multitalker-frontend", "/root/frontend")
)
with web_image.imports():
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse, Response
from fastapi.staticfiles import StaticFiles
@app.cls(image=web_image)
@modal.concurrent(max_inputs=20)
class WebServer:
@modal.asgi_app()
def web(self):
web_app = FastAPI()
web_app.mount("/static", StaticFiles(directory="frontend"))
@web_app.get("/status")
async def status():
return Response(status_code=200)
# serve frontend
@web_app.get("/")
async def index():
html_content = open("frontend/index.html").read()
# Get the base WebSocket URL (without transcriber parameters)
cls_instance = Transcriber()
ws_base_url = (
cls_instance.webapp.get_web_url().replace("http", "ws") + "/ws"
)
script_tag = f'<script>window.WS_BASE_URL = "{ws_base_url}"; window.TRANSCRIPTION_MODE = "replace";</script>'
html_content = html_content.replace(
'<script src="/static/parakeet.js"></script>',
f'{script_tag}\n<script src="/static/parakeet.js"></script>',
)
return HTMLResponse(content=html_content)
return web_app