Interpretabilidad: explicación de SHAP tabular

Utilice Kernel SHAP (SHapley Additive exPlanations) para explicar un modelo de clasificación para datos tabulares. El SHAP de kernel es un método independiente del modelo que calcula la contribución de cada característica a la predicción de un modelo. Se entrena un modelo de regresión logística con el conjunto de datos Adult Census Income y luego se usa el transformador SynapseML TabularSHAP para calcular explicaciones a nivel de características.

Prerequisites

  • Cree un nuevo cuaderno en su espacio de trabajo y adjúntelo a un lakehouse. Para obtener más información, consulte Creación de un cuaderno.

SynapseML, PySpark, pandas y plotly están preinstalados en entornos de cuadernos de Fabric. No se requiere ninguna instalación adicional de paquetes.

Importación de paquetes y definición de UDF auxiliares

En el cuaderno de Fabric, pegue el código siguiente en una celda y ejecútelo. Este paso importa las bibliotecas necesarias y define dos funciones definidas por el usuario (UDF) para extraer elementos vectoriales más adelante.

import pyspark
from synapse.ml.explainers import TabularSHAP
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.sql.types import FloatType, ArrayType
from pyspark.sql.functions import col, lit, rand, broadcast, udf
import pandas as pd

vec_access = udf(lambda v, i: float(v[i]), FloatType())
vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))

Comprobar: ejecute el código siguiente en una nueva celda. Debería ver la salida TabularSHAP imported successfully.

print("TabularSHAP imported successfully")
print(f"PySpark version: {pyspark.__version__}")

Carga de datos y entrenamiento de un modelo de clasificación

Cargue el conjunto de datos Adult Census Income desde Azure Blob Storage, indexe la etiqueta objetivo y entrene una canalización de regresión logística.

df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet"
)

labelIndexer = StringIndexer(
    inputCol="income", outputCol="label", stringOrderType="alphabetAsc"
).fit(df)
print("Label index assignment: " + str(set(zip(labelIndexer.labels, [0, 1]))))

training = labelIndexer.transform(df).cache()

categorical_features = [
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "native-country",
]
categorical_features_idx = [feat + "_idx" for feat in categorical_features]
categorical_features_enc = [feat + "_enc" for feat in categorical_features]
numeric_features = [
    "age",
    "education-num",
    "capital-gain",
    "capital-loss",
    "hours-per-week",
]

strIndexer = StringIndexer(
    inputCols=categorical_features, outputCols=categorical_features_idx
)
onehotEnc = OneHotEncoder(
    inputCols=categorical_features_idx, outputCols=categorical_features_enc
)
vectAssem = VectorAssembler(
    inputCols=categorical_features_enc + numeric_features, outputCol="features"
)
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="fnlwgt")
pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])
model = pipeline.fit(training)

Comprobar: ejecute la celda siguiente. Debería ver recuentos de filas para los datos de entrenamiento y la confirmación de las fases de canalización.

print(f"Training rows: {training.count()}")
print(f"Pipeline stages: {[type(s).__name__ for s in model.stages]}")
assert training.count() > 30000, "Dataset should contain over 30,000 rows"
print("Model trained successfully")

# Expected output:
#Training rows: 32561
#Pipeline stages: ['StringIndexerModel', 'OneHotEncoderModel', #'VectorAssembler', 'LogisticRegressionModel']
#Model trained successfully

Selección de observaciones para explicar

Seleccione aleatoriamente cinco observaciones de los datos de entrenamiento puntuados. Estas observaciones son las instancias para las que se generan explicaciones de SHAP.

explain_instances = (
    model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()
)
display(explain_instances)

Comprobar: confirme el tamaño de la muestra.

count = explain_instances.count()
print(f"Explain instances: {count}")
assert count == 5, f"Expected 5 rows, got {count}"
print("Sample selected successfully")

Configuración y ejecución de TabularSHAP

Cree un TabularSHAP explicador y aplíquelo a las observaciones seleccionadas. Los parámetros clave son:

Parámetro Description
inputCols Columnas de características que usa el modelo para la predicción.
outputCol Nombre de la columna que contiene valores de salida SHAP.
numSamples Número de muestras de perturbación para el cálculo de Kernel SHAP. Los valores más altos son más precisos, pero más lentos.
model Modelo de canalización entrenado que se desea explicar.
targetCol Columna de salida del modelo que se va a explicar. En este ejemplo, la columna es probability.
targetClasses Índices de clases para explicar. [1] explica solo la probabilidad de la clase 1. Use [0, 1] para explicar ambas clases.
backgroundData Ejemplo de datos de entrenamiento usados como distribución de referencia para integrar características.
shap = TabularSHAP(
    inputCols=categorical_features + numeric_features,
    outputCol="shapValues",
    numSamples=5000,
    model=model,
    targetCol="probability",
    targetClasses=[1],
    backgroundData=broadcast(training.orderBy(rand()).limit(100).cache()),
)

shap_df = shap.transform(explain_instances)

Note

