Unity 目錄中的 Python 使用者定義資料表函式 (UDTF)

這很重要

在 Unity Catalog 中登錄 Python UDTF 正處於 公開預覽階段。

Unity Catalog 使用者自定義資料表函式(UDTF)會註冊那些返回完整資料表而非純量值的函式。 與從每次呼叫傳回單一結果值的純量函數不同,UDTF 會在 SQL 陳述式的 FROM 子句中呼叫,並且可以傳回多個列和列。

UDTF 功能特別適用於:

  • 將陣列或複雜的資料結構轉換為多行
  • 將外部 API 或服務整合到 SQL 工作流程中
  • 實作自訂資料產生或擴充邏輯
  • 處理需要跨列有狀態作業的資料

每個 UDTF 呼叫都接受零個或多個引數。 這些引數可以是純量運算式或代表整個輸入表的表引數。

UDTF 可以透過兩種方式註冊:

需求

下列計算類型支援 Unity 目錄 Python UDTF:

  • 無伺服器筆記本和任務
  • 具有標準存取模式的傳統計算 (Databricks Runtime 17.1 和更新版本)
  • SQL 倉儲 (無伺服器或專業版)

在 Unity Catalog 中創建 UDTF

使用 SQL DDL 在 Unity 目錄中建立受控管的 UDTF。 UDTF 是使用 SQL 陳述式的 FROM 子句來呼叫。

CREATE OR REPLACE FUNCTION square_numbers(start INT, end INT)
RETURNS TABLE (num INT, squared INT)
LANGUAGE PYTHON
HANDLER 'SquareNumbers'
DETERMINISTIC
AS $$
class SquareNumbers:
    """
    Basic UDTF that computes a sequence of integers
    and includes the square of each number in the range.
    """
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)
$$;

SELECT * FROM square_numbers(1, 5);

+-----+---------+
| num | squared |
+-----+---------+
| 1   | 1       |
| 2   | 4       |
| 3   | 9       |
| 4   | 16      |
| 5   | 25      |
+-----+---------+

Azure Databricks 將 Python UDTF 實作為 Python 類別,並使用必要的 eval 方法來產生輸出資料列。

表格引數

備註

TABLE Databricks Runtime 17.2 和更新版本支援引數。

UDTF 可以接受整個表作為輸入參數,從而實現複雜的有狀態轉換和聚合。

eval()terminate() 生命週期方法

UDTF 中的表格引數會使用下列函式來處理每一列:

  • eval():針對輸入表格中的每一列呼叫一次。 這是主要的處理方法,也是必需的。
  • terminate():在每個分割區結束時呼叫一次,在處理完eval()所有資料列之後。 使用此方法可產生最終彙總結果或執行清除作業。 此方法是可選的,但對於聚合、計數或批次處理等有狀態操作至關重要。

如需eval()terminate()方法的詳細資訊,請參閱Apache Spark 文件:Python UDTF

列存取模式

eval() 從引數接收 TABLE 資料列作為 pyspark.sql.Row 物件。 您可以依資料行名稱 (row['id']row['name']) 或索引 (row[0]row[1]) 來存取值。

  • 模式彈性:宣告TABLE沒有模式定義的引數(例如,data TABLE,)。 t TABLE 函式接受任何資料表結構,因此您的程式碼應該驗證必要的資料行是否存在。

請參閱 範例:將 IP 位址與 CIDR 網路區塊比對 ,以及 範例:使用 Azure Databricks 視覺端點批次影像字幕

計算一個動態輸出模式(多態 UDTF )

備註

多型 UC UDTF 需要 Databricks Runtime 18.1 及以上版本。

多型 UDTF 在查詢時使用靜態 analyze() 方法動態決定其輸出模式,而非事先宣告欄位。 要建立一個,請使用 RETURNS TABLE 無欄位定義,並在處理類別上定義一個 analyze() 方法。

以下範例從 JSON 字串中擷取呼叫者指定的欄位,並根據參數 fields 回傳不同的欄位:

CREATE OR REPLACE FUNCTION extract_fields(json_str STRING, fields STRING)
RETURNS TABLE
LANGUAGE PYTHON
HANDLER 'ExtractFields'
AS $$
class ExtractFields:
    @staticmethod
    def analyze(json_str, fields):

        # Build the output schema from the requested field names
        from pyspark.sql.types import StructType, StructField, StringType
        from pyspark.sql.udtf import AnalyzeResult
        col_names = [f.strip() for f in fields.value.split(",")]
        return AnalyzeResult(
            StructType([StructField(name, StringType()) for name in col_names])
        )

    def eval(self, json_str: str, fields: str):
        # Parse the JSON and yield only the requested fields
        import json
        data = json.loads(json_str)
        col_names = [f.strip() for f in fields.split(",")]
        yield tuple(data.get(name) for name in col_names)
