Share via


用戶定義的純量函式 - Python

本文包含 Python 使用者定義函式 (UDF) 範例。 它示範如何註冊 UDF、如何叫用 UDF,並提供 Spark SQL 中子運算式評估順序的注意事項。

在 Databricks Runtime 14.0 和更新版本中,您可以使用 Python 使用者定義數據表函式 (UDF) 來註冊傳回整個關聯而非純量值的函式。 請參閱 什麼是 Python 使用者定義數據表函式?

注意

在 Databricks Runtime 12.2 LTS 和以下版本中,Unity 目錄中不支援使用共用存取模式的 Python UDF 和 Pandas UDF。 Databricks Runtime 13.3 LTS 和更新版本支援純量 Python UDF 和純量 Pandas UDF,適用於所有存取模式。

在 Databricks Runtime 13.3 LTS 和更新版本中,您可以使用 SQL 語法向 Unity 目錄註冊純量 Python UDF。 請參閱 Unity 目錄中的使用者定義函式 (UDF)。

將函式註冊為UDF

def squared(s):
  return s * s
spark.udf.register("squaredWithPython", squared)

您可以選擇性地設定 UDF 的傳回類型。 預設傳回型態為 StringType

from pyspark.sql.types import LongType
def squared_typed(s):
  return s * s
spark.udf.register("squaredWithPython", squared_typed, LongType())

在Spark SQL中呼叫 UDF

spark.range(1, 20).createOrReplaceTempView("test")
%sql select id, squaredWithPython(id) as id_squared from test

搭配 DataFrame 使用 UDF

from pyspark.sql.functions import udf
from pyspark.sql.types import LongType
squared_udf = udf(squared, LongType())
df = spark.table("test")
display(df.select("id", squared_udf("id").alias("id_squared")))

或者,您可以使用註釋語法來宣告相同的 UDF:

from pyspark.sql.functions import udf
@udf("long")
def squared_udf(s):
  return s * s
df = spark.table("test")
display(df.select("id", squared_udf("id").alias("id_squared")))

評估順序和 Null 檢查

Spark SQL(包括 SQL 和 DataFrame 和數據集 API)不保證子表達式的評估順序。 特別是,運算子或函式的輸入不一定以左至右或任何其他固定順序進行評估。 例如,邏輯 ANDOR 表達式沒有由左至右的「縮短」語意。

因此,依賴布爾表達式評估的副作用或順序,以及 和 HAVING 子句的順序WHERE是危險的,因為這類表達式和子句可以在查詢優化和規劃期間重新排序。 具體來說,如果 UDF 依賴 SQL 中的短期語意進行 Null 檢查,則不保證在叫用 UDF 之前會發生 Null 檢查。 例如,

spark.udf.register("strlen", lambda s: len(s), "int")
spark.sql("select s from test1 where s is not null and strlen(s) > 1") # no guarantee

這個 WHERE 子句不保證 strlen 在篩選出 Null 之後叫用 UDF。

若要執行適當的 Null 檢查,建議您執行下列其中一項:

  • 讓UDF本身成為 Null 感知,並在 UDF 本身內執行 Null 檢查
  • 使用 IFCASE WHEN 表達式執行 Null 檢查,並在條件式分支中叫用 UDF
spark.udf.register("strlen_nullsafe", lambda s: len(s) if not s is None else -1, "int")
spark.sql("select s from test1 where s is not null and strlen_nullsafe(s) > 1") // ok
spark.sql("select s from test1 where if(s is not null, strlen(s), null) > 1")   // ok