Sandboxes now GA, run LLM-generated code at scale! Learn more

Transform images with SDXL Turbo

In this example, we run the SDXL Turbo model in image-to-image mode: the model takes in a prompt and an image and transforms the image to better match the prompt.

For example, the model transformed the image on the left into the image on the right based on the prompt dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k.

SDXL Turbo is a distilled model designed for fast, interactive image synthesis. Learn more about it here.

Define a container image

First, we define the environment the model inference will run in, the container image.

from io import BytesIO
from pathlib import Path

import modal

CACHE_DIR = "/cache"

image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install(
        "accelerate~=0.25.0",  # Allows `device_map="auto"``, for computation of optimized device_map
        "diffusers~=0.24.0",  # Provides model libraries
        "huggingface-hub[hf-transfer]~=0.25.2",  # Lets us download models from Hugging Face's Hub
        "Pillow~=10.1.0",  # Image manipulation in Python
        "safetensors~=0.4.1",  # Enables safetensor format as opposed to using unsafe pickle format
        "transformers~=4.35.2",  # This is needed for `import torch`
    )
    .env(
        {
            "HF_HUB_ENABLE_HF_TRANSFER": "1",  # Allows faster model downloads
            "HF_HUB_CACHE_DIR": CACHE_DIR,  # Points the Hugging Face cache to a Volume
        }
    )
)

cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)

app = modal.App(
    "image-to-image", image=image, volumes={CACHE_DIR: cache_volume}
)

with image.imports():
    import torch
    from diffusers import AutoPipelineForImage2Image
    from diffusers.utils import load_image
    from huggingface_hub import snapshot_download
    from PIL import Image

Downloading, setting up, and running SDXL Turbo

The Modal Cls defined below contains all the logic to download, set up, and run SDXL Turbo.

The container lifecycle decorator (@modal.enter()) ensures that the model is loaded into memory when a container starts, before it picks up any inputs.

The inference method runs the actual model inference. It takes in an image as a collection of bytes and a string prompt and returns a new image (also as a collection of bytes).

To avoid excessive cold-starts, we set the container_idle_timeout to 240 seconds, meaning once a GPU has loaded the model it will stay online for 4 minutes before spinning down.

We also provide a function that will download the model weights to the cache Volume ahead of time. You can run this function directly with modal run. Otherwise, the weights will be cached after the first container cold start.

@app.function()
def download_models():
    # Ignore files that we don't need to speed up download time.
    ignore = [
        "*.bin",
        "*.onnx_data",
        "*/diffusion_pytorch_model.safetensors",
    ]

    snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore)


@app.cls(gpu=modal.gpu.A10G(), container_idle_timeout=240)
class Model:
    @modal.enter()
    def enter(self):
        self.pipe = AutoPipelineForImage2Image.from_pretrained(
            "stabilityai/sdxl-turbo",
            torch_dtype=torch.float16,
            variant="fp16",
            device_map="auto",
        )

    @modal.method()
    def inference(
        self, image_bytes: bytes, prompt: str, strength: float = 0.9
    ) -> bytes:
        init_image = load_image(Image.open(BytesIO(image_bytes))).resize(
            (512, 512)
        )
        num_inference_steps = 4
        # "When using SDXL-Turbo for image-to-image generation, make sure that num_inference_steps * strength is larger or equal to 1"
        # See: https://huggingface.co/stabilityai/sdxl-turbo
        assert num_inference_steps * strength >= 1

        image = self.pipe(
            prompt,
            image=init_image,
            num_inference_steps=num_inference_steps,
            strength=strength,
            guidance_scale=0.0,
        ).images[0]

        byte_stream = BytesIO()
        image.save(byte_stream, format="PNG")
        image_bytes = byte_stream.getvalue()

        return image_bytes

Running the model from the command line

You can run the model from the command line with

modal run image_to_image.py

Use --help for additional details.

@app.local_entrypoint()
def main(
    image_path=Path(__file__).parent / "demo_images/dog.png",
    prompt="dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k",
    strength=0.9,  # increase to favor the prompt over the baseline image
):
    print(f"🎨 reading input image from {image_path}")
    input_image_bytes = Path(image_path).read_bytes()
    print(f"🎨 editing image with prompt {prompt}")
    output_image_bytes = Model().inference.remote(input_image_bytes, prompt)

    dir = Path("/tmp/stable-diffusion")
    dir.mkdir(exist_ok=True, parents=True)

    output_path = dir / "output.png"
    print(f"🎨 saving output image to {output_path}")
    output_path.write_bytes(output_image_bytes)