Create a training run using the Foundation Model Training API


This feature is in Public Preview. Reach out to your Databricks account team to enroll in the Public Preview.

This page describes how to create and configure a training run using the Foundation Model Training API and describes all of the parameters used in API call. You can also create a run using the UI. For instructions, see Create a training run using the Foundation Model Training UI.


See Requirements.

Create a training run

To create training runs programmatically, use the create() function. This function trains a model on the provided dataset and converts the final Composer checkpoint to a Hugging Face formatted checkpoint for inference.

The required inputs are the model you want to train, the location of your training dataset, and where to register your model. There are also optional fields that allow you to perform evaluation and change the hyperparameters of your run. After you create a run, the checkpoints are saved to the MLflow run, and the final checkpoint is registered to Unity Catalog for easy deployment.

See Configure a training run for details about arguments for the create() function.

from databricks.model_training import foundation_model as fm

run = fm.create(
  train_data_path='dbfs:/Volumes/main/mydirectory/ift/train.jsonl', # UC Volume with JSONL formatted data
  # Public HF dataset is also supported
  # train_data_path='mosaicml/dolly_hhrlhf/train'
  register_to='main.mydirectory', # UC catalog and schema to register the model to

After the run completes, the completed run and final checkpoints are saved, and the model is registered to Unity Catalog.

Configure a training run

The following table summarizes the fields for the create() function.

Field Required Type Description
model x str The name of the model to use. See Supported models.
train_data_path x str The location of your training data. This can be a location in Unity Catalog (<catalog>.<schema>.<table> or dbfs:/Volumes/<catalog>/<schema>/<volume>/<dataset>.jsonl), or a HuggingFace dataset.

For INSTRUCTION_FINETUNE, the data should be formatted with each row containing a prompt and response field.

For CONTINUED_PRETRAIN, this is a folder of .txt files. See Prepare data for Foundation Model Training for accepted data formats and Recommended data size for model training for data size recommendations.
register_to x str The Unity Catalog catalog and schema (<catalog>.<schema> or <catalog>.<schema>.<custom-name>) where the model is registered after training for easy deployment. If custom-name is not provided, this defaults to the run name.
data_prep_cluster_id str The cluster ID of the cluster to use for Spark data processing. This is required for supervised training tasks where the training data is in a Delta table. For information on how to find the cluster ID, see Get cluster id.
experiment_path str The path to the MLflow experiment where the training run output (metrics and checkpoints) is saved. Defaults to the run name within the user’s personal workspace (i.e. /Users/<username>/<run_name>).
task_type str The type of task to run. Can be INSTRUCTION_FINETUNE (default), CHAT_COMPLETION, or CONTINUED_PRETRAIN.
eval_data_path str The remote location of your evaluation data (if any). Must follow the same format as train_data_path.
eval_prompts str A list of prompt strings to generate responses during evaluation. Default is None (do not generate prompts). Results are logged to the experiment every time the model is checkpointed. Generations occur at every model checkpoint with the following generation parameters: max_new_tokens: 100, temperature: 1, top_k: 50, top_p: 0.95, do_sample: true.
custom_weights_path str The remote location of a custom model checkpoint for training. Default is None, meaning the run starts from the original pretrained weights of the chosen model. If custom weights are provided, these weights are used instead of the original pretrained weights of the model. These weights must be a Composer checkpoint and must match the architecture of the model specified
training_duration str The total duration of your run. Default is one epoch or 1ep. Can be specified in epochs (10ep) or tokens (1000000tok).
learning_rate str The learning rate for model training. Default is 5e-7. The optimizer is DecoupledLionW with betas of 0.99 and 0.95 and no weight decay. The learning rate scheduler is LinearWithWarmupSchedule with a warmup of 2% of the total training duration and a final learning rate multiplier of 0.
context_length str The maximum sequence length of a data sample. This is used to truncate any data that is too long and to package shorter sequences together for efficiency.

The default is the default for the provided model. Increasing the context length beyond each model’s default is not supported. See Supported models for the context length of each model.
validate_inputs Boolean Whether to validate the access to input paths before submitting the training job. Default is True.

Build on custom model weights

Foundation Model Training supports training any of the supported models starting from custom weights using the optional parameter custom_weights_path.

For example, you can create a domain-specific model with your custom data and then pass the desired checkpoint as an input for further training.

You can provide the remote location to the Composer checkpoint from your previous run for training. Checkpoint paths can be found in the Artifacts tab of a previous MLflow run and are of the form: dbfs:/databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<run_name>/checkpoints/<checkpoint_folder>[.symlink], where the symlink extension is optional. This checkpoint folder name corresponds to the batch and epoch of a particular snapshot, such as ep29-ba30/. The final snapshot is accessible with the symlink latest-sharded-rank0.symlink.

Artifacts tab for a previous MLflow run

The path can then be passed to the custom_weights_path parameter in your configuration.

model = 'meta-llama/Llama-2-7b-chat-hf'
custom_weights_path = 'your/checkpoint/path'

Get cluster id

To retrieve the cluster id:

  1. In the left nav bar of the Databricks workspace, click Compute.

  2. In the table, click the name of your cluster.

  3. Click More button in the upper-right corner and select View JSON from the drop-down menu.

  4. The Cluster JSON file appears. Copy the cluster id, which is the first line in the file.

    cluster id

Get status of a run

You can track the progress of a run using the Experiment page in the Databricks UI or using the API command get_events(). For details, see View, manage, and analyze Foundation Model Training runs.

Example output from get_events():

Use API to get run status

Sample run details on the Experiment page:

Get run status from the experiments UI

Next steps

After your training run is complete, you can review metrics in MLflow and deploy your model for inference. See steps 5 through 7 of Tutorial: Create and deploy a Foundation Model Training run.

Additional resources