Inferensi ONNX pada Spark
Dalam contoh ini, Anda melatih model LightGBM dan mengonversi model ke format ONNX . Setelah dikonversi, Anda menggunakan model untuk menyimpulkan beberapa data pengujian di Spark.
Contoh ini menggunakan paket dan versi Python berikut:
onnxmltools==1.7.0
lightgbm==3.2.1
Prasyarat
- Lampirkan buku catatan Anda ke lakehouse. Di sisi kiri, pilih Tambahkan untuk menambahkan lakehouse yang ada atau buat lakehouse.
- Anda mungkin perlu menginstal
onnxmltools
dengan menambahkan!pip install onnxmltools==1.7.0
dalam sel kode lalu menjalankan sel.
Memuat data contoh
Untuk memuat data contoh, tambahkan contoh kode berikut ke sel di buku catatan Anda lalu jalankan sel:
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
from synapse.ml.core.platform import *
df = (
spark.read.format("csv")
.option("header", True)
.option("inferSchema", True)
.load(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv"
)
)
display(df)
Output akan terlihat mirip dengan tabel berikut, meskipun nilai dan jumlah baris mungkin berbeda:
Rasio Cakupan Bunga | Bendera Pendapatan Bersih | Ekuitas terhadap Tanggung Jawab |
---|---|---|
0.5641 | 1.0 | 0.0165 |
0.5702 | 1.0 | 0.0208 |
0.5673 | 1.0 | 0.0165 |
Menggunakan LightGBM untuk melatih model
from pyspark.ml.feature import VectorAssembler
from synapse.ml.lightgbm import LightGBMClassifier
feature_cols = df.columns[1:]
featurizer = VectorAssembler(inputCols=feature_cols, outputCol="features")
train_data = featurizer.transform(df)["Bankrupt?", "features"]
model = (
LightGBMClassifier(featuresCol="features", labelCol="Bankrupt?")
.setEarlyStoppingRound(300)
.setLambdaL1(0.5)
.setNumIterations(1000)
.setNumThreads(-1)
.setMaxDeltaStep(0.5)
.setNumLeaves(31)
.setMaxDepth(-1)
.setBaggingFraction(0.7)
.setFeatureFraction(0.7)
.setBaggingFreq(2)
.setObjective("binary")
.setIsUnbalance(True)
.setMinSumHessianInLeaf(20)
.setMinGainToSplit(0.01)
)
model = model.fit(train_data)
Mengonversi model ke format ONNX
Kode berikut mengekspor model terlatih ke booster LightGBM lalu mengonversinya ke format ONNX:
import lightgbm as lgb
from lightgbm import Booster, LGBMClassifier
def convertModel(lgbm_model: LGBMClassifier or Booster, input_size: int) -> bytes:
from onnxmltools.convert import convert_lightgbm
from onnxconverter_common.data_types import FloatTensorType
initial_types = [("input", FloatTensorType([-1, input_size]))]
onnx_model = convert_lightgbm(
lgbm_model, initial_types=initial_types, target_opset=9
)
return onnx_model.SerializeToString()
booster_model_str = model.getLightGBMBooster().modelStr().get()
booster = lgb.Booster(model_str=booster_model_str)
model_payload_ml = convertModel(booster, len(feature_cols))
Setelah konversi, muat payload ONNX ke dalam ONNXModel
dan periksa input dan output model:
from synapse.ml.onnx import ONNXModel
onnx_ml = ONNXModel().setModelPayload(model_payload_ml)
print("Model inputs:" + str(onnx_ml.getModelInputs()))
print("Model outputs:" + str(onnx_ml.getModelOutputs()))
Petakan input model ke nama kolom kerangka data input (FeedDict), dan petakan nama kolom dataframe output ke output model (FetchDict).
onnx_ml = (
onnx_ml.setDeviceType("CPU")
.setFeedDict({"input": "features"})
.setFetchDict({"probability": "probabilities", "prediction": "label"})
.setMiniBatchSize(5000)
)
Menggunakan model untuk inferensi
Untuk melakukan inferensi dengan model, kode berikut membuat data pengujian dan mengubah data melalui model ONNX.
from pyspark.ml.feature import VectorAssembler
import pandas as pd
import numpy as np
n = 1000 * 1000
m = 95
test = np.random.rand(n, m)
testPdf = pd.DataFrame(test)
cols = list(map(str, testPdf.columns))
testDf = spark.createDataFrame(testPdf)
testDf = testDf.union(testDf).repartition(200)
testDf = (
VectorAssembler()
.setInputCols(cols)
.setOutputCol("features")
.transform(testDf)
.drop(*cols)
.cache()
)
display(onnx_ml.transform(testDf))
Output akan terlihat mirip dengan tabel berikut, meskipun nilai dan jumlah baris mungkin berbeda:
Indeks | Fitur | Prediksi | Peluang |
---|---|---|---|
1 | "{"type":1,"values":[0.105... |
0 | "{"0":0.835... |
2 | "{"type":1,"values":[0.814... |
0 | "{"0":0.658... |
Konten terkait
Saran dan Komentar
https://aka.ms/ContentUserFeedback.
Segera hadir: Sepanjang tahun 2024 kami akan menghentikan penggunaan GitHub Issues sebagai mekanisme umpan balik untuk konten dan menggantinya dengan sistem umpan balik baru. Untuk mengetahui informasi selengkapnya, lihat:Kirim dan lihat umpan balik untuk