Tutorial: Membuat, melatih, dan mengevaluasi model peningkatan

Tutorial ini menyajikan contoh end-to-end dari alur kerja Ilmu Data Synapse, di Microsoft Fabric. Anda mempelajari cara membuat, melatih, dan mengevaluasi model peningkatan dan menerapkan teknik pemodelan peningkatan.

Prasyarat

Ikuti di buku catatan

Anda bisa mengikuti di buku catatan dengan salah satu dari dua cara:

  • Buka dan jalankan notebook bawaan dalam pengalaman Ilmu Data Synapse
  • Unggah buku catatan Anda dari GitHub ke pengalaman Ilmu Data Synapse

Buka buku catatan bawaan

Contoh notebook pemodelan Uplift menyertai tutorial ini. Kunjungi Untuk membuka buku catatan sampel bawaan tutorial dalam pengalaman Ilmu Data Synapse:1. Buka halaman beranda Synapse Ilmu Data. 1. Pilih Gunakan sampel. 1. Pilih sampel yang sesuai:* Dari tab alur kerja End-to-end (Python) default, jika sampelnya adalah untuk tutorial Python. * Dari tab Alur kerja end-to-end (R), jika sampel adalah untuk tutorial R. * Dari tab Tutorial cepat, jika sampel adalah untuk tutorial cepat.1. Lampirkan lakehouse ke buku catatan sebelum Anda mulai menjalankan kode. untuk informasi selengkapnya tentang mengakses buku catatan sampel bawaan untuk tutorial.

Untuk membuka buku catatan sampel bawaan tutorial dalam pengalaman Ilmu Data Synapse:

  1. Buka halaman beranda Synapse Ilmu Data

  2. Pilih Gunakan sampel

  3. Pilih sampel yang sesuai:

    1. Dari tab Alur kerja end-to-end (Python) default, jika sampelnya adalah untuk tutorial Python
    2. Dari tab Alur kerja end-to-end (R), jika sampelnya adalah untuk tutorial R
    3. Dari tab Tutorial cepat, jika sampel adalah untuk tutorial cepat
  4. Melampirkan lakehouse ke buku catatan sebelum Anda mulai menjalankan kode

Mengimpor notebook dari GitHub

Notebook AIsample - Uplift Modeling.ipynb menyertai tutorial ini.

Untuk membuka buku catatan yang menyertai tutorial ini, ikuti instruksi dalam Menyiapkan sistem Anda untuk tutorial ilmu data, untuk mengimpor buku catatan ke ruang kerja Anda.

Anda bisa membuat buku catatan baru jika Anda lebih suka menyalin dan menempelkan kode dari halaman ini.

Pastikan untuk melampirkan lakehouse ke buku catatan sebelum Anda mulai menjalankan kode.

Langkah 1: Muat data

Dataset

Criteo AI Lab membuat himpunan data. Himpunan data tersebut memiliki baris 13M. Setiap baris mewakili satu pengguna. Setiap baris memiliki 12 fitur, indikator perawatan, dan dua label biner yang mencakup kunjungan dan konversi.

f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 konversi perawatan

  • f0 - f11: nilai fitur (nilai padat dan mengambang)
  • perawatan: apakah pengguna secara acak menargetkan perawatan (misalnya, iklan) (1 = perawatan, 0 = kontrol)
  • konversi: apakah konversi terjadi (misalnya, melakukan pembelian) untuk pengguna (biner, label)
  • kunjungi: apakah konversi terjadi (misalnya, melakukan pembelian) untuk pengguna (biner, label)

Kutipan

Himpunan data yang digunakan untuk buku catatan ini memerlukan kutipan BibTex ini:

@inproceedings{Diemert2018,
author = {{Diemert Eustache, Betlei Artem} and Renaudin, Christophe and Massih-Reza, Amini},
title={A Large Scale Benchmark for Uplift Modeling},
publisher = {ACM},
booktitle = {Proceedings of the AdKDD and TargetAd Workshop, KDD, London,United Kingdom, August, 20, 2018},
year = {2018}
}

Tip

Dengan menentukan parameter berikut, Anda dapat menerapkan buku catatan ini pada himpunan data yang berbeda dengan mudah.

IS_CUSTOM_DATA = False  # If True, the user must upload the dataset manually
DATA_FOLDER = "Files/uplift-modelling"
DATA_FILE = "criteo-research-uplift-v2.1.csv"