$$;

-- Extract the name and city
SELECT * FROM extract_fields(
  '{"name": "Alice", "age": 30, "city": "Seattle"}',
  'name, city'
);
+-------+---------+
| name  | city    |
+-------+---------+
| Alice | Seattle |
+-------+---------+

定義 analyze 方法

處理類別必須包含一個方法@staticmethod,名為analyze,該方法接受與 UDTF 相同的參數,並回傳描述輸出結構的AnalyzeResult。 Azure Databricks 在查詢規劃時呼叫 analyze(),以解析結構後再執行函式。

analyze 每個參數都是類別的 AnalyzeArgument 一個實例:

Field Description
dataType 輸入參數的類型作為 DataType。 對於輸入資料表的參數,這是 StructType 代表資料表欄位的。
value 輸入參數的值作為 Optional[Any]。 這是用於 None 表格參數或非恆定表達式。
isTable 輸入參數是否為表參數。BooleanType
isConstantExpression 輸入參數是否為可作為 BooleanType 的常數折疊表示式。

analyze 方法回傳類別的 AnalyzeResult 一個實例:

Field Description
schema 結果資料表的架構為 StructType
withSinglePartition True,則將所有輸入列傳送至同一個 UDTF 類別實例。
partitionBy 若非空,則依指定表達式分割輸入列,使每個獨特組合由獨立的 UDTF 實例處理。
orderBy 若非空,則指定每個分割中列的順序。
select 若非空,則指定 UDTF 從輸入 TABLE 參數接收哪些欄位。

警告

對於 Unity Catalog 多型 UDTF,你必須將所有匯入放入 analyze() 方法主體。 Unity Catalog 沙盒環境中無法提供頂層匯入。

analyzeeval 的前向狀態

這個 analyze 方法在查詢規劃時執行一次,所以你可以用它來預處理常數參數、解析組態或建置查找。 若要將結果轉發到 eval,請建立 AnalyzeResult@dataclass 子類,帶有自訂欄位,從 analyze 返回,然後在 __init__ 方法中接受。 這樣可以避免每排重複昂貴的工作。

以下範例將語言代碼解析為完整語言名稱,然後在 analyze 轉發,這樣 eval 就可以標記每一列而無需重複查找。

CREATE OR REPLACE FUNCTION tag_language(t TABLE, lang_code STRING)
RETURNS TABLE
LANGUAGE PYTHON
HANDLER 'TagLanguage'
AS $$
class TagLanguage:
    @staticmethod
    def analyze(t, lang_code):
        from dataclasses import dataclass
        from pyspark.sql.types import StructType, StructField, StringType
        from pyspark.sql.udtf import AnalyzeResult

        @dataclass
        class LangResult(AnalyzeResult):
            language: str = ""

        # Resolve the language code to a full name once during planning
        languages = {"en": "English", "es": "Spanish", "fr": "French", "de": "German"}
        return LangResult(
            schema=StructType([
                StructField("text", StringType()),
                StructField("language", StringType())
            ]),
            language=languages.get(lang_code.value, "Unknown")
        )

    def __init__(self, result):
        self._language = result.language

    def eval(self, row, lang_code: str):
        # Tag each row with the pre-resolved language name
        yield (row['text'], self._language)
$$;

SELECT * FROM tag_language(
  TABLE(VALUES ('Hola mundo'), ('Buenos días') t(text)),
  'es'
);
+-------------+----------+
| text        | language |
+-------------+----------+
| Hola mundo  | Spanish  |
| Buenos días | Spanish  |
+-------------+----------+

關於轉發狀態的更多模式與細節,請參見 「轉發狀態到未來 eval 通話」。

從方法 analyze 中指定分割

當多型態 UDTF 接受資料表參數時,analyze 方法可以透過在 AnalyzeResult 上設定 partitionByorderBywithSinglePartitionselect 來控制輸入列在 UDTF 實例間的分布。 這消除了呼叫者必須指定 PARTITION BYORDER BY 在 SQL 中的需求。

