MLflow experiment

The MLflow experiment data source provides a standard API to load MLflow experiment run data. You can load data from the notebook experiment, or you can use the MLflow experiment name or experiment ID.

Requirements

Databricks Runtime 6.0 ML or above.

Load data from the notebook experiment

To load data from the notebook experiment, use load().

Python

df = spark.read.format("mlflow-experiment").load()
display(df)

Scala

val df = spark.read.format("mlflow-experiment").load()
display(df)

Load data using experiment IDs

To load data from one or more workspace experiments, specify the experiment IDs as shown.

Python

df = spark.read.format("mlflow-experiment").load("3270527066281272")
display(df)

Scala

val df = spark.read.format("mlflow-experiment").load("3270527066281272,953590262154175")
display(df)

Load data using experiment name

You can also pass the experiment name to the load() method.

Python

expId = mlflow.get_experiment_by_name("/Shared/diabetes_experiment/").experiment_id
df = spark.read.format("mlflow-experiment").load(expId)
display(df)

Scala

val expId = mlflow.getExperimentByName("/Shared/diabetes_experiment/").get.getExperimentId
val df = spark.read.format("mlflow-experiment").load(expId)
display(df)

Filter data based on metrics and parameters

The examples in this section show how you can filter data after loading it from an experiment.

Python

df = spark.read.format("mlflow-experiment").load("3270527066281272")
filtered_df = df.filter("metrics.loss < 0.01 AND params.learning_rate > '0.001'")
display(filtered_df)

Scala

val df = spark.read.format("mlflow-experiment").load("3270527066281272")
val filtered_df = df.filter("metrics.loss < 1.85 AND params.num_epochs > '30'")
display(filtered_df)

Schema

The schema of the DataFrame returned by the data source is:

root
|-- run_id: string
|-- experiment_id: string
|-- metrics: map
|    |-- key: string
|    |-- value: double
|-- params: map
|    |-- key: string
|    |-- value: string
|-- tags: map
|    |-- key: string
|    |-- value: string
|-- start_time: timestamp
|-- end_time: timestamp
|-- status: string
|-- artifact_uri: string