次の方法で共有


時系列予測モデルをトレーニングし評価する

このノートブックで、季節周期を持つ時系列データを予測するプログラムを作成します。 NYC オープン データ ポータルで NYC 税務局が発行した 2003 年から 2015 年までの日付を含む NYC 不動産販売データセットを使用します。

前提条件

ノートブックで作業を進める

ノートブックの次の 2 つの方法のいずれかに従うことができます。

  • Data Science エクスペリエンスの組み込みのノートブックを開いて実行します。
  • GitHub から Synapse Data Science 環境にノートブックをアップロードします。

ビルトインのノートブックを開きます

このチュートリアルには、サンプルの時系列ノートブックが付属しています。

チュートリアルの組み込みのサンプル ノートブックを Synapse Data Science エクスペリエンスで開くには、次のようにします。

  1. Synapse Data Science のホーム ページに移動します。

  2. [サンプルの使用] を選択してください。

  3. 対応するサンプルを選択してください。

    • サンプルが Python チュートリアル用の場合は、既定の [エンド ツー エンド ワークフロー (Python)] タブから。
    • サンプルが R チュートリアル用の場合は、[エンド ツー エンド ワークフロー (R)] タブから。
    • サンプルがクイック チュートリアル用の場合は、[クイック チュートリアル] タブから。
  4. コードの実行を開始する前に、[レイクハウスをノートブックにアタッチします]

GitHub からノートブックをインポートする

AIsample - Time Series Forecasting.ipynb は、このチュートリアルに付属するノートブックです。

このチュートリアルが付属するノートブックを開く場合は、「データ サイエンス チュートリアル用にシステムを準備する」内の指示に従い、ノートブックを、お使いのワークスペースにインポートします。

このページからコードをコピーして貼り付ける場合は、[新しいノートブックを作成する] ことができます。

コードの実行を開始する前に、必ずレイクハウスをノートブックにアタッチしてください。

手順 1: カスタム ライブラリをインストールする

機械学習モデルを開発するとき、またはアドホック データ分析を処理する場合は、カスタム ライブラリ (たとえば、このノートブックにおける Apache Spark 用の prophet) をすぐにインストールする必要がある場合があります。 これを行うには、2 つの方法があります。

  1. インライン インストール機能 (%pip%conda など) を使用すると、新しいライブラリの使用をすばやく開始できます。 この方法では、カスタム ライブラリはワークスペースではなく、現在のノートブックにのみインストールされます。
# Use pip to install libraries
%pip install <library name>

# Use conda to install libraries
%conda install <library name>
  1. または、Fabric 環境を作成し、パブリック ソースからライブラリをインストールするか、あるいはカスタム ライブラリをそこにアップロードすると、ワークスペース管理者はその環境をワークスペースの既定としてアタッチできます。 その後、環境内のすべてのライブラリが、ワークスペース内のすべてのノートブックと Spark ジョブ定義で使用できるようになります。 環境の詳細については、「Microsoft Fabric で環境を作成、構成、および使用する」を参照してください。

このノートブックでは、%pip install を使用して prophet ライブラリをインストールします。 PySpark カーネルは、%pip install の後に再起動します。 つまり、他のセルを実行する前にライブラリをインストールする必要があります。

# Use pip to install Prophet
%pip install prophet

手順 2: データを読み込む

データセット

このノートブックでは、NYC Property Sales データのデータセットを使用します。 NYC 財務省によって、NYC オープン データ ポータルで発行された 2003 年から 2015 年までのデータが カバーされています。

このデータセットには、13 年間ニューヨーク市の不動産市場で販売されたすべての建物の販売記録が含まれています。 データセット内の列の定義については、「不動産販売ファイルの用語集」を参照してください。

自治区 近隣 building_class_category tax_class ブロック 区画 eastment building_class_at_present address apartment_number zip_code residential_units commercial_units total_units land_square_feet gross_square_feet year_built tax_class_at_time_of_sale building_class_at_time_of_sale sale_price sale_date
Manhattan ALPHABET CITY 07 RENTALS - WALKUP APARTMENTS 0.0 384.0 17.0 C4 225 EAST 2ND STREET 10009.0 10.0 0.0 10.0 2145.0 6670.0 1900.0 2.0 C4 275000.0 2007-06-19
Manhattan ALPHABET CITY 07 RENTALS - WALKUP APARTMENTS 2.0 405.0 12.0 C7 508 EAST 12TH STREET 10009.0 28.0 2.0 30.0 3872.0 15428.0 1930.0 2.0 C7 7794005.0 2007-05-21

