Modellelőjelzések értelmezése permutációs funkció fontosságával

A permutációs funkció fontossága (PFI) használatával megtudhatja, hogyan értelmezheti ML.NET gépi tanulási modell előrejelzéseit. A PFI az egyes funkciók által az előrejelzéshez adott relatív hozzájárulást adja.

A gépi tanulási modelleket gyakran átlátszatlan dobozoknak tekintik, amelyek bemeneteket vesznek fel és kimenetet hoznak létre. A kimenetet befolyásoló funkciók köztes lépéseit vagy interakcióit ritkán értjük. Mivel a gépi tanulás a mindennapi élet több aspektusába, például az egészségügybe kerül, rendkívül fontos megérteni, hogy egy gépi tanulási modell miért hozza meg a döntéseket. Ha például egy gépi tanulási modell alapján történik a diagnózis, az egészségügyi szakembereknek meg kell vizsgálniuk azokat a tényezőket, amelyek a diagnózisok készítésébe mentek. A megfelelő diagnózis megadása nagy különbséget tehet annak eldöntésében, hogy a betegnek gyors felépülése van-e, vagy sem. Ezért minél magasabb a modell magyarázhatósági szintje, annál nagyobb a bizalom, hogy az egészségügyi szakembereknek el kell fogadniuk vagy el kell utasítaniuk a modell által hozott döntéseket.

A modellek magyarázatára különböző technikákat használnak, amelyek közül az egyik a PFI. A PFI a Breiman Random Forests-tanulmánya által inspirált besorolási és regressziós modellek magyarázatára szolgáló technika (lásd a 10. szakaszt). Magas szinten úgy működik, hogy véletlenszerűen összeadja az adatokat egyszerre egy funkcióval a teljes adatkészlethez, és kiszámítja, hogy mennyivel csökken az érdeklődési kör metrikája. Minél nagyobb a változás, annál fontosabb a funkció.

Emellett a legfontosabb funkciók kiemelésével a modellkészítők az értelmesebb funkciók egy részhalmazára összpontosíthatnak, ami csökkentheti a zajt és a betanítási időt.

Az adatok betöltése

A mintához használt adathalmaz funkciói az 1–12. oszlopban találhatók. A cél az előrejelzés Price.

Oszlop Szolgáltatás Leírás
0 Bűnözési arány Egy főre jutó bűnözési arány
2 Lakókörnyezetek Lakózónák a városban
3 Kereskedelmi zónák Nem lakossági zónák a városban
4 NearWater Víztest közelsége
5 ToxicWasteLevels Toxicitási szintek (PPM)
6 AverageRoomNumber A ház helyiségeinek átlagos száma
7 HomeAge Az otthon kora
8 BusinessCenterDistance Távolság a legközelebbi üzleti körzethez
9 HighwayAccess Autópályák közelsége
10 TaxRate Ingatlanadó mértéke
11 StudentTeacherRatio Diákok és tanárok aránya
12 PercentPopulationBelowPoverty A szegénység alatt élő népesség százaléka
13 Ár Az otthon ára

Az adathalmaz mintája az alábbiakban látható:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Az ebben a mintában szereplő adatokat egy osztályhoz hasonló HousingPriceData osztály modellezheti, és betöltheti egy IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

A modell betanítása

Az alábbi kódminta egy lineáris regressziós modell betanításának folyamatát mutatja be a lakásárak előrejelzéséhez.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

A modell ismertetése a permutációs funkció fontosságával (PFI)

A ML.NET a megfelelő feladathoz használja a PermutationFeatureImportance metódust.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

A betanítási adatkészleten való használat PermutationFeatureImportance eredménye objektumokból állRegressionMetricsStatistics.ImmutableArray RegressionMetricsStatisticsolyan összefoglaló statisztikákat biztosít, mint a középérték és a szórás a paraméter által permutationCount megadott permutációk számával RegressionMetrics egyenlő több megfigyelés esetében.

A funkció fontosságának mérésére használt metrika a probléma megoldásához használt gépi tanulási feladattól függ. A regressziós feladatok például használhatnak egy általános értékelési metrikát, például az R-négyzetet a fontosság mérésére. A modellértékelési metrikákkal kapcsolatos további információkért tekintse meg a ML.NET modell kiértékelését metrikákkal.

A fontosság, vagy ebben az esetben az R-négyzetes metrikák abszolút átlagos csökkenése, amelyet PermutationFeatureImportance kiszámítottunk, a legfontosabbtól a legkevésbé fontosig rendezhető.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Az egyes funkciók featureImportanceMetrics értékeinek nyomtatása az alábbihoz hasonló kimenetet eredményezne. Ne feledje, hogy különböző eredményeket kell látnia, mert ezek az értékek a megadott adatoktól függően változnak.

Szolgáltatás Váltás r-négyzetre
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
TaxRate -0.008545
AverageRoomNumber -0.003949
Bűnözési arány -0.003665
Kereskedelmi zónák 0.002749
HomeAge -0.002426
Lakókörnyezetek -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

Az adathalmaz öt legfontosabb jellemzőjét vizsgálva a modell által előrejelzett ház árát befolyásolja az autópályák közelsége, a környéken található iskolák tanulóinak aránya, a nagyobb munkaügyi központok közelsége, az ingatlanadó mértéke és az otthonban lévő szobák átlagos száma.

Következő lépések