Interpretar previsões do modelo usando Importância do Recurso de Permutação

Usando a PFI (Importância do Recurso de Permutação), saiba como interpretar ML.NET previsões de modelo de machine learning. A PFI informa a contribuição relativa que cada recurso faz a uma previsão.

Modelos de machine learning geralmente são considerados caixas opacas que pegam entradas e geram uma saída. As etapas intermediárias ou as interações entre os recursos que influenciam a saída raramente são compreendidas. Conforme o aprendizado de máquina é introduzido em mais aspectos da vida diária, como serviços de saúde, é de extrema importância entender por que um modelo de machine learning toma as decisões que ele toma. Por exemplo, se os diagnósticos forem feitos por um modelo de machine learning, os profissionais de saúde precisarão de uma maneira de examinar os fatores que contribuíram para esse diagnóstico. Fornecer o diagnóstico certo pode fazer uma grande diferença em se um paciente tem uma recuperação rápida ou não. Portanto, quanto maior o nível de capacidade de explicação de um modelo, mais confiança os profissionais de saúde terão em aceitar ou rejeitar as decisões tomadas pelo modelo.

Várias técnicas são usadas para explicar os modelos, uma delas é a PFI. PFI é uma técnica usada para explicar os modelos de classificação e regressão inspirados pelo artigo de Breiman chamado Random Forests (Florestas aleatórias) (confira a seção 10). Em um alto nível, a maneira como eles funcionam é embaralhando aleatoriamente um recurso de dados por vez para todo o conjunto de dados e calculando o quanto a métrica de desempenho de interesse diminui. Quanto maior a alteração, mais importante é esse recurso.

Além disso, ao realçar os recursos mais importantes, construtores de modelo podem se concentrar no uso de um subconjunto de recursos mais significativos que pode reduzir o ruído e tempo de treinamento.

Carregar os dados

Os recursos no conjunto de dados que está sendo usado para este exemplo estão nas colunas 1 a 12. A meta é prever Price.

Coluna Recurso Descrição
1 CrimeRate Taxa de criminalidade per capita
2 ResidentialZones Zonas residenciais da cidade
3 CommercialZones Zonas não residenciais da cidade
4 NearWater Proximidade a recursos hídricos
5 ToxicWasteLevels Níveis de toxicidade (PPM)
6 AverageRoomNumber Número médio de ambientes na casa
7 HomeAge Idade da casa
8 BusinessCenterDistance Distância até o bairro comercial mais próximo
9 HighwayAccess Proximidade de rodovias
10 TaxRate Taxa de imposto sobre propriedade
11 StudentTeacherRatio Taxa de alunos para professores
12 PercentPopulationBelowPoverty Percentual da população vivendo abaixo da linha de pobreza
13 Preço Preço da casa

Um exemplo do conjunto de dados é mostrado abaixo:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Os dados desta amostra podem ser modelados por uma classe como HousingPriceData e carregados em uma IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Treinar o modelo

O exemplo de código a seguir ilustra o processo de treinamento de um modelo de regressão linear para prever preços de casa.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Explicar o modelo com PFI (Importância de Recurso de Permutação)

No ML.NET, use o método PermutationFeatureImportance para suas respectivas tarefas.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

O resultado de usar PermutationFeatureImportance no conjunto de dados de treinamento é um ImmutableArray de objetos RegressionMetricsStatistics. RegressionMetricsStatistics fornece estatísticas resumidas, como média e desvio padrão para várias observações de RegressionMetrics igual ao número de permutações especificado pelo parâmetro permutationCount.

A métrica usada para medir a importância do recurso depende da tarefa de machine learning usada para resolver o problema. Por exemplo, as tarefas de regressão podem usar uma métrica de avaliação comum, como R ao quadrado, para medir a importância. Para mais informações sobre métricas de avaliação de modelo, confira Avaliar seu modelo de ML.NET com métricas.

A importância ou, neste caso, a redução média absoluta de métrica R ao quadrado calculada por PermutationFeatureImportance pode ser ordenada da mais importante para a menos importante.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Imprimir os valores para cada um dos recursos em featureImportanceMetrics geraria saída semelhante à abaixo. Lembre-se de que você deve esperar ver resultados diferentes, pois esses valores variam conforme os dados apresentados.

Recurso Alterar para R ao quadrado
HighwayAccess -0,042731
StudentTeacherRatio -0,012730
BusinessCenterDistance -0,010491
TaxRate -0,008545
AverageRoomNumber -0,003949
CrimeRate -0,003665
CommercialZones 0,002749
HomeAge -0,002426
ResidentialZones -0,002319
NearWater 0,000203
PercentPopulationLivingBelowPoverty 0,000031
ToxicWasteLevels -0,000019

Vamos analisar os cinco recursos mais importantes para este conjunto de dados, o preço de uma casa previsto por esse modelo é influenciado pela sua proximidade a rodovias, pela proporção de alunos para professor das escolas na área, pela proximidade com centros de emprego importantes, pela taxa de impostos sobre propriedade e pelo número médio de ambientes na casa.

Próximas etapas