Utvärdera en maskininlärningsmodell

Slutförd

Så du har tränat en förutsägelsemodell. Hur vet du om det är bra?

För att utvärdera en modell måste du använda valideringsdata som du höll tillbaka. För övervakade maskininlärningsmodeller gör den här metoden att du kan jämföra etiketterna som förutsägs av modellen med de faktiska etiketterna i valideringsdatauppsättningen. Genom att jämföra förutsägelserna med de sanna etikettvärdena kan du beräkna ett intervall med utvärderingsmått för att kvantifiera modellens förutsägelseprestanda.

Utvärdera regressionsmodeller

Regressionsmodeller förutsäger numeriska värden, så alla utvärderingar av modellens förutsägelseprestanda kräver att du överväger skillnaderna mellan de förutsagda värdena och de faktiska etikettvärdena i valideringsdatauppsättningen. Eftersom valideringsdatauppsättningen innehåller flera fall, varav vissa kan ha mer exakta förutsägelser än andra, behöver du något sätt att aggregera skillnaderna och fastställa ett övergripande mått för prestanda. Vanliga mått som används för att utvärdera en regressionsmodell är:

  • Genomsnittligt kvadratfel (MSE): Det här måttet beräknas genom att skillnaderna mellan varje förutsägelse och det faktiska värdet kvadreras och kvadratskillnaderna läggs ihop och medelvärdet beräknas (genomsnitt). Om du quaring värdena gör skillnaderna absoluta (ignorerar om skillnaden är negativ eller positiv) och ger större vikt till större skillnader.
  • RMSE (Root Mean Squared Error): Även om MSE-måttet är en bra indikation på felnivån i modellförutsägelserna relaterar det inte till den faktiska måttenheten för etiketten. I en modell som förutsäger försäljning (i us-dollar) representerar MSE-värdet till exempel faktiskt de kvadratiska dollarvärdena. För att utvärdera hur långt bort förutsägelserna är i termer av dollar måste du beräkna kvadratroten för MSE.
  • Bestämningskoefficient (R2): R2-måttet mäter korrelationen mellan kvadratfunktionen och förutsagda värden. Detta resulterar i ett värde mellan 0 och 1 som mäter mängden varians som kan förklaras av modellen. Ju närmare det här värdet är 1, desto bättre förutsäger modellen.

De flesta maskininlärningsramverk tillhandahåller klasser som beräknar dessa mått åt dig. Spark MLlib-biblioteket tillhandahåller till exempel klassen RegressionEvaluator , som du kan använda enligt det här kodexemplet:

from pyspark.ml.evaluation import RegressionEvaluator

# Inference predicted labels from validation data
predictions_df = model.transform(validation_df)

# Assume predictions_df includes a 'prediction' column with the predicted labels
# and a 'label' column with the actual known label values

# Use an evaluator to get metrics
evaluator = RegressionEvaluator()
evaluator.setPredictionCol("prediction")
mse = evaluator.evaluate(predictions_df, {evaluator.metricName: "mse"})
rmse = evaluator.evaluate(predictions_df, {evaluator.metricName: "rmse"})
r2 = evaluator.evaluate(predictions_df, {evaluator.metricName: "r2"})
print("MSE:", str(mse))
print("RMSE:", str(rmse))
print("R2", str(r2))

Utvärdera klassificeringsmodeller