# Data schema
FEATURE_COLUMNS = [f"f{i}" for i in range(12)]
TREATMENT_COLUMN = "treatment"
LABEL_COLUMN = "visit"

EXPERIMENT_NAME = "aisample-upliftmodelling"  # MLflow experiment name

Mengimpor pustaka

Sebelum memproses, Anda harus mengimpor pustaka Spark dan SynapseML yang diperlukan. Anda juga harus mengimpor pustaka visualisasi data - misalnya, Seaborn, pustaka visualisasi data Python. Pustaka visualisasi data menyediakan antarmuka tingkat tinggi untuk membangun sumber daya visual pada DataFrame dan array. Pelajari lebih lanjut tentang Spark, SynapseML, dan Seaborn.

import os
import gzip

import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *

import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.style as style
import seaborn as sns

%matplotlib inline

from synapse.ml.featurize import Featurize
from synapse.ml.core.spark import FluentAPI
from synapse.ml.lightgbm import *
from synapse.ml.train import ComputeModelStatistics

import mlflow

Mengunduh himpunan data dan mengunggah ke lakehouse

Kode ini mengunduh versi himpunan data yang tersedia untuk umum, lalu menyimpan sumber daya data tersebut di Fabric lakehouse.

Penting

Pastikan Anda Menambahkan lakehouse ke buku catatan sebelum menjalankannya. Kegagalan untuk melakukannya akan mengakibatkan kesalahan.

if not IS_CUSTOM_DATA:
    # Download demo data files into lakehouse if not exist
    import os, requests

    remote_url = "http://go.criteo.net/criteo-research-uplift-v2.1.csv.gz"
    download_file = "criteo-research-uplift-v2.1.csv.gz"
    download_path = f"/lakehouse/default/{DATA_FOLDER}/raw"

    if not os.path.exists("/lakehouse/default"):
        raise FileNotFoundError("Default lakehouse not found, please add a lakehouse and restart the session.")
    os.makedirs(download_path, exist_ok=True)
    if not os.path.exists(f"{download_path}/{DATA_FILE}"):
        r = requests.get(f"{remote_url}", timeout=30)
        with open(f"{download_path}/{download_file}", "wb") as f:
            f.write(r.content)
        with gzip.open(f"{download_path}/{download_file}", "rb") as fin:
            with open(f"{download_path}/{DATA_FILE}", "wb") as fout:
                fout.write(fin.read())
    print("Downloaded demo data files into lakehouse.")

Mulai rekam runtime buku catatan ini.

# Record the notebook running time
import time

ts = time.time()

Menyiapkan pelacakan eksperimen MLflow

Untuk memperluas kemampuan pengelogan MLflow, autologging secara otomatis mengambil nilai parameter input dan metrik output model pembelajaran mesin selama pelatihannya. Informasi ini kemudian dicatat ke ruang kerja, di mana API MLflow atau eksperimen yang sesuai di ruang kerja dapat mengakses dan memvisualisasikannya. Kunjungi sumber daya ini untuk informasi selengkapnya tentang autologging.

# Set up the MLflow experiment
import mlflow

mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.autolog(disable=True)  # Disable MLflow autologging

Catatan

Untuk menonaktifkan autologging Microsoft Fabric dalam sesi notebook, panggil mlflow.autolog() dan atur disable=True.

Membaca data dari lakehouse

Baca data mentah dari bagian Lakehouse Files dan tambahkan lebih banyak kolom untuk bagian tanggal yang berbeda. Informasi yang sama digunakan untuk membuat tabel delta yang dipartisi.

raw_df = spark.read.csv(f"{DATA_FOLDER}/raw/{DATA_FILE}", header=True, inferSchema=True).cache()

Langkah 2: Analisis data eksploratif

display Gunakan perintah untuk melihat statistik tingkat tinggi tentang himpunan data. Anda juga dapat menampilkan tampilan Bagan untuk memvisualisasikan subset himpunan data dengan mudah.

display(raw_df.limit(20))

Periksa persentase pengguna yang mengunjungi, persentase pengguna yang mengonversi, dan persentase pengunjung yang mengonversi.

raw_df.select(
    F.mean("visit").alias("Percentage of users that visit"),
    F.mean("conversion").alias("Percentage of users that convert"),
    (F.sum("conversion") / F.sum("visit")).alias("Percentage of visitors that convert"),
).show()

