Serverless Qwen 3-8B with SGLang and Modal Snapshots
In this example, we show how to serve SGLang on Modal with ~10x faster cold starts.
Fast cold starts are particularly useful for LLM inference applications that have highly “bursty” workloads, like document processing. See this guide for a breakdown of different LLM inference workloads and how to optimize them.
The key technique is CPU + GPU memory snapshotting, which saves and restores the SGLang server directly from its in-memory state.
This adds some complexity to the deployment. If you just want to get started running a basic LLM server on Modal, see this example.
Set up the container image
Our first order of business is to define the environment our server will run in:
the container Image.
We start from a container image provided by the SGLang team via Dockerhub.
While we’re at it, we import the dependencies we’ll need both remotely and locally (for deployment).
import asyncio
import subprocess
import time
import aiohttp
import modal
import modal.experimental
MINUTES = 60 # seconds
sglang_image = (
modal.Image.from_registry(
"lmsysorg/sglang:v0.5.6.post2-cu129-amd64-runtime"
).entrypoint([]) # silence chatty logs on container start
)We also choose a GPU to deploy our inference server onto. We choose the H100 GPU, which offers excellent price-performance and supports 8bit floating point operations, which are the lowest precision well-supported in the relevant GPU kernels across a variety of model architectures.
N_GPUS = 1
GPU = f"H100!:{N_GPUS}"Actual speedups are generally less than what you get from “napkin math” based on available bandwidths — we observed a speedup of about 30% moving from one to two H100s when developing this example. We recommend application-specific benchmarking guided by published generic benchmarks.
Loading and cacheing the model weights
We’ll serve Alibaba’s Qwen 3 LLM. For lower latency and faster cold starts, we pick a smaller model (8B params) in a lower precision floating point format (FP8). This reduces the amount of data that needs to be loaded from GPU RAM into SM SRAM in each forward pass.
MODEL_NAME = "Qwen/Qwen3-8B-FP8"
MODEL_REVISION = (
"220b46e3b2180893580a4454f21f22d3ebb187d3" # latest commit as of 2026-01
)We load the model from the Hugging Face Hub, so we’ll need their Python package.
sglang_image = sglang_image.uv_pip_install("huggingface-hub==0.36.0")We don’t want to load the model from the Hub every time we start the server. We can load it much faster from a Modal Volume. Typical speeds are around one to two GB/s.
HF_CACHE_VOL = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
HF_CACHE_PATH = "/root/.cache/huggingface"
MODEL_PATH = f"{HF_CACHE_PATH}/{MODEL_NAME}"In addition to pointing the Hugging Face Hub at the path where we mount the Volume, we also turn on “high performance” downloads, which can fully saturate our network bandwidth.
sglang_image = sglang_image.env(
{"HF_HUB_CACHE": HF_CACHE_PATH, "HF_XET_HIGH_PERFORMANCE": "1"}
)Cacheing compilation artifacts
Model weights aren’t the only thing we want to cache.
As a rule, LLM inference servers like SGLang don’t directly provide their own kernels. They draw high-performance kernels from a variety of sources.
As of version 0.5.6, SGLang’s default kernel backend
for FP8 matrix multiplications (fp8-gemm-backend)
on Hopper SM architecture GPUs like the H100 is DeepGEMM by DeepSeek.
The binaries of these kernels are not included in the SGLang Docker image and so must be JIT-compiled. We store these in a Modal Volume as well.
DG_CACHE_VOL = modal.Volume.from_name("deepgemm-cache", create_if_missing=True)
DG_CACHE_PATH = "/root/.cache/deepgemm"JIT DeepGEMM kernels are on by default, but we explicitly enable them via an environment variable.
sglang_image = sglang_image.env({"SGLANG_ENABLE_JIT_DEEPGEMM": "1"})We trigger the compilation by running sglang.compile_deep_gemm in a subprocess kicked off from a Python function.
def compile_deep_gemm():
import os
if int(os.environ.get("SGLANG_ENABLE_JIT_DEEPGEMM", "1")):
subprocess.run(
f"python3 -m sglang.compile_deep_gemm --model-path {MODEL_NAME} --revision {MODEL_REVISION} --tp {N_GPUS}",
shell=True,
)We run this Python function on Modal as part of building the Image so that it has access to the appropriate GPU and the caches for our model and compilaton artifacts.
sglang_image = sglang_image.run_function(
compile_deep_gemm,
volumes={DG_CACHE_PATH: DG_CACHE_VOL, HF_CACHE_PATH: HF_CACHE_VOL},
gpu=GPU,
)Speed up cold starts with GPU snapshotting
Modal is a serverless compute platform, so all of your inference services automatically scale up and down to handle variable load.
Scaling up a new replica requires quite a bit of work — loading up Python and system packages, loading model weights, setting up the inference engine, and so on.
We can skip over and speed up a bunch of this work when spinning up new replicas after the first by directly booting from a memory snapshot, which contains the exact in-memory representation of our server just before it begins taking requests.
Most applications can be snapshot and experience substantial speedups (2x to 10x, see our initial benchmarks here). However, it generally requires some extra work to adapt the application code.
For instance, we here set an environment variable that improves the compatibility of the Torch Inductor compiler with GPU snapshotting.
sglang_image = sglang_image.env({"TORCHINDUCTOR_COMPILE_THREADS": "1"})Below, we walk through the additional steps required to make an SGLang server compatible with snapshots.
Sleeping and waking an SGLang server
We prepare our SGLang inference server for snapshotting by first sending
a few requests to “warm it up”, ensuring that it is fully ready to process requests.
Then we “put it to sleep”, moving non-essential data out of GPU memory,
with a request to /release_memory_occupation.
At this point, we can take a memory snapshot.
Upon snapshot restoration, we “wake up” the server
with a request to /resume_memory_occupation.
We use the requests library to send ourselves these HTTP requests on localhost/127.0.0.1.
with sglang_image.imports():
import requests
def warmup():
payload = {
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_tokens": 16,
}
for _ in range(3):
requests.post(
f"http://127.0.0.1:{PORT}/v1/chat/completions", json=payload, timeout=10
).raise_for_status()
def sleep():
requests.post(
f"http://127.0.0.1:{PORT}/release_memory_occupation", json={}
).raise_for_status()
def wake_up():
requests.post(
f"http://127.0.0.1:{PORT}/resume_memory_occupation", json={}
).raise_for_status()Define the inference server and infrastructure
We wrap up all of the choices we made about the infrastructure of our inference server into a number of Python decorators that we apply to a Python class that encapsulates the logic to run our server.
The key decorators are:
@app.clsto define the core of our service. We attach our Image, request a GPU, attach our cache Volumes, specify the region, and configure auto-scaling. See the reference documentation for details.@modal.web_serverto turn our Python code into an HTTP server. The wrapped code needs to eventually listen for HTTP connections on the providedport.@modal.concurrentto specify how many requests our server can handle before we need to scale up.@modal.enterand@modal.exitto indicate which methods of the class should be run when starting the server and shutting it down. Theentermethods also define what code is run before memory snapshot creation (snap=True) and after memory snapshot restoration (snap=False).
The modal.concurrent decorator and the lifecycle management are particular important
for bursty workloads and for snapshotting, respectively, so let’s discuss them in detail.
Determining autoscaling policy with @modal.concurrent
To handle bursty workloads, we need to decide how we will scale up and down replicas in response to load. Without autoscaling, users’ requests will queue when the server becomes overloaded.
We can set two values with the @modal.concurrent decorator. max_inputs should be set to the maximum number of inputs a replica can handle concurrently
without internal queueing — the max-running-requests in SGLang. target_inputs can be left unset or, if the per-request latency
degrades too much when handling the maximum batch size,
it can be set to a lower value.
TARGET_INPUTS = 10
MAX_INPUTS = 1000Generally, this choice needs to be made as part of LLM inference engine benchmarking in reference to a particular application’s latency and throughput targets.
Controlling container lifecycles with @modal.enter
Modal considers a new replica ready to receive inputs once the @modal.enter methods have exited
and the container accepts connections.
To ensure that we actually finish setting up our server before we are marked ready for inputs,
we define a helper function to check whether the server is finished setting up.
def wait_ready(process: subprocess.Popen, timeout: int = 5 * MINUTES):
deadline = time.time() + timeout
while time.time() < deadline:
try:
check_running(process)
requests.get(f"http://127.0.0.1:{PORT}/health").raise_for_status()
return
except (
subprocess.CalledProcessError,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
):
time.sleep(1)
raise TimeoutError(f"SGLang server not ready within timeout of {timeout} seconds")
def check_running(p: subprocess.Popen):
if (rc := p.poll()) is not None:
raise subprocess.CalledProcessError(rc, cmd=p.args)With all this in place, we are ready to define our high-performance, low-latency LLM inference server.
app = modal.App(name="example-sglang-snapshot")
PORT = 8000
@app.cls(
image=sglang_image,
gpu=GPU,
volumes={HF_CACHE_PATH: HF_CACHE_VOL, DG_CACHE_PATH: DG_CACHE_VOL},
enable_memory_snapshot=True,
experimental_options={"enable_gpu_snapshot": True},
)
@modal.concurrent(target_inputs=TARGET_INPUTS, max_inputs=MAX_INPUTS)
class SGLang:
@modal.enter(snap=True)
def startup(self):
"""Start the SGLang server and block until it is healthy, then warm it up and put it to sleep."""
cmd = [
"python",
"-m",
"sglang.launch_server",
"--model-path",
MODEL_NAME,
"--revision",
MODEL_REVISION,
"--served-model-name",
MODEL_NAME,
"--host",
"0.0.0.0",
"--port",
f"{PORT}",
"--tp", # use all GPUs to split up tensor-parallel operations
f"{N_GPUS}",
"--cuda-graph-max-bs", # capture CUDA graphs up to batch sizes we're likely to observe
f"{MAX_INPUTS}",
"--max-running-requests",
f"{MAX_INPUTS}",
"--enable-metrics", # expose metrics endpoints for telemetry
"--enable-memory-saver", # enable offload, for snapshotting
"--enable-weights-cpu-backup", # enable offload, for snapshotting
]
self.process = subprocess.Popen(cmd)
wait_ready(self.process)
warmup() # for snapshotting
sleep()
@modal.enter(snap=False)
def wake_up(self):
wake_up()
@modal.web_server(
port=PORT, # wrapped code must listen on this port
startup_timeout=10 * MINUTES, # how long can server startup take?
)
def serve(self):
pass
@modal.exit()
def stop(self):
self.process.terminate()Deploy the server
To deploy the server on Modal, just run
modal deploy sglang_snapshot.pyThis will create a new App on Modal and build the container image for it if it hasn’t been built yet.
Interact with the server
Once it is deployed, you’ll see a URL appear in the command line,
something like https://your-workspace-name--example-sglang-snapshot-sglang.modal.run.
You can find interactive Swagger UI docs at the /docs route of that URL, i.e. https://your-workspace-name--example-sglang-snapshot-sglang.modal.direct/docs.
These docs describe each route and indicate the expected input and output
and translate requests into curl commands.
For simple routes, you can even send a request directly from the docs page.
Test the server
To make it easier to test the server setup, we also include a local_entrypoint that hits the server with a simple client.
If you execute the command
modal run sglang_snapshot.pya fresh replica of the server will be spun up on Modal while the code below executes on your local machine.
Think of this like writing simple tests inside of the if __name__ == "__main__" block of a Python script, but for cloud deployments!
@app.local_entrypoint()
async def test(test_timeout=10 * MINUTES, prompt=None, twice=True):
url = SGLang().serve.get_web_url()
system_prompt = {
"role": "system",
"content": "You are a pirate who can't help but drop sly reminders that he went to Harvard.",
}
if prompt is None:
prompt = "Explain the Singular Value Decomposition."
content = [{"type": "text", "text": prompt}]
messages = [ # OpenAI chat format
system_prompt,
{"role": "user", "content": content},
]
await probe(url, messages, timeout=test_timeout)
if twice:
messages[0]["content"] = "You are Jar Jar Binks."
print(f"Sending messages to {url}:", *messages, sep="\n\t")
await probe(url, messages, timeout=1 * MINUTES)This test relies on the two helper functions below, which ping the server and wait for a valid response.
async def probe(url, messages=None, timeout=5 * MINUTES):
if messages is None:
messages = [{"role": "user", "content": "Tell me a joke."}]
deadline = time.time() + timeout
async with aiohttp.ClientSession(base_url=url) as session:
while time.time() < deadline:
try:
await _send_request(session, "llm", messages)
return
except asyncio.TimeoutError:
await asyncio.sleep(1)
raise TimeoutError(f"No response from server within {timeout} seconds")
async def _send_request(
session: aiohttp.ClientSession,
model: str,
messages: list,
timeout: int | None = None,
) -> None:
async with session.post(
"/v1/chat/completions",
json={"messages": messages, "model": model},
timeout=timeout,
) as resp:
resp.raise_for_status()
print((await resp.json())["choices"][0]["message"]["content"])Test memory snapshotting
Using modal run creates an ephemeral Modal App, rather than a deployed Modal App.
Ephemeral Modal Apps are short-lived, so they turn off memory snapshotting.
To test the memory snapshot version of the server,
first deploy it with modal deploy and then hit it with a client.
You should observe startup improvements after a handful of cold starts (usually less than five). If you want to see the speedup during a test, we recommend heading to the deployed App in your Modal dashboard and manually stopping containers after they have served a request to ensure turnover.
You can use the client code below to test the endpoint.
if __name__ == "__main__":
# after deployment, we can use the class from anywhere
SGLang = modal.Cls.from_name("example-sglang-snapshot", "SGLang")
print("calling inference server")
try:
asyncio.run(probe(SGLang().serve.get_web_url()))
except modal.exception.NotFoundError as e:
raise Exception(
f"To take advantage of GPU snapshots, deploy first with modal deploy {__file__}"
) from eIt can be run with the command
python sglang_snapshot.py