Este paso puede tardar varios minutos en función de numSamples y del tamaño del clúster. Con numSamples=5000 y cinco observaciones, calcule entre 3 y 10 minutos en un clúster predeterminado de Fabric Spark.

Comprobar: compruebe que existe la columna de salida SHAP.

assert "shapValues" in shap_df.columns, "shapValues column missing"
print(f"SHAP output columns: {shap_df.columns}")
print("TabularSHAP transform completed")

Extracción de valores SHAP

Extraiga la probabilidad de la clase 1 y los valores SHAP del DataFrame de resultados. Para cada observación, el vector de valores SHAP comienza con el valor base (salida media del conjunto de datos en segundo plano), seguido de un valor por característica.

shaps = (
    shap_df.withColumn("probability", vec_access(col("probability"), lit(1)))
    .withColumn("shapValues", vec2array(col("shapValues").getItem(0)))
    .select(
        ["shapValues", "probability", "label"] + categorical_features + numeric_features
    )
)

shaps_local = shaps.toPandas()
shaps_local.sort_values("probability", ascending=False, inplace=True, ignore_index=True)
pd.set_option("display.max_colwidth", None)
display(shaps_local)

Verifique: Confirme la estructura del DataFrame de pandas.

expected_cols = len(categorical_features) + len(numeric_features) + 3
print(f"DataFrame shape: {shaps_local.shape}")
print(f"Expected columns: {expected_cols}, Actual: {shaps_local.shape[1]}")
assert shaps_local.shape == (5, expected_cols), f"Unexpected shape: {shaps_local.shape}"
print("SHAP values extracted successfully")

Visualización de valores SHAP

Cree un gráfico de barras para cada observación que muestre cómo contribuye cada característica a la probabilidad prevista.

from plotly.subplots import make_subplots
import plotly.graph_objects as go

features = categorical_features + numeric_features
features_with_base = ["Base"] + features

rows = shaps_local.shape[0]

fig = make_subplots(
    rows=rows,
    cols=1,
    subplot_titles="Probability: "
    + shaps_local["probability"].apply("{:.2%}".format)
    + "; Label: "
    + shaps_local["label"].astype(str),
)

for index, row in shaps_local.iterrows():
    feature_values = [0] + [row[feature] for feature in features]
    shap_values = row["shapValues"]
    list_of_tuples = list(zip(features_with_base, feature_values, shap_values))
    shap_pdf = pd.DataFrame(list_of_tuples, columns=["name", "value", "shap"])
    fig.add_trace(
        go.Bar(
            x=shap_pdf["name"],
            y=shap_pdf["shap"],
            hovertext="value: " + shap_pdf["value"].astype(str),
        ),
        row=index + 1,
        col=1,
    )

fig.update_yaxes(range=[-1, 1], fixedrange=True, zerolinecolor="black")
fig.update_xaxes(type="category", tickangle=45, fixedrange=True)
fig.update_layout(height=400 * rows, title_text="SHAP explanations")
fig.show()

Verifique: Confirme que se creó el objeto de trazado.

print(f"Figure traces: {len(fig.data)}")
print(f"Figure height: {fig.layout.height}px")
assert len(fig.data) == 5, f"Expected 5 traces, got {len(fig.data)}"
print("Visualization created successfully")

Interpretación de los resultados

Cada sublot representa una observación. Las barras muestran:

  • Base: la salida media del modelo en el conjunto de datos en segundo plano (probabilidad de línea base).
  • Valores SHAP positivos: características que impulsan la predicción hacia la clase 1 (ingresos mayores que 50 000).
  • Valores SHAP negativos: características que impulsan la predicción hacia la clase 0 (ingresos inferiores o iguales a 50 000).

La suma del valor base y todos los valores SHAP de características son iguales a la probabilidad prevista del modelo para esa observación.

Solución de problemas

Cuestión Causa Solución
OutOfMemoryError durante TabularSHAP numSamples es demasiado grande para la memoria disponible. Reduzca numSamples, por ejemplo a 1000 o aumente la memoria del ejecutor de Spark.
La transformación SHAP es lenta Un valor alto de numSamples con muchas funciones aumenta el tiempo de cómputo. Reduzca numSamples a 1000-2000 para obtener resultados exploratorios más rápidos. Aumento para el análisis final.
FileNotFoundException para parquet El acceso de red a mmlspark.blob.core.windows.net está bloqueado. Compruebe que el área de trabajo de Fabric tiene acceso saliente a Internet. Como alternativa, cargue el conjunto de datos en lakehouse.
shapValues column contiene valores NULL Es posible que se produzca un error en algunas observaciones si los valores de características están fuera de la distribución de entrenamiento. Compruebe si hay valores NULL o inesperados en las características de entrada. Filtre los valores NULL de los resultados.
display() no muestra ninguna salida El código se ejecuta fuera de un entorno de cuaderno de Fabric. Use shaps_local.head() o print(shaps_local) en entornos de Python estándar.

Limpieza

Si subió el conjunto de datos a un lakehouse para este tutorial, elimínelo para liberar espacio de almacenamiento:

# Remove cached DataFrames from memory
training.unpersist()
explain_instances.unpersist()
print("Cached DataFrames released")