Analisis menunjukkan bahwa 4,9% pengguna dari grup perawatan - pengguna yang menerima perawatan, atau iklan - mengunjungi toko online. Hanya 3,8% pengguna dari grup kontrol - pengguna yang tidak pernah menerima perawatan, atau tidak pernah ditawarkan atau terkena iklan - melakukan hal yang sama. Selain itu, 0,31% dari semua pengguna dari grup perawatan yang dikonversi, atau melakukan pembelian - sementara hanya 0,19% pengguna dari grup kontrol yang melakukannya. Akibatnya, tingkat konversi pengunjung yang melakukan pembelian, yang juga anggota kelompok perawatan, adalah 6,36%, dibandingkan hanya 5,07%** untuk pengguna grup kontrol. Berdasarkan hasil ini, perawatan berpotensi meningkatkan tingkat kunjungan sekitar 1%, dan tingkat konversi pengunjung sekitar 1,3%. Perawatan ini menyebabkan peningkatan yang signifikan.

Langkah 3: Tentukan model untuk pelatihan

Menyiapkan pelatihan dan menguji himpunan data

Di sini, Anda cocok dengan transformator Featurize ke raw_df DataFrame, untuk mengekstrak fitur dari kolom input yang ditentukan dan menghasilkan fitur tersebut ke kolom baru bernama features.

DataFrame yang dihasilkan disimpan dalam DataFrame baru bernama df.

transformer = Featurize().setOutputCol("features").setInputCols(FEATURE_COLUMNS).fit(raw_df)
df = transformer.transform(raw_df)
# Split the DataFrame into training and test sets, with a 80/20 ratio and a seed of 42
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

# Print the training and test dataset sizes
print("Size of train dataset: %d" % train_df.count())
print("Size of test dataset: %d" % test_df.count())

# Group the training dataset by the treatment column, and count the number of occurrences of each value
train_df.groupby(TREATMENT_COLUMN).count().show()

Menyiapkan himpunan data perawatan dan kontrol

Setelah membuat himpunan data pelatihan dan pengujian, Anda juga harus membentuk himpunan data perawatan dan kontrol, untuk melatih model pembelajaran mesin untuk mengukur peningkatan.

# Extract the treatment and control DataFrames
treatment_train_df = train_df.where(f"{TREATMENT_COLUMN} > 0")
control_train_df = train_df.where(f"{TREATMENT_COLUMN} = 0")

Setelah menyiapkan data, Anda dapat melanjutkan untuk melatih model dengan LightGBM.

Pemodelan uplift: T-Learner dengan LightGBM

Meta-learners adalah sekumpulan algoritma, dibangun di atas algoritma pembelajaran mesin seperti LightGBM, Xgboost, dll. Mereka membantu memperkirakan efek perawatan rata-rata bersyarah, atau CATE. T-learner adalah meta-learner yang tidak menggunakan satu model. Sebagai gantinya, T-learner menggunakan satu model per variabel perawatan. Oleh karena itu, dua model dikembangkan dan kami menyebut meta-learner sebagai T-learner. T-learner menggunakan beberapa model pembelajaran mesin untuk mengatasi masalah sepenuhnya membuang perawatan, dengan memaksa pelajar untuk terlebih dahulu membaginya.

mlflow.autolog(exclusive=False)
classifier = (
    LightGBMClassifier()
    .setFeaturesCol("features")  # Set the column name for features
    .setNumLeaves(10)  # Set the number of leaves in each decision tree
    .setNumIterations(100)  # Set the number of boosting iterations
    .setObjective("binary")  # Set the objective function for binary classification
    .setLabelCol(LABEL_COLUMN)  # Set the column name for the label
)

# Start a new MLflow run with the name "uplift"
active_run = mlflow.start_run(run_name="uplift")

# Start a new nested MLflow run with the name "treatment"
with mlflow.start_run(run_name="treatment", nested=True) as treatment_run:
    treatment_run_id = treatment_run.info.run_id  # Get the ID of the treatment run
    treatment_model = classifier.fit(treatment_train_df)  # Fit the classifier on the treatment training data

# Start a new nested MLflow run with the name "control"
with mlflow.start_run(run_name="control", nested=True) as control_run:
    control_run_id = control_run.info.run_id  # Get the ID of the control run
    control_model = classifier.fit(control_train_df)  # Fit the classifier on the control training data
     