目標は、履歴データに基づいて月間総売上を予測するモデルを構築することです。 このためには、Facebookによって開発されたオープンソース予測ライブラリである、Prophet を使用します。 Prophet は加法モデルに基づいており、非線形傾向は毎日、毎週、毎年の季節性、休日の影響に適合しています。 Prophet は、強い季節的影響や複数の季節の履歴データを持つ時系列データセットに最適です。 さらに、Prophet は不足しているデータとデータの外れ値を堅牢に処理します。

Prophet は、次の 3 つのコンポーネントで構成される分解可能な時系列モデルを使用します。

  • トレンド: Prophet は自動変化点を選択して、区分ごとの一定の成長率を推測します。
  • 季節性: 既定では、Prophet はフーリエ級数を使用して週単位と年単位の季節性をモデル化します。
  • 祝日: Prophetは過去と将来のすべての祝日を必要とします。 将来繰り返されない休日の場合、Prophet の予測には含まれません。

このノートブックは月単位でデータを集計するため、休日は無視されます。

Prophet のモデリング手法の詳細については、公式の論文を参照してください。

データセットをダウンロードし、レイクハウスにアップロードする

データ ソースは 15 個の .csv ファイルで構成されます。 これらのファイルには、ニューヨークにある 5 つの自治区における 2003 年から 2015 年までの不動産販売記録が含まれています。 便宜上、nyc_property_sales.tar ファイルにこれらの .csv ファイルをすべて保持し、1 つのファイルに圧縮されています。 この .tar ファイルは、パブリックで利用可能な BLOB ストレージによってホストされます。

ヒント

このコード セルに示されているパラメーターを使用すると、このノートブックをさまざまなデータセットに簡単に適用できます。

URL = "https://synapseaisolutionsa.blob.core.windows.net/public/NYC_Property_Sales_Dataset/"
TAR_FILE_NAME = "nyc_property_sales.tar"
DATA_FOLDER = "Files/NYC_Property_Sales_Dataset"
TAR_FILE_PATH = f"/lakehouse/default/{DATA_FOLDER}/tar/"
CSV_FILE_PATH = f"/lakehouse/default/{DATA_FOLDER}/csv/"

EXPERIMENT_NAME = "aisample-timeseries" # MLflow experiment name

このコードは、一般公開されているバージョンのデータセットをダウンロードし、Fabric Lakehouse に格納します。

重要

ノートブックを実行する前に、必ずレイクハウスをノートブックに追加してください。 間違ってインストールすると、エラーが発生します。

import os

if not os.path.exists("/lakehouse/default"):
    # Add a lakehouse if the notebook has no default lakehouse
    # A new notebook will not link to any lakehouse by default
    raise FileNotFoundError(
        "Default lakehouse not found, please add a lakehouse for the notebook."
    )
else:
    # Verify whether or not the required files are already in the lakehouse, and if not, download and unzip
    if not os.path.exists(f"{TAR_FILE_PATH}{TAR_FILE_NAME}"):
        os.makedirs(TAR_FILE_PATH, exist_ok=True)
        os.system(f"wget {URL}{TAR_FILE_NAME} -O {TAR_FILE_PATH}{TAR_FILE_NAME}")

    os.makedirs(CSV_FILE_PATH, exist_ok=True)
    os.system(f"tar -zxvf {TAR_FILE_PATH}{TAR_FILE_NAME} -C {CSV_FILE_PATH}")

このノートブックの実行時間の記録を開始します。

# Record the notebook running time
import time

ts = time.time()

MLflow 実験追跡を設定する

MLflow ログ機能を拡張するために、機械学習モデルの自動ログ機能は自動的に入力パラメーターと出力メトリックの値をトレーニング中にキャプチャします。 その後、この情報はワークスペースに記録され、MLflow API またはワークスペース内の対応する実験によりアクセス、視覚化が可能になります。 自動ログ記録の詳細については、こちらのリソースを参照してください。

# Set up the MLflow experiment
import mlflow

mlflow.set_experiment(EXPERIMENT_NAME)
mlflow.autolog(disable=True)  # Disable MLflow autologging

Note

ノートブック セッションで Microsoft Fabric の自動ログ記録を無効にする場合は、mlflow.autolog() を呼び出して disable=True を設定します。

