Dynamic batching (beta)
Modal’s @batched
feature allows you to accumulate requests
and process them in dynamically-sized batches, rather than one-by-one.
Batching increases throughput at a potential cost to latency. Batched requests can share resources and reuse work, reducing the time and cost per request. Batching is particularly useful for GPU-accelerated machine learning workloads, as GPUs are designed to maximize throughput and are frequently bottlenecked on shareable resources, like weights stored in memory.
Static batching can lead to unbounded latency, as the function waits for a fixed number of requests to arrive. Modal’s dynamic batching waits for the lesser of a fixed time or a fixed number of requests before executing, maximizing the throughput benefit of batching while minimizing the latency penalty.
Enable dynamic batching with @batched
To enable dynamic batching, apply the
@modal.batched
decorator to the target
Python function. Then, wrap it in @app.function()
and run it on Modal,
and the inputs will be accumulated and processed in batches.
Here’s what that looks like:
import modal
app = modal.App()
@app.function()
@modal.batched(max_batch_size=2, wait_ms=1000)
async def batch_add(xs: list[int], ys: list[int]) -> list[int]:
return [x + y for x, y in zip(xs, ys)]
When you invoke a function decorated with @batched
, you invoke it asynchronously on individual inputs.
Outputs are returned where they were invoked.
For instance, the code below invokes the decorated batch_add
function above three times, but batch_add
only executes twice:
@app.local_entrypoint()
async def main():
inputs = [(1, 300), (2, 200), (3, 100)]
async for result in batch_add.starmap.aio(inputs):
print(f"Sum: {result}")
# Sum: 301
# Sum: 202
# Sum: 103
The first time it is executed with xs
batched to [1, 2]
and ys
batched to [300, 200]
. After about a one second delay, it is executed with xs
batched to [3]
and ys
batched to [100]
.
The result is an iterator that yields 301
, 202
, and 101
.
Use @batched
with functions that take and return lists
For a Python function to be compatible with @modal.batched
, it must adhere to
the following rules:
- The inputs to the function must be lists.
In the example above, we pass
xs
andys
, which are both lists ofint
s. - The function must return a list. In the example above, the function returns a list of sums.
- The lengths of all the input lists and the output list must be the same.
In the example above, if
L == len(xs) == len(ys)
, thenL == len(batch_add(xs, ys))
.
Modal Cls
methods are compatible with dynamic batching
Methods on Modal Cls
es also support dynamic batching.
import modal
app = modal.App()
@app.cls()
class BatchedClass():
@modal.batched(max_batch_size=2, wait_ms=1000)
async def batch_add(self, xs: list[int], ys: list[int]) -> list[int]:
return [x + y for x, y in zip(xs, ys)]
One additional rule applies to classes with Batched Methods:
- If a class has a Batched Method, it cannot have other Batched Methods or Methods.
Configure the wait time and batch size of dynamic batches
The @batched
decorator takes in two required configuration parameters:
max_batch_size
limits the number of inputs combined into a single batch.wait_ms
limits the amount of time the Function waits for more inputs after the first input is received.
The first invocation of the Batched Function initiates a new batch, and subsequent
calls add requests to this ongoing batch. If max_batch_size
is reached,
the batch immediately executes. If the max_batch_size
is not met but wait_ms
has passed since the first request was added to the batch, the unfilled batch is
executed.
Selecting a batch configuration
To optimize the batching configurations for your application, consider the following heuristics:
Set
max_batch_size
to the largest value your function can handle, so you can amortize and parallelize as much work as possible.Set
wait_ms
to the difference between your targeted latency and the execution time. Most applications have a targeted latency, and this allows the latency of any request to stay within that limit.
Serve @modal.web_endpoint
s with dynamic batching
Here’s a simple example of serving a Function that batches requests dynamically
with a @modal.web_endpoint
. Run
modal serve
, submit requests to the endpoint,
and the Function will batch your requests on the fly.
import modal
app = modal.App()
@app.function()
@modal.batched(max_batch_size=2, wait_ms=1000)
async def batch_add(xs: list[int], ys: list[int]) -> list[int]:
return [x + y for x, y in zip(xs, ys)]
@app.function()
@modal.web_endpoint(method="POST", docs=True)
async def add(body: dict[str, int]) -> dict[str, int]:
result = await batch_add.remote.aio(body["x"], body["y"])
return {"result": result}
Now, you can submit requests to the web endpoint and process them in batches. For instance, the three requests in the following example, which might be requests from concurrent clients in a real deployment, will be batched into two executions:
import asyncio
import aiohttp
async def send_post_request(session, url, data):
async with session.post(url, json=data) as response:
return await response.json()
async def main():
# Enter the URL of your web endpoint here
url = "https://workspace--app-name-endpoint-name.modal.run"
async with aiohttp.ClientSession() as session:
# Submit three requests asynchronously
tasks = [
send_post_request(session, url, {"x": 1, "y": 300}),
send_post_request(session, url, {"x": 2, "y": 200}),
send_post_request(session, url, {"x": 3, "y": 100}),
]
results = await asyncio.gather(*tasks)
for result in results:
print(f"Sum: {result['result']}")
asyncio.run(main())