Menggunakan himpunan data pengujian untuk prediksi

Di sini, Anda menggunakan treatment_model dan control_model, keduanya ditentukan sebelumnya, untuk mengubah himpunan test_df data pengujian. Kemudian, Anda menghitung peningkatan yang diprediksi. Anda menentukan peningkatan yang diprediksi sebagai perbedaan antara hasil perawatan yang diprediksi dan hasil kontrol yang diprediksi. Semakin besar perbedaan peningkatan yang diprediksi ini, semakin besar efektivitas perawatan (misalnya, iklan) pada individu atau subgrup.

getPred = F.udf(lambda v: float(v[1]), FloatType())

# Cache the resulting DataFrame for easier access
test_pred_df = (
    test_df.mlTransform(treatment_model)
    .withColumn("treatment_pred", getPred("probability"))
    .drop("rawPrediction", "probability", "prediction")
    .mlTransform(control_model)
    .withColumn("control_pred", getPred("probability"))
    .drop("rawPrediction", "probability", "prediction")
    .withColumn("pred_uplift", F.col("treatment_pred") - F.col("control_pred"))
    .select(TREATMENT_COLUMN, LABEL_COLUMN, "treatment_pred", "control_pred", "pred_uplift")
    .cache()
)

# Display the first twenty rows of the resulting DataFrame
display(test_pred_df.limit(20))

Melakukan evaluasi model

Karena peningkatan aktual tidak dapat diamati untuk setiap individu, Anda perlu mengukur peningkatan atas sekelompok individu. Anda menggunakan Kurva Uplift yang memplot peningkatan kumulatif nyata di seluruh populasi.

Cuplikan layar bagan yang memperlihatkan kurva model peningkatan yang dinormalisasi versus perawatan acak.

Sumbu x mewakili rasio populasi yang dipilih untuk perawatan. Nilai 0 menunjukkan tidak ada kelompok perawatan - tidak ada yang terpapar, atau ditawarkan, perawatan. Nilai 1 menunjukkan kelompok perawatan penuh - semua orang terpapar, atau ditawarkan, perawatan. Sumbu y menunjukkan ukuran peningkatan. Tujuannya adalah untuk menemukan ukuran kelompok perawatan, atau persentase populasi yang akan ditawarkan atau diekspos ke perawatan (misalnya, iklan). Pendekatan ini mengoptimalkan pemilihan target, untuk mengoptimalkan hasilnya.

Pertama, beri peringkat urutan DataFrame pengujian berdasarkan peningkatan yang diprediksi. Peningkatan yang diprediksi adalah perbedaan antara hasil perawatan yang diprediksi dan hasil kontrol yang diprediksi.

# Compute the percentage rank of the predicted uplift values in descending order, and display the top twenty rows
test_ranked_df = test_pred_df.withColumn("percent_rank", F.percent_rank().over(Window.orderBy(F.desc("pred_uplift"))))

display(test_ranked_df.limit(20))

Selanjutnya, hitung persentase kumulatif kunjungan dalam kelompok perawatan dan kontrol.

# Calculate the number of control and treatment samples
C = test_ranked_df.where(f"{TREATMENT_COLUMN} == 0").count()
T = test_ranked_df.where(f"{TREATMENT_COLUMN} != 0").count()

# Add columns to the DataFrame to calculate the control and treatment cumulative sum
test_ranked_df = (
    test_ranked_df.withColumn(
        "control_label",
        F.when(F.col(TREATMENT_COLUMN) == 0, F.col(LABEL_COLUMN)).otherwise(0),
    )
    .withColumn(
        "treatment_label",
        F.when(F.col(TREATMENT_COLUMN) != 0, F.col(LABEL_COLUMN)).otherwise(0),
    )
    .withColumn(
        "control_cumsum",
        F.sum("control_label").over(Window.orderBy("percent_rank")) / C,
    )
    .withColumn(
        "treatment_cumsum",
        F.sum("treatment_label").over(Window.orderBy("percent_rank")) / T,
    )
)

# Display the first 20 rows of the dataframe
display(test_ranked_df.limit(20))

Akhirnya, pada setiap persentase, hitung peningkatan grup sebagai perbedaan antara persentase kumulatif kunjungan antara kelompok perawatan dan kontrol.

test_ranked_df = test_ranked_df.withColumn("group_uplift", F.col("treatment_cumsum") - F.col("control_cumsum")).cache()
display(test_ranked_df.limit(20))

