Unity 目录中的 Python 用户定义表函数 (UDTF)

重要

在 Unity 目录中注册 Python UDTF 处于 公共预览阶段

Unity Catalog 用户定义的表函数(UDTF)注册的是返回完整表而非标量值的函数。 与从每个调用返回单个结果值的标量函数不同,UDDF 在 SQL 语句的 FROM 子句中调用,并可以返回多个行和列。

UDTF 特别有用的情况包括:

  • 将数组或复杂数据结构转换为多行
  • 将外部 API 或服务集成到 SQL 工作流中
  • 实现自定义数据生成或扩充逻辑
  • 处理需要跨行执行有状态操作的数据

每个 UDTF 调用都接受零个或多个参数。 这些参数可以是表示整个输入表的标量表达式或表参数。

可以通过两种方式注册 UDTF:

要求

以下计算类型支持 Unity Catalog Python UDTFs:

  • 无服务器笔记本和作业
  • 具有标准访问模式的经典计算(Databricks Runtime 17.1 及更高版本)
  • SQL 仓库(无服务器或专业版)

在 Unity Catalog 中创建 UDTF

使用 SQL DDL 在 Unity 目录中创建受治理的 UDTF。 可以通过 SQL 语句的 FROM 子句调用 UDTF。

CREATE OR REPLACE FUNCTION square_numbers(start INT, end INT)
RETURNS TABLE (num INT, squared INT)
LANGUAGE PYTHON
HANDLER 'SquareNumbers'
DETERMINISTIC
AS $$
class SquareNumbers:
    """
    Basic UDTF that computes a sequence of integers
    and includes the square of each number in the range.
    """
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)
$$;

SELECT * FROM square_numbers(1, 5);

+-----+---------+
| num | squared |
+-----+---------+
| 1   | 1       |
| 2   | 4       |
| 3   | 9       |
| 4   | 16      |
| 5   | 25      |
+-----+---------+

Azure Databricks 将 Python UDTF 实现为 Python 类,并使用一个必需的方法eval来生成输出行。

表参数

注释

在 Databricks Runtime 17.2 及更高版本中支持 TABLE 参数。

UDTF 可以接受整个表作为输入参数,实现复杂的有状态转换和聚合。

eval()terminate() 生命周期方法

UDF 中的表参数使用以下函数来处理每一行:

  • eval():为输入表中的每一行调用一次。 这是主要的处理方法,是必需的。
  • terminate():在每个分区的末尾调用一次,之后所有行都经过 eval()处理。 使用此方法生成最终聚合结果或执行清理作。 此方法是可选的,但对于有状态作(如聚合、计数或批处理)至关重要。

有关eval()terminate()方法的详细信息,请参阅Apache Spark 文档:Python UDTF

行访问模式