Klassificeringsmodeller förutsäger kategoriska etiketter (klasser) genom att beräkna ett sannolikhetsvärde för varje möjlig klass och välja klassetiketten med högst sannolikhet. De mått som används för att utvärdera en klassificeringsmodell återspeglar hur ofta dessa klassförutsägelser var korrekta jämfört med de faktiska kända etiketterna i valideringsdatauppsättningen. Vanliga mått som används för att utvärdera en klassificeringsmodell är:

  • Noggrannhet: Ett enkelt mått som anger andelen klassförutsägelser som gjorts av modellen som var korrekt. Även om detta kan verka som det uppenbara sättet att utvärdera prestanda för en klassificeringsmodell, bör du överväga ett scenario där en modell används för att förutsäga om en person kommer att pendla till jobbet med bil, buss eller spårvagn. Anta att 95 % av fallen i valideringsuppsättningen använder en bil, 3 % tar bussen och 2 % tar en spårvagn. En modell som helt enkelt alltid förutsäger en bil kommer att vara 95% korrekt - även om den faktiskt inte har någon förutsägande förmåga att diskriminera mellan de tre klasserna.
  • Mått per klass:
    • Precision: Andelen förutsägelser för den angivna klassen som var korrekt. Detta mäts som antalet sanna positiva identifieringar (korrekta förutsägelser för den här klassen) dividerat med det totala antalet förutsägelser för den här klassen (inklusive falska positiva identifieringar).
    • Kom ihåg: Andelen faktiska instanser av den här klassen som var korrekt förutsagda (sanna positiva identifieringar dividerade med det totala antalet om instanser av den här klassen i valideringsdatauppsättningen, inklusive falska negativa - fall där modellen felaktigt förutsade en annan klass).
    • F1-poäng: Ett kombinerat mått för precision och träffsäkerhet (beräknat som det harmoniska medelvärdet av precision och träffsäkerhet).
  • Kombinerad (viktad) precision, träffsäkerhet och F1-mått för alla klasser.

När det gäller regression omfattar de flesta maskininlärningsramverk klasser som kan beräkna klassificeringsmått. Följande kod använder till exempel MulticlassClassificationEvaluator i Spark MLlib-biblioteket.

from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Inference predicted labels from validation data
predictions_df = model.transform(validation_df)

# Assume predictions_df includes a 'prediction' column with the predicted labels
# and a 'label' column with the actual known label values

# Use an evaluator to get metrics
accuracy = evaluator.evaluate(predictions_df, {evaluator.metricName:"accuracy"})
print("Accuracy:", accuracy)

labels = [0,1,2]
print("\nIndividual class metrics:")
for label in sorted(labels):
    print ("Class %s" % (label))
    precision = evaluator.evaluate(predictions_df, {evaluator.metricLabel:label,
                                                    evaluator.metricName:"precisionByLabel"})
    print("\tPrecision:", precision)
    recall = evaluator.evaluate(predictions_df, {evaluator.metricLabel:label,
                                                 evaluator.metricName:"recallByLabel"})
    print("\tRecall:", recall)
    f1 = evaluator.evaluate(predictions_df, {evaluator.metricLabel:label,
                                             evaluator.metricName:"fMeasureByLabel"})
    print("\tF1 Score:", f1)
    
overallPrecision = evaluator.evaluate(predictions_df, {evaluator.metricName:"weightedPrecision"})
print("Overall Precision:", overallPrecision)
overallRecall = evaluator.evaluate(predictions_df, {evaluator.metricName:"weightedRecall"})
print("Overall Recall:", overallRecall)
overallF1 = evaluator.evaluate(predictions_df, {evaluator.metricName:"weightedFMeasure"})
print("Overall F1 Score:", overallF1)

Utvärdera oövervakade klustermodeller

Oövervakade klustermodeller har inte kända sanna etikettvärden. Målet med klustringsmodellen är att gruppera liknande fall i kluster baserat på deras funktioner. För att utvärdera ett kluster behöver du ett mått som anger nivån på separationen mellan kluster. Du kan se de klustrade fallen som ritade punkter i flerdimensionellt utrymme. Punkter i samma kluster bör vara nära varandra och långt borta från punkter i ett annat kluster.

Ett sådant mått är silhuettmåttet, som beräknar kvadratiskt euklidiskt avstånd och ger en indikation på konsekvens i kluster. Silhuettvärden kan vara mellan 1 och -1, med ett värde nära 1 som anger att punkterna i ett kluster ligger nära de andra punkterna i samma kluster och långt från de andra klustrens punkter.

Spark MLlib-biblioteket tillhandahåller klassen ClusteringEvaluator , som beräknar Silhuetten för förutsägelserna från en klustermodell enligt följande:

from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.ml.linalg import Vectors

# Inference predicted labels from validation data
predictions_df = model.transform(validation_df)

# Assume predictions_df includes a 'prediction' column with the predicted cluster

# Use an evaluator to get metrics
evaluator = ClusteringEvaluator(predictionCol="prediction")
silhouetteVal  = evaluator.evaluate(predictions_df)
print(silhouetteVal)