Интерпретация прогнозов модели с помощью функции PFI

Использование функции PFI для интерпретации прогнозов модели машинного обучения ML.NET. PFI оценивает относительный вклад каждого признака в прогнозе.

Модели машинного обучения часто считаются "черными ящиками", которые принимают входные данные и создают выходные данные. Промежуточные этапы или взаимодействие между компонентами, которые влияют на выходные данные, распознаются редко. Машинное обучение проникает во все аспекты нашей повседневной жизни, такие как здравоохранение, а значит, очень важно понимать, почему модель машинного обучения принимает именно те решения, которые она принимает. Например, если диагностика осуществляется на основе модели машинного обучения, у медицинских работников должен быть способ изучить факторы, которые на нее повлияли. Правильная диагностика намного повышает шансы пациента на быстрое выздоровление. Таким образом, чем объяснение модели лучше, тем увереннее медицинские работники будут принимать или отклонять принятые моделью решения.

Для объяснения моделей применяются различные приемы, одной из которых является PFI. PFI — это прием, который используется для объяснения моделей классификации и регрессии и основан на работе Бреймана Random Forests (Случайные леса) (см. раздел 10). В общем смысле он случайным образом тасует весь набор данных по одному компоненту за раз и рассчитывает, на сколько снижается метрика эффективности, отражающая интерес. Чем больше изменение, тем важнее компонент.

Выделив самые важные компоненты, сборщики модели могут сосредоточиться на работе с подмножеством более значимых компонентов и таким образом уменьшить шум и время обучения.

Загрузка данных

Компоненты набора данных, которые используются в этом примере, перечислены в столбцах 1–12. Цель — спрогнозировать Price.

Column Возможность Описание
1 CrimeRate Уровень преступности на душу населения
2 ResidentialZones Жилые районы в городе
3 CommercialZones Нежилые районы в городе
4 NearWater Близость к водоему
5 ToxicWasteLevels Уровень токсичности (PPM)
6 AverageRoomNumber Среднее количество комнат в доме
7 HomeAge Возраст дома
8 BusinessCenterDistance Расстояние до ближайшего делового района
9 HighwayAccess Близость к автомагистралям
10 TaxRate Ставка налога на имущество
11 StudentTeacherRatio Соотношение учащихся и учителей
12 PercentPopulationBelowPoverty Процент населения, живущего за чертой бедности
13 Цена, Цена на дом

Пример набора данных:

1,24,13,1,0.59,3,96,11,23,608,14,13,32
4,80,18,1,0.37,5,14,7,4,346,19,13,41
2,98,16,1,0.25,10,5,1,8,689,13,36,12

Данные в этом примере можно моделировать с помощью класса, например, HousingPriceData, и загружать в IDataView.

class HousingPriceData
{
    [LoadColumn(0)]
    public float CrimeRate { get; set; }

    [LoadColumn(1)]
    public float ResidentialZones { get; set; }

    [LoadColumn(2)]
    public float CommercialZones { get; set; }

    [LoadColumn(3)]
    public float NearWater { get; set; }

    [LoadColumn(4)]
    public float ToxicWasteLevels { get; set; }

    [LoadColumn(5)]
    public float AverageRoomNumber { get; set; }

    [LoadColumn(6)]
    public float HomeAge { get; set; }

    [LoadColumn(7)]
    public float BusinessCenterDistance { get; set; }

    [LoadColumn(8)]
    public float HighwayAccess { get; set; }

    [LoadColumn(9)]
    public float TaxRate { get; set; }

    [LoadColumn(10)]
    public float StudentTeacherRatio { get; set; }

    [LoadColumn(11)]
    public float PercentPopulationBelowPoverty { get; set; }

    [LoadColumn(12)]
    [ColumnName("Label")]
    public float Price { get; set; }
}

Обучение модели

Код в приведенном ниже примере иллюстрирует процесс обучения модели линейной регрессии для прогнозирования цен на дома.

// 1. Get the column name of input features.
string[] featureColumnNames =
    data.Schema
        .Select(column => column.Name)
        .Where(columnName => columnName != "Label").ToArray();

// 2. Define training pipeline
IEstimator<ITransformer> sdcaEstimator =
    mlContext.Transforms.Concatenate("Features", featureColumnNames)
        .Append(mlContext.Transforms.NormalizeMinMax("Features"))
        .Append(mlContext.Regression.Trainers.Sdca());

// 3. Train machine learning model
var sdcaModel = sdcaEstimator.Fit(data);

Объяснение модели с помощью функции PFI

В ML.NET используйте метод PermutationFeatureImportance для решения соответствующей задачи.

// Use the model to make predictions
var transformedData = sdcaModel.Transform(data);

// Calculate feature importance
ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
    mlContext
        .Regression
        .PermutationFeatureImportance(sdcaModel, transformedData, permutationCount:3);

В результате применения метода PermutationFeatureImportance в проверочном наборе данных создается ImmutableArray из объектов RegressionMetricsStatistics. RegressionMetricsStatistics предоставляет сводные статистические данные, такие как среднее и стандартное отклонение, для нескольких значений параметра RegressionMetrics, количество которых равно числу перестановок, которое определяет параметр permutationCount.

Метрика, используемая для измерения важности компонентов, зависит от задачи машинного обучения, используемой для решения проблемы. Например, задачи регрессии могут использовать общую метрику оценки, такую как R-квадрат, для измерения важности. Дополнительные сведения о метриках оценки модели см. в статье Оценка модели ML.NET с помощью метрик.

Важность, или в данном случае среднее абсолютное снижение метрики R-квадрат, рассчитанное методом PermutationFeatureImportance, можно отсортировать по убыванию.

// Order features by importance
var featureImportanceMetrics =
    permutationFeatureImportance
        .Select((metric, index) => new { index, metric.RSquared })
        .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

Console.WriteLine("Feature\tPFI");

foreach (var feature in featureImportanceMetrics)
{
    Console.WriteLine($"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6}");
}

При печати значений для каждого параметра в featureImportanceMetrics выдается результат представленного ниже вида. У вас результаты будут другими, поскольку эти значения зависят от предоставленных данных.

Функция Изменение в R-квадрат
HighwayAccess –0,042731
StudentTeacherRatio –0,012730
BusinessCenterDistance –0,010491
TaxRate –0,008545
AverageRoomNumber –0,003949
CrimeRate –0,003665
CommercialZones 0,002749
HomeAge –0,002426
ResidentialZones –0,002319
NearWater 0,000203
PercentPopulationLivingBelowPoverty 0,000031
ToxicWasteLevels –0,000019

Рассмотрим пять самых важных компонентов для этого набора данных. Цена на дом, спрогнозированная этой моделью, зависит от его близости к автомагистралям, соотношения учащихся и учителей в окрестных школах, близости к основным центрам трудоустройства, ставки налога на имущество и среднего количества комнат.

Следующие шаги