レイクハウスから日付の生データを読み取る

df = (
    spark.read.format("csv")
    .option("header", "true")
    .load("Files/NYC_Property_Sales_Dataset/csv")
)

手順 3: 探索的データ分析の開始

データセットを確認するには、より深く理解するためにデータのサブセットを手動で調べることもできます。 display 関数を使用して DataFrame を出力できます。 また、チャート ビューを表示すると、データセットのサブセットを簡単に視覚化できます。

display(df)

データセットを手動で確認することで、いくつかの初期の知見が得られます。

  • 0.00 ドルの販売価格のインスタンス。 用語集では、これは現金を考慮しない所有権の譲渡を暗示します。 つまり、トランザクションにキャッシュ フローはありません。 データセットから 0.00 ドルの sales_price 値を持つ売上を削除する必要があります。

  • データセットには、さまざまなビルド クラスが含まれます。 ただし、このノートブックは、用語集ではタイプ "A" としてマークされている住宅の建物に焦点を当てます。 データセットをフィルター処理して、住宅用建物のみを含める必要があります。 これを行うには、building_class_at_time_of_sale または building_class_at_present 列を含めます。 building_class_at_time_of_sale データのみを含める必要があります。

  • データセットには、total_units 値が 0、または gross_square_feet 値が 0 のインスタンスが含まれます。 total_units または gross_square_units 値が 0 であるすべてのインスタンスを削除する必要があります。

  • 一部の列 (たとえば、apartment_numbertax_classbuild_class_at_present) には、欠損値または NULL 値があります。 不足しているデータに事務的なエラーまたは存在しないデータが含まれているとします。 分析はこれらの欠損値に依存しないため、無視できます。

  • sale_price 列は文字列として格納され、先頭に "$" 文字が付加されます。 分析を続行するために、この列を数値として表します。 sale_price 列は整数としてキャストする必要があります。

型変換とフィルター処理

特定された問題の一部を解決するには、必要なライブラリをインポートします。

# Import libraries
import pyspark.sql.functions as F
from pyspark.sql.types import *

売上データを文字列から整数にキャストする

正規表現を使用して、文字列の数値部分をドル記号から区切り (たとえば、文字列 "$300,000" で "$" と "300,000" を分割)、数値部分を整数としてキャストします。

次に、次の条件をすべて満たすインスタンスのみを含むようにデータをフィルター処理します。

  1. sales_price が 0 より大きい
  2. total_units が 0 より大きい
  3. gross_square_feet が 0 より大きい
  4. building_class_at_time_of_sale はタイプ A です。
df = df.withColumn(
    "sale_price", F.regexp_replace("sale_price", "[$,]", "").cast(IntegerType())
)
df = df.select("*").where(
    'sale_price > 0 and total_units > 0 and gross_square_feet > 0 and building_class_at_time_of_sale like "A%"'
)

月単位での集計

データ リソースはプロパティの売上を日ごとに追跡しますが、この方法では、このノートブックでは細かすぎます。 代わりに、月単位でデータを集計します。

まず、月と年のデータのみを表示するように日付値を変更します。 日付の値には依然として、年データが含まれます。 2005 年 12 月と 2006 年 12 月などは、引き続き区別できます。

さらに、分析に関連する列のみを保持します。 これには、sales_pricetotal_unitsgross_square_feetsales_date が含まれます。 また、sales_datemonth にリネームする必要があります。

monthly_sale_df = df.select(
    "sale_price",
    "total_units",
    "gross_square_feet",
    F.date_format("sale_date", "yyyy-MM").alias("month"),
)
display(monthly_sale_df)

sale_pricetotal_unitsgross_square_feet 値を月別に集計します。 次に、データを month でグループ化し、各グループ内のすべての値を合計します。

summary_df = (
    monthly_sale_df.groupBy("month")
    .agg(
        F.sum("sale_price").alias("total_sales"),
        F.sum("total_units").alias("units"),
        F.sum("gross_square_feet").alias("square_feet"),
    )
    .orderBy("month")
)

display(summary_df)

Pyspark から Pandas への換算

Pyspark DataFrames は、大規模なデータセットを適切に処理します。 ただし、データ集計のため、DataFrame のサイズは小さくなります。 これは、pandas DataFrames を使用できるようになったことを示唆しています。

このコードは、pyspark DataFrame から pandas DataFrame にデータセットをキャストします。