eval() 接收自 TABLE 参数的行作为 pyspark.sql.Row 对象。 可以按列名(row['id']row['name'])或索引(row[0]row[1]来访问值。

  • 架构灵活性:声明TABLE没有架构定义的参数(例如,data TABLEt TABLE)。 该函数接受任何表结构,因此代码应验证所需的列是否存在。

请参阅 示例:将 IP 地址与 CIDR 网络块匹配 ,示例 :使用 Azure Databricks 视觉终结点进行批处理图像字幕

计算动态输出架构(多态 UDDF)

注释

多态 UC UDTF 需要 Databricks Runtime 18.1 及更高版本。

多态 UDTF 使用静态 analyze() 方法在查询时动态确定其输出架构,而不是提前声明列。 若要创建一个,请使用 RETURNS TABLE 不带列定义,并在处理程序类上定义方法 analyze()

以下示例从 JSON 字符串中提取调用方指定的字段,根据 fields 参数返回不同的列:

CREATE OR REPLACE FUNCTION extract_fields(json_str STRING, fields STRING)
RETURNS TABLE
LANGUAGE PYTHON
HANDLER 'ExtractFields'
AS $$
class ExtractFields:
    @staticmethod
    def analyze(json_str, fields):

        # Build the output schema from the requested field names
        from pyspark.sql.types import StructType, StructField, StringType
        from pyspark.sql.udtf import AnalyzeResult
        col_names = [f.strip() for f in fields.value.split(",")]
        return AnalyzeResult(
            StructType([StructField(name, StringType()) for name in col_names])
        )

    def eval(self, json_str: str, fields: str):
        # Parse the JSON and yield only the requested fields
        import json
        data = json.loads(json_str)
        col_names = [f.strip() for f in fields.split(",")]
        yield tuple(data.get(name) for name in col_names)
$$;

-- Extract the name and city
SELECT * FROM extract_fields(
  '{"name": "Alice", "age": 30, "city": "Seattle"}',
  'name, city'
);
+-------+---------+
| name  | city    |
+-------+---------+
| Alice | Seattle |
+-------+---------+

定义analyze方法

处理程序类必须包含一个名为@staticmethod的方法,它接受与UDTF相同的analyze参数,并返回一个描述输出架构的AnalyzeResult。 Azure Databricks在查询规划时调用 analyze(),以在执行函数之前解析架构。

每个参数 analyze 都是类的 AnalyzeArgument 实例:

领域 Description
dataType 输入参数的类型是 DataType。 对于输入表参数,这个 StructType 代表表的列。
value 作为 Optional[Any] 的输入参数的值。 此 None 适用于表参数或非常量表达式。
isTable 输入参数是否为表参数,如BooleanType
isConstantExpression 输入参数是否为常量可折叠表达式,并以 BooleanType的形式呈现。

该方法 analyze 返回类的 AnalyzeResult 实例:

领域 Description
schema 结果表的架构为 StructType
withSinglePartition 如果 True,将所有输入行发送到同一 UDTF 类实例。
partitionBy 如果非空,则按指定的表达式对输入行进行分区,以便每个唯一组合由单独的 UDTF 实例处理。
orderBy 如果为非空,则指定每个分区中的行的顺序。
select 如果不为空,则指定 UDTF 接收的输入参数 TABLE 中的具体列。

警告

对于 Unity Catalog 的多态 UDTF,您必须在 analyze() 方法体中包含所有导入。 顶级导入在 Unity Catalog 沙盒环境中不可用。

将状态从 analyze 转发到 eval

该方法 analyze 在查询规划时运行一次,因此可以使用该方法预处理常量参数、分析配置或生成查找。 若要将这些结果转发到 eval,请创建一个包含自定义字段的 AnalyzeResult@dataclass 子类,从 analyze 返回该子类,并在 __init__ 方法中接受它。 这可避免对每行重复成本高昂的工作。

以下示例将语言代码解析为一次 analyze 完整语言名称并转发该名称,因此 eval 可以在不重复查找的情况下标记每一行:

CREATE OR REPLACE FUNCTION tag_language(t TABLE, lang_code STRING)
RETURNS TABLE
LANGUAGE PYTHON
HANDLER 'TagLanguage'
AS $$
class TagLanguage:
    @staticmethod
    def analyze(t, lang_code):
        from dataclasses import dataclass
        from pyspark.sql.types import StructType, StructField, StringType
        from pyspark.sql.udtf import AnalyzeResult

        @dataclass
        class LangResult(AnalyzeResult):
            language: str = ""

        # Resolve the language code to a full name once during planning
        languages = {"en": "English", "es": "Spanish", "fr": "French", "de": "German"}
        return LangResult(
            schema=StructType([
                StructField("text", StringType()),
                StructField("language", StringType())
            ]),
            language=languages.get(lang_code.value, "Unknown")
        )

    def __init__(self, result):
        self._language = result.language

    def eval(self, row, lang_code: str):
        # Tag each row with the pre-resolved language name
        yield (row['text'], self._language)
$$;

SELECT * FROM tag_language(
  TABLE(VALUES ('Hola mundo'), ('Buenos días') t(text)),
  'es'
);
+-------------+----------+
| text        | language |
+-------------+----------+
| Hola mundo  | Spanish  |
| Buenos días | Spanish  |
+-------------+----------+

有关转发状态的更多模型和细节,请参阅 关于将状态转发到未来调用的说明

analyze 方法指定分区

当多态 UDTF 接受表参数时,analyze 方法可以通过在 AnalyzeResult 上设置 partitionByorderBywithSinglePartitionselect 来控制输入行在 UDTF 实例之间的分布方式。 这样就无需调用方在 SQL 中指定 PARTITION BYORDER BY

有关完整的分区 API 和示例,请参阅 指定方法 analyze 中的输入行的分区

环境隔离

注释

共享隔离环境需要 Databricks Runtime 17.2 及更高版本。 在早期版本中,所有 Unity 目录 Python UDDF 都以严格的隔离模式运行。

默认情况下,具有相同所有者和会话的 Unity 目录 Python UDF 可以共享隔离环境。 这通过减少需要启动的单独环境的数量来提高性能并减少内存使用量。

严格隔离

若要确保 UDTF 始终在其自己的完全隔离环境中运行,请添加 STRICT ISOLATION 特征子句。

大多数 UDTF 不需要严格的隔离。 标准数据处理 UDF 受益于默认共享隔离环境,运行速度更快,内存消耗较低。

STRICT ISOLATION 特征子句添加到 UDTF,该子句:

  • 使用eval()exec()或类似函数以代码形式运行输入。
  • 将文件写入本地文件系统。
  • 修改全局变量或系统状态。
  • 访问或修改环境变量。

以下 UDTF 示例设置自定义环境变量,回读变量,并使用变量将一组数字相乘。 由于 UDTF 会改变进程环境,因此请在 STRICT ISOLATION 中运行它。 否则,它可能会泄漏或替代同一环境中其他 UDF/UDF 的环境变量,从而导致行为不正确。

CREATE OR REPLACE TEMPORARY FUNCTION multiply_numbers(factor STRING)
RETURNS TABLE (original INT, scaled INT)
LANGUAGE PYTHON
STRICT ISOLATION
HANDLER 'Multiplier'
AS $$
import os

class Multiplier:
    def eval(self, factor: str):
        # Save the factor as an environment variable
        os.environ["FACTOR"] = factor

        # Read it back and convert it to a number
        scale = int(os.getenv("FACTOR", "1"))

        # Multiply 0 through 4 by the factor
        for i in range(5):
            yield (i, i * scale)
$$;

SELECT * FROM multiply_numbers("3");

设定 DETERMINISTIC,如果您的函数生成一致的结果

如果函数为相同的输入生成相同的输出,请在函数定义中添加 DETERMINISTIC。 这允许查询优化来提高性能。

默认情况下,Batch Unity Catalog Python UDTF 被假定为非确定性的,除非已被显式声明。 非确定性函数的示例包括:生成随机值、访问当前时间或日期或进行外部 API 调用。

请参阅 CREATE FUNCTION(SQL 和 Python)

实例

以下示例演示了 Unity 目录 Python UDDF 的实际用例,从简单的数据转换到复杂的外部集成。

示例:重新实现 explode

虽然 Spark 提供了一个内置的 explode 函数,但创建您自己的函数版本可以演示 UDTF 的基本模式,即从单个输入生成多个输出行。

CREATE OR REPLACE FUNCTION my_explode(arr ARRAY<STRING>)
RETURNS TABLE (element STRING)
LANGUAGE PYTHON
HANDLER 'MyExplode'
DETERMINISTIC
AS $$
class MyExplode:
    def eval(self, arr):
        if arr is None:
            return
        for element in arr:
            yield (element,)
$$;

直接在 SQL 查询中使用函数:

SELECT element FROM my_explode(array('apple', 'banana', 'cherry'));
+---------+
| element |
+---------+
| apple   |
| banana  |
| cherry  |
+---------+

或者,将其应用于具有 LATERAL 联接的现有表数据:

SELECT s.*, e.element
FROM my_items AS s,
LATERAL my_explode(s.items) AS e;

示例:通过 REST API 进行 IP 地址的地理定位

此示例演示 UDDF 如何将外部 API 直接集成到 SQL 工作流中。 分析师可以使用熟悉的 SQL 语法通过实时 API 调用来丰富数据,而无需单独的 ETL 进程。

CREATE OR REPLACE FUNCTION ip_to_location(ip_address STRING)
RETURNS TABLE (city STRING, country STRING)
LANGUAGE PYTHON
HANDLER 'IPToLocationAPI'
AS $$
class IPToLocationAPI:
    def eval(self, ip_address):
        import requests
        api_url = f"https://api.ip-lookup.example.com/{ip_address}"
        try:
            response = requests.get(api_url)
            response.raise_for_status()
            data = response.json()
            yield (data.get('city'), data.get('country'))
        except requests.exceptions.RequestException as e:
            # Return nothing if the API request fails
            return
$$;

注释

使用无服务器计算或配置了标准访问模式的计算时,Python UDTF 允许通过端口 80、443 和 53 的 TCP/UDP 网络流量。

使用函数通过地理信息丰富 Web 日志数据:

SELECT
  l.timestamp,
  l.request_path,
  geo.city,
  geo.country
FROM web_logs AS l,
LATERAL ip_to_location(l.ip_address) AS geo;

此方法支持实时地理分析,而无需预先处理的查阅表或单独的数据管道。 UDTF 处理 HTTP 请求、JSON 分析和错误处理,使外部数据源可通过标准 SQL 查询访问。

示例:将 IP 地址与 CIDR 网络块匹配

此示例演示了将 IP 地址与 CIDR 网络块匹配,这是一项需要复杂 SQL 逻辑的常见数据工程任务。

首先,使用 IPv4 和 IPv6 地址创建示例数据:

-- An example IP logs with both IPv4 and IPv6 addresses
CREATE OR REPLACE TEMPORARY VIEW ip_logs AS
VALUES
  ('log1', '192.168.1.100'),
  ('log2', '10.0.0.5'),
  ('log3', '172.16.0.10'),
  ('log4', '8.8.8.8'),
  ('log5', '2001:db8::1'),
  ('log6', '2001:db8:85a3::8a2e:370:7334'),
  ('log7', 'fe80::1'),
  ('log8', '::1'),
  ('log9', '2001:db8:1234:5678::1')
t(log_id, ip_address);

接下来,定义并注册 UDTF。 请注意 Python 类结构:

  • t TABLE 参数接受具有任何架构的输入表。 UDTF 会自动适应以处理所提供的任何列。 这种灵活性意味着可以在不同表中使用相同的函数,而无需修改函数签名。 但是,必须仔细检查行的架构以确保兼容性。
  • 该方法 __init__ 用于繁重的一次性设置,例如加载大型网络列表。 此工作在每个输入表的分区中执行一次。
  • 该方法 eval 处理每行并包含核心匹配逻辑。 此方法对输入分区中的每个行执行一次,每个执行由该分区的 IpMatcher UDTF 类的相应实例执行。
  • HANDLER 子句指定实现 UDTF 逻辑的 Python 类的名称。
CREATE OR REPLACE TEMPORARY FUNCTION ip_cidr_matcher(t TABLE)
RETURNS TABLE(log_id STRING, ip_address STRING, network STRING, ip_version INT)
LANGUAGE PYTHON
HANDLER 'IpMatcher'
COMMENT 'Match IP addresses against a list of network CIDR blocks'
AS $$
class IpMatcher:
    def __init__(self):
        import ipaddress
        # Heavy initialization - load networks once per partition
        self.nets = []
        cidrs = ['192.168.0.0/16', '10.0.0.0/8', '172.16.0.0/12',
                 '2001:db8::/32', 'fe80::/10', '::1/128']
        for cidr in cidrs:
            self.nets.append(ipaddress.ip_network(cidr))

    def eval(self, row):
        import ipaddress
	    # Validate that required fields exist
        required_fields = ['log_id', 'ip_address']
        for field in required_fields:
            if field not in row:
                raise ValueError(f"Missing required field: {field}")
        try:
            ip = ipaddress.ip_address(row['ip_address'])
            for net in self.nets:
                if ip in net:
                    yield (row['log_id'], row['ip_address'], str(net), ip.version)
                    return
            yield (row['log_id'], row['ip_address'], None, ip.version)
        except ValueError:
            yield (row['log_id'], row['ip_address'], 'Invalid', None)
$$;

现在已在 ip_cidr_matcher Unity 目录中注册,请使用 TABLE() 语法直接从 SQL 调用它:

-- Process all IP addresses
SELECT
  *
FROM
  ip_cidr_matcher(t => TABLE(ip_logs))
ORDER BY
  log_id;
+--------+-------------------------------+-----------------+-------------+
| log_id | ip_address                    | network         | ip_version  |
+--------+-------------------------------+-----------------+-------------+
| log1   | 192.168.1.100                 | 192.168.0.0/16  | 4           |
| log2   | 10.0.0.5                      | 10.0.0.0/8      | 4           |
| log3   | 172.16.0.10                   | 172.16.0.0/12   | 4           |
| log4   | 8.8.8.8                       | null            | 4           |
| log5   | 2001:db8::1                   | 2001:db8::/32   | 6           |
| log6   | 2001:db8:85a3::8a2e:370:7334  | 2001:db8::/32   | 6           |
| log7   | fe80::1                       | fe80::/10       | 6           |
| log8   | ::1                           | ::1/128         | 6           |
| log9   | 2001:db8:1234:5678::1         | 2001:db8::/32   | 6           |
+--------+-------------------------------+-----------------+-------------+

示例:使用 Azure Databricks 视觉端点进行批量图像标注

此示例演示了使用 Azure Databricks 的视觉模型服务终端进行批量图像说明生成。 它展示了如何使用 terminate() 进行批处理和基于分区的执行。

  1. 创建具有公共映像 URL 的表:

    CREATE OR REPLACE TEMPORARY VIEW sample_images AS
    VALUES
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg', 'scenery'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Camponotus_flavomarginatus_ant.jpg/1024px-Camponotus_flavomarginatus_ant.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Cat_August_2010-4.jpg/1200px-Cat_August_2010-4.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/M101_hires_STScI-PRC2006-10a.jpg/1024px-M101_hires_STScI-PRC2006-10a.jpg', 'scenery')
    images(image_url, category);
    
  2. 创建 Unity Catalog Python UDTF 用于生成图像标题。

    1. 使用配置初始化 UDTF,包括批大小、Azure Databricks API 令牌、视觉模型终结点和工作区 URL。
    2. eval 方法中,将图像 URL 收集到缓冲区中。 当缓冲区达到批大小时,触发批处理。 这可确保在单个 API 调用中同时处理多个映像,而不是每个映像的单个调用。
    3. 在批处理方法中,下载所有缓冲图像,将其编码为 base64,并将其发送到 Databricks VisionModel 的单个 API 请求。 模型同时处理所有图像,并返回整个批处理的标题。
    4. 该方法 terminate 在每个分区的末尾完全执行一次。 在终止方法中,处理缓冲区中的任何剩余图像,并生成所有收集的字幕作为结果。

