Hosting any LLaMA 3 model with Text Generation Inference (TGI)

In this example, we show how to run an optimized inference server using Text Generation Inference (TGI) with performance advantages over standard text generation pipelines including:

  • continuous batching, so multiple generations can take place at the same time on a single container
  • PagedAttention, which applies memory paging to the attention mechanism’s key-value cache, increasing throughput

This example deployment, accessible here, can serve LLaMA 3 70B with 70 second cold starts, up to 200 tokens/s of throughput, and a per-token latency of 55ms.

Setup

First we import the components we need from modal.

import subprocess
from pathlib import Path

import modal

Next, we set which model to serve, taking care to specify the GPU configuration required to fit the model into VRAM, and the quantization method (bitsandbytes or gptq) if desired. Note that quantization does degrade token generation performance significantly.

Any model supported by TGI can be chosen here.

MODEL_ID = "NousResearch/Meta-Llama-3-8B"
MODEL_REVISION = "315b20096dc791d381d514deb5f8bd9c8d6d3061"

Add ["--quantize", "gptq"] for TheBloke GPTQ models.

LAUNCH_FLAGS = [
    "--model-id",
    MODEL_ID,
    "--port",
    "8000",
    "--revision",
    MODEL_REVISION,
]

Define a container image

We want to create a Modal Image which has the Huggingface model cache pre-populated. The benefit of this is that the container no longer has to re-download the model from Huggingface - instead, it will take advantage of Modal’s internal filesystem for faster cold starts. On the largest 70B model, the 135GB model can be loaded in as little as 70 seconds.

Download the weights

We can use the included utilities to download the model weights (and convert to safetensors, if necessary) as part of the image build.

def download_model():
    subprocess.run(
        [
            "text-generation-server",
            "download-weights",
            MODEL_ID,
            "--revision",
            MODEL_REVISION,
        ],
    )

Image definition

We’ll start from a Docker Hub image recommended by TGI, and override the default ENTRYPOINT for Modal to run its own which enables seamless serverless deployments.

Next we run the download function above to pre-populate the image with our model weights.

If you adapt this example to run another model, note that for this step to work on a gated model the HF_TOKEN environment variable must be set and provided as a Modal Secret.

Finally, we install the text-generation client to interface with TGI’s Rust webserver over localhost.

app = modal.App("example-tgi-" + MODEL_ID.split("/")[-1])

tgi_image = (
    modal.Image.from_registry(
        "ghcr.io/huggingface/text-generation-inference:1.4"
    )
    .dockerfile_commands("ENTRYPOINT []")
    .run_function(
        download_model,
        timeout=3600,
    )
    .pip_install("text-generation")
)

The model class

The inference function is best represented with Modal’s class syntax. The class syntax is a special representation for a Modal function which splits logic into two parts:

  1. the @enter() function, which runs once per container when it starts up, and
  2. the @method() function, which runs per inference request.

This means the model is loaded into the GPUs, and the backend for TGI is launched just once when each container starts, and this state is cached for each subsequent invocation of the function. Note that on start-up, we must wait for the Rust webserver to accept connections before considering the container ready.

Here, we also

  • specify the secret so the HUGGING_FACE_HUB_TOKEN environment variable can be set
  • specify how many A100s we need per container
  • specify that each container is allowed to handle up to 10 inputs (i.e. requests) simultaneously
  • keep idle containers for 10 minutes before spinning down
  • increase the timeout limit
GPU_CONFIG = modal.gpu.H100(count=2)  # 2 H100s


@app.cls(
    gpu=GPU_CONFIG,
    allow_concurrent_inputs=15,
    container_idle_timeout=60 * 10,
    timeout=60 * 60,
    image=tgi_image,
)
class Model:
    @modal.enter()
    def start_server(self):
        import socket
        import time

        from text_generation import AsyncClient

        self.launcher = subprocess.Popen(
            ["text-generation-launcher"] + LAUNCH_FLAGS,
        )
        self.client = AsyncClient("http://127.0.0.1:8000", timeout=60)
        self.template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

        # Poll until webserver at 127.0.0.1:8000 accepts connections before running inputs.
        def webserver_ready():
            try:
                socket.create_connection(("127.0.0.1", 8000), timeout=1).close()
                return True
            except (socket.timeout, ConnectionRefusedError):
                # Check if launcher webserving process has exited.
                # If so, a connection can never be made.
                retcode = self.launcher.poll()
                if retcode is not None:
                    raise RuntimeError(
                        f"launcher exited unexpectedly with code {retcode}"
                    )
                return False

        while not webserver_ready():
            time.sleep(1.0)

        print("Webserver ready!")

    @modal.exit()
    def terminate_server(self):
        self.launcher.terminate()

    @modal.method()
    async def generate(self, question: str):
        prompt = self.template.format(user=question)
        result = await self.client.generate(
            prompt, max_new_tokens=1024, stop_sequences=["<|eot_id|>"]
        )

        return result.generated_text

    @modal.method()
    async def generate_stream(self, question: str):
        prompt = self.template.format(user=question)

        async for response in self.client.generate_stream(
            prompt, max_new_tokens=1024, stop_sequences=["<|eot_id|>"]
        ):
            if (
                not response.token.special
                and response.token.text != "<|eot_id|>"
            ):
                yield response.token.text

Run the model

We define a local_entrypoint to invoke our remote function. You can run this script locally with modal run text_generation_inference.py.

@app.local_entrypoint()
def main(prompt: str = None):
    if prompt is None:
        prompt = "Implement a Python function to compute the Fibonacci numbers."
    print(Model().generate.remote(prompt))

Serve the model

Once we deploy this model with modal deploy text_generation_inference.py, we can serve it behind an ASGI app front-end. The front-end code (a single file of Alpine.js) is available here.

You can try our deployment here.

frontend_path = Path(__file__).parent.parent / "llm-frontend"


@app.function(
    mounts=[modal.Mount.from_local_dir(frontend_path, remote_path="/assets")],
    keep_warm=1,
    allow_concurrent_inputs=10,
    timeout=60 * 10,
)
@modal.asgi_app(label="llama3")
def tgi_app():
    import json

    import fastapi
    import fastapi.staticfiles
    from fastapi.responses import StreamingResponse

    web_app = fastapi.FastAPI()

    @web_app.get("/stats")
    async def stats():
        stats = await Model().generate_stream.get_current_stats.aio()
        return {
            "backlog": stats.backlog,
            "num_total_runners": stats.num_total_runners,
            "model": MODEL_ID,
        }

    @web_app.get("/completion/{question}")
    async def completion(question: str):
        from urllib.parse import unquote

        async def generate():
            async for text in Model().generate_stream.remote_gen.aio(
                unquote(question)
            ):
                yield f"data: {json.dumps(dict(text=text), ensure_ascii=False)}\n\n"

        return StreamingResponse(generate(), media_type="text/event-stream")

    web_app.mount(
        "/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
    )
    return web_app

Invoke the model from other apps

Once the model is deployed, we can invoke inference from other apps, sharing the same pool of GPU containers with all other apps we might need.

$ python
>>> import modal
>>> f = modal.Function.lookup("example-tgi-Meta-Llama-3-70B-Instruct", "Model.generate")
>>> f.remote("What is the story about the fox and grapes?")
'The story about the fox and grapes ...