Condividi tramite


Interpretare le previsioni del modello con Permutation Feature Importance

Informazioni su come interpretare le previsioni dei modelli di Machine Learning ML.NET usando Permutation Feature Importance (PFI). PFI fornisce il contributo relativo di ogni funzionalità a una previsione.

I modelli di Machine Learning sono spesso considerati scatole opache che accettano gli input e generano un output. I passaggi intermedi o le interazioni tra le caratteristiche che influenzano l'output vengono riconosciute raramente. Poiché il Machine Learning viene ora applicato a più aspetti delle attività quotidiane, ad esempio nel settore sanitario, è di importanza fondamentale comprenderne in che modo un modello di Machine Learning prende le decisioni. Ad esempio, se le diagnosi vengono effettuate tramite un modello di Machine Learning, i professionisti del settore sanitario necessitano di un modo per esaminare i fattori che hanno contribuito alle diagnosi. Una diagnosi corretta può fare una grande differenza nella velocità di recupero di un paziente. Più è dettagliato il livello di descrizione di un modello, maggiore sarà la fiducia dei professionisti del settore sanitario nell'accettare o rifiutare le decisioni prese dal modello.

Per descrivere i modelli vengono usate tecniche diverse, tra cui PFI. PFI è una tecnica usata per descrivere i modelli di classificazione e regressione basata sul documento Random Forests di Breiman (vedere la sezione 10). A livello generale, il funzionamento è basato sulla selezione in ordine casuale dei dati una caratteristica alla volta per l'intero set di dati e sul calcolo della diminuzione della metrica delle prestazioni dell'interesse. Maggiore è la modifica, maggiore è l'importanza della funzionalità.

Inoltre, evidenziando le funzionalità più importanti, i generatori di modelli possono concentrarsi sull'uso di un subset di funzionalità più significative che possono potenzialmente ridurre il rumore e i tempi di training.

Caricare i dati

Le funzionalità del set di dati usate per questo esempio si trovano dalla colonna 1 alla colonna 12. L'obiettivo consiste nella previsione di Price.

Colonna Funzionalità Descrizione
1 CrimeRate Tasso di criminalità pro capite
2 ResidentialZones Zone residenziali della città
3 CommercialZones Zone non residenziali della città
4 NearWater Prossimità al corpo idrico
5 ToxicWasteLevels Livelli di tossicità (PPM)
6 AverageRoomNumber Numero medio di locali di un'abitazione
7 HomeAge Età dell'abitazione
8 BusinessCenterDistance Distanza dal quartiere direzionale più vicino
9 HighwayAccess Prossimità alle autostrade
10 TaxRate Aliquota dell'imposta sugli immobili
11 StudentTeacherRatio Rapporto studenti-insegnanti
12 PercentPopulationBelowPoverty Percentuale di popolazione che vive al di sotto della soglia di povertà
13 Price Prezzo dell'abitazione

Di seguito è riportato un esempio del set di dati:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

I dati in questo esempio possono essere modellati in base a una classe come HousingPriceData e caricati in un'interfaccia IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Eseguire il training del modello

L'esempio di codice seguente illustra il processo di training di un modello di regressione lineare per la previsione dei prezzi delle abitazioni.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Descrivere il modello con Permutation Feature Importance (PFI)

In ML.NET usare il metodo PermutationFeatureImportance per la rispettiva attività.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

Se viene usato PermutationFeatureImportance nel training set, viene restituita una ImmutableArray di oggetti RegressionMetricsStatistics. RegressionMetricsStatistics fornisce statistiche di riepilogo come la media e la deviazione standard per più osservazioni di RegressionMetricscorrispondente al numero di permutazioni specificate dal parametro permutationCount.

La metrica usata per misurare l'importanza della funzionalità dipende dall'attività di Machine Learning usata per risolvere il problema. Ad esempio, le attività di regressione possono usare una metrica di valutazione comune, ad esempio R quadrato, per misurare l'importanza. Per altre informazioni sulle metriche di valutazione del modello, vedere Valutare il modello di ML.NET con le metriche.

L'importanza, o in questo caso la diminuzione della media assoluta nella metrica R quadrato calcolata da PermutationFeatureImportance, può quindi essere ordinata dalla più importante alla meno importante.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

La stampa dei valori di ogni funzionalità in featureImportanceMetrics genererà un output simile al seguente. Tenere presente che probabilmente verranno visualizzati risultati diversi poiché questi valori variano in base ai dati ricevuti.

Funzionalità Modifica in R quadrato
HighwayAccess -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistance -0.010491
TaxRate -0.008545
AverageRoomNumber -0.003949
CrimeRate -0.003665
CommercialZones 0.002749
HomeAge -0.002426
ResidentialZones -0.002319
NearWater 0.000203
PercentPopulationLivingBelowPoverty 0.000031
ToxicWasteLevels -0.000019

Esaminando le cinque caratteristiche più importanti di questo set di dati, è possibile osservare che il prezzo di un'abitazione previsto da questo modello è influenzato dalla vicinanza alle autostrade, dal rapporto studenti-insegnanti nelle scuole della zona, dalla vicinanza ai maggiori centri di occupazione, dall'aliquota dell'imposta sugli immobili e dal numero medio di locali dell'abitazione.

Passaggi successivi