Interpretabilitas - Penjelas SHAP Tabular

Dalam contoh ini, kami menggunakan Kernel SHAP untuk menjelaskan model klasifikasi tabular yang dibangun dari himpunan data Sensus Dewasa.

Pertama, kita mengimpor paket dan menentukan beberapa UDF yang kita butuhkan nanti.

import pyspark
from synapse.ml.explainers import *
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.sql.types import *
from pyspark.sql.functions import *
import pandas as pd
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

from synapse.ml.core.platform import *




vec_access = udf(lambda v, i: float(v[i]), FloatType())
vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))

Sekarang mari kita baca data dan latih model klasifikasi biner.

df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet"
)

labelIndexer = StringIndexer(
    inputCol="income", outputCol="label", stringOrderType="alphabetAsc"
).fit(df)
print("Label index assigment: " + str(set(zip(labelIndexer.labels, [0, 1]))))

training = labelIndexer.transform(df).cache()
display(training)
categorical_features = [
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "native-country",
]
categorical_features_idx = [col + "_idx" for col in categorical_features]
categorical_features_enc = [col + "_enc" for col in categorical_features]
numeric_features = [
    "age",
    "education-num",
    "capital-gain",
    "capital-loss",
    "hours-per-week",
]

strIndexer = StringIndexer(
    inputCols=categorical_features, outputCols=categorical_features_idx
)
onehotEnc = OneHotEncoder(
    inputCols=categorical_features_idx, outputCols=categorical_features_enc
)
vectAssem = VectorAssembler(
    inputCols=categorical_features_enc + numeric_features, outputCol="features"
)
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="fnlwgt")
pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])
model = pipeline.fit(training)

Setelah model dilatih, kami secara acak memilih beberapa pengamatan yang akan dijelaskan.

explain_instances = (
    model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()
)
display(explain_instances)

Kami membuat penjelas TabularSHAP, mengatur kolom input ke semua fitur yang diambil model, menentukan model dan kolom output target yang coba kami jelaskan. Dalam hal ini, kami mencoba menjelaskan output "probabilitas", yang merupakan vektor panjang 2, dan kami hanya melihat probabilitas kelas 1. Tentukan targetClasses jika [0, 1] Anda ingin menjelaskan probabilitas kelas 0 dan 1 secara bersamaan. Akhirnya kami mengambil sampel 100 baris dari data pelatihan untuk data latar belakang, yang digunakan untuk mengintegrasikan fitur di Kernel SHAP.

shap = TabularSHAP(
    inputCols=categorical_features + numeric_features,
    outputCol="shapValues",
    numSamples=5000,
    model=model,
    targetCol="probability",
    targetClasses=[1],
    backgroundData=broadcast(training.orderBy(rand()).limit(100).cache()),
)

shap_df = shap.transform(explain_instances)

Setelah kita memiliki dataframe yang dihasilkan, kita mengekstrak probabilitas kelas 1 dari output model, nilai SHAP untuk kelas target, fitur asli, dan label yang sebenarnya. Kemudian kami mengonversinya menjadi kerangka data pandas untuk visualisasi. Untuk setiap pengamatan, elemen pertama dalam vektor nilai SHAP adalah nilai dasar (output rata-rata himpunan data latar belakang), dan masing-masing elemen berikut adalah nilai SHAP untuk setiap fitur.

shaps = (
    shap_df.withColumn("probability", vec_access(col("probability"), lit(1)))
    .withColumn("shapValues", vec2array(col("shapValues").getItem(0)))
    .select(
        ["shapValues", "probability", "label"] + categorical_features + numeric_features
    )
)

shaps_local = shaps.toPandas()
shaps_local.sort_values("probability", ascending=False, inplace=True, ignore_index=True)
pd.set_option("display.max_colwidth", None)
shaps_local

Kami menggunakan subplot plot untuk memvisualisasikan nilai SHAP.

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd

features = categorical_features + numeric_features
features_with_base = ["Base"] + features

rows = shaps_local.shape[0]

fig = make_subplots(
    rows=rows,
    cols=1,
    subplot_titles="Probability: "
    + shaps_local["probability"].apply("{:.2%}".format)
    + "; Label: "
    + shaps_local["label"].astype(str),
)

for index, row in shaps_local.iterrows():
    feature_values = [0] + [row[feature] for feature in features]
    shap_values = row["shapValues"]
    list_of_tuples = list(zip(features_with_base, feature_values, shap_values))
    shap_pdf = pd.DataFrame(list_of_tuples, columns=["name", "value", "shap"])
    fig.add_trace(
        go.Bar(
            x=shap_pdf["name"],
            y=shap_pdf["shap"],
            hovertext="value: " + shap_pdf["value"].astype(str),
        ),
        row=index + 1,
        col=1,
    )

fig.update_yaxes(range=[-1, 1], fixedrange=True, zerolinecolor="black")
fig.update_xaxes(type="category", tickangle=45, fixedrange=True)
fig.update_layout(height=400 * rows, title_text="SHAP explanations")
fig.show()