Python 用户定义表函数 (UDTF)
重要
此功能在 Databricks Runtime 14.3 LTS 及更高版本中作为公共预览版提供。
用户定义的表函数 (UDTF) 允许注册返回表而不是标量值的函数。 与每次调用都返回单个结果值的标量值函数不同,每个 UDTF 都在 SQL 语句的 FROM
子句中调用并返回整个表作为输出。
每个 UDTF 调用都可以接受零个或多个参数。 这些参数可以是标量表达式或代表整个输入表的表参数。
Apache Spark 通过必需的 eval
方法使用 yield
发出输出行来将 Python UDDF 实现为 Python 类。
若要将类用作 UDTF,必须导入 PySpark udtf
函数。 Databricks 建议将此函数用作修饰器,并且使用 returnType
选项显式指定字段名称和类型(除非类定义 analyze
方法,如后面的部分所述)。
以下 UDTF 使用两个整数参数的固定列表创建一个表:
from pyspark.sql.functions import lit, udtf
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, x: int, y: int):
yield x + y, x - y
GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
| 3| -1|
+----+-----+
UDTF 将注册到本地 SparkSession
,并在笔记本或作业级别隔离。
不能将 UDTF 注册为 Unity Catalog 中的对象,UDDF 不能用于 SQL 仓库。
可以将 UDTF 注册到当前 SparkSession
,以便使用函数 spark.udtf.register()
进行 SQL 查询。 提供 SQL 函数和 Python UDTF 类的名称。
spark.udtf.register("get_sum_diff", GetSumDiff)
注册后,可以使用 %sql
magic 命令或 spark.sql()
函数在 SQL 中使用 UDTF:
spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);")
%sql
SELECT * FROM get_sum_diff(1,2);
如果 UDTF 接收少量数据作为输入,但输出大型表,则 Databricks 建议使用 Apache Arrow。 可以通过在声明 UDTF 时指定 useArrow
参数来启用它:
@udtf(returnType="c1: int, c2: int", useArrow=True)
可以使用 Python *args
或 **kwargs
语法并实现逻辑来处理未指定数量的输入值。
以下示例会返回相同的结果,同时显式检查参数的输入长度和类型:
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, *args):
assert(len(args) == 2)
assert(isinstance(arg, int) for arg in args)
x = args[0]
y = args[1]
yield x + y, x - y
GetSumDiff(lit(1), lit(2)).show()
下面是相同的示例,但使用了关键字参数:
@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
def eval(self, **kwargs):
x = kwargs["x"]
y = kwargs["y"]
yield x + y, x - y
GetSumDiff(x=lit(1), y=lit(2)).show()
UDTF 返回带有输出架构的行,该架构由列名和类型的有序序列组成。 如果 UDTF 架构对于所有查询都应始终保持不变,则可以在 @udtf
装饰器后指定静态固定架构。 它必须是 StructType
:
StructType().add("c1", StringType())
或表示结构类型的 DDL 字符串:
c1: string
UDDF 还可以根据输入参数的值以编程方式计算每个调用的输出架构。 为此,请定义一个名为 analyze
的静态方法,该方法接受与提供给特定 UDTF 调用的参数对应的零个或多个参数。
analyze
方法的每个参数都是 AnalyzeArgument
类的实例,其中包含以下字段:
AnalyzeArgument 类字段 |
说明 |
---|---|
dataType |
作为 DataType 的输入参数的类型。 对于输入表参数,这是表示表列的 StructType 。 |
value |
作为 Optional[Any] 的输入参数的值。 对于非常数表参数或文本标量参数,这是 None 。 |
isTable |
输入参数是否是作为 BooleanType 的表。 |
isConstantExpression |
输入参数是否是作为 BooleanType 的常数可折叠表达式。 |
analyze
方法返回 AnalyzeResult
类的实例,其中包括作为 StructType
的结果表架构以及一些可选字段。 如果 UDTF 接受输入表参数,则 AnalyzeResult
还可以包含一种请求的方法,用于在多个 UDTF 调用中对输入表的行进行分区和排序,如下文所述。
AnalyzeResult 类字段 |
说明 |
---|---|
schema |
作为 StructType 的结果表的架构。 |
withSinglePartition |
是否将所有输入行发送到作为 BooleanType 的 UDTF 类实例。 |
partitionBy |
如果设置为非空,则具有分区表达式的每个唯一值组合的所有行都将由 UDTF 类的单独实例使用。 |
orderBy |
如果设置为非空,则指定每个分区中的行的顺序。 |
select |
如果设置为非空,则这是 UDTF 为 Catalyst 指定的表达式序列,用于根据输入 TABLE 参数中的列进行评估。 UDTF 按列出的顺序接收列表中的每个名称的一个输入属性。 |
analyze
示例为输入字符串参数中的每个单词返回一个输出列。
@udtf
class MyUDTF:
@staticmethod
def analyze(text: AnalyzeArgument) -> AnalyzeResult:
schema = StructType()
for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
schema = schema.add(f"word_{index}", IntegerType())
return AnalyzeResult(schema=schema)
def eval(self, text: str):
counts = {}
for word in text.split(" "):
if word not in counts:
counts[word] = 0
counts[word] += 1
result = []
for word in sorted(list(set(text.split(" ")))):
result.append(counts[word])
yield result
['word_0', 'word_1']
analyze
方法可用作执行初始化的便捷位置,然后将结果转发给同一 UDTF 调用的未来 eval
方法调用。
为此,请创建 AnalyzeResult
的子类,并从 analyze
方法返回该子类的实例。
然后,将附加参数添加到 __init__
方法以接受该实例。
analyze
示例返回常数输出架构,但在结果元数据中添加自定义信息,供未来 __init__
方法调用使用:
@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
buffer: str = ""
@udtf
class TestUDTF:
def __init__(self, analyze_result=None):
self._total = 0
if analyze_result is not None:
self._buffer = analyze_result.buffer
else:
self._buffer = ""
@staticmethod
def analyze(argument, _) -> AnalyzeResult:
if (
argument.value is None
or argument.isTable
or not isinstance(argument.value, str)
or len(argument.value) == 0
):
raise Exception("The first argument must be a non-empty string")
assert argument.dataType == StringType()
assert not argument.isTable
return AnalyzeResultWithBuffer(
schema=StructType()
.add("total", IntegerType())
.add("buffer", StringType()),
withSinglePartition=True,
buffer=argument.value,
)
def eval(self, argument, row: Row):
self._total += 1
def terminate(self):
yield self._total, self._buffer
self.spark.udtf.register("test_udtf", TestUDTF)
spark.sql(
"""
WITH t AS (
SELECT id FROM range(1, 21)
)
SELECT total, buffer
FROM test_udtf("abc", TABLE(t))
"""
).show()
+-------+-------+
| count | buffer|
+-------+-------+
| 20 | "abc"|
+-------+-------+
eval
方法针对输入表参数的每一行运行一次(如果未提供任何表参数,则只运行一次),然后在末尾调用一次 terminate
方法。 方法通过生成元组、列表或 pyspark.sql.Row
对象来输出符合结果架构的零行或更多行。
此示例通过提供三个元素的元组返回行:
def eval(self, x, y, z):
yield (x, y, z)
还可以省略括号:
def eval(self, x, y, z):
yield x, y, z
添加尾随逗号以返回仅包含一列的行:
def eval(self, x, y, z):
yield x,
还可以生成 pyspark.sql.Row
对象。
def eval(self, x, y, z)
from pyspark.sql.types import Row
yield Row(x, y, z)
此示例使用 Python 列表从 terminate
方法生成输出行。 为此,可以将状态存储在 UDTF 评估早期步骤的类中。
def terminate(self):
yield [self.x, self.y, self.z]
可以将标量参数作为由文本值或基于它们的函数组成的常数表达式传递给 UDTF。 例如:
SELECT * FROM udtf(42, group => upper("finance_department"));
除了标量输入参数外,Python UDF 还可以接受输入表作为参数。 单个 UDTF 还可以接受表参数和多个标量参数。
然后,任何 SQL 查询都可以使用 TABLE
关键字提供输入表,后跟括号中相应的表标识符,例如 TABLE(t)
。 或者,可以传递表子查询,例如 TABLE(SELECT a, b, c FROM t)
或 TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))
。
然后,输入表参数表示为 eval
方法的 pyspark.sql.Row
参数,对输入表中每一行的 eval
方法进行一次调用。 可以使用标准 PySpark 列字段注释与每行中的列进行交互。 以下示例演示如何显式导入 PySpark Row
类型,然后在 id
字段中筛选传递的表:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="id: int")
class FilterUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],
spark.udtf.register("filter_udtf", FilterUDTF)
若要查询函数,请使用 TABLE
SQL 关键字:
SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
| 6|
| 7|
| 8|
| 9|
+---+
使用表参数调用 UDTF 时,任何 SQL 查询都可以根据一个或多个输入表列的值跨多个 UDTF 调用对输入表进行分区。
若要指定分区,请在 TABLE
参数之后的函数调用中使用 PARTITION BY
子句。
这可以保证具有分区列值的每个唯一组合的所有输入行都由 UDTF 类的一个实例使用。
请注意,除了简单的列引用外,PARTITION BY
子句还接受基于输入表列的任意表达式。 例如,可以指定字符串的 LENGTH
、从日期中提取月份或连接两个值。
还可以指定 WITH SINGLE PARTITION
而不是 PARTITION BY
以仅请求一个分区,其中所有输入行必须由 UDTF 类的一个实例使用。
在每个分区中,可以选择指定 UDTF 的 eval
方法使用输入行时所需的顺序。 为此,请在上述 PARTITION BY
或 WITH SINGLE PARTITION
子句后提供 ORDER BY
子句。
例如,请考虑以下 UDTF:
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="a: string, b: int")
class FilterUDTF:
def __init__(self):
self.key = ""
self.max = 0
def eval(self, row: Row):
self.key = row["a"]
self.max = max(self.max, row["b"])
def terminate(self):
yield self.key, self.max
spark.udtf.register("filter_udtf", FilterUDTF)
可以通过多种方法在输入表上调用 UDF 时指定分区选项:
-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
| a | b |
+-------+----+
| "abc" | 2 |
| "abc" | 4 |
| "def" | 6 |
| "def" | 8 |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
| a | b |
+-------+----+
| "abc" | 4 |
| "def" | 8 |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
| a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
| a | b |
+-------+----+
| "def" | 8 |
+-------+----+
请注意,对于在 SQL 查询中调用 UDTF 对输入表进行分区的上述每种方法,UDTF 的 analyze
方法都有相应的方法可以自动指定相同的分区方法。
- 可以更新
analyze
方法来设置字段partitionBy=[PartitioningColumn("a")]
,然后简单地使用SELECT * FROM udtf(TABLE(t))
调用函数,而不是将 UDTF 作为SELECT * FROM udtf(TABLE(t) PARTITION BY a)
调用。 - 通过相同的标记,你无需在 SQL 查询中指定
TABLE(t) WITH SINGLE PARTITION ORDER BY b
,而可以使analyze
设置字段withSinglePartition=true
和orderBy=[OrderingColumn("b")]
,并仅传递TABLE(t)
。 - 无需在 SQL 查询中传递
TABLE(SELECT a FROM t)
,而是通过analyze
设置select=[SelectedColumn("a")]
,然后只传递TABLE(t)
。
在以下示例中,analyze
返回常数输出架构,从输入表中选择列的子集,并指定输入表根据 date
列的值在多个 UDTF 调用中进行分区:
@staticmethod
def analyze(*args) -> AnalyzeResult:
"""
The input table will be partitioned across several UDTF calls based on the monthly
values of each `date` column. The rows within each partition will arrive ordered by the `date`
column. The UDTF will only receive the `date` and `word` columns from the input table.
"""
from pyspark.sql.functions import (
AnalyzeResult,
OrderingColumn,
PartitioningColumn,
)
assert len(args) == 1, "This function accepts one argument only"
assert args[0].isTable, "Only table arguments are supported"
return AnalyzeResult(
schema=StructType()
.add("month", DateType())
.add('longest_word", IntegerType()),
partitionBy=[
PartitioningColumn("extract(month from date)")],
orderBy=[
OrderingColumn("date")],
select=[
SelectedColumn("date"),
SelectedColumn(
name="length(word),
alias="length_word")])