Partilhar via


Interpretar previsões de modelo usando a importância do recurso de permutação

Usando a Importância do Recurso de Permutação (PFI), aprenda a interpretar ML.NET previsões de modelo de aprendizado de máquina. O PFI dá a contribuição relativa que cada recurso faz para uma previsão.

Os modelos de aprendizagem automática são muitas vezes pensados como caixas opacas que recebem entradas e geram uma saída. As etapas intermediárias ou interações entre os recursos que influenciam a saída raramente são compreendidas. À medida que o aprendizado de máquina é introduzido em mais aspetos da vida cotidiana, como a saúde, é de extrema importância entender por que um modelo de aprendizado de máquina toma as decisões que toma. Por exemplo, se os diagnósticos são feitos por um modelo de aprendizado de máquina, os profissionais de saúde precisam de uma maneira de analisar os fatores que levaram a fazer esse diagnóstico. Fornecer o diagnóstico correto pode fazer uma grande diferença sobre se um paciente tem uma recuperação rápida ou não. Portanto, quanto maior o nível de explicabilidade em um modelo, maior a confiança dos profissionais de saúde para aceitar ou rejeitar as decisões tomadas pelo modelo.

Várias técnicas são usadas para explicar modelos, uma das quais é o PFI. O PFI é uma técnica usada para explicar modelos de classificação e regressão inspirada no artigo Random Forests de Breiman (ver secção 10). Em um nível alto, a maneira como ele funciona é embaralhando aleatoriamente os dados, um recurso de cada vez, para todo o conjunto de dados e calculando o quanto a métrica de desempenho de juros diminui. Quanto maior a mudança, mais importante é esse recurso.

Além disso, ao destacar os recursos mais importantes, os construtores de modelos podem se concentrar no uso de um subconjunto de recursos mais significativos que podem potencialmente reduzir o ruído e o tempo de treinamento.

Carregar os dados

Os recursos no conjunto de dados que está sendo usado para este exemplo estão nas colunas 1-12. O objetivo é prever Price.

Coluna Funcionalidade Descrição
5 Taxa de Criminalidade Taxa de criminalidade per capita
2 Zonas Residenciais Zonas residenciais na cidade
3 Zonas Comerciais Zonas não residenciais na cidade
4 NearWater Proximidade da massa de água
5 ToxicWasteLevels Níveis de toxicidade (PPM)
6 MédiaQuartoNúmero Número médio de quartos em casa
7 HomeIdade Idade do lar
8 BusinessCenterDistância Distância até à zona empresarial mais próxima
9 Acesso Rodoviário Proximidade de autoestradas
10 Alíquota Taxa do IPTU
11 StudentTeacherRatio Rácio de alunos/professores
12 PercentPopulaçãoAbaixoda Pobreza Percentagem da população que vive abaixo da pobreza
13 Preço Preço da casa

Um exemplo do conjunto de dados é mostrado abaixo:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Os dados neste exemplo podem ser modelados por uma classe como HousingPriceData e carregados em um IDataViewarquivo .

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Preparar o modelo

O exemplo de código abaixo ilustra o processo de treinamento de um modelo de regressão linear para prever os preços das casas.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Explicar o modelo com a Importância da Característica de Permutação (PFI)

Em ML.NET use o PermutationFeatureImportance método para sua respetiva tarefa.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

O resultado do uso PermutationFeatureImportance no conjunto de dados de treinamento é um ImmutableArray dos RegressionMetricsStatistics objetos. RegressionMetricsStatistics fornece estatísticas resumidas como média e desvio padrão para observações múltiplas iguais RegressionMetrics ao número de permutações especificadas pelo permutationCount parâmetro.

A métrica usada para medir a importância do recurso depende da tarefa de aprendizado de máquina usada para resolver seu problema. Por exemplo, as tarefas de regressão podem usar uma métrica de avaliação comum, como R-quadrado, para medir a importância. Para obter mais informações sobre métricas de avaliação de modelo, consulte avaliar seu modelo de ML.NET com métricas.

A importância, ou neste caso, a diminuição média absoluta na métrica R-quadrado calculada por PermutationFeatureImportance pode então ser ordenada do mais importante para o menos importante.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

Imprimir os valores para cada um dos recursos em featureImportanceMetrics geraria uma saída semelhante à abaixo. Tenha em mente que você deve esperar ver resultados diferentes porque esses valores variam com base nos dados que são fornecidos.

Caraterística Mudar para R-Squared
Acesso Rodoviário -0.042731
StudentTeacherRatio -0.012730
BusinessCenterDistância -0.010491
Alíquota -0.008545
MédiaQuartoNúmero -0.003949
Taxa de Criminalidade -0.003665
Zonas Comerciais 0.002749
HomeIdade -0.002426
Zonas Residenciais -0.002319
NearWater 0.000203
PercentPopulaçãoViverAbaixo da Pobreza 0.000031
ToxicWasteLevels -0.000019

Olhando para as cinco características mais importantes para este conjunto de dados, o preço de uma casa previsto por este modelo é influenciado pela sua proximidade de autoestradas, rácio de alunos professores das escolas da área, proximidade dos principais centros de emprego, taxa de imposto predial e número médio de quartos em casa.

Próximos passos