Zelfstudie: Gegevens analyseren met glm
Meer informatie over het uitvoeren van lineaire en logistieke regressie met behulp van een gegeneraliseerd lineair model (GLM) in Azure Databricks. glm
past bij een gegeneraliseerd lineair model, vergelijkbaar met R's glm()
.
Syntaxis: glm(formula, data, family...)
Parameters:
formula
: Symbolische beschrijving van het model dat moet worden gemonteerd, bijvoorbeeld:ResponseVariable ~ Predictor1 + Predictor2
. Ondersteunde operators:~
,+
,-
en.
data
: Elk SparkDataFramefamily
: Tekenreeks,"gaussian"
voor lineaire regressie of"binomial"
voor logistieke regressielambda
: Numerieke parameter, Regularisatieparameteralpha
: Numerieke, elastic-net mengparameter
Uitvoer: MLlib PipelineModel
Deze zelfstudie laat zien hoe u lineaire en logistieke regressie uitvoert op de gegevensset diamanten.
Diamantgegevens laden en splitsen in trainings- en testsets
require(SparkR)
# Read diamonds.csv dataset as SparkDataFrame
diamonds <- read.df("/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv",
source = "com.databricks.spark.csv", header="true", inferSchema = "true")
diamonds <- withColumnRenamed(diamonds, "", "rowID")
# Split data into Training set and Test set
trainingData <- sample(diamonds, FALSE, 0.7)
testData <- except(diamonds, trainingData)
# Exclude rowIDs
trainingData <- trainingData[, -1]
testData <- testData[, -1]
print(count(diamonds))
print(count(trainingData))
print(count(testData))
head(trainingData)
Een lineair regressiemodel trainen met behulp van glm()
In deze sectie wordt beschreven hoe u de prijs van een diamant kunt voorspellen op basis van de kenmerken door een lineair regressiemodel te trainen met behulp van de trainingsgegevens.
Er is een combinatie van categorische kenmerken (knippen - Ideaal, Premium, Zeer goed...) en doorlopende kenmerken (diepte, karaat). SparkR codeert deze functies automatisch, zodat u deze functies niet handmatig hoeft te coderen.
# Family = "gaussian" to train a linear regression model
lrModel <- glm(price ~ ., data = trainingData, family = "gaussian")
# Print a summary of the trained model
summary(lrModel)
Gebruik predict()
de testgegevens om te zien hoe goed het model werkt voor nieuwe gegevens.
Syntaxis: predict(model, newData)
Parameters:
model
: MLlib-modelnewData
: SparkDataFrame, meestal uw testset
Uitvoer: SparkDataFrame
# Generate predictions using the trained model
predictions <- predict(lrModel, newData = testData)
# View predictions against mpg column
display(select(predictions, "price", "prediction"))
Evalueer het model.
errors <- select(predictions, predictions$price, predictions$prediction, alias(predictions$price - predictions$prediction, "error"))
display(errors)
# Calculate RMSE
head(select(errors, alias(sqrt(sum(errors$error^2 , na.rm = TRUE) / nrow(errors)), "RMSE")))
Een logistiek regressiemodel trainen met behulp van glm()
In deze sectie wordt beschreven hoe u een logistieke regressie maakt op dezelfde gegevensset om de knipsel van een diamant te voorspellen op basis van een aantal functies.
Logistieke regressie in MLlib ondersteunt binaire classificatie. Als u het algoritme in dit voorbeeld wilt testen, moet u de gegevens subseten om met twee labels te werken.
# Subset data to include rows where diamond cut = "Premium" or diamond cut = "Very Good"
trainingDataSub <- subset(trainingData, trainingData$cut %in% c("Premium", "Very Good"))
testDataSub <- subset(testData, testData$cut %in% c("Premium", "Very Good"))
# Family = "binomial" to train a logistic regression model
logrModel <- glm(cut ~ price + color + clarity + depth, data = trainingDataSub, family = "binomial")
# Print summary of the trained model
summary(logrModel)
# Generate predictions using the trained model
predictionsLogR <- predict(logrModel, newData = testDataSub)
# View predictions against label column
display(select(predictionsLogR, "label", "prediction"))
Evalueer het model.
errorsLogR <- select(predictionsLogR, predictionsLogR$label, predictionsLogR$prediction, alias(abs(predictionsLogR$label - predictionsLogR$prediction), "error"))
display(errorsLogR)