Interpretación de las predicciones del modelo mediante la importancia de características de permutación

Con la importancia de características de permutación (PFI), aprenda a interpretar las predicciones del modelo de aprendizaje automático de ML.NET. PFI proporciona la contribución relativa que cada característica aporta a una predicción.

A menudo se piensa en los modelos de Machine Learning como cajas opacas que toman entradas y generan salidas. Rara vez se entienden los pasos intermedios o las interacciones entre las características que afectan a la salida. A medida que el aprendizaje automático se introduce en otros aspectos de la vida diaria como la asistencia sanitaria, es de vital importancia comprender por qué un modelo de Machine Learning toma esas decisiones. Por ejemplo, si un modelo de Machine Learning realiza diagnósticos, los profesionales sanitarios necesitan una forma de buscar en los factores que se incluyeron en la realización de dicho diagnóstico. Proporcionar el diagnóstico correcto puede marcar una gran diferencia en el hecho de que un paciente tenga una recuperación rápida o no. Por lo tanto, cuanto mayor sea el nivel de explicación en un modelo, mayor será la confianza que tengan los profesionales sanitarios para aceptar o rechazar las decisiones tomadas por el modelo.

Se utilizan diversas técnicas para explicar modelos, una de los cuales es PFI. PFI es una técnica utilizada para explicar los modelos de clasificación y regresión que se inspira en el artículo Random Forests de Breiman (consulte la sección 10). En un nivel alto, esto funciona de manera es revolviendo los datos de manera aleatoria en una característica a la vez para todo el conjunto de datos y calculando cuánto se reduce la métrica de rendimiento de interés. Cuanto mayor sea el cambio, más importante será esa característica.

Además, al resaltar las características más importantes, los compiladores del modelo pueden centrarse en el uso de un subconjunto de características más significativas que potencialmente pueden reducir el ruido y el tiempo de entrenamiento.

Carga de los datos

Las características del conjunto de datos que se usa para este ejemplo están en las columnas 1 a 12. El objetivo es predecir Price.

Columna Característica Descripción
1 CrimeRate Tasa de criminalidad per cápita
2 ResidentialZones Zonas residenciales en la ciudad
3 CommercialZones Zonas no residenciales en la ciudad
4 NearWater Proximidad a un cuerpo de agua
5 ToxicWasteLevels Niveles de toxicidad (PPM)
6 AverageRoomNumber Número promedio de salas en casa
7 HomeAge Antigüedad de la casa
8 BusinessCenterDistance Distancia al distrito comercial más cercano
9 HighwayAccess Proximidad a las autopistas
10 TaxRate Tasa de impuestos sobre la propiedad
11 StudentTeacherRatio Proporción de alumnos por profesores
12 PercentPopulationBelowPoverty Porcentaje de la población que vive por debajo de pobreza
13 Price Precio de la vivienda

A continuación, se muestra un ejemplo del conjunto de datos:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Los datos en este ejemplo se pueden modelar mediante una clase como HousingPriceData y se pueden cargar en IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Entrenar el modelo

El ejemplo de código siguiente ilustra el proceso de entrenamiento de un modelo de regresión lineal para predecir los precios de la vivienda.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Explicar el modelo con la importancia de características de permutación (PFI)

En ML.NET, use el método PermutationFeatureImportance para la tarea correspondiente.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

El resultado del uso de PermutationFeatureImportance en el conjunto de datos de entrenamiento es un ImmutableArray de objetos RegressionMetricsStatistics. RegressionMetricsStatistics proporciona estadísticas de resumen, como desviación media y estándar para diversas observaciones de RegressionMetrics equivalentes al número de permutaciones especificadas por el parámetro permutationCount.

La métrica que se usa para medir la importancia de las características depende de la tarea de aprendizaje automático que se usa para resolver el problema. Por ejemplo, las tareas de regresión pueden usar una métrica de evaluación común, como R cuadrado, para medir la importancia. Para más información sobre las métricas de evaluación de modelos, vea Evaluación de su modelo de ML.NET con métricas.

La importancia o, en este caso, la disminución de la media absoluta de la métrica de R cuadrado calculada con PermutationFeatureImportance, se puede ordenar de las más importante a la menos importante.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Imprimir los valores para cada una de las características de featureImportanceMetrics generaría resultados similares a los siguientes. Tenga en cuenta que debe esperar para ver resultados diferentes porque estos valores varían en función de los datos que se proporcionan.

Característica Cambiar a R cuadrado
HighwayAccess -0,042731
StudentTeacherRatio -0,012730
BusinessCenterDistance -0,010491
TaxRate -0,008545
AverageRoomNumber -0,003949
CrimeRate -0,003665
CommercialZones 0,002749
HomeAge -0,002426
ResidentialZones -0,002319
NearWater 0,000203
PercentPopulationLivingBelowPoverty 0,000031
ToxicWasteLevels -0,000019

Al echar un vistazo a las cinco características más importantes de este conjunto de datos, el precio de una casa previsto por este modelo viene determinado por su proximidad a las autopistas, la proporción de alumnos por profesores en la zona, la proximidad a los principales centros de empleo, la tasa de impuestos sobre la propiedad y el número promedio de dormitorios en una casa.

Pasos siguientes