Sekarang, plot kurva uplift untuk prediksi himpunan data pengujian. Anda harus mengonversi PySpark DataFrame ke Pandas DataFrame sebelum merencanakan.

def uplift_plot(uplift_df):
    """
    Plot the uplift curve
    """
    gain_x = uplift_df.percent_rank
    gain_y = uplift_df.group_uplift
    # Plot the data
    fig = plt.figure(figsize=(10, 6))
    mpl.rcParams["font.size"] = 8

    ax = plt.plot(gain_x, gain_y, color="#2077B4", label="Normalized Uplift Model")

    plt.plot(
        [0, gain_x.max()],
        [0, gain_y.max()],
        "--",
        color="tab:orange",
        label="Random Treatment",
    )
    plt.legend()
    plt.xlabel("Porportion Targeted")
    plt.ylabel("Uplift")
    plt.grid()

    return fig, ax


test_ranked_pd_df = test_ranked_df.select(["pred_uplift", "percent_rank", "group_uplift"]).toPandas()
fig, ax = uplift_plot(test_ranked_pd_df)

mlflow.log_figure(fig, "UpliftCurve.png")

Cuplikan layar bagan yang memperlihatkan kurva model peningkatan yang dinormalisasi versus perawatan acak.

Sumbu x mewakili rasio populasi yang dipilih untuk perawatan. Nilai 0 menunjukkan tidak ada kelompok perawatan - tidak ada yang terpapar, atau ditawarkan, perawatan. Nilai 1 menunjukkan kelompok perawatan penuh - semua orang terpapar, atau ditawarkan, perawatan. Sumbu y menunjukkan ukuran peningkatan. Tujuannya adalah untuk menemukan ukuran kelompok perawatan, atau persentase populasi yang akan ditawarkan atau diekspos ke perawatan (misalnya, iklan). Pendekatan ini mengoptimalkan pemilihan target, untuk mengoptimalkan hasilnya.

Pertama, beri peringkat urutan DataFrame pengujian berdasarkan peningkatan yang diprediksi. Peningkatan yang diprediksi adalah perbedaan antara hasil perawatan yang diprediksi dan hasil kontrol yang diprediksi.

# Compute the percentage rank of the predicted uplift values in descending order, and display the top twenty rows
test_ranked_df = test_pred_df.withColumn("percent_rank", F.percent_rank().over(Window.orderBy(F.desc("pred_uplift"))))

display(test_ranked_df.limit(20))

Selanjutnya, hitung persentase kumulatif kunjungan dalam kelompok perawatan dan kontrol.

# Calculate the number of control and treatment samples
C = test_ranked_df.where(f"{TREATMENT_COLUMN} == 0").count()
T = test_ranked_df.where(f"{TREATMENT_COLUMN} != 0").count()

# Add columns to the DataFrame to calculate the control and treatment cumulative sum
test_ranked_df = (
    test_ranked_df.withColumn(
        "control_label",
        F.when(F.col(TREATMENT_COLUMN) == 0, F.col(LABEL_COLUMN)).otherwise(0),
    )
    .withColumn(
        "treatment_label",
        F.when(F.col(TREATMENT_COLUMN) != 0, F.col(LABEL_COLUMN)).otherwise(0),
    )
    .withColumn(
        "control_cumsum",
        F.sum("control_label").over(Window.orderBy("percent_rank")) / C,
    )
    .withColumn(
        "treatment_cumsum",
        F.sum("treatment_label").over(Window.orderBy("percent_rank")) / T,
    )
)

# Display the first 20 rows of the dataframe
display(test_ranked_df.limit(20))

Akhirnya, pada setiap persentase, hitung peningkatan grup sebagai perbedaan antara persentase kumulatif kunjungan antara kelompok perawatan dan kontrol.

test_ranked_df = test_ranked_df.withColumn("group_uplift", F.col("treatment_cumsum") - F.col("control_cumsum")).cache()
display(test_ranked_df.limit(20))

Sekarang, plot kurva uplift untuk prediksi himpunan data pengujian. Anda harus mengonversi PySpark DataFrame ke Pandas DataFrame sebelum merencanakan.