注释

用您的 Azure Databricks 实际工作区 URL(<workspace-url>)替换 https://your-workspace.cloud.databricks.com

CREATE OR REPLACE TEMPORARY FUNCTION batch_inference_image_caption(data TABLE, api_token STRING)
RETURNS TABLE (caption STRING)
LANGUAGE PYTHON
HANDLER 'BatchInferenceImageCaption'
COMMENT 'batch image captioning by sending groups of image URLs to a Databricks vision endpoint and returning concise captions for each image.'
AS $$
class BatchInferenceImageCaption:
    def __init__(self):
        self.batch_size = 3
        self.vision_endpoint = "databricks-claude-sonnet-4-5"
        self.workspace_url = "<workspace-url>"
        self.image_buffer = []
        self.results = []

    def eval(self, row, api_token):
        self.image_buffer.append((str(row[0]), api_token))
        if len(self.image_buffer) >= self.batch_size:
            self._process_batch()

    def terminate(self):
        if self.image_buffer:
            self._process_batch()
        for caption in self.results:
            yield (caption,)

    def _process_batch(self):
        batch_data = self.image_buffer.copy()
        self.image_buffer.clear()

        import base64
        import httpx
        import requests

        # API request timeout in seconds
        api_timeout = 60
        # Maximum tokens for vision model response
        max_response_tokens = 300
        # Temperature controls randomness (lower = more deterministic)
        model_temperature = 0.3

        # create a batch for the images
        batch_images = []
        api_token = batch_data[0][1] if batch_data else None

        for image_url, _ in batch_data:
            image_response = httpx.get(image_url, timeout=15)
            image_data = base64.standard_b64encode(image_response.content).decode("utf-8")
            batch_images.append(image_data)

        content_items = [{
            "type": "text",
            "text": "Provide brief captions for these images, one per line."
        }]
        for img_data in batch_images:
            content_items.append({
                "type": "image_url",
                "image_url": {
                    "url": "data:image/jpeg;base64," + img_data
                }
            })

        payload = {
            "messages": [{
                "role": "user",
                "content": content_items
            }],
            "max_tokens": max_response_tokens,
            "temperature": model_temperature
        }

        response = requests.post(
            self.workspace_url + "/serving-endpoints/" +
            self.vision_endpoint + "/invocations",
            headers={
                'Authorization': 'Bearer ' + api_token,
                'Content-Type': 'application/json'
            },
            json=payload,
            timeout=api_timeout
        )

        result = response.json()
        batch_response = result['choices'][0]['message']['content'].strip()

        lines = batch_response.split('\n')
        captions = [line.strip() for line in lines if line.strip()]

        while len(captions) < len(batch_data):
            captions.append(batch_response)

        self.results.extend(captions[:len(batch_data)])