完整的分割 API 及範例,請參見指定來自analyze 方法的輸入列的分割

環境隔離

備註

共用隔離環境需要 Databricks Runtime 17.2 和更新版本。 在舊版中,所有 Unity 目錄 Python UDTF 都會以嚴格隔離模式執行。

根據預設,具有相同擁有者和工作階段的 Unity 目錄 Python UDTF 可以共用隔離環境。 這可藉由減少需要啟動的個別環境數目來改善效能並減少記憶體使用量。

嚴格隔離

若要確保 UDTF 一律在其自己的完全隔離環境中執行,請新增 STRICT ISOLATION 特性子句。

大多數 UDTF 不需要嚴格隔離。 標準資料處理 UDTF 受益於預設的共用隔離環境,並以較低的記憶體耗用量執行得更快。

STRICT ISOLATION 特性子句新增至符合以下條件的 UDTF:

  • 改為以代碼形式來運行輸入,使用eval()exec()或類似的函數。
  • 將檔案寫入本機檔案系統。
  • 修改全域變數或系統狀態。
  • 存取或修改環境變數。

下列 UDTF 範例會設定自訂環境變數、讀回變數,並使用變數將一組數字相乘。 因為 UDTF 會改變處理環境,所以在 STRICT ISOLATION 中執行它。 否則,它可能會洩漏或覆寫相同環境中其他 UDF/UDTF 的環境變數,導致不正確的行為。

CREATE OR REPLACE TEMPORARY FUNCTION multiply_numbers(factor STRING)
RETURNS TABLE (original INT, scaled INT)
LANGUAGE PYTHON
STRICT ISOLATION
HANDLER 'Multiplier'
AS $$
import os

class Multiplier:
    def eval(self, factor: str):
        # Save the factor as an environment variable
        os.environ["FACTOR"] = factor

        # Read it back and convert it to a number
        scale = int(os.getenv("FACTOR", "1"))

        # Multiply 0 through 4 by the factor
        for i in range(5):
            yield (i, i * scale)
$$;

SELECT * FROM multiply_numbers("3");

設定 DETERMINISTIC 以確保您的函數產生一致的結果

如果函數定義對相同的輸入產生相同的輸出,則在您的函數定義中新增 DETERMINISTIC。 這允許查詢優化以提高效能。

根據預設,除非明確宣告,否則會假設批次 Unity 目錄 Python UDTF 是不確定的。 非確定性函數的範例包括:產生隨機值、存取目前時間或日期,或進行外部 API 呼叫。

請參閱 CREATE FUNCTION (SQL 和 Python)。

實戰實例

下列範例示範 Unity 目錄 Python UDTF 的實際使用案例,從簡單的資料轉換進展到複雜的外部整合。

範例:重新實作 explode

雖然 Spark 提供內建 explode 函式,但建立您自己的版本會示範採用單一輸入並產生多個輸出資料列的基本 UDTF 模式。

CREATE OR REPLACE FUNCTION my_explode(arr ARRAY<STRING>)
RETURNS TABLE (element STRING)
LANGUAGE PYTHON
HANDLER 'MyExplode'
DETERMINISTIC
AS $$
class MyExplode:
    def eval(self, arr):
        if arr is None:
            return
        for element in arr:
            yield (element,)
$$;

直接在 SQL 查詢中使用函數:

SELECT element FROM my_explode(array('apple', 'banana', 'cherry'));
+---------+
| element |
+---------+
| apple   |
| banana  |
| cherry  |
+---------+

或透過LATERAL聯結將其套用至現有資料表:

SELECT s.*, e.element
FROM my_items AS s,
LATERAL my_explode(s.items) AS e;

範例:透過 REST API 進行 IP 位址的地理位置查找

此範例示範 UDTF 如何將外部 API 直接整合到 SQL 工作流程中。 分析師可以使用熟悉的 SQL 語法透過即時 API 呼叫來豐富資料,而無需單獨的 ETL 流程。

CREATE OR REPLACE FUNCTION ip_to_location(ip_address STRING)
RETURNS TABLE (city STRING, country STRING)
LANGUAGE PYTHON
HANDLER 'IPToLocationAPI'
AS $$
class IPToLocationAPI:
    def eval(self, ip_address):
        import requests
        api_url = f"https://api.ip-lookup.example.com/{ip_address}"
        try:
            response = requests.get(api_url)
            response.raise_for_status()
            data = response.json()
            yield (data.get('city'), data.get('country'))
        except requests.exceptions.RequestException as e:
            # Return nothing if the API request fails
            return
