Train an SLM from scratch with early-stopping grid search over hyperparameters
When you want a language model that performs well on your task, there are three options, ordered by the degree of customization:
Prompt Engineering: large and capable language models understand tasks in natural language, so you can carefully design a natural language “prompt” to elicit the desired behavior.
Fine-Tuning: those same language models were trained by gradient descent on data sets representing tasks, and they can be further trained by gradient descent on data sets representative of your task.
Training from Scratch: if you have enough data for your task, you can throw the pretrained model away and make your own.
Each step adds additional engineering complexity, but also leads to a superior cost-performance Pareto frontier for your tasks. Fine-tuned models at one-tenth the size regularly outperform more generic models, and models trained from scratch outperform them.
Because these models are so much smaller than the Large Language Models that power generic assistant chatbots like ChatGPT and Claude, they are often called Small Language Models (SLMs).
In this example, we will explore training an SLM from scratch on Modal.
In fact, we’ll train 8 SLMs in parallel with different hyperparameters and then select the best one for additional training.
We’ll monitor this training live and serve our training and trained models as web endpoints and simple browser UIs.
Along the way we’ll use many features of the Modal platform: distributed volumes, multiple web endpoints, and parallel container execution.
Together, these features give every machine learning and AI team the same infrastructural capabilities that the most sophisticated companies have in their internal platforms.
Basic Setup
We’ll use A10G GPUs for training, which are able to train the model to recognizably improved performance in ~15 minutes while keeping costs under ~$1.
Create a Volume to store data, weights, and logs
Since we’ll be coordinating training across multiple machines we’ll use a distributed Volume to store the data, checkpointed models, and TensorBoard logs.
Define dependencies in container images
The container image for training is based on Modal’s default slim Debian Linux image with torch for defining and running our neural network and tensorboard for monitoring training.
We also have some local dependencies that we’ll need to import into the remote environment. We add them into the remote container.
We’ll serve a simple web endpoint:
And we’ll deploy a web UI for interacting with our trained models using Gradio.
We can also “pre-import” libraries that will be used by the functions we run on Modal in a given image
using the with image.imports context manager.
Running SLM training on Modal
Here we define the training function, wrapping it in a decorator
that specifies the infrastructural parameters, like the container image we want to use,
which volume to mount where, the gpu we’re using, and so on.
Training consists of specifying optimization parameters, loading the dataset, building the model, setting up TensorBoard logging &
checkpointing, and then finally executing the training_loop itself.
Launch a hyperparameter sweep from a local_entrypoint
The main entry point coordinates the hyperparameter optimization.
First we specify the default hyperparameters for the model, taken from Andrej Karpathy’s walkthrough.
For better performance, you can increase the context_size and scale up the GPU accordingly.
Next we define the local entrypoint: the code we run locally to coordinate training.
It will train 8 models in parallel across 8 containers, each
with different hyperparameters, varying the number of heads (n_heads), the context_size (called the “block size” by Karpathy), and the dropout rate (dropout). To run in
parallel we need to use the starmap method.
We train all of the models until the first checkpoint and then stop early so we can compare the validation losses.
Then we restart training for the best model and train it to completion.
You can kick off training with the following command:
The output will look something like this:
The local_entrypoint code is below. Note that the arguments to it can also be passed via the command line.
Use --help for details.
Monitor experiments with TensorBoard
To monitor our training we will create a TensorBoard WSGI web app, which will display the progress of our training across all 8 models. We’ll use the latest logs for the most recent experiment written to the Volume.
To ensure we have the latest data we add some WSGI Middleware that checks the Modal Volume for updates when the page is reloaded.
To ensure a unique color per experiment you can click the palette (🎨) icon
under TensorBoard > Time Series > Run and use the Regex: E(\d{4})-(\d{2})-(\d{2})-(\d{6})\.(\d{6})
You can deploy this TensorBoard service by running
and visit it at the URL that ends with -monitor-training.modal.run.
After training finishes, your TensorBoard UI will look something like this:
You can also find some sample text generated by the model in the “Text” tab.
Notice that there are 8 models training, and the one with the lowest validation loss at step 600 continues training to 3000 steps.
Serving SLMs on Modal during and after training
Because our weights are stored in a distributed Volume, we can deploy an inference endpoint based off of them without any extra work — and we can even check in on models while we’re still training them! # For more on storing model weights on Modal, see this guide.
Remote inference with Modal Clses
We wrap our inference in a Modal Cls called ModelInference.
The user of ModelInference can control which model is used by providing the experiment_name. Each unique choice creates a separate auto-scaling deployment.
If the user does not specify an experiment_name, the latest experiment
is used.
Adding a simple web endpoint
The ModelInference class above is available for use
from any other Python environment with the right Modal credentials
and the modal package installed — just use lookup.
But we can also expose it as a web endpoint for easy access from anywhere, including other programming languages or the command line.
This endpoint can be deployed on Modal with modal deploy.
That will allow us to generate text via a simple curl command like this:
which will return something like:
It’s not exactly Shakespeare, but at least it shows our model learned something!
You can choose which model to use by specifying the experiment_name in the query parameters of the request URL.
Serving a Gradio UI with asgi_app
Second, we create a Gradio web app for generating text via a graphical user interface in the browser. That way our fellow team members and stakeholders can easily interact with the model and give feedback, even when we’re still training the model.
You should see the URL for this UI in the output of modal deploy or on your Modal app dashboard for this app.
The Gradio UI will look something like this:
Addenda
The remainder of this code is boilerplate.
Training Loop
There’s quite a lot of code for just the training loop! If you’d rather not write this stuff yourself, consider a training framework like PyTorch Lightning or Hugging Face.
Miscellaneous
The remaining code includes small helper functions for training the model.