def uplift_plot(uplift_df):
    """
    Plot the uplift curve
    """
    gain_x = uplift_df.percent_rank
    gain_y = uplift_df.group_uplift
    # Plot the data
    fig = plt.figure(figsize=(10, 6))
    mpl.rcParams["font.size"] = 8

    ax = plt.plot(gain_x, gain_y, color="#2077B4", label="Normalized Uplift Model")

    plt.plot(
        [0, gain_x.max()],
        [0, gain_y.max()],
        "--",
        color="tab:orange",
        label="Random Treatment",
    )
    plt.legend()
    plt.xlabel("Porportion Targeted")
    plt.ylabel("Uplift")
    plt.grid()

    return fig, ax


test_ranked_pd_df = test_ranked_df.select(["pred_uplift", "percent_rank", "group_uplift"]).toPandas()
fig, ax = uplift_plot(test_ranked_pd_df)

mlflow.log_figure(fig, "UpliftCurve.png")

Cuplikan layar bagan yang memperlihatkan kurva model peningkatan yang dinormalisasi versus perawatan acak.

Analisis dan kurva peningkatan keduanya menunjukkan bahwa populasi 20% teratas, seperti yang diberi peringkat oleh prediksi, akan memiliki keuntungan besar jika mereka menerima perawatan. Ini berarti bahwa 20% populasi teratas mewakili grup yang dapat dibujuk. Oleh karena itu, Anda kemudian dapat mengatur skor cutoff untuk ukuran grup perawatan yang diinginkan pada 20%, untuk mengidentifikasi pelanggan pilihan target untuk dampak terbesar.

cutoff_percentage = 0.2
cutoff_score = test_ranked_pd_df.iloc[int(len(test_ranked_pd_df) * cutoff_percentage)][
    "pred_uplift"
]

print("Uplift scores that exceed {:.4f} map to Persuadables.".format(cutoff_score))
mlflow.log_metrics(
    {"cutoff_score": cutoff_score, "cutoff_percentage": cutoff_percentage}
)

Langkah 4: Daftarkan Model ML akhir

Anda menggunakan MLflow untuk melacak dan mencatat semua eksperimen untuk grup perawatan dan kontrol. Pelacakan dan pengelogan ini mencakup parameter, metrik, dan model yang sesuai. Informasi ini dicatat di bawah nama eksperimen, di ruang kerja, untuk digunakan nanti.

# Register the model
treatment_model_uri = "runs:/{}/model".format(treatment_run_id)
mlflow.register_model(treatment_model_uri, f"{EXPERIMENT_NAME}-treatmentmodel")

control_model_uri = "runs:/{}/model".format(control_run_id)
mlflow.register_model(control_model_uri, f"{EXPERIMENT_NAME}-controlmodel")

mlflow.end_run()

Untuk melihat eksperimen Anda:

  1. Di panel kiri, pilih ruang kerja Anda.
  2. Temukan dan pilih nama eksperimen, dalam hal ini aisample-upliftmodelling.

Cuplikan layar yang memperlihatkan hasil eksperimen pemodelan peningkatan aisample.

Langkah 5: Simpan hasil prediksi

Microsoft Fabric menawarkan PREDICT - fungsi yang dapat diskalakan yang mendukung penilaian batch di mesin komputasi apa pun. Ini memungkinkan pelanggan untuk mengoprasionalkan model pembelajaran mesin. Pengguna dapat membuat prediksi batch langsung dari buku catatan atau halaman item untuk model tertentu. Kunjungi sumber daya ini untuk mempelajari selengkapnya tentang PREDICT, dan untuk mempelajari cara menggunakan PREDICT di Microsoft Fabric.

# Load the model back
loaded_treatmentmodel = mlflow.spark.load_model(treatment_model_uri, dfs_tmpdir="Files/spark")
loaded_controlmodel = mlflow.spark.load_model(control_model_uri, dfs_tmpdir="Files/spark")

# Make predictions
batch_predictions_treatment = loaded_treatmentmodel.transform(test_df)
batch_predictions_control = loaded_controlmodel.transform(test_df)
batch_predictions_treatment.show(5)
# Save the predictions in the lakehouse
batch_predictions_treatment.write.format("delta").mode("overwrite").save(
    f"{DATA_FOLDER}/predictions/batch_predictions_treatment"
)
batch_predictions_control.write.format("delta").mode("overwrite").save(
    f"{DATA_FOLDER}/predictions/batch_predictions_control"
)
# Determine the entire runtime
print(f"Full run cost {int(time.time() - ts)} seconds.")