$$;

備註

Python UDTF 函數可在使用無伺服器運算或設置為標準存取模式的運算環境下,允許透過連接埠 80、443 和 53 進行 TCP/UDP 網路流量。

使用此功能以地理資訊豐富 Web 日誌資料:

SELECT
  l.timestamp,
  l.request_path,
  geo.city,
  geo.country
FROM web_logs AS l,
LATERAL ip_to_location(l.ip_address) AS geo;

這種方法可以實現即時地理分析,而無需預先處理的查找表或單獨的資料管道。 UDTF 處理 HTTP 請求、JSON 解析和錯誤處理,使外部資料來源可透過標準 SQL 查詢存取。

範例:將 IP 位址與 CIDR 網路區塊進行比對

此範例示範如何將 IP 位址與 CIDR 網路區塊進行比對,這是一項需要複雜 SQL 邏輯的常見資料工程任務。

首先,建立同時具有 IPv4 和 IPv6 位址的範例資料:

-- An example IP logs with both IPv4 and IPv6 addresses
CREATE OR REPLACE TEMPORARY VIEW ip_logs AS
VALUES
  ('log1', '192.168.1.100'),
  ('log2', '10.0.0.5'),
  ('log3', '172.16.0.10'),
  ('log4', '8.8.8.8'),
  ('log5', '2001:db8::1'),
  ('log6', '2001:db8:85a3::8a2e:370:7334'),
  ('log7', 'fe80::1'),
  ('log8', '::1'),
  ('log9', '2001:db8:1234:5678::1')
t(log_id, ip_address);

接下來,定義並註冊 UDTF。 請注意 Python 類別結構:

  • t TABLE 參數接受具有任何結構描述的輸入表格。 UDTF 會自動調整,以處理提供的任何欄位。 這種靈活性意味著您可以在不同的表中使用相同的函數,而無需修改函數簽名。 不過,您必須仔細檢查資料列的結構描述,以確保相容性。
  • __init__ 方法用於繁重的一次性設定,例如載入大型網路清單。 這項工作會在輸入表格的每個分割區中進行一次。
  • eval 方法會處理每一列,並包含核心比對邏輯。 這個方法只會針對輸入分割區中的每一列執行一次,而且每次執行都會由該分割區的 UDTF 類別的 IpMatcher 對應實例執行。
  • HANDLER 句會指定實作 UDTF 邏輯之 Python 類別的名稱。
CREATE OR REPLACE TEMPORARY FUNCTION ip_cidr_matcher(t TABLE)
RETURNS TABLE(log_id STRING, ip_address STRING, network STRING, ip_version INT)
LANGUAGE PYTHON
HANDLER 'IpMatcher'
COMMENT 'Match IP addresses against a list of network CIDR blocks'
AS $$
class IpMatcher:
    def __init__(self):
        import ipaddress
        # Heavy initialization - load networks once per partition
        self.nets = []
        cidrs = ['192.168.0.0/16', '10.0.0.0/8', '172.16.0.0/12',
                 '2001:db8::/32', 'fe80::/10', '::1/128']
        for cidr in cidrs:
            self.nets.append(ipaddress.ip_network(cidr))

    def eval(self, row):
        import ipaddress
	    # Validate that required fields exist
        required_fields = ['log_id', 'ip_address']
        for field in required_fields:
            if field not in row:
                raise ValueError(f"Missing required field: {field}")
        try:
            ip = ipaddress.ip_address(row['ip_address'])
            for net in self.nets:
                if ip in net:
                    yield (row['log_id'], row['ip_address'], str(net), ip.version)
                    return
            yield (row['log_id'], row['ip_address'], None, ip.version)
        except ValueError:
            yield (row['log_id'], row['ip_address'], 'Invalid', None)
$$;

現在 ip_cidr_matcher 已在 Unity Catalog 中註冊,可以從 SQL 使用 TABLE() 語法直接呼叫它:

-- Process all IP addresses
SELECT
  *
FROM
  ip_cidr_matcher(t => TABLE(ip_logs))
ORDER BY
  log_id;