$$;

若要使用批处理图像标题 UDTF,请使用示例图像表调用它:

注释

your_secret_scopeapi_token 替换为 Databricks API 令牌的实际机密范围和密钥名称。

SELECT
  caption
FROM
  batch_inference_image_caption(
    data => TABLE(sample_images),
    api_token => secret('your_secret_scope', 'api_token')
  )
+---------------------------------------------------------------------------------------------------------------+
| caption                                                                                                       |
+---------------------------------------------------------------------------------------------------------------+
| Wooden boardwalk cutting through vibrant wetland grasses under blue skies                                     |
| Black ant in detailed macro photography standing on a textured surface                                        |
| Tabby cat lounging comfortably on a white ledge against a white wall                                          |
| Stunning spiral galaxy with bright central core and sweeping blue-white arms against the black void of space. |
+---------------------------------------------------------------------------------------------------------------+

您还可以按照类别生成图像标题:

SELECT
  *
FROM
  batch_inference_image_caption(
    TABLE(sample_images)
    PARTITION BY category ORDER BY (category),
    secret('your_secret_scope', 'api_token')
  )
+------------------------------------------------------------------------------------------------------+
| caption                                                                                              |
+------------------------------------------------------------------------------------------------------+
| Black ant in detailed macro photography standing on a textured surface                               |
| Stunning spiral galaxy with bright center and sweeping blue-tinged arms against the black of space.  |
| Tabby cat lounging comfortably on white ledge against white wall                                     |
| Wooden boardwalk cutting through lush wetland grasses under blue skies                               |
+------------------------------------------------------------------------------------------------------+

