Re-train a model
Learn how to retrain a machine learning model in ML.NET.
The world and the data around it change at a constant pace. As such, models need to change and update as well. ML.NET provides functionality for re-training models using learned model parameters as a starting point to continually build on previous experience rather than starting from scratch every time.
The following algorithms are re-trainable in ML.NET:
- AveragedPerceptronTrainer
- FieldAwareFactorizationMachineTrainer
- LbfgsLogisticRegressionBinaryTrainer
- LbfgsMaximumEntropyMulticlassTrainer
- LbfgsPoissonRegressionTrainer
- LinearSvmTrainer
- OnlineGradientDescentTrainer
- SgdCalibratedTrainer
- SgdNonCalibratedTrainer
- SymbolicSgdLogisticRegressionBinaryTrainer
Load pre-trained model
First, load the pre-trained model into your application. To learn more about loading training pipelines and models, see Save and load a trained model.
// Create MLContext
MLContext mlContext = new MLContext();
// Define DataViewSchema of data prep pipeline and trained model
DataViewSchema dataPrepPipelineSchema, modelSchema;
// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);
// Load trained model
ITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);
Extract pre-trained model parameters
Once the model is loaded, extract the learned model parameters by accessing the Model
property of the pre-trained model. The pre-trained model was trained using the linear regression model OnlineGradientDescentTrainer
which creates a RegressionPredictionTransformer
that outputs LinearRegressionModelParameters
. These model parameters contain the learned bias and weights or coefficients of the model. These values will be used as a starting point for the new re-trained model.
// Extract trained model parameters
LinearRegressionModelParameters originalModelParameters =
((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;
Note
The model parameters output depend on the algorithm used. For example OnlineGradientDescentTrainer
uses LinearRegressionModelParameters
, while LbfgsMaximumEntropyMulticlassTrainer outputs MaximumEntropyModelParameters
. When extracting model parameters, cast to the appropriate type.
Re-train model
The process for retraining a model is no different than that of training a model. The only difference is, the Fit
method in addition to the data also takes as input the original learned model parameters and uses them as a starting point in the re-training process.
// New Data
HousingData[] housingData = new HousingData[]
{
new HousingData
{
Size = 850f,
HistoricalPrices = new float[] { 150000f,175000f,210000f },
CurrentPrice = 205000f
},
new HousingData
{
Size = 900f,
HistoricalPrices = new float[] { 155000f, 190000f, 220000f },
CurrentPrice = 210000f
},
new HousingData
{
Size = 550f,
HistoricalPrices = new float[] { 99000f, 98000f, 130000f },
CurrentPrice = 180000f
}
};
//Load New Data
IDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);
// Preprocess Data
IDataView transformedNewData = dataPrepPipeline.Transform(newData);
// Retrain model
RegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =
mlContext.Regression.Trainers.OnlineGradientDescent()
.Fit(transformedNewData, originalModelParameters);
At this point, you can save your re-trained model and use it in your application. For more information, see the save and load a trained model and make predictions with a trained model guides.
Compare model parameters
How do you know if re-training actually happened? One way would be to compare whether the re-trained model's parameters are different than those of the original model. The code sample below compares the original against the re-trained model weights and outputs them to the console.
// Extract Model Parameters of re-trained model
LinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;
// Inspect Change in Weights
var weightDiffs =
originalModelParameters.Weights.Zip(
retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();
Console.WriteLine("Original | Retrained | Difference");
for(int i=0;i < weightDiffs.Count();i++)
{
Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");
}
The table below shows what the output might look like.
Original | Retrained | Difference |
---|---|---|
33039.86 | 56293.76 | -23253.9 |
29099.14 | 49586.03 | -20486.89 |
28938.38 | 48609.23 | -19670.85 |
30484.02 | 53745.43 | -23261.41 |
.NET feedback
The .NET documentation is open source. Provide feedback here.
Feedback
Submit and view feedback for