import pandas as pd

df_pandas = summary_df.toPandas()
display(df_pandas)

視覚化

ニューヨーク市の不動産取引の傾向を調べて、データをより深く理解することができます。 これにより、潜在的なパターンと季節性の傾向に関する分析情報が得られます。 このリソースでの Microsoft Fabric データ可視化の詳細について説明します。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

f, (ax1, ax2) = plt.subplots(2, 1, figsize=(35, 10))
plt.sca(ax1)
plt.xticks(np.arange(0, 15 * 12, step=12))
plt.ticklabel_format(style="plain", axis="y")
sns.lineplot(x="month", y="total_sales", data=df_pandas)
plt.ylabel("Total Sales")
plt.xlabel("Time")
plt.title("Total Property Sales by Month")

plt.sca(ax2)
plt.xticks(np.arange(0, 15 * 12, step=12))
plt.ticklabel_format(style="plain", axis="y")
sns.lineplot(x="month", y="square_feet", data=df_pandas)
plt.ylabel("Total Square Feet")
plt.xlabel("Time")
plt.title("Total Property Square Feet Sold by Month")
plt.show()

探索的データ分析からの観測の概要

  • このデータは、毎年の周期での、明確な定期的パターンを示しています。つまり、データには年単位の季節性があります。
  • 夏の月は、冬の月に比べて販売量が多いようです。
  • 売上が高い年と売上の低い年の比較では、高い販売年において、高い販売月との低い販売月の収益差は常に、低い販売年における高い販売月と低い販売月の収益差より大きいです。

たとえば、2004 年の売上が最も高い月と最も低い販売月の収益差は次のようになります。

$900,000,000 - $500,000,000 = $400,000,000

2011 年の場合、その収益差異の計算は次のようになります。

$400,000,000 - $300,000,000 = $100,000,000

これは後に、多重度加法間で季節性の影響を決定する必要があるときに重要になります。

手順 4: モデルのトレーニングと追跡

モデル フィッティング

Prophet の入力は常に 2 列の DataFrame です。 入力列の 1 つは ds という時間列で、もう 1 つの入力列は y という値列です。 時間列には、日付、時刻、または datetime データ形式 (たとえば YYYY_MM) を指定する必要があります。 このデータセットは、その条件を満たしています。 値列は数値データ形式である必要があります。

モデル フィッティングについては、必要なことは、時間列の名前を ds、値列を y にリネームし、データを Prophet に渡すことのみです。 詳細については、Prophet Python API ドキュメントを参照してください。

df_pandas["ds"] = pd.to_datetime(df_pandas["month"])
df_pandas["y"] = df_pandas["total_sales"]

Prophetは、scikit-learn 規約に従います。 まず、Prophet の新しいインスタンスを作成し、特定のパラメーター (たとえば seasonality_mode) を設定し、そのインスタンスをデータセットに合わせて調整します。

  • 定数の加法係数は Prophet の既定の季節性効果ですが、季節性効果パラメーターには "乗法" 季節性を使用する必要があります。 前のセクションの分析では、季節性の振幅の変化により、単純な加法季節性がデータにまったく適合しないことを示していました。

  • データが月単位で集計されたため、weekly_seasonality パラメーターを off に設定します。 これにより、週単位のデータは使用できなくなります。

  • Markov Chain Monte Carlo (MCMC) メソッドを使用して、季節性の不確実性の推定をキャプチャします。 既定では、Prophet から傾向と観測ノイズに関する不確実性の推定を得られますが、季節性に関しては得られません。 MCMC はより多くの処理時間を必要としますが、これにより、アルゴリズムから季節性と傾向と観測ノイズに関する不確実性の推定を得られます。 詳細については、Prophet の不確定間隔に関するドキュメントを参照してください。

  • changepoint_prior_scale パラメーターを使用して、自動変更ポイント検出の感度を調整します。 Prophet アルゴリズムは、軌道が突然変化するデータ内のインスタンスを自動的に見つけようとします。 正しい値を見つけるのが難しくなる可能性があります。 これを解決するには、さまざまな値を試してから、最適なパフォーマンスのモデルを選択します。 詳しくは、Prophet トレンド チェンジポイントのドキュメント をご覧ください。

from prophet import Prophet

