Trénování modelu strojového učení pomocí křížového ověřování
Naučte se používat křížové ověřování k trénování robustnějších modelů strojového učení v ML.NET.
Křížové ověření je technika trénování a vyhodnocení modelu, která rozdělí data do několika oddílů a trénuje více algoritmů v těchto oddílech. Tato technika zlepšuje odolnost modelu tím, že z trénovacího procesu vydrží data. Kromě zlepšení výkonu u nezoznaných pozorování může být v prostředích s omezenými daty efektivním nástrojem pro trénování modelů s menší datovou sadou.
Data a datový model
Data ze souboru, který má následující formát:
Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41
Data lze modelovat podle třídy jako HousingData
a načíst do objektu IDataView
.
public class HousingData
{
[LoadColumn(0)]
public float Size { get; set; }
[LoadColumn(1, 3)]
[VectorType(3)]
public float[] HistoricalPrices { get; set; }
[LoadColumn(4)]
[ColumnName("Label")]
public float CurrentPrice { get; set; }
}
Příprava dat
Před použitím dat před sestavením modelu strojového učení je předzpracujte. V této ukázce Size
HistoricalPrices
se sloupce zkombinují do jednoho vektoru funkce, který je výstupem do nového sloupce volaného Features
metodou Concatenate
. Kromě získání dat do formátu očekávaného algoritmy ML.NET zřetězení sloupců optimalizuje následné operace v kanálu použitím operace jednou pro zřetězený sloupec místo jednotlivých samostatných sloupců.
Jakmile se sloupce zkombinují do jednoho vektoru, použije se u sloupce, NormalizeMinMax
který získá Size
a HistoricalPrices
ve stejném rozsahu mezi 0–Features
1.
// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
.Append(mlContext.Transforms.NormalizeMinMax("Features"));
// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);
// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);
Trénování modelu s křížovým ověřováním
Po předběžném zpracování dat je čas model vytrénovat. Nejprve vyberte algoritmus, který je nejvíce v souladu s úlohou strojového učení, který se má provést. Vzhledem k tomu, že predikovaná hodnota je číselně souvislá hodnota, je úkol regresní. Jedním z regresních algoritmů implementovaných ML.NET je StochasticDualCoordinateAscentCoordinator
algoritmus. K trénování modelu pomocí křížového ověření použijte metodu CrossValidate
.
Poznámka:
I když tato ukázka používá lineární regresní model, crossValidate se vztahuje na všechny ostatní úlohy strojového učení v ML.NET s výjimkou detekce anomálií.
// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();
// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);
CrossValidate
provádí následující operace:
- Rozdělí data do několika oddílů, které se rovnají hodnotě zadané v parametru
numberOfFolds
. Výsledkem každého oddíluTrainTestData
je objekt. - Model se vytrénuje na každém oddílu pomocí zadaného estimátoru algoritmů strojového učení v trénovací sadě dat.
- Výkon každého modelu se vyhodnocuje pomocí
Evaluate
metody testovací datové sady. - Model spolu s jeho metrikami se vrátí pro každý z těchto modelů.
Výsledek uložený v cvResults
kolekci CrossValidationResult
objektů. Tento objekt zahrnuje vytrénovaný model i metriky, které jsou přístupné jak ve formě, tak Model
Metrics
i vlastnosti. V této ukázce Model
je vlastnost typu ITransformer
a Metrics
vlastnost je typu RegressionMetrics
.
Vyhodnocení modelu
K metrikám pro různé vytrénované modely je možné přistupovat prostřednictvím Metrics
vlastnosti jednotlivého CrossValidationResult
objektu. V tomto případě je metrika R-Squared přístupná a uložená v proměnné rSquared
.
IEnumerable<double> rSquared =
cvResults
.Select(fold => fold.Metrics.RSquared);
Pokud zkontrolujete obsah rSquared
proměnné, měl by mít výstup pět hodnot v rozsahu od 0 do 1, kde je to nejlepší. Pomocí metrik, jako je R-Squared, vyberte modely od nejlepších po nejhorší výkon. Pak výběrem horního modelu proveďte předpovědi nebo proveďte další operace.
// Select all models
ITransformer[] models =
cvResults
.OrderByDescending(fold => fold.Metrics.RSquared)
.Select(fold => fold.Model)
.ToArray();
// Get Top Model
ITransformer topModel = models[0];
Váš názor
https://aka.ms/ContentUserFeedback.
Připravujeme: V průběhu roku 2024 budeme postupně vyřazovat problémy z GitHub coby mechanismus zpětné vazby pro obsah a nahrazovat ho novým systémem zpětné vazby. Další informace naleznete v tématu:Odeslat a zobrazit názory pro