Llama 2 inference with MLC

Machine Learning Compilation (MLC) is high-performance tool for serving LLMs including Llama 2. We will use the mlc_chat package and the pre-compiled Llama 2 binaries to run inference using a Modal GPU.

This example is adapted from this MLC chat collab.

import queue
import threading
import time
from typing import Dict, Generator, List

import modal

Imports and global settings

Determine which GPU you want to use.

GPU: str = "a10g"

Chose model size. At the time of writing MLC chat only provides compiled binaries for Llama 7b and 13b.

LLAMA_MODEL_SIZE: str = "13b"

Define the image and Modal Stub. We use an official NVIDIA CUDA 12.1 image to match MLC CUDA requirements.

image = (
    modal.Image.from_registry(
        "nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04",
        add_python="3.11",
    ).run_commands(
        "apt-get update",
        "apt-get install -y curl git",
        # Install git lfs
        "curl -sSf https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash",
        "apt-get install -y git-lfs",
        "pip3 install --pre --force-reinstall mlc-ai-nightly-cu121 mlc-chat-nightly-cu121 -f https://mlc.ai/wheels",
    )
    # "These commands will download many prebuilt libraries as well as the chat
    # configuration for Llama-2-7b that mlc_chat needs" [...]
    .run_commands(
        "mkdir -p dist/prebuilt",
        "git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib",
        f"cd dist/prebuilt && git clone https://huggingface.co/mlc-ai/mlc-chat-Llama-2-{LLAMA_MODEL_SIZE}-chat-hf-q4f16_1",
    )
)
stub = modal.Stub("mlc-inference", image=image)


LOADING_MESSAGE: str = f"""

                      #%%%%%%%%%%%%(         #%%%%%%%%%%%%#
                    ,%##%,         %%      .%#(%/         %%.
                   %%.  .%#         (%*   (%*   %%         *%/
                 .%#      %%.        .%% %%      (%*         %%
                #%,        ,%%%%%%%%%%%%%/        .%%         (%*
               %%         (%*         %%*%(         #%,         %%
             (%*         %%         *%(   %%         .%#         #%,
            %%         *%/         %%.     *%(         #%,        .%#
          /%(         %%         .%#         %%         .%%%%%%%%%%%%%.
           (%/      ,%#         #%,           /%(      .%%         (%,
             %%    %%.        .%%               %%    (%*         %%
              (%*.%#         (%*                 /%/ %%         /%/
                %%%%%%%%%%%%%%                     %%%%%%%%%%%%%%

                      LOADING => Llama 2 ({LLAMA_MODEL_SIZE}) [{GPU}]


"""

Define Modal function

The generate function will load MLC chat and the compiled model into memory and run inference on an input prompt. This is a generator, streaming tokens back to the client as they are generated.

@stub.function(gpu=GPU)
def generate(prompt: str) -> Generator[Dict[str, str], None, None]:
    from mlc_chat import ChatModule
    from mlc_chat.callback import DeltaCallback

    yield {
        "type": "loading",
        "message": LOADING_MESSAGE + "\n\n",
    }

    class QueueCallback(DeltaCallback):
        """Stream the output of the chat module to client."""

        def __init__(self, callback_interval: float):
            super().__init__()
            self.queue: queue.Queue = queue.Queue()
            self.stopped = False
            self.callback_interval = callback_interval

        def delta_callback(self, delta_message: str):
            self.stopped = False
            self.queue.put(delta_message)

        def stopped_callback(self):
            self.stopped = True

    cm = ChatModule(
        model=f"/dist/prebuilt/mlc-chat-Llama-2-{LLAMA_MODEL_SIZE}-chat-hf-q4f16_1",
        lib_path=f"/dist/prebuilt/lib/Llama-2-{LLAMA_MODEL_SIZE}-chat-hf-q4f16_1-cuda.so",
    )
    queue_callback = QueueCallback(callback_interval=1)

    # Generate tokens in a background thread so we can yield tokens
    # to caller as a generator.
    def _generate():
        cm.generate(
            prompt=prompt,
            progress_callback=queue_callback,
        )

    background_thread = threading.Thread(target=_generate)
    background_thread.start()

    # Yield as a generator to caller function and spawn
    # text-to-speech functions.
    while not queue_callback.stopped:
        yield {"type": "output", "message": queue_callback.queue.get()}

Run model

Create a local Modal entrypoint that calls the generate function. This uses the curses to render tokens as they are streamed back from Modal.

Run this locally with modal run -q mlc_inference.py --prompt "What is serverless computing?"

@stub.local_entrypoint()
def main(prompt: str):
    import curses

    def _generate(stdscr):
        buffer: List[str] = []

        def _buffered_message():
            return "".join(buffer) + ("\n" * 4)

        start = time.time()
        for payload in generate.remote_gen(prompt):
            message = payload["message"]
            if payload["type"] == "loading":
                stdscr.clear()
                stdscr.addstr(0, 0, message)
                stdscr.refresh()
            else:
                buffer.append(message)
                stdscr.clear()
                stdscr.addstr(0, 0, _buffered_message())
                stdscr.refresh()

        n_tokens = len(buffer)
        elapsed = time.time() - start
        print(
            f"[DONE] {n_tokens} tokens generated in {elapsed:.2f}s ({n_tokens / elapsed:.0f} tok/s). Press any key to exit."
        )
        stdscr.getch()

    curses.wrapper(_generate)