def fit_model(dataframe, seasonality_mode, weekly_seasonality, chpt_prior, mcmc_samples):
    m = Prophet(
        seasonality_mode=seasonality_mode,
        weekly_seasonality=weekly_seasonality,
        changepoint_prior_scale=chpt_prior,
        mcmc_samples=mcmc_samples,
    )
    m.fit(dataframe)
    return m

クロス検証

Prophet には、クロス検証ツールが組み込まれています。 このツールでは、予測エラーを見積もり、最適なパフォーマンスでモデルを見つけることができます。

クロス検証手法では、モデルの効率を検証できます。 この手法では、データセットのサブセットに対してモデルをトレーニングし、データセットの以前には見えなかったサブセットに対してテストを実行します。 この手法では、統計モデルが独立したデータセットにどの程度一般化されるかをチェックできます。

クロス検証の場合は、トレーニング データセットには含まれなかったデータセットの特定のサンプルを予約します。 次に、デプロイの前に、そのサンプルでトレーニング済みのモデルをテストします。 ただし、モデルで 2005 年 1 月から 2005 年 3 月のデータを確認し、2005 年 2 月の予測を試みると、モデルではデータ傾向の結果を把握できる場合があるため、モデルがチートを行う可能性があります。そのため、このアプローチは時系列にはうまくいきません。 実際の用途では、未知の地域として将来を予測することが目的です。

これを処理し、テストを信頼できるものとするために、日付に基づいてデータセットを分割します。 トレーニングには、特定の日付までのデータセット (たとえば、最初の 11 年間のデータ) を使用してから、残りの未知のデータを予測に使用します。

このシナリオでは、11 年間のトレーニング データから始めて、1 年間の期間を使用して毎月の予測を行います。 具体的には、トレーニング データには、2003 年から 2013 年までのすべてが含まれています。 最初の実行では、2014 年 1 月から 2015 年 1 月までの予測が処理します。 次の実行では、2014 年 2 月から 2015 年 2 月までの予測などを処理します。

トレーニング済みの 3 つのモデルごとにこのプロセスを繰り返して、最適なパフォーマンスを発揮するモデルを確認します。 次に、これらの予測を実際の値と比較して、最適なモデルの予測品質を確立します。

from prophet.diagnostics import cross_validation
from prophet.diagnostics import performance_metrics

def evaluation(m):
    df_cv = cross_validation(m, initial="4017 days", period="30 days", horizon="365 days")
    df_p = performance_metrics(df_cv, monthly=True)
    future = m.make_future_dataframe(periods=12, freq="M")
    forecast = m.predict(future)
    return df_p, future, forecast

MLflow を使用してモデルをログする

モデルをログに記録してパラメーターを追跡し、後で使用できるようにモデルを保存します。 関連するすべてのモデル情報は、ワークスペースで実験名に記録されます。 モデル、パラメーター、メトリック、および MLflow 自動ログ記録項目が、1 回の MLflow 実行時に保存されます。

# Setup MLflow
from mlflow.models.signature import infer_signature

実験を実施する

機械学習の実験は、関連するすべての機械学習の実行を編成および制御するための主要な単位として機能します。 "実行" はモデル コードの 1 回の実行に対応します。 機械学習の実験追跡とは、さまざまな実験とそのコンポーネントをすべて管理することを指します。 これには、パラメーター、メトリック、モデル、その他の成果物が含まれ、特定の機械学習実験の必要なコンポーネントを整理するのに役立ちます。 機械学習実験の追跡により、保存された実験を使用して過去の結果を簡単に複製することもできます。 Microsoft Fabric での機械学習の実験の詳細について説明します。 含める手順 (このノートブックでの Prophet モデルのフィッティングと評価など) を決定したら、実験を実行できます。

model_name = f"{EXPERIMENT_NAME}-prophet"

models = []
df_metrics = []
forecasts = []
seasonality_mode = "multiplicative"
weekly_seasonality = False
changepoint_priors = [0.01, 0.05, 0.1]
mcmc_samples = 100

for chpt_prior in changepoint_priors:
    with mlflow.start_run(run_name=f"prophet_changepoint_{chpt_prior}"):
        # init model and fit
        m = fit_model(df_pandas, seasonality_mode, weekly_seasonality, chpt_prior, mcmc_samples)
        models.append(m)
        # Validation
        df_p, future, forecast = evaluation(m)
        df_metrics.append(df_p)
        forecasts.append(forecast)
        # Log model and parameters with MLflow
        mlflow.prophet.log_model(
            m,
            model_name,
            registered_model_name=model_name,
            signature=infer_signature(future, forecast),
        )
        mlflow.log_params(
            {
                "seasonality_mode": seasonality_mode,
                "mcmc_samples": mcmc_samples,
                "weekly_seasonality": weekly_seasonality,
                "changepoint_prior": chpt_prior,
            }
        )
        metrics = df_p.mean().to_dict()
        metrics.pop("horizon")
        mlflow.log_metrics(metrics)

