ai_forecast function

Applies to: check marked yes Databricks SQL

Important

This functionality is in Public Preview. Reach out to your Databricks account team to participate in the preview.

The ai_forecast() function is a table-valued function designed to extrapolate time series data into the future. In its most general form, ai_forecast() accepts grouped, multivariate, mixed-granularity data, and forecasts that data up to some horizon in the future.

Requirements

  • Only available on Pro and Serverless SQL warehouses running Databricks SQL version 2024.35 and above.
  • Check the Databricks SQL pricing page.

Syntax


ai_forecast(
  observed TABLE,
  horizon DATE | TIMESTAMP | STRING,
  time_col STRING,
  value_col STRING | ARRAY<STRING>,
  group_col STRING | ARRAY<STRING> | NULL DEFAULT NULL,
  prediction_interval_width DOUBLE DEFAULT 0.95,
  frequency STRING DEFAULT 'auto',
  seed INTEGER | NULL DEFAULT NULL,
  parameters STRING DEFAULT '{}'
)

Arguments

  • observed is the table-valued input that is used as training data for the forecasting procedure.
  • horizon is a timestamp-castable quantity representing the right-exclusive end time of the forecasting results. Within a group (see group_col) forecast results span the time between the last observation and horizon. If horizon is less than the last observation time, then no results are generated.
  • time_col is a string referencing the “time column” in observed. The column referenced by time_col should be a DATE or a TIMESTAMP.
  • value_col is a string or an array of strings referencing value columns in observed. The columns referenced by this argument should be castable to DOUBLE.
  • group_col (optional) is a string or an array of strings representing the group columns in observed. If specified, group columns are used as partitioning criteria, and forecasts are generated for each group independently. If unspecified, the full input data is treated as a single group.
  • prediction_interval_width (optional) is a value between 0 and 1 representing the width of the prediction interval. Future values have a prediction_interval_width % probability of falling between {v}_upper and {v}_lower.
  • frequency (optional) is a time unit or pandas offset alias string specifying the time granularity of the forecast results. If unspecified, the forecast granularity is automatically inferred for each group independently. If a frequency value is specified, it is applied equally to all groups.
    • The inferred frequency within a group is the mode of the most recent observations. This is a convenience operation that is not tunable by the user.
    • As an example, a time series with 99 “mondays” and 1 “tuesday” results in the “week” being the inferred frequency.
  • seed (optional) is a number used to initialize any pseudorandom number generators used in the forecasting procedure.
  • parameters (optional) is a string-encoded JSON or the name of a column identifier that represents the parameterization of the forecasting procedure. Any combination of parameters can be specified in any order, for example, {“weekly_order”: 10, “global_cap”: 1000}. Any unspecified parameters are automatically determined based on the attributes of the training data. The following parameters are supported:
    • global_cap and global_floor can be used together or independently to define the possible domain of the metric values. {“global_floor”: 0}, for example, can be used to constrain a metric like cost to always be positive. These apply globally to the training data and the forecasted data, and can not be used to provide tight constraints on the forecasted values only.
    • daily_order and weekly_order set the fourier order of the daily and weekly seasonality components.

Returns

A STRUCT where each field corresponds to an entity type specified in labels. Each field contains a string representing the extracted entity. If more than one candidate for any entity type is found, only one is returned.

If content is NULL, the result is NULL.

The input table must contain one time column, 1 or more value columns and optional group column(s). A parameters column can also be included so that model parameterizations can be read from the table contents. See the parameters in the Arguments section for more information. All other columns from the input table are ignored by AI_FORECAST.

The output table propagates the time and group columns through to the output with their types unchanged For example, if the time column is DATE, then the output is also DATE. For each value column there are three output columns with the pattern {v}_forecast, {v}_upper, and {v}_lower. Regardless of the input value types, the forecasted value columns are always type DOUBLE. The output table contains future values only, spanning the range of time between the end of the observed data until horizon.

