Train a machine learning model using cross validation

Learn how to use cross validation to train more robust machine learning models in ML.NET.

Cross-validation is a training and model evaluation technique that splits the data into several partitions and trains multiple algorithms on these partitions. This technique improves the robustness of the model by holding out data from the training process. In addition to improving performance on unseen observations, in data-constrained environments it can be an effective tool for training models with a smaller dataset.

The data and data model

Given data from a file that has the following format:

Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41

The data can be modeled by a class like HousingData and loaded into an IDataView.

public class HousingData
    public float Size { get; set; }

    [LoadColumn(1, 3)]
    public float[] HistoricalPrices { get; set; }

    public float CurrentPrice { get; set; }

Prepare the data

Pre-process the data before using it to build the machine learning model. In this sample, the Size and HistoricalPrices columns are combined into a single feature vector, which is output to a new column called Features using the Concatenate method. In addition to getting the data into the format expected by ML.NET algorithms, concatenating columns optimizes subsequent operations in the pipeline by applying the operation once for the concatenated column instead of each of the separate columns.

Once the columns are combined into a single vector, NormalizeMinMax is applied to the Features column to get Size and HistoricalPrices in the same range between 0-1.

// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
    mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })

// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);

Train model with cross validation

Once the data has been pre-processed, it's time to train the model. First, select the algorithm that most closely aligns with the machine learning task to be performed. Because the predicted value is a numerically continuous value, the task is regression. One of the regression algorithms implemented by ML.NET is the StochasticDualCoordinateAscentCoordinator algorithm. To train the model with cross-validation use the CrossValidate method.


Although this sample uses a linear regression model, CrossValidate is applicable to all other machine learning tasks in ML.NET except Anomaly Detection.

// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();

// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);

CrossValidate performs the following operations:

  1. Partitions the data into a number of partitions equal to the value specified in the numberOfFolds parameter. The result of each partition is a TrainTestData object.
  2. A model is trained on each of the partitions using the specified machine learning algorithm estimator on the training data set.
  3. Each model's performance is evaluated using the Evaluate method on the test data set.
  4. The model along with its metrics are returned for each of the models.

The result stored in cvResults is a collection of CrossValidationResult objects. This object includes the trained model as well as metrics which are both accessible form the Model and Metrics properties respectively. In this sample, the Model property is of type ITransformer and the Metrics property is of type RegressionMetrics.

Evaluate the model

Metrics for the different trained models can be accessed through the Metrics property of the individual CrossValidationResult object. In this case, the R-Squared metric is accessed and stored in the variable rSquared.

IEnumerable<double> rSquared =
        .Select(fold => fold.Metrics.RSquared);

If you inspect the contents of the rSquared variable, the output should be five values ranging from 0-1 where closer to 1 means best. Using metrics like R-Squared, select the models from best to worst performing. Then, select the top model to make predictions or perform additional operations with.

// Select all models
ITransformer[] models =
        .OrderByDescending(fold => fold.Metrics.RSquared)
        .Select(fold => fold.Model)

// Get Top Model
ITransformer topModel = models[0];