示例:ML 模型评估的 ROC 曲线和 AUC 计算

此示例演示了如何使用 scikit-learn 计算接收机工作特性(ROC)曲线和曲线下面积(AUC)得分,以评估二分类模型。

此示例展示了几个重要的模式:

  • 外部库用法:集成 scikit-learn 进行 ROC 曲线计算
  • 有状态聚合:在计算指标之前累积所有行的预测
  • terminate() 方法用法:仅在评估所有行后处理完整的数据集并生成结果
  • 错误处理:验证输入表中是否存在所需的列

UDTF 使用 eval() 方法累积内存中的所有预测,然后在 terminate() 方法中计算并生成完整的 ROC 曲线。 此模式对于需要完整数据集进行计算的指标非常有用。

CREATE OR REPLACE TEMPORARY FUNCTION compute_roc_curve(t TABLE)
RETURNS TABLE (threshold DOUBLE, true_positive_rate DOUBLE, false_positive_rate DOUBLE, auc DOUBLE)
LANGUAGE PYTHON
HANDLER 'ROCCalculator'
COMMENT 'Compute ROC curve and AUC using scikit-learn'
AS $$
class ROCCalculator:
    def __init__(self):
        from sklearn import metrics
        self._roc_curve = metrics.roc_curve
        self._roc_auc_score = metrics.roc_auc_score

        self._true_labels = []
        self._predicted_scores = []

    def eval(self, row):
        if 'y_true' not in row or 'y_score' not in row:
            raise KeyError("Required columns 'y_true' and 'y_score' not found")

        true_label = row['y_true']
        predicted_score = row['y_score']

        label = float(true_label)
        self._true_labels.append(label)
        self._predicted_scores.append(float(predicted_score))

    def terminate(self):
        false_pos_rate, true_pos_rate, thresholds = self._roc_curve(
            self._true_labels,
            self._predicted_scores,
            drop_intermediate=False
        )

        auc_score = float(self._roc_auc_score(self._true_labels, self._predicted_scores))

        for threshold, tpr, fpr in zip(thresholds, true_pos_rate, false_pos_rate):
            yield float(threshold), float(tpr), float(fpr), auc_score
