Interpret model predictions using Permutation Feature Importance
Raksts
Using Permutation Feature Importance (PFI), learn how to interpret ML.NET machine learning model predictions. PFI gives the relative contribution each feature makes to a prediction.
Machine learning models are often thought of as opaque boxes that take inputs and generate an output. The intermediate steps or interactions among the features that influence the output are rarely understood. As machine learning is introduced into more aspects of everyday life, such as healthcare, it's of utmost importance to understand why a machine learning model makes the decisions it does. For example, if diagnoses are made by a machine learning model, healthcare professionals need a way to look into the factors that went into making that diagnosis. Providing the right diagnosis could make a great difference on whether a patient has a speedy recovery or not. Therefore the higher the level of explainability in a model, the greater confidence healthcare professionals have to accept or reject the decisions made by the model.
Various techniques are used to explain models, one of which is PFI. PFI is a technique used to explain classification and regression models that's inspired by Breiman's Random Forests paper (see section 10). At a high level, the way it works is by randomly shuffling data one feature at a time for the entire dataset and calculating how much the performance metric of interest decreases. The larger the change, the more important that feature is.
Additionally, by highlighting the most important features, model builders can focus on using a subset of more meaningful features, which can potentially reduce noise and training time.
Load the data
The features in the dataset used for this sample are in columns 1-12. The goal is to predict Price.
// Use the model to make predictionsvar transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
The metric used to measure feature importance depends on the machine learning task used to solve your problem. For example, regression tasks might use a common evaluation metric such as R-squared to measure importance. For more information on model evaluation metrics, see evaluate your ML.NET model with metrics.
The importance, or in this case, the absolute average decrease in R-squared metric, calculated by PermutationFeatureImportance, can then be ordered from most important to least important.
C#
// Order features by importance.var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Printing the values for each of the features in featureImportanceMetrics generates output similar to the output that follows. You should expect to see different results because these values vary based on the data that they're given.
Feature
Change to R-Squared
HighwayAccess
-0.042731
StudentTeacherRatio
-0.012730
BusinessCenterDistance
-0.010491
TaxRate
-0.008545
AverageRoomNumber
-0.003949
CrimeRate
-0.003665
CommercialZones
0.002749
HomeAge
-0.002426
ResidentialZones
-0.002319
NearWater
0.000203
PercentPopulationLivingBelowPoverty
0.000031
ToxicWasteLevels
-0.000019
If you look at the five most important features for this dataset, the price of a house predicted by this model is influenced by its proximity to highways, student teacher ratio of schools in the area, proximity to major employment centers, property tax rate, and average number of rooms in the home.
Šī satura avotu var atrast vietnē GitHub, kur varat arī izveidot un pārskatīt problēmas un atgādāšanas pieprasījumus. Lai iegūtu papildinformāciju, skatiet mūsu līdzstrādnieku rokasgrāmatu.
.NET atsauksmes
.NET ir atklātā pirmkoda projekts. Atlasiet saiti, lai sniegtu atsauksmes:
Pievienojieties meetup sērijai, lai kopā ar citiem izstrādātājiem un ekspertiem izveidotu mērogojamus AI risinājumus, kuru pamatā ir reālas lietošanas gadījumi.
When we think of machine learning, we often focus on the training process. A small amount of preparation before this process can not only speed up and improve learning, but also give us some confidence about how well our models will work when faced with data we have never seen before.
Learn how to use cross validation to build more robust machine learning models in ML.NET. Cross-validation is a training and model evaluation technique that splits the data into several partitions and trains multiple algorithms on these partitions.
Learn how to build machine learning models, collect metrics, and measure performance with ML.NET. A machine learning model identifies patterns within training data to make predictions using new data.
The ML.NET Automated ML (AutoML) API automates the model building process and generates a model ready for deployment. Learn the options that you can use to configure automated machine learning tasks.