Sdílet prostřednictvím


Úlohy klasifikace pomocí SynapseML

V tomto článku provedete stejnou úlohu klasifikace dvěma různými způsoby: jednou pomocí prostého pyspark a jednoho použití synapseml knihovny. Dvě metody přinášejí stejný výkon, ale zvýrazňují jednoduchost použití synapseml v porovnání s pyspark.

Úkolem je předpovědět, jestli je hodnocení knihy prodané na Amazonu dobré (hodnocení > 3) nebo špatné na základě textu recenze. Toho dosáhnete trénováním výukových aplikací LogisticRegression s různými hyperparametry a výběrem nejlepšího modelu.

Požadavky

Připojte poznámkový blok k jezeru. Na levé straně vyberte Přidat a přidejte existující jezerní dům nebo vytvořte jezero.

Nastavení

Naimportujte potřebné knihovny Pythonu a získejte relaci Sparku.

from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Čtení dat

Stáhněte si data a přečtěte si je.

rawData = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/BookReviewsFromAmazon10K.parquet"
)
rawData.show(5)

Extrakce funkcí a zpracování dat

Skutečná data jsou složitější než výše uvedená datová sada. U datové sady je běžné mít funkce více typů, jako je text, číselná a kategorická. Pokud chcete ilustrovat, jak obtížné je pracovat s těmito datovými sadami, přidejte do datové sady dvě číselné funkce: počet slov recenze a střední délku slova.

from pyspark.sql.functions import udf
from pyspark.sql.types import *


def wordCount(s):
    return len(s.split())


def wordLength(s):
    import numpy as np

    ss = [len(w) for w in s.split()]
    return round(float(np.mean(ss)), 2)


wordLengthUDF = udf(wordLength, DoubleType())
wordCountUDF = udf(wordCount, IntegerType())
from synapse.ml.stages import UDFTransformer

wordLength = "wordLength"
wordCount = "wordCount"
wordLengthTransformer = UDFTransformer(
    inputCol="text", outputCol=wordLength, udf=wordLengthUDF
)
wordCountTransformer = UDFTransformer(
    inputCol="text", outputCol=wordCount, udf=wordCountUDF
)
from pyspark.ml import Pipeline

data = (
    Pipeline(stages=[wordLengthTransformer, wordCountTransformer])
    .fit(rawData)
    .transform(rawData)
    .withColumn("label", rawData["rating"] > 3)
    .drop("rating")
)
data.show(5)

Klasifikace pomocí pysparku

Pokud chcete zvolit nejlepší klasifikátor LogisticRegression pomocí pyspark knihovny, musíte explicitně provést následující kroky:

  1. Zpracování funkcí:
    • Tokenizace textového sloupce
    • Hash tokenizovaného sloupce do vektoru pomocí hashování
    • Sloučení číselných prvků s vektorem
  2. Zpracovat sloupec popisku: přetypujte ho na správný typ.
  3. Trénování několika algoritmů LogistickéRegrese v train datové sadě pomocí různých hyperparametrů
  4. Vypočítá oblast pod křivkou ROC pro každý natrénovaný model a vybere model s nejvyšší metrikou vypočítanou v test datové sadě.
  5. Vyhodnocení nejlepšího modelu v validation sadě
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml.feature import VectorAssembler

# Featurize text column
tokenizer = Tokenizer(inputCol="text", outputCol="tokenizedText")
numFeatures = 10000
hashingScheme = HashingTF(
    inputCol="tokenizedText", outputCol="TextFeatures", numFeatures=numFeatures
)
tokenizedData = tokenizer.transform(data)
featurizedData = hashingScheme.transform(tokenizedData)

# Merge text and numeric features in one feature column
featureColumnsArray = ["TextFeatures", "wordCount", "wordLength"]
assembler = VectorAssembler(inputCols=featureColumnsArray, outputCol="features")
assembledData = assembler.transform(featurizedData)

# Select only columns of interest
# Convert rating column from boolean to int
processedData = assembledData.select("label", "features").withColumn(
    "label", assembledData.label.cast(IntegerType())
)
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.classification import LogisticRegression

# Prepare data for learning
train, test, validation = processedData.randomSplit([0.60, 0.20, 0.20], seed=123)

# Train the models on the 'train' data
lrHyperParams = [0.05, 0.1, 0.2, 0.4]
logisticRegressions = [
    LogisticRegression(regParam=hyperParam) for hyperParam in lrHyperParams
]
evaluator = BinaryClassificationEvaluator(
    rawPredictionCol="rawPrediction", metricName="areaUnderROC"
)
metrics = []
models = []

# Select the best model
for learner in logisticRegressions:
    model = learner.fit(train)
    models.append(model)
    scoredData = model.transform(test)
    metrics.append(evaluator.evaluate(scoredData))
bestMetric = max(metrics)
bestModel = models[metrics.index(bestMetric)]

# Get AUC on the validation dataset
scoredVal = bestModel.transform(validation)
print(evaluator.evaluate(scoredVal))

Klasifikace pomocí SynapseML

Potřebné kroky synapseml jsou jednodušší:

  1. TrainClassifier Estimátor interně ztěžuje data, pokud sloupce vybrané v traindatové sadě , testvalidation představují funkce.

  2. FindBestModel Estimátor najde nejlepší model z fondu natrénovaných modelů vyhledáním modelu, který se v datové sadě nejlépe hodí test pro danou zadanou metriku.

  3. ComputeModelStatistics Transformer vypočítá různé metriky pro vyhodnocenou datovou sadu (v našem případě validation datovou sadu) najednou.

from synapse.ml.train import TrainClassifier, ComputeModelStatistics
from synapse.ml.automl import FindBestModel

# Prepare data for learning
train, test, validation = data.randomSplit([0.60, 0.20, 0.20], seed=123)

# Train the models on the 'train' data
lrHyperParams = [0.05, 0.1, 0.2, 0.4]
logisticRegressions = [
    LogisticRegression(regParam=hyperParam) for hyperParam in lrHyperParams
]
lrmodels = [
    TrainClassifier(model=lrm, labelCol="label", numFeatures=10000).fit(train)
    for lrm in logisticRegressions
]

# Select the best model
bestModel = FindBestModel(evaluationMetric="AUC", models=lrmodels).fit(test)


# Get AUC on the validation dataset
predictions = bestModel.transform(validation)
metrics = ComputeModelStatistics().transform(predictions)
print(
    "Best model's AUC on validation set = "
    + "{0:.2f}%".format(metrics.first()["AUC"] * 100)
)