Interpretar previsões de modelo usando a importância do recurso de permutação
Usando a Importância do Recurso de Permutação (PFI), aprenda a interpretar ML.NET previsões de modelo de aprendizado de máquina. O PFI dá a contribuição relativa que cada recurso faz para uma previsão.
Os modelos de aprendizagem automática são muitas vezes pensados como caixas opacas que recebem entradas e geram uma saída. As etapas intermediárias ou interações entre os recursos que influenciam a saída raramente são compreendidas. À medida que o aprendizado de máquina é introduzido em mais aspetos da vida cotidiana, como a saúde, é de extrema importância entender por que um modelo de aprendizado de máquina toma as decisões que toma. Por exemplo, se os diagnósticos são feitos por um modelo de aprendizado de máquina, os profissionais de saúde precisam de uma maneira de analisar os fatores que levaram a fazer esse diagnóstico. Fornecer o diagnóstico correto pode fazer uma grande diferença sobre se um paciente tem uma recuperação rápida ou não. Portanto, quanto maior o nível de explicabilidade em um modelo, maior a confiança dos profissionais de saúde para aceitar ou rejeitar as decisões tomadas pelo modelo.
Várias técnicas são usadas para explicar modelos, uma das quais é o PFI. O PFI é uma técnica usada para explicar modelos de classificação e regressão inspirada no artigo Random Forests de Breiman (ver secção 10). Em um nível alto, a maneira como ele funciona é embaralhando aleatoriamente os dados, um recurso de cada vez, para todo o conjunto de dados e calculando o quanto a métrica de desempenho de juros diminui. Quanto maior a mudança, mais importante é esse recurso.
Além disso, ao destacar os recursos mais importantes, os construtores de modelos podem se concentrar no uso de um subconjunto de recursos mais significativos que podem potencialmente reduzir o ruído e o tempo de treinamento.
Carregar os dados
Os recursos no conjunto de dados que está sendo usado para este exemplo estão nas colunas 1-12. O objetivo é prever Price
.
Coluna | Funcionalidade | Descrição |
---|---|---|
5 | Taxa de Criminalidade | Taxa de criminalidade per capita |
2 | Zonas Residenciais | Zonas residenciais na cidade |
3 | Zonas Comerciais | Zonas não residenciais na cidade |
4 | NearWater | Proximidade da massa de água |
5 | ToxicWasteLevels | Níveis de toxicidade (PPM) |
6 | MédiaQuartoNúmero | Número médio de quartos em casa |
7 | HomeIdade | Idade do lar |
8 | BusinessCenterDistância | Distância até à zona empresarial mais próxima |
9 | Acesso Rodoviário | Proximidade de autoestradas |
10 | Alíquota | Taxa do IPTU |
11 | StudentTeacherRatio | Rácio de alunos/professores |
12 | PercentPopulaçãoAbaixoda Pobreza | Percentagem da população que vive abaixo da pobreza |
13 | Preço | Preço da casa |
Um exemplo do conjunto de dados é mostrado abaixo:
1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12
Os dados neste exemplo podem ser modelados por uma classe como HousingPriceData
e carregados em um IDataView
arquivo .
class HousingPriceData
{
[LoadColumn(0)]
public float CrimeRate { get; set; }
[LoadColumn(1)]
public float ResidentialZones { get; set; }
[LoadColumn(2)]
public float CommercialZones { get; set; }
[LoadColumn(3)]
public float NearWater { get; set; }
[LoadColumn(4)]
public float ToxicWasteLevels { get; set; }
[LoadColumn(5)]
public float AverageRoomNumber { get; set; }
[LoadColumn(6)]
public float HomeAge { get; set; }
[LoadColumn(7)]
public float BusinessCenterDistance { get; set; }
[LoadColumn(8)]
public float HighwayAccess { get; set; }
[LoadColumn(9)]
public float TaxRate { get; set; }
[LoadColumn(10)]
public float StudentTeacherRatio { get; set; }
[LoadColumn(11)]
public float PercentPopulationBelowPoverty { get; set; }
[LoadColumn(12)]
[ColumnName("Label")]
public float Price { get; set; }
}
Preparar o modelo
O exemplo de código abaixo ilustra o processo de treinamento de um modelo de regressão linear para prever os preços das casas.
// 1. Get the column name of input features.
string[] featureColumnNames =
data.Schema
.Select(column => column.Name)
.Where(columnName => columnName != "Label").ToArray();
// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
mlContext.Transforms.Concatenate("Features", featureColumnNames)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);
Explicar o modelo com a Importância da Característica de Permutação (PFI)
Em ML.NET use o PermutationFeatureImportance
método para sua respetiva tarefa.
// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);
// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
mlContext
.Regression
.PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);
O resultado do uso PermutationFeatureImportance
no conjunto de dados de treinamento é um ImmutableArray
dos RegressionMetricsStatistics
objetos. RegressionMetricsStatistics
fornece estatísticas resumidas como média e desvio padrão para observações múltiplas iguais RegressionMetrics
ao número de permutações especificadas pelo permutationCount
parâmetro.
A métrica usada para medir a importância do recurso depende da tarefa de aprendizado de máquina usada para resolver seu problema. Por exemplo, as tarefas de regressão podem usar uma métrica de avaliação comum, como R-quadrado, para medir a importância. Para obter mais informações sobre métricas de avaliação de modelo, consulte avaliar seu modelo de ML.NET com métricas.
A importância, ou neste caso, a diminuição média absoluta na métrica R-quadrado calculada por PermutationFeatureImportance
pode então ser ordenada do mais importante para o menos importante.
// Order features by importance
var featureImportanceMetrics =
permutationFeatureImportance
.Select((metric, index) => new { index, metric.RSquared })
.OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));
Console.WriteLine("Feature\tPFI");
foreach (var feature in featureImportanceMetrics)
{
Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}
Imprimir os valores para cada um dos recursos em featureImportanceMetrics
geraria uma saída semelhante à abaixo. Tenha em mente que você deve esperar ver resultados diferentes porque esses valores variam com base nos dados que são fornecidos.
Caraterística | Mudar para R-Squared |
---|---|
Acesso Rodoviário | -0.042731 |
StudentTeacherRatio | -0.012730 |
BusinessCenterDistância | -0.010491 |
Alíquota | -0.008545 |
MédiaQuartoNúmero | -0.003949 |
Taxa de Criminalidade | -0.003665 |
Zonas Comerciais | 0.002749 |
HomeIdade | -0.002426 |
Zonas Residenciais | -0.002319 |
NearWater | 0.000203 |
PercentPopulaçãoViverAbaixo da Pobreza | 0.000031 |
ToxicWasteLevels | -0.000019 |
Olhando para as cinco características mais importantes para este conjunto de dados, o preço de uma casa previsto por este modelo é influenciado pela sua proximidade de autoestradas, rácio de alunos professores das escolas da área, proximidade dos principais centros de emprego, taxa de imposto predial e número médio de quartos em casa.