Check out our new GPU Glossary! Read now

Run Stable Video Diffusion in a Streamlit app

This example runs the Stable Video Diffusion image-to-video model.

import os
import sys

import modal

app = modal.App(name="example-stable-video-diffusion-streamlit")
q = modal.Queue.from_name(
    "stable-video-diffusion-streamlit", create_if_missing=True
)

session_timeout = 15 * 60


def download_model():
    # Needed because all paths are relative :/
    os.chdir("/sgm")
    sys.path.append("/sgm")

    from huggingface_hub import snapshot_download
    from omegaconf import OmegaConf
    from scripts.demo.streamlit_helpers import load_model_from_config
    from scripts.demo.video_sampling import VERSION2SPECS

    snapshot_download(
        "stabilityai/stable-video-diffusion-img2vid",
        local_dir="checkpoints/",
        local_dir_use_symlinks=False,
    )

    spec = VERSION2SPECS["svd"]
    config = OmegaConf.load(spec["config"])
    load_model_from_config(config, spec["ckpt"])


svd_image = (
    # The generative-models repo hardcodes `tokenizers==0.12.1`, for which there is no
    # pre-built python 3.11 wheel.
    modal.Image.debian_slim(python_version="3.10")
    .apt_install("git")
    .run_commands(
        "git clone https://github.com/Stability-AI/generative-models.git /sgm"
    )
    .workdir("/sgm")
    .pip_install(".")
    .pip_install(
        "torch==2.0.1+cu118",
        "torchvision==0.15.2+cu118",
        "torchaudio==2.0.2+cu118",
        extra_index_url="https://download.pytorch.org/whl/cu118",
    )
    .run_commands("pip install -r requirements/pt2.txt")
    .apt_install("ffmpeg", "libsm6", "libxext6")  # for CV2
    .pip_install("safetensors")
    .run_function(download_model, gpu="any")
)


@app.function(image=svd_image, timeout=session_timeout, gpu="A100")
def run_streamlit(publish_url: bool = False):
    from streamlit.web.bootstrap import load_config_options, run

    # TODO: figure out better way to do this with streamlit.
    os.chdir("/sgm")
    sys.path.append("/sgm")

    # Run the server. This function will not return until the server is shut down.
    with modal.forward(8501) as tunnel:
        # Reload Streamlit config with information about Modal tunnel address.
        if publish_url:
            q.put(tunnel.url)
        load_config_options(
            {"browser.serverAddress": tunnel.host, "browser.serverPort": 443}
        )
        run(
            main_script_path="/sgm/scripts/demo/video_sampling.py",
            is_hello=False,
            args=["--timeout", str(session_timeout)],
            flag_options={},
        )


endpoint_image = modal.Image.debian_slim(python_version="3.10").pip_install(
    "fastapi[standard]==0.115.4",
    "pydantic==2.9.2",
    "starlette==0.41.2",
)


@app.function(image=endpoint_image)
@modal.web_endpoint(method="GET", label="svd")
def share():
    from fastapi.responses import RedirectResponse

    run_streamlit.spawn(publish_url=True)
    url = q.get()
    return RedirectResponse(url, status_code=303)