$$;

使用预测创建示例二进制分类数据:

CREATE OR REPLACE TEMPORARY VIEW binary_classification_data AS
SELECT *
FROM VALUES
  ( 1, 1.0, 0.95, 'high_confidence_positive'),
  ( 2, 1.0, 0.87, 'high_confidence_positive'),
  ( 3, 1.0, 0.82, 'medium_confidence_positive'),
  ( 4, 0.0, 0.78, 'false_positive'),
  ( 5, 1.0, 0.71, 'medium_confidence_positive'),
  ( 6, 0.0, 0.65, 'false_positive'),
  ( 7, 0.0, 0.58, 'true_negative'),
  ( 8, 1.0, 0.52, 'low_confidence_positive'),
  ( 9, 0.0, 0.45, 'true_negative'),
  (10, 0.0, 0.38, 'true_negative'),
  (11, 1.0, 0.31, 'low_confidence_positive'),
  (12, 0.0, 0.15, 'true_negative'),
  (13, 0.0, 0.08, 'high_confidence_negative'),
  (14, 0.0, 0.03, 'high_confidence_negative')
AS data(sample_id, y_true, y_score, prediction_type);

计算 ROC 曲线和 AUC:

SELECT
    threshold,
    true_positive_rate,
    false_positive_rate,
    auc
