Bagikan melalui


Melatih model pembelajaran mesin menggunakan validasi silang

Pelajari cara menggunakan validasi silang untuk melatih model pembelajaran mesin yang lebih kuat di ML.NET.

Validasi silang adalah teknik evaluasi pelatihan dan model yang membagi data menjadi beberapa partisi dan melatih beberapa algoritma pada partisi ini. Teknik ini meningkatkan ketahanan model dengan mengecualikan data dari proses pelatihan. Selain meningkatkan performa pada pengamatan yang tidak terlihat, di lingkungan yang dibatasi data dapat menjadi alat yang efektif untuk melatih model dengan himpunan data yang lebih kecil.

Data dan model data

Data yang diberikan dari file yang memiliki format berikut:

Size (Sq. ft.), HistoricalPrice1 ($), HistoricalPrice2 ($), HistoricalPrice3 ($), Current Price ($)
620.00, 148330.32, 140913.81, 136686.39, 146105.37
550.00, 557033.46, 529181.78, 513306.33, 548677.95
1127.00, 479320.99, 455354.94, 441694.30, 472131.18
1120.00, 47504.98, 45129.73, 43775.84, 46792.41

Data dapat dimodelkan oleh kelas seperti HousingData dan dimuat ke dalam IDataView.

public class HousingData
{
    [LoadColumn(0)]
    public float Size { get; set; }

    [LoadColumn(1, 3)]
    [VectorType(3)]
    public float[] HistoricalPrices { get; set; }

    [LoadColumn(4)]
    [ColumnName("Label")]
    public float CurrentPrice { get; set; }
}

Menyiapkan data

Pra-proses data sebelum menggunakannya untuk membangun model pembelajaran mesin. Dalam sampel ini, kolom Size dan HistoricalPrices digabungkan ke dalam vektor fitur tunggal, yang merupakan output ke kolom baru yang disebut Features menggunakan metode Concatenate. Selain memasukkan data ke dalam format yang diharapkan oleh algoritma ML.NET, menggabungkan kolom mengoptimalkan operasi berikutnya dalam alur dengan menerapkan operasi sekali untuk kolom yang digabungkan alih-alih masing-masing kolom terpisah.

Setelah kolom digabungkan ke dalam satu vektor, NormalizeMinMax diterapkan ke kolom Features untuk mendapatkan Size dan HistoricalPrices dalam rentang yang sama antara 0-1.

// Define data prep estimator
IEstimator<ITransformer> dataPrepEstimator =
    mlContext.Transforms.Concatenate("Features", new string[] { "Size", "HistoricalPrices" })
        .Append(mlContext.Transforms.NormalizeMinMax("Features"));

// Create data prep transformer
ITransformer dataPrepTransformer = dataPrepEstimator.Fit(data);

// Transform data
IDataView transformedData = dataPrepTransformer.Transform(data);

Melatih model dengan validasi silang

Setelah data diproseksi sebelumnya, saatnya untuk melatih model. Pertama, pilih algoritma yang paling selaras dengan tugas pembelajaran mesin yang akan dilakukan. Karena nilai yang diprediksi adalah nilai berkelanjutan numerik, tugasnya adalah regresi. Salah satu algoritma regresi yang diterapkan oleh ML.NET adalah algoritma StochasticDualCoordinateAscentCoordinator. Untuk melatih model dengan validasi silang, gunakan metode CrossValidate.

Nota

Meskipun sampel ini menggunakan model regresi linier, CrossValidate berlaku untuk semua tugas pembelajaran mesin lainnya di ML.NET kecuali Deteksi Anomali.

// Define StochasticDualCoordinateAscent algorithm estimator
IEstimator<ITransformer> sdcaEstimator = mlContext.Regression.Trainers.Sdca();

// Apply 5-fold cross validation
var cvResults = mlContext.Regression.CrossValidate(transformedData, sdcaEstimator, numberOfFolds: 5);

CrossValidate melakukan operasi berikut:

  1. Mempartisi data ke dalam sejumlah partisi yang sama dengan nilai yang ditentukan dalam parameter numberOfFolds. Hasil dari setiap partisi adalah objek TrainTestData.
  2. Model dilatih pada setiap partisi menggunakan estimator algoritma pembelajaran mesin yang ditentukan pada himpunan data pelatihan.
  3. Performa setiap model dievaluasi menggunakan metode Evaluate pada himpunan data pengujian.
  4. Model bersama dengan metriknya dikembalikan untuk setiap model.

Hasil yang disimpan dalam cvResults adalah kumpulan objek CrossValidationResult. Objek ini mencakup model terlatih serta metrik yang dapat diakses melalui properti Model dan Metrics masing-masing. Dalam sampel ini, properti Model berjenis ITransformer dan properti Metrics berjenis RegressionMetrics.

Evaluasi model

Metrik untuk masing-masing model yang dilatih dapat diakses melalui properti Metrics dari objek CrossValidationResult individu. Dalam hal ini, metrik R-Squared diakses dan disimpan dalam variabel rSquared.

IEnumerable<double> rSquared =
    cvResults
        .Select(fold => fold.Metrics.RSquared);

Jika Anda memeriksa konten variabel rSquared, output harus lima nilai mulai dari 0-1 di mana lebih dekat ke 1 berarti yang terbaik. Menggunakan metrik seperti R-Squared, pilih model dari performa terbaik hingga terburuk. Kemudian, pilih model teratas untuk membuat prediksi atau melakukan operasi tambahan dengan.

// Select all models
ITransformer[] models =
    cvResults
        .OrderByDescending(fold => fold.Metrics.RSquared)
        .Select(fold => fold.Model)
        .ToArray();

// Get Top Model
ITransformer topModel = models[0];