Interpretare le previsioni del modello con Permutation Feature Importance
Informazioni su come interpretare le previsioni dei modelli di Machine Learning ML.NET usando Permutation Feature Importance (PFI). PFI fornisce il contributo relativo di ogni funzionalità a una previsione.
I modelli di Machine Learning sono spesso considerati scatole opache che accettano gli input e generano un output. I passaggi intermedi o le interazioni tra le caratteristiche che influenzano l'output vengono riconosciute raramente. Poiché il Machine Learning viene ora applicato a più aspetti delle attività quotidiane, ad esempio nel settore sanitario, è di importanza fondamentale comprenderne in che modo un modello di Machine Learning prende le decisioni. Ad esempio, se le diagnosi vengono effettuate tramite un modello di Machine Learning, i professionisti del settore sanitario necessitano di un modo per esaminare i fattori che hanno contribuito alle diagnosi. Una diagnosi corretta può fare una grande differenza nella velocità di recupero di un paziente. Più è dettagliato il livello di descrizione di un modello, maggiore sarà la fiducia dei professionisti del settore sanitario nell'accettare o rifiutare le decisioni prese dal modello.
Per descrivere i modelli vengono usate tecniche diverse, tra cui PFI. PFI è una tecnica usata per descrivere i modelli di classificazione e regressione basata sul documento Random Forests di Breiman (vedere la sezione 10). A livello generale, il funzionamento è basato sulla selezione in ordine casuale dei dati una caratteristica alla volta per l'intero set di dati e sul calcolo della diminuzione della metrica delle prestazioni dell'interesse. Maggiore è la modifica, maggiore è l'importanza della funzionalità.
Inoltre, evidenziando le funzionalità più importanti, i generatori di modelli possono concentrarsi sull'uso di un subset di funzionalità più significative che possono potenzialmente ridurre il rumore e i tempi di training.
Caricare i dati
Le funzionalità del set di dati usate per questo esempio si trovano dalla colonna 1 alla colonna 12. L'obiettivo consiste nella previsione di Price
.
Colonna | Funzionalità | Descrizione |
---|---|---|
1 | CrimeRate | Tasso di criminalità pro capite |
2 | ResidentialZones | Zone residenziali della città |
3 | CommercialZones | Zone non residenziali della città |
4 | NearWater | Prossimità al corpo idrico |
5 | ToxicWasteLevels | Livelli di tossicità (PPM) |
6 | AverageRoomNumber | Numero medio di locali di un'abitazione |
7 | HomeAge | Età dell'abitazione |
8 | BusinessCenterDistance | Distanza dal quartiere direzionale più vicino |
9 | HighwayAccess | Prossimità alle autostrade |
10 | TaxRate | Aliquota dell'imposta sugli immobili |
11 | StudentTeacherRatio | Rapporto studenti-insegnanti |
12 | PercentPopulationBelowPoverty | Percentuale di popolazione che vive al di sotto della soglia di povertà |
13 | Price | Prezzo dell'abitazione |
Di seguito è riportato un esempio del set di dati:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
I dati in questo esempio possono essere modellati in base a una classe come HousingPriceData
e caricati in un'interfaccia IDataView
.
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Eseguire il training del modello
L'esempio di codice seguente illustra il processo di training di un modello di regressione lineare per la previsione dei prezzi delle abitazioni.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Descrivere il modello con Permutation Feature Importance (PFI)
In ML.NET usare il metodo PermutationFeatureImportance
per la rispettiva attività.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
Se viene usato PermutationFeatureImportance
nel training set, viene restituita una ImmutableArray
di oggetti RegressionMetricsStatistics
. RegressionMetricsStatistics
fornisce statistiche di riepilogo come la media e la deviazione standard per più osservazioni di RegressionMetrics
corrispondente al numero di permutazioni specificate dal parametro permutationCount
.
La metrica usata per misurare l'importanza della funzionalità dipende dall'attività di Machine Learning usata per risolvere il problema. Ad esempio, le attività di regressione possono usare una metrica di valutazione comune, ad esempio R quadrato, per misurare l'importanza. Per altre informazioni sulle metriche di valutazione del modello, vedere Valutare il modello di ML.NET con le metriche.
L'importanza, o in questo caso la diminuzione della media assoluta nella metrica R quadrato calcolata da PermutationFeatureImportance
, può quindi essere ordinata dalla più importante alla meno importante.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
La stampa dei valori di ogni funzionalità in featureImportanceMetrics
genererà un output simile al seguente. Tenere presente che probabilmente verranno visualizzati risultati diversi poiché questi valori variano in base ai dati ricevuti.
Funzionalità | Modifica in R quadrato |
---|---|
HighwayAccess | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistance | -0.010491 |
TaxRate | -0.008545 |
AverageRoomNumber | -0.003949 |
CrimeRate | -0.003665 |
CommercialZones | 0.002749 |
HomeAge | -0.002426 |
ResidentialZones | -0.002319 |
NearWater | 0.000203 |
PercentPopulationLivingBelowPoverty | 0.000031 |
ToxicWasteLevels | -0.000019 |
Esaminando le cinque caratteristiche più importanti di questo set di dati, è possibile osservare che il prezzo di un'abitazione previsto da questo modello è influenzato dalla vicinanza alle autostrade, dal rapporto studenti-insegnanti nelle scuole della zona, dalla vicinanza ai maggiori centri di occupazione, dall'aliquota dell'imposta sugli immobili e dal numero medio di locali dell'abitazione.