[プロパティ] パネルのスクリーンショット。

Prophet を使用してモデルを視覚化する

Prophet には、モデルのフィッティング結果を表示するビルトイン視覚化機能があります。

黒い点は、モデルのトレーニングに使用されるデータ ポイントを表します。 青い線は予測を表し、水色の領域は不確実性区間を示します。 異なる changepoint_prior_scale 値を持つ 3 つのモデルを構築しました。 このコード ブロックの結果には、これら 3 つのモデルの予測が表示されます。

for idx, pack in enumerate(zip(models, forecasts)):
    m, forecast = pack
    fig = m.plot(forecast)
    fig.suptitle(f"changepoint = {changepoint_priors[idx]}")

最初のグラフの最小 changepoint_prior_scale 値は、傾向の変化のアンダーフィットにつながります。 3 番目のグラフでの最大 changepoint_prior_scale 値をでは、オーバーフィットが発生する可能性があります。 2 番目のグラフが最適な選択肢のようです。 これは、2 番目のモデルが最も適していることを示唆しています。

Prophet では、基になる傾向と季節性を簡単に視覚化することもできます。 このコード ブロックの結果には、2 番目のモデルの視覚化が表示されます。

BEST_MODEL_INDEX = 1  # Set the best model index according to the previous results
fig2 = models[BEST_MODEL_INDEX].plot_components(forecast)

価格データの年単位の傾向グラフのスクリーンショット。

これらのグラフでは、薄い青色のシェーディングが不確定性を反映しています。 上のグラフは、強い、長い期間の揺れる傾向を示しています。 数年周期で、販売量は増加と減少を繰り返しています。 下のグラフは、売上が 2 月と 9 月にピークに達し、その月の年の最大値に達する傾向があることを示しています。 これらの月の直後である 3 月と 10 月は、年の最小値に分類されます。

次に例を示すさまざまなメトリックを使用して、モデルのパフォーマンスを評価します。

  • 平均二乗誤差 (MSE)
  • 二乗平均平方根誤差 (RMSE)
  • 平均絶対誤差 (MAE)
  • 平均絶対パーセント誤差 (MAPE)
  • 絶対パーセント誤差の中央値 (MDAPE)
  • 対称平均絶対誤差率 (sMAPE)

yhat_loweryhat_upper の見積もりを使用してカバレッジを評価します。 将来 1 年を 12 回予測する期間の違いに注意してください。

display(df_metrics[BEST_MODEL_INDEX])

MAPE メトリックでは、この予測モデルでは、将来 1 か月間の予測には通常、約 8% の誤差が含まれます。 ただし、1 年後の予測では、誤差は約 10% に増加します。

手順 5: モデルをスコア付けし、予測結果を保存する

次に、モデルにスコアを付け、予測結果を保存します。

Predict Transformer を使用して予測を行う

これで、モデルを読み込み、それを使用して予測を行うことができます。 スケーラブルな Microsoft Fabric 関数である PREDICT を使用して機械学習モデルを操作できます。これは、任意のコンピューティング エンジンでのバッチ スコアリングをサポートするものです。 このリソースでPREDICT および、Microsoft Fabric でそれを使用する方法の詳細について説明しています。

from synapse.ml.predict import MLFlowTransformer

spark.conf.set("spark.synapse.ml.predict.enabled", "true")

model = MLFlowTransformer(
    inputCols=future.columns.values,
    outputCol="prediction",
    modelName=f"{EXPERIMENT_NAME}-prophet",
    modelVersion=BEST_MODEL_INDEX,
)

test_spark = spark.createDataFrame(data=future, schema=future.columns.to_list())

batch_predictions = model.transform(test_spark)

display(batch_predictions)
# Code for saving predictions into lakehouse
batch_predictions.write.format("delta").mode("overwrite").save(
    f"{DATA_FOLDER}/predictions/batch_predictions"
)
# Determine the entire runtime
print(f"Full run cost {int(time.time() - ts)} seconds.")