+--------+-------------------------------+-----------------+-------------+
| log_id | ip_address                    | network         | ip_version  |
+--------+-------------------------------+-----------------+-------------+
| log1   | 192.168.1.100                 | 192.168.0.0/16  | 4           |
| log2   | 10.0.0.5                      | 10.0.0.0/8      | 4           |
| log3   | 172.16.0.10                   | 172.16.0.0/12   | 4           |
| log4   | 8.8.8.8                       | null            | 4           |
| log5   | 2001:db8::1                   | 2001:db8::/32   | 6           |
| log6   | 2001:db8:85a3::8a2e:370:7334  | 2001:db8::/32   | 6           |
| log7   | fe80::1                       | fe80::/10       | 6           |
| log8   | ::1                           | ::1/128         | 6           |
| log9   | 2001:db8:1234:5678::1         | 2001:db8::/32   | 6           |
+--------+-------------------------------+-----------------+-------------+

範例:使用 Azure Databricks 視覺端點進行批次影像描述

此範例示範使用 Azure Databricks 視覺模型服務端點進行批次圖像說明。 它展示了如何使用 terminate() 進行批處理和分區基礎的執行。

  1. 建立具有公開影像URL的表格:

    CREATE OR REPLACE TEMPORARY VIEW sample_images AS
    VALUES
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg', 'scenery'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Camponotus_flavomarginatus_ant.jpg/1024px-Camponotus_flavomarginatus_ant.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Cat_August_2010-4.jpg/1200px-Cat_August_2010-4.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/M101_hires_STScI-PRC2006-10a.jpg/1024px-M101_hires_STScI-PRC2006-10a.jpg', 'scenery')
    images(image_url, category);
    
  2. 建立 Unity 目錄 Python UDTF 以產生影像標題:

    1. 使用設定初始化 UDTF,包括批次大小、Azure Databricks API 令牌、視覺模型端點以及工作區 URL。
    2. 在方法 eval 中,將影像 URL 收集到緩衝區中。 當緩衝區達到批次大小時,觸發批次處理。 這可確保在單一 API 呼叫中一起處理多個影像,而不是每個影像的個別呼叫。
    3. 在批次處理方法中,下載所有緩衝的影像,將它們編碼為 base64,然後將它們傳送至 Databricks VisionModel 的單一 API 要求。 模型會同時處理所有影像,並傳回整個批次的標題。
    4. terminate 方法在每個分割區的結尾只執行一次。 在 terminate 方法中,處理緩衝區中任何剩餘的影像,將收集到的所有標題作為結果返回。

備註