Input table Output table
ts: TIMESTAMP
val: DOUBLE
ts: TIMESTAMP
val_forecast: DOUBLE
val_upper: DOUBLE
val_lower: DOUBLE
ds: DATE
val BIGINT
ds: DATE
val_forecast: DOUBLE
val_upper: DOUBLE
val_lower: DOUBLE
ts: TIMESTAMP
dim1: STRING
dollars: DECIMAL(10, 2)
ts: TIMESTAMP
dim1: STRING
dollars_forecast: DOUBLE
dollars_upper: DOUBLE
dollars_lower: DOUBLE
ts: TIMESTAMP
dim1: STRING
dim2: BIGINT
dollars: DECIMAL(10, 2)
users: BIGINT
ts: TIMESTAMP
dim1: STRING
dim2: BIGINT
dollars_forecast: DOUBLE
dollars_upper: DOUBLE
dollars_lower: DOUBLE
users_forecast: DOUBLE
users_upper: DOUBLE
users_lower: DOUBLE

Examples

The following example forecasts until a specified date:


WITH
aggregated AS (
  SELECT
    DATE(tpep_pickup_datetime) AS ds,
    SUM(fare_amount) AS revenue
  FROM
    samples.nyctaxi.trips
  GROUP BY
    1
)
SELECT * FROM AI_FORECAST(
  TABLE(aggregated),
  horizon => '2016-03-31',
  time_col => 'ds',
  value_col => 'revenue'
)

The following is a more complex example:

It is very common for tables to not materialize 0s or empty entries. If the values of the missing entries can be inferred, for example 0 and 100%, then these values should be coalesced prior to calling the forecast function. If the values are truly missing or unknown, then they can be left empty.

For very sparse data, like more than half of your data containing missing entries, it is best practice to provide a frequency value explicitly. Two entries 35 days apart are inferred as a time series with granularity 35D, rather than a daily series with 34 missing entries.


WITH
aggregated AS (
  SELECT
    DATE(tpep_pickup_datetime) AS ds,
    dropoff_zip,
    SUM(fare_amount) AS revenue,
    COUNT(*) AS n_trips
  FROM
    samples.nyctaxi.trips
  GROUP BY
    1, 2
),
spine AS (
  SELECT all_dates.ds, all_zipcodes.dropoff_zip
  FROM (SELECT DISTINCT ds FROM aggregated) all_dates
  CROSS JOIN (SELECT DISTINCT dropoff_zip FROM aggregated) all_zipcodes
)
SELECT * FROM AI_FORECAST(
  TABLE(
    SELECT
      spine.*,
      COALESCE(aggregated.revenue, 0) AS revenue,
      COALESCE(aggregated.n_trips, 0) AS n_trips
    FROM spine LEFT JOIN aggregated USING (ds, dropoff_zip)
  ),
  horizon => '2016-03-31',
  time_col => 'ds',
  value_col => ARRAY('revenue', 'n_trips'),
  group_col => 'dropoff_zip',
  prediction_interval_width => 0.9,
  parameters => '{"global_floor": 0}'
)

The parameters argument can also be a string matching a column identifier. This can be used to apply different parameterizations to different dimensions. Storing parameter JSONs in a table is also a convenient way to reuse previously-found parameterizations on new data.

WITH past AS (
  SELECT
    CASE
      WHEN fare_amount < 30 THEN 'Under $30'
      ELSE '$30 or more'
    END AS revenue_bucket,
    CASE
      WHEN fare_amount < 30 THEN '{"daily_order": 0}'
      ELSE '{"daily_order": "auto"}'
    END AS parameters,
    DATE(tpep_pickup_datetime) AS ds,
    SUM(fare_amount) AS revenue
  FROM samples.nyctaxi.trips
  GROUP BY ALL
 )

SELECT * FROM AI_FORECAST(
  TABLE(past),
  horizon => (SELECT MAX(ds) + INTERVAL 30 DAYS FROM past),
  time_col => 'ds',
  value_col => 'revenue',
  group_col => ARRAY('revenue_bucket'),
  parameters => 'parameters'
  )

Limitations

The following limitations apply during the preview:

  • The default forecasting procedure is a prophet-like piecewise linear and seasonality model. This is the only supported forecasting procedure available.
  • Error messages are delivered through the Python UDTF engine, and contain Python trace back information. The end of the trace back contains the actual error message.