Retrain a model

Learn how to retrain a machine learning model in ML.NET.

The world and its data change constantly. As such, models need to change and update as well. ML.NET provides functionality for retraining models using learned model parameters as a starting point to continually build on previous experience rather than starting from scratch every time.

The following algorithms are retrainable in ML.NET:

Load pretrained model

First, load the pretrained model into your application. To learn more about loading training pipelines and models, see Save and load a trained model.

// Create MLContext
MLContext mlContext = new MLContext();

// Define DataViewSchema of data prep pipeline and trained model
DataViewSchema dataPrepPipelineSchema, modelSchema;

// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);

// Load trained model
ITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);

Extract pretrained model parameters

Once the model is loaded, extract the learned model parameters by accessing the Model property of the pretrained model. The pretrained model was trained using the linear regression model OnlineGradientDescentTrainer, which creates a RegressionPredictionTransformer that outputs LinearRegressionModelParameters. These model parameters contain the learned bias and weights or coefficients of the model. These values are used as a starting point for the new retrained model.

// Extract trained model parameters
LinearRegressionModelParameters originalModelParameters =
    ((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;

Note

The model parameters output depend on the algorithm used. For example OnlineGradientDescentTrainer uses LinearRegressionModelParameters, while LbfgsMaximumEntropyMulticlassTrainer outputs MaximumEntropyModelParameters. When extracting model parameters, cast to the appropriate type.

Retrain a model

The process for retraining a model is no different than that of training a model. The only difference is, the Fit method in addition to the data also takes the original learned model parameters as input and uses them as a starting point in the retraining process.

// New Data
HousingData[] housingData = new HousingData[]
{
    new HousingData
    {
        Size = 850f,
        HistoricalPrices = new float[] { 150000f,175000f,210000f },
        CurrentPrice = 205000f
    },
    new HousingData
    {
        Size = 900f,
        HistoricalPrices = new float[] { 155000f, 190000f, 220000f },
        CurrentPrice = 210000f
    },
    new HousingData
    {
        Size = 550f,
        HistoricalPrices = new float[] { 99000f, 98000f, 130000f },
        CurrentPrice = 180000f
    }
};

//Load New Data
IDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);

// Preprocess Data
IDataView transformedNewData = dataPrepPipeline.Transform(newData);

// Retrain model
RegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =
    mlContext.Regression.Trainers.OnlineGradientDescent()
        .Fit(transformedNewData, originalModelParameters);

At this point, you can save your retrained model and use it in your application. For more information, see the save and load a trained model and make predictions with a trained model guides.

Compare model parameters

How do you know if retraining actually happened? One way would be to compare whether the retrained model's parameters are different than those of the original model. The following code sample compares the original against the retrained model weights and outputs them to the console.

// Extract Model Parameters of re-trained model
LinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;

// Inspect Change in Weights
var weightDiffs =
    originalModelParameters.Weights.Zip(
        retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();

Console.WriteLine("Original | Retrained | Difference");
for(int i=0;i < weightDiffs.Count();i++)
{
    Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");
}

The following table shows what the output might look like.

Original Retrained Difference
33039.86 56293.76 -23253.9
29099.14 49586.03 -20486.89
28938.38 48609.23 -19670.85
30484.02 53745.43 -23261.41