<workspace-url> 替換為您實際的 Azure Databricks 工作區 URL(https://your-workspace.cloud.databricks.com)。

CREATE OR REPLACE TEMPORARY FUNCTION batch_inference_image_caption(data TABLE, api_token STRING)
RETURNS TABLE (caption STRING)
LANGUAGE PYTHON
HANDLER 'BatchInferenceImageCaption'
COMMENT 'batch image captioning by sending groups of image URLs to a Databricks vision endpoint and returning concise captions for each image.'
AS $$
class BatchInferenceImageCaption:
    def __init__(self):
        self.batch_size = 3
        self.vision_endpoint = "databricks-claude-sonnet-4-5"
        self.workspace_url = "<workspace-url>"
        self.image_buffer = []
        self.results = []

    def eval(self, row, api_token):
        self.image_buffer.append((str(row[0]), api_token))
        if len(self.image_buffer) >= self.batch_size:
            self._process_batch()

    def terminate(self):
        if self.image_buffer:
            self._process_batch()
        for caption in self.results:
            yield (caption,)

    def _process_batch(self):
        batch_data = self.image_buffer.copy()
        self.image_buffer.clear()

        import base64
        import httpx
        import requests

        # API request timeout in seconds
        api_timeout = 60
        # Maximum tokens for vision model response
        max_response_tokens = 300
        # Temperature controls randomness (lower = more deterministic)
        model_temperature = 0.3

        # create a batch for the images
        batch_images = []
        api_token = batch_data[0][1] if batch_data else None

        for image_url, _ in batch_data:
            image_response = httpx.get(image_url, timeout=15)
            image_data = base64.standard_b64encode(image_response.content).decode("utf-8")
            batch_images.append(image_data)

        content_items = [{
            "type": "text",
            "text": "Provide brief captions for these images, one per line."
        }]
        for img_data in batch_images:
            content_items.append({
                "type": "image_url",
                "image_url": {
                    "url": "data:image/jpeg;base64," + img_data
                }
            })

        payload = {
            "messages": [{
                "role": "user",
                "content": content_items
            }],
            "max_tokens": max_response_tokens,
            "temperature": model_temperature
        }

        response = requests.post(
            self.workspace_url + "/serving-endpoints/" +
            self.vision_endpoint + "/invocations",
            headers={
                'Authorization': 'Bearer ' + api_token,
                'Content-Type': 'application/json'
            },
            json=payload,
            timeout=api_timeout
        )

        result = response.json()
        batch_response = result['choices'][0]['message']['content'].strip()

        lines = batch_response.split('\n')
        captions = [line.strip() for line in lines if line.strip()]

        while len(captions) < len(batch_data):
            captions.append(batch_response)

        self.results.extend(captions[:len(batch_data)])
$$;

若要使用批次影像標題 UDTF,請使用範例影像資料表呼叫它:

備註

your_secret_scopeapi_token 替換為 Databricks API 權杖的實際秘密範圍和金鑰名稱。

SELECT
  caption
FROM
  batch_inference_image_caption(
    data => TABLE(sample_images),
    api_token => secret('your_secret_scope', 'api_token')
  )
+---------------------------------------------------------------------------------------------------------------+
| caption                                                                                                       |
+---------------------------------------------------------------------------------------------------------------+
| Wooden boardwalk cutting through vibrant wetland grasses under blue skies                                     |
| Black ant in detailed macro photography standing on a textured surface                                        |
| Tabby cat lounging comfortably on a white ledge against a white wall                                          |
| Stunning spiral galaxy with bright central core and sweeping blue-white arms against the black void of space. |
+---------------------------------------------------------------------------------------------------------------+

您還可以按類別生成圖像標題類別:

SELECT
  *
FROM
  batch_inference_image_caption(
    TABLE(sample_images)
    PARTITION BY category ORDER BY (category),
    secret('your_secret_scope', 'api_token')
  )
+------------------------------------------------------------------------------------------------------+
| caption                                                                                              |
+------------------------------------------------------------------------------------------------------+
| Black ant in detailed macro photography standing on a textured surface                               |
| Stunning spiral galaxy with bright center and sweeping blue-tinged arms against the black of space.  |
| Tabby cat lounging comfortably on white ledge against white wall                                     |
| Wooden boardwalk cutting through lush wetland grasses under blue skies                               |
+------------------------------------------------------------------------------------------------------+

範例:用於 ML 模型評估的 ROC 曲線和 AUC 計算

此範例示範使用 scikit-learn 計算接收者工作特性 (ROC) 曲線和曲線下面積 (AUC) 分數,以評估二元分類模型。

此範例展示了幾個重要的模式:

  • 外部庫使用: 集成 scikit-learn 用於 ROC 曲線計算
  • 有狀態聚合:將所有行的預測累積起來後再計算指標
  • terminate() 方法用法:處理完整的資料集,只有在評估完所有資料列後才產生結果
  • 錯誤處理:驗證輸入表格中是否存在必要的資料行

UDTF 使用該 eval() 方法在記憶體中累積所有預測,然後計算並產生該 terminate() 方法中的完整 ROC 曲線。 此模式對於需要完整資料集進行計算的計量很有用。

CREATE OR REPLACE TEMPORARY FUNCTION compute_roc_curve(t TABLE)
RETURNS TABLE (threshold DOUBLE, true_positive_rate DOUBLE, false_positive_rate DOUBLE, auc DOUBLE)
LANGUAGE PYTHON
HANDLER 'ROCCalculator'
COMMENT 'Compute ROC curve and AUC using scikit-learn'
AS $$
class ROCCalculator:
    def __init__(self):
        from sklearn import metrics
        self._roc_curve = metrics.roc_curve
        self._roc_auc_score = metrics.roc_auc_score

        self._true_labels = []
        self._predicted_scores = []

    def eval(self, row):
        if 'y_true' not in row or 'y_score' not in row:
            raise KeyError("Required columns 'y_true' and 'y_score' not found")

        true_label = row['y_true']
        predicted_score = row['y_score']

        label = float(true_label)
        self._true_labels.append(label)
        self._predicted_scores.append(float(predicted_score))

    def terminate(self):
        false_pos_rate, true_pos_rate, thresholds = self._roc_curve(
            self._true_labels,
            self._predicted_scores,
            drop_intermediate=False
        )

        auc_score = float(self._roc_auc_score(self._true_labels, self._predicted_scores))

        for threshold, tpr, fpr in zip(thresholds, true_pos_rate, false_pos_rate):
            yield float(threshold), float(tpr), float(fpr), auc_score
$$;

使用預測建立範例二進位分類資料:

CREATE OR REPLACE TEMPORARY VIEW binary_classification_data AS
SELECT *
FROM VALUES
  ( 1, 1.0, 0.95, 'high_confidence_positive'),
  ( 2, 1.0, 0.87, 'high_confidence_positive'),
  ( 3, 1.0, 0.82, 'medium_confidence_positive'),
  ( 4, 0.0, 0.78, 'false_positive'),
  ( 5, 1.0, 0.71, 'medium_confidence_positive'),
  ( 6, 0.0, 0.65, 'false_positive'),
  ( 7, 0.0, 0.58, 'true_negative'),
  ( 8, 1.0, 0.52, 'low_confidence_positive'),
  ( 9, 0.0, 0.45, 'true_negative'),
  (10, 0.0, 0.38, 'true_negative'),
  (11, 1.0, 0.31, 'low_confidence_positive'),
  (12, 0.0, 0.15, 'true_negative'),
  (13, 0.0, 0.08, 'high_confidence_negative'),
  (14, 0.0, 0.03, 'high_confidence_negative')
AS data(sample_id, y_true, y_score, prediction_type);

計算 ROC 曲線和 AUC:

SELECT
    threshold,
    true_positive_rate,
    false_positive_rate,
    auc
FROM compute_roc_curve(
  TABLE(
    SELECT y_true, y_score
    FROM binary_classification_data
    WHERE y_true IS NOT NULL AND y_score IS NOT NULL
    ORDER BY sample_id
  )
)
ORDER BY threshold DESC;
+-----------+---------------------+----------------------+-------+
| threshold | true_positive_rate  | false_positive_rate  | auc   |
+-----------+---------------------+----------------------+-------+
| 1.95      | 0.0                 | 0.0                  | 0.786 |
| 0.95      | 0.167               | 0.0                  | 0.786 |
| 0.87      | 0.333               | 0.0                  | 0.786 |
| 0.82      | 0.5                 | 0.0                  | 0.786 |
| 0.78      | 0.5                 | 0.125                | 0.786 |
| 0.71      | 0.667               | 0.125                | 0.786 |
| 0.65      | 0.667               | 0.25                 | 0.786 |
| 0.58      | 0.667               | 0.375                | 0.786 |
| 0.52      | 0.833               | 0.375                | 0.786 |
| 0.45      | 0.833               | 0.5                  | 0.786 |
| 0.38      | 0.833               | 0.625                | 0.786 |
| 0.31      | 1.0                 | 0.625                | 0.786 |
| 0.15      | 1.0                 | 0.75                 | 0.786 |
| 0.08      | 1.0                 | 0.875                | 0.786 |
| 0.03      | 1.0                 | 1.0                  | 0.786 |
+-----------+---------------------+----------------------+-------+

範例:從表格參數進行動態欄位投影

此範例結合多型 UDTF 與表格參數。 UDTF 接受一個表格及逗號分隔的欄位名稱清單,然後只投影輸入中的欄位。 該 analyze 方法檢查輸入資料表的結構,並建立僅包含所請求欄位的輸出結構。

CREATE OR REPLACE FUNCTION project_columns(t TABLE, columns STRING)
RETURNS TABLE
LANGUAGE PYTHON
HANDLER 'ProjectColumns'
AS $$
class ProjectColumns:
    @staticmethod
    def analyze(t, columns):
        from pyspark.sql.types import StructType
        from pyspark.sql.udtf import AnalyzeResult

        requested = [c.strip() for c in columns.value.split(",")]
        input_schema = t.dataType
        output_fields = []
        for field in input_schema.fields:
            if field.name in requested:
                output_fields.append(field)
        if not output_fields:
            raise ValueError(
                f"None of the requested columns {requested} "
                f"exist in the input table"
            )
        return AnalyzeResult(schema=StructType(output_fields))

    def eval(self, row, columns: str):
        requested = [c.strip() for c in columns.split(",")]
        yield tuple(row[col] for col in requested if col in row)
$$;

使用這個函式從表格中選擇特定欄位:

SELECT * FROM project_columns(
  TABLE(SELECT * FROM samples.nyctaxi.trips LIMIT 5),
  'pickup_zip, dropoff_zip, fare_amount'
);

局限性

以下限制適用於 Unity Catalog Python UDTFs:

後續步驟