Sdílet prostřednictvím


Opětovné trénování modelu

Naučte se přetrénovat model strojového učení v ML.NET.

Svět a její data se neustále mění. Modely se proto musí také měnit a aktualizovat. ML.NET poskytuje funkce pro přetrénování modelů, které používají naučené parametry modelu jako výchozí bod pro neustálé sestavování na předchozích zkušenostech, a ne vždy od začátku.

V ML.NET se dají přetrénovat následující algoritmy:

Načtení předem natrénovaného modelu

Nejprve načtěte předem natrénovaný model do aplikace. Další informace o načítání trénovacích kanálů a modelů najdete v tématu Uložení a načtení natrénovaného modelu.

// Create MLContext
MLContext mlContext = new MLContext();

// Define DataViewSchema of data prep pipeline and trained model
DataViewSchema dataPrepPipelineSchema, modelSchema;

// Load data preparation pipeline
ITransformer dataPrepPipeline = mlContext.Model.Load("data_preparation_pipeline.zip", out dataPrepPipelineSchema);

// Load trained model
ITransformer trainedModel = mlContext.Model.Load("ogd_model.zip", out modelSchema);

Extrahování předtrénovaných parametrů modelu

Po načtení modelu extrahujte naučené parametry modelu přístupem k Model vlastnosti předtrénovaného modelu. Předtrénovaný model byl trénován pomocí lineární regresního modelu OnlineGradientDescentTrainer, který vytvoří RegressionPredictionTransformer výstup LinearRegressionModelParameters. Tyto parametry modelu obsahují naučené předsudky a váhy nebo koeficienty modelu. Tyto hodnoty se používají jako výchozí bod nového přetrénovaného modelu.

// Extract trained model parameters
LinearRegressionModelParameters originalModelParameters =
    ((ISingleFeaturePredictionTransformer<object>)trainedModel).Model as LinearRegressionModelParameters;

Poznámka:

Výstup parametrů modelu závisí na použitém algoritmu. Například používá , zatímco LbfgsMaximumEntropyMulticlassTrainer výstupy MaximumEntropyModelParameters.LinearRegressionModelParametersOnlineGradientDescentTrainer Při extrahování parametrů modelu přetypujte na příslušný typ.

Opětovné trénování modelu

Proces opětovného trénování modelu se neliší od trénování modelu. Jediným rozdílem je, že Fit metoda kromě dat přebírá jako vstup i původní naučené parametry modelu a používá je jako výchozí bod v procesu opětovného trénování.

// New Data
HousingData[] housingData = new HousingData[]
{
    new HousingData
    {
        Size = 850f,
        HistoricalPrices = new float[] { 150000f,175000f,210000f },
        CurrentPrice = 205000f
    },
    new HousingData
    {
        Size = 900f,
        HistoricalPrices = new float[] { 155000f, 190000f, 220000f },
        CurrentPrice = 210000f
    },
    new HousingData
    {
        Size = 550f,
        HistoricalPrices = new float[] { 99000f, 98000f, 130000f },
        CurrentPrice = 180000f
    }
};

//Load New Data
IDataView newData = mlContext.Data.LoadFromEnumerable<HousingData>(housingData);

// Preprocess Data
IDataView transformedNewData = dataPrepPipeline.Transform(newData);

// Retrain model
RegressionPredictionTransformer<LinearRegressionModelParameters> retrainedModel =
    mlContext.Regression.Trainers.OnlineGradientDescent()
        .Fit(transformedNewData, originalModelParameters);

V tomto okamžiku můžete model znovu natrénovaný uložit a použít ho ve své aplikaci. Další informace najdete v tématu uložení a načtení natrénovaného modelu a vytváření předpovědí pomocí trénovaných průvodců modelu .

Porovnání parametrů modelu

Jak poznáte, jestli k přetrénování skutečně došlo? Jedním ze způsobů by bylo porovnat, jestli se parametry přetrénovaného modelu liší od parametrů původního modelu. Následující ukázka kódu porovná původní hodnoty s přetrénovanými hmotnostmi modelu a vypíše je do konzoly.

// Extract Model Parameters of re-trained model
LinearRegressionModelParameters retrainedModelParameters = retrainedModel.Model as LinearRegressionModelParameters;

// Inspect Change in Weights
var weightDiffs =
    originalModelParameters.Weights.Zip(
        retrainedModelParameters.Weights, (original, retrained) => original - retrained).ToArray();

Console.WriteLine("Original | Retrained | Difference");
for(int i=0;i < weightDiffs.Count();i++)
{
    Console.WriteLine($"{originalModelParameters.Weights[i]} | {retrainedModelParameters.Weights[i]} | {weightDiffs[i]}");
}

Následující tabulka ukazuje, jak může výstup vypadat.

Původní Znovu natrénováno Rozdíl
33039.86 56293.76 -23253.9
29099.14 49586.03 -20486.89
28938.38 48609.23 -19670.85
30484.02 53745.43 -23261.41