Interpretierbarkeit – Tabellarische SHAP-Erklärung
In diesem Beispiel verwenden wir Kernel SHAP, um ein tabellarisches Klassifizierungsmodell zu erklären, das aus dem Dataset Adults Census erstellt wurde.
Zuerst importieren wir die Pakete und definieren einige UDFs, die wir später benötigen.
import pyspark
from synapse.ml.explainers import *
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.sql.types import *
from pyspark.sql.functions import *
import pandas as pd
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
from synapse.ml.core.platform import *
vec_access = udf(lambda v, i: float(v[i]), FloatType())
vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))
Nun lesen wir die Daten und trainieren ein binäres Klassifizierungsmodell.
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet"
)
labelIndexer = StringIndexer(
inputCol="income", outputCol="label", stringOrderType="alphabetAsc"
).fit(df)
print("Label index assigment: " + str(set(zip(labelIndexer.labels, [0, 1]))))
training = labelIndexer.transform(df).cache()
display(training)
categorical_features = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"native-country",
]
categorical_features_idx = [col + "_idx" for col in categorical_features]
categorical_features_enc = [col + "_enc" for col in categorical_features]
numeric_features = [
"age",
"education-num",
"capital-gain",
"capital-loss",
"hours-per-week",
]
strIndexer = StringIndexer(
inputCols=categorical_features, outputCols=categorical_features_idx
)
onehotEnc = OneHotEncoder(
inputCols=categorical_features_idx, outputCols=categorical_features_enc
)
vectAssem = VectorAssembler(
inputCols=categorical_features_enc + numeric_features, outputCol="features"
)
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="fnlwgt")
pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])
model = pipeline.fit(training)
Nachdem das Modell trainiert wurde, wählen wir zufällig einige Beobachtungen aus, die erläutert werden sollen.
explain_instances = (
model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()
)
display(explain_instances)
Wir erstellen einen TabularSHAP-Erklärer, legen die Eingabespalten auf alle Features fest, die das Modell übernimmt, und geben das Modell und die Zielausgabespalte an, die wir erklären möchten. In diesem Fall versuchen wir, die Wahrscheinlichkeitsausgabe zu erklären, die ein Vektor der Länge 2 ist, und wir betrachten nur die Wahrscheinlichkeit der Klasse 1. Geben Sie targetClasses mit an [0, 1]
, wenn Sie die Wahrscheinlichkeit der Klasse 0 und 1 gleichzeitig erklären möchten. Schließlich werden 100 Zeilen aus den Trainingsdaten für Hintergrunddaten verwendet, die für die Integration von Features in Kernel SHAP verwendet werden.
shap = TabularSHAP(
inputCols=categorical_features + numeric_features,
outputCol="shapValues",
numSamples=5000,
model=model,
targetCol="probability",
targetClasses=[1],
backgroundData=broadcast(training.orderBy(rand()).limit(100).cache()),
)
shap_df = shap.transform(explain_instances)
Sobald wir den resultierenden Dataframe haben, extrahieren wir die Wahrscheinlichkeit der Klasse 1 der Modellausgabe, die SHAP-Werte für die Zielklasse, die ursprünglichen Features und die true-Bezeichnung. Anschließend konvertieren wir es zur Visualisierung in einen Pandas-Dataframe. Für jede Beobachtung ist das erste Element im SHAP-Wertevektor der Basiswert (die mittlere Ausgabe des Hintergrunddatasets), und jedes der folgenden Elemente ist die SHAP-Werte für jedes Feature.
shaps = (
shap_df.withColumn("probability", vec_access(col("probability"), lit(1)))
.withColumn("shapValues", vec2array(col("shapValues").getItem(0)))
.select(
["shapValues", "probability", "label"] + categorical_features + numeric_features
)
)
shaps_local = shaps.toPandas()
shaps_local.sort_values("probability", ascending=False, inplace=True, ignore_index=True)
pd.set_option("display.max_colwidth", None)
shaps_local
Wir verwenden plotly subplot, um die SHAP-Werte zu visualisieren.
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
features = categorical_features + numeric_features
features_with_base = ["Base"] + features
rows = shaps_local.shape[0]
fig = make_subplots(
rows=rows,
cols=1,
subplot_titles="Probability: "
+ shaps_local["probability"].apply("{:.2%}".format)
+ "; Label: "
+ shaps_local["label"].astype(str),
)
for index, row in shaps_local.iterrows():
feature_values = [0] + [row[feature] for feature in features]
shap_values = row["shapValues"]
list_of_tuples = list(zip(features_with_base, feature_values, shap_values))
shap_pdf = pd.DataFrame(list_of_tuples, columns=["name", "value", "shap"])
fig.add_trace(
go.Bar(
x=shap_pdf["name"],
y=shap_pdf["shap"],
hovertext="value: " + shap_pdf["value"].astype(str),
),
row=index + 1,
col=1,
)
fig.update_yaxes(range=[-1, 1], fixedrange=True, zerolinecolor="black")
fig.update_xaxes(type="category", tickangle=45, fixedrange=True)
fig.update_layout(height=400 * rows, title_text="SHAP explanations")
fig.show()