FROM compute_roc_curve(
  TABLE(
    SELECT y_true, y_score
    FROM binary_classification_data
    WHERE y_true IS NOT NULL AND y_score IS NOT NULL
    ORDER BY sample_id
  )
)
ORDER BY threshold DESC;
+-----------+---------------------+----------------------+-------+
| threshold | true_positive_rate  | false_positive_rate  | auc   |
+-----------+---------------------+----------------------+-------+
| 1.95      | 0.0                 | 0.0                  | 0.786 |
| 0.95      | 0.167               | 0.0                  | 0.786 |
| 0.87      | 0.333               | 0.0                  | 0.786 |
| 0.82      | 0.5                 | 0.0                  | 0.786 |
| 0.78      | 0.5                 | 0.125                | 0.786 |
| 0.71      | 0.667               | 0.125                | 0.786 |
| 0.65      | 0.667               | 0.25                 | 0.786 |
| 0.58      | 0.667               | 0.375                | 0.786 |
| 0.52      | 0.833               | 0.375                | 0.786 |
| 0.45      | 0.833               | 0.5                  | 0.786 |
| 0.38      | 0.833               | 0.625                | 0.786 |
| 0.31      | 1.0                 | 0.625                | 0.786 |
| 0.15      | 1.0                 | 0.75                 | 0.786 |
| 0.08      | 1.0                 | 0.875                | 0.786 |
| 0.03      | 1.0                 | 1.0                  | 0.786 |
+-----------+---------------------+----------------------+-------+

示例:表参数的动态列投影

此示例将多态 UDDF 与表参数组合在一起。 UDTF 接受表和以逗号分隔的列名称列表,然后仅投影输入中的这些列。 该方法 analyze 检查输入表的架构,并生成仅包含所请求列的输出架构。

CREATE OR REPLACE FUNCTION project_columns(t TABLE, columns STRING)
RETURNS TABLE
LANGUAGE PYTHON
HANDLER 'ProjectColumns'
AS $$
class ProjectColumns:
    @staticmethod
    def analyze(t, columns):
        from pyspark.sql.types import StructType
        from pyspark.sql.udtf import AnalyzeResult

        requested = [c.strip() for c in columns.value.split(",")]
        input_schema = t.dataType
        output_fields = []
        for field in input_schema.fields:
            if field.name in requested:
                output_fields.append(field)
        if not output_fields:
            raise ValueError(
                f"None of the requested columns {requested} "
                f"exist in the input table"
            )
        return AnalyzeResult(schema=StructType(output_fields))

    def eval(self, row, columns: str):
        requested = [c.strip() for c in columns.split(",")]
        yield tuple(row[col] for col in requested if col in row)
$$;

使用函数从表中选择特定列:

SELECT * FROM project_columns(
  TABLE(SELECT * FROM samples.nyctaxi.trips LIMIT 5),
  'pickup_zip, dropoff_zip, fare_amount'
);

局限性

以下限制适用于 Unity Catalog Python UDTF:

后续步骤