Bagikan melalui


Menginterpretasikan prediksi model menggunakan Kepentingan Fitur Permutasi

Menggunakan Permutation Feature Importance (PFI), pelajari cara menginterpretasikan prediksi model pembelajaran mesin ML.NET. PFI memberikan kontribusi relatif yang dilakukan setiap fitur terhadap prediksi.

Model pembelajaran mesin sering dianggap sebagai kotak buram yang mengambil input dan menghasilkan output. Langkah-langkah perantara atau interaksi di antara fitur yang memengaruhi output jarang dipahami. Karena pembelajaran mesin diperkenalkan ke dalam lebih banyak aspek kehidupan sehari-hari seperti perawatan kesehatan, sangat penting untuk memahami mengapa model pembelajaran mesin membuat keputusan yang dilakukannya. Misalnya, jika diagnostik dibuat oleh model pembelajaran mesin, profesional perawatan kesehatan memerlukan cara untuk melihat faktor-faktor yang masuk ke membuat diagnosa tersebut. Memberikan diagnosis yang tepat dapat membuat perbedaan yang besar tentang apakah pasien memiliki pemulihan cepat atau tidak. Oleh karena itu, semakin tinggi tingkat penjelasan dalam model, semakin besar kepercayaan diri profesional layanan kesehatan harus menerima atau menolak keputusan yang dibuat oleh model.

Berbagai teknik digunakan untuk menjelaskan model, salah satunya PFI. PFI adalah teknik yang digunakan untuk menjelaskan model klasifikasi dan regresi yang terinspirasi oleh kertas Hutan Acak Breiman (lihat bagian 10). Pada tingkat tinggi, cara kerjanya adalah dengan mengacak data satu fitur secara acak pada satu waktu untuk seluruh himpunan data dan menghitung berapa banyak metrik performa minat menurun. Semakin besar perubahan, semakin penting fitur itu.

Selain itu, dengan menyoroti fitur yang paling penting, pembuat model dapat fokus menggunakan subset fitur yang lebih bermakna yang berpotensi mengurangi kebisingan dan waktu pelatihan.

Muat data

Fitur dalam himpunan data yang digunakan untuk sampel ini berada di kolom 1-12. Tujuannya adalah untuk memprediksi Price.

Kolom Fitur Deskripsi
1 Kriminalitas Tingkat kejahatan per kapita
2 Zona Perumahan Zona perumahan di kota
3 CommercialZones Zona non-perumahan di kota
4 NearWater Kedekatan dengan tubuh air
5 ToxicWasteLevels Tingkat toksisitas (PPM)
6 AverageRoomNumber Jumlah rata-rata kamar di rumah
7 Beranda Usia rumah
8 BusinessCenterDistance Jarak ke kawasan bisnis terdekat
9 HighwayAccess Kedekatan dengan jalan raya
10 Laju Pajak Tarif pajak properti
11 StudentTeacherRatio Rasio siswa terhadap guru
12 PercentPopulationBelowPoverty Persentase populasi yang hidup di bawah kemiskinan
13 Harga Harga rumah

Sampel himpunan data ditunjukkan di bawah ini:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Data dalam sampel ini dapat dimodelkan oleh kelas seperti HousingPriceData dan dimuat ke dalam IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Melatih model

Sampel kode di bawah ini menggambarkan proses pelatihan model regresi linier untuk memprediksi harga rumah.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Menjelaskan model dengan Pentingnya Fitur Permutasi (PFI)

Dalam ML.NET gunakan PermutationFeatureImportance metode untuk tugas Anda masing-masing.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Hasil penggunaan PermutationFeatureImportance pada himpunan data pelatihan adalah ImmutableArrayRegressionMetricsStatistics objek. RegressionMetricsStatistics memberikan statistik ringkasan seperti rata-rata dan simpangihan standar untuk beberapa pengamatan RegressionMetrics yang sama dengan jumlah permutasi yang ditentukan oleh permutationCount parameter .

Metrik yang digunakan untuk mengukur kepentingan fitur tergantung pada tugas pembelajaran mesin yang digunakan untuk menyelesaikan masalah Anda. Misalnya, tugas regresi dapat menggunakan metrik evaluasi umum seperti R-kuadrat untuk mengukur kepentingan. Untuk informasi selengkapnya tentang metrik evaluasi model, lihat mengevaluasi model ML.NET Anda dengan metrik.

Pentingnya, atau dalam hal ini, penurunan rata-rata absolut dalam metrik R-kuadrat yang dihitung PermutationFeatureImportance kemudian dapat diurutkan dari yang paling penting hingga paling tidak penting.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Mencetak nilai untuk setiap fitur di featureImportanceMetrics akan menghasilkan output yang mirip dengan yang di bawah ini. Perlu diingat bahwa Anda harus mengharapkan untuk melihat hasil yang berbeda karena nilai-nilai ini bervariasi berdasarkan data yang diberikan.

Fitur Ubah ke R-Kuadrat
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
Laju Pajak -0.008545
AverageRoomNumber -0.003949
Kriminalitas -0.003665
CommercialZones 0.002749
Beranda -0.002426
Zona Perumahan -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

Melihat lima fitur terpenting untuk himpunan data ini, harga rumah yang diprediksi oleh model ini dipengaruhi oleh kedekatannya dengan jalan raya, rasio guru siswa sekolah di daerah, kedekatan dengan pusat pekerjaan utama, tarif pajak properti dan jumlah rata-rata kamar di rumah.

Langkah berikutnya