Microsoft Fabric에서 scikit-learn을 사용하여 모델을 학습하는 방법

Scikit-learn(scikit-learn.org)은 인기 있는 오픈 소스 기계 학습 프레임워크입니다. 감독 및 감독되지 않은 학습에 자주 사용됩니다. 또한 모델 맞춤, 데이터 전처리, 모델 선택, 모델 평가 등을 위한 다양한 도구를 제공합니다.

이 섹션에서는 Scikit-Learn 모델의 반복을 학습하고 추적하는 방법의 예를 살펴봅니다.

scikit-learn 설치

scikit-learn을 시작하려면 전자 필기장 내에 설치되어 있는지 확인해야 합니다. 다음 명령을 사용하여 환경에 scikit-learn 버전을 설치하거나 업그레이드할 수 있습니다.

%pip install scikit-learn

다음으로, MLFLow API를 사용하여 기계 학습 실험을 만듭니다. MLflow set_experiment() API는 아직 없는 경우 새 기계 학습 실험을 만듭니다.

import mlflow

mlflow.set_experiment("sample-sklearn")

scikit-learn 모델 학습

실험을 만든 후 샘플 데이터 세트를 만들고 로지스틱 회귀 모델을 만듭니다. 또한 MLflow 실행을 시작하고 메트릭, 매개 변수 및 최종 로지스틱 회귀 모델을 추적합니다. 최종 모델을 생성한 후에는 추가 추적을 위해 결과 모델도 저장합니다.

import mlflow.sklearn
import numpy as np
from sklearn.linear_model import LogisticRegression
from mlflow.models.signature import infer_signature

with mlflow.start_run() as run:

    lr = LogisticRegression()
    X = np.array([-2, -1, 0, 1, 2, 1]).reshape(-1, 1)
    y = np.array([0, 0, 1, 1, 1, 0])
    lr.fit(X, y)
    score = lr.score(X, y)
    signature = infer_signature(X, y)

    print("log_metric.")
    mlflow.log_metric("score", score)

    print("log_params.")
    mlflow.log_param("alpha", "alpha")

    print("log_model.")
    mlflow.sklearn.log_model(lr, "sklearn-model", signature=signature)
    print("Model saved in run_id=%s" % run.info.run_id)

    print("register_model.")
    mlflow.register_model(

        "runs:/{}/sklearn-model".format(run.info.run_id), "sample-sklearn"
    )
    print("All done")

샘플 데이터 세트에서 모델 로드 및 평가

모델이 저장되면 추론을 위해 로드할 수도 있습니다. 이렇게 하려면 모델을 로드하고 샘플 데이터 세트에서 유추를 실행합니다.

# Inference with loading the logged model
from synapse.ml.predict import MLflowTransformer

spark.conf.set("spark.synapse.ml.predict.enabled", "true")

model = MLflowTransformer(
    inputCols=["x"],
    outputCol="prediction",
    modelName="sample-sklearn",
    modelVersion=1,
)

test_spark = spark.createDataFrame(
    data=np.array([-2, -1, 0, 1, 2, 1]).reshape(-1, 1).tolist(), schema=["x"]
)

batch_predictions = model.transform(test_spark)

batch_predictions.show()