Teilen über


Benutzerdefinierte Python-Tabellenfunktionen (UDTFs) im Unity-Katalog

Von Bedeutung

Das Registrieren von Python UDTFs im Unity-Katalog befindet sich in der öffentlichen Vorschau.

Eine benutzerdefinierte Tabellenfunktion (Unity Catalog, UDTF) registriert Funktionen, die vollständige Tabellen anstelle von Skalarwerten zurückgeben. Im Gegensatz zu skalaren Funktionen, die einen einzelnen Ergebniswert aus jedem Aufruf zurückgeben, werden UDTFs in einer SQL-Anweisungsklausel FROM aufgerufen und können mehrere Zeilen und Spalten zurückgeben.

UDTFs sind besonders nützlich für:

  • Transformieren von Arrays oder komplexen Datenstrukturen in mehrere Zeilen
  • Integrieren externer APIs oder Dienste in SQL-Workflows
  • Implementieren von benutzerdefinierter Datengenerierungs- oder Anreicherungslogik
  • Verarbeiten von Daten, die zustandsbehaftete Vorgänge über Zeilen hinweg erfordern

Jeder UDTF-Aufruf akzeptiert null oder mehr Argumente. Diese Argumente können skalare Ausdrücke oder Tabellenargumente sein, die ganze Eingabetabellen darstellen.

UDTFs können auf zwei Arten registriert werden:

  • Unity-Katalog: Registrieren Sie den UDTF als geregeltes Objekt im Unity-Katalog.
  • Sitzungsbereich: Registrieren Sie sich für das lokale SparkSession, isolierte Notizbuch oder den aktuellen Auftrag. Siehe python user-defined table functions (UDTFs).

Anforderungen

Unity Catalog Python UDTFs werden für die folgenden Computetypen unterstützt:

  • Klassische Compute mit Standardzugriffsmodus (Databricks Runtime 17.1 und höher)
  • SQL Warehouse (Serverless oder Pro)

Erstellen eines UDTF im Unity-Katalog

Verwenden Sie SQL DDL, um eine gesteuerte UDTF im Unity Catalog zu erstellen. UDTFs werden mithilfe einer SQL-Anweisungsklausel FROM aufgerufen.

CREATE OR REPLACE FUNCTION square_numbers(start INT, end INT)
RETURNS TABLE (num INT, squared INT)
LANGUAGE PYTHON
HANDLER 'SquareNumbers'
DETERMINISTIC
AS $$
class SquareNumbers:
    """
    Basic UDTF that computes a sequence of integers
    and includes the square of each number in the range.
    """
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)
$$;

SELECT * FROM square_numbers(1, 5);

+-----+---------+
| num | squared |
+-----+---------+
| 1   | 1       |
| 2   | 4       |
| 3   | 9       |
| 4   | 16      |
| 5   | 25      |
+-----+---------+

Azure Databricks implementiert Python UDTFs als Python-Klassen mit einer obligatorischen eval Methode, die Ausgabezeilen liefert.

Tabellenargumente

Hinweis

TABLE Argumente werden in Databricks Runtime 17.2 und höher unterstützt.

UDTFs können ganze Tabellen als Eingabeargumente akzeptieren und komplexe zustandsbehaftete Transformationen und Aggregationen ermöglichen.

eval() und terminate() Lebenszyklusmethoden

Tabellenargumente in UDTFs verwenden die folgenden Funktionen, um jede Zeile zu verarbeiten:

  • eval(): Wird einmal für jede Zeile in der Eingabetabelle aufgerufen. Dies ist die Hauptverarbeitungsmethode und ist erforderlich.
  • terminate(): Wird einmal am Ende jeder Partition aufgerufen, nachdem alle Zeilen durch eval() verarbeitet wurden. Verwenden Sie diese Methode, um endgültige aggregierte Ergebnisse zu erzielen oder Bereinigungsvorgänge auszuführen. Diese Methode ist optional, aber für zustandsbehaftete Vorgänge wie Aggregationen, Zählen oder Batchverarbeitung unerlässlich.

Weitere Informationen über eval() und terminate() Methoden finden Sie in der Apache Spark-Dokumentation: Python UDTF.

Zeilenzugriffsmuster

eval() empfängt Zeilen von TABLE Argumenten als pyspark.sql.Row-Objekte. Sie können auf Werte nach Spaltenname (row['id'], row['name']) oder nach Index (row[0], row[1]) zugreifen.

  • Schemaflexibilität: Deklarieren von TABLE Argumenten ohne Schemadefinitionen (z. B data TABLE. , t TABLE). Die Funktion akzeptiert eine beliebige Tabellenstruktur, sodass ihr Code überprüfen sollte, ob erforderliche Spalten vorhanden sind.

Siehe Beispiel: Abgleichen von IP-Adressen mit CIDR-Netzwerkblöcken und Beispiel: Batchbildbeschriftung mithilfe von Azure Databricks-Vision-Endpunkten.

Umgebungsisolation

Hinweis

Freigegebene Isolationsumgebungen erfordern Databricks Runtime 17.2 und höher. In früheren Versionen werden alle Unity Catalog Python UDTFs im strikten Isolationsmodus ausgeführt.

Unity Catalog Python UDTFs mit demselben Besitzer und derselben Sitzung können standardmäßig eine Isolationsumgebung verwenden. Dies verbessert die Leistung und verringert die Speicherauslastung, indem die Anzahl der separaten Umgebungen reduziert wird, die gestartet werden müssen.

Strenge Isolierung

Um sicherzustellen, dass ein UDTF immer in einer eigenen, vollständig isolierten Umgebung ausgeführt wird, fügen Sie die STRICT ISOLATION Merkmalsklausel hinzu.

Die meisten UDTFs benötigen keine strenge Isolierung. Standarddatenverarbeitungs-UDTFs profitieren von der standardmäßigen freigegebenen Isolationsumgebung und können schneller mit geringerer Arbeitsspeicherauslastung ausgeführt werden.

Fügen Sie die STRICT ISOLATION Merkmalsklausel zu UDTFs hinzu, die:

  • Führen Sie Eingaben als Code mithilfe eval()von , exec()oder ähnlichen Funktionen aus.
  • Schreiben Sie Dateien in das lokale Dateisystem.
  • Ändern sie globale Variablen oder den Systemstatus.
  • Zugreifen oder Ändern von Umgebungsvariablen

Im folgenden UDTF-Beispiel wird eine benutzerdefinierte Umgebungsvariable festgelegt, die Variable zurückgelesen und eine Reihe von Zahlen mithilfe der Variablen multipliziert. Da die UDTF die Prozessumgebung stummschaltet, führen Sie sie in STRICT ISOLATION. Andernfalls könnte es Umgebungsvariablen für andere UDFs/UDTFs in derselben Umgebung durchlecken oder außer Kraft setzen, was zu einem falschen Verhalten führt.

CREATE OR REPLACE TEMPORARY FUNCTION multiply_numbers(factor STRING)
RETURNS TABLE (original INT, scaled INT)
LANGUAGE PYTHON
STRICT ISOLATION
HANDLER 'Multiplier'
AS $$
import os

class Multiplier:
    def eval(self, factor: str):
        # Save the factor as an environment variable
        os.environ["FACTOR"] = factor

        # Read it back and convert it to a number
        scale = int(os.getenv("FACTOR", "1"))

        # Multiply 0 through 4 by the factor
        for i in range(5):
            yield (i, i * scale)
$$;

SELECT * FROM multiply_numbers("3");

Festlegen DETERMINISTIC , ob Ihre Funktion konsistente Ergebnisse erzeugt

Fügen Sie DETERMINISTIC ihrer Funktionsdefinition hinzu, wenn sie dieselben Ausgaben für die gleichen Eingaben erzeugt. Dadurch können Abfrageoptimierungen die Leistung verbessern.

Standardmäßig wird angenommen, dass Batch Unity Catalog Python UDTFs nicht deterministisch ist, es sei denn, es wird explizit deklariert. Beispiele für nicht deterministische Funktionen sind: Generieren von Zufallswerten, Zugreifen auf aktuelle Uhrzeiten oder Datumsangaben oder Durchführen externer API-Aufrufe.

Siehe CREATE FUNCTION (SQL und Python).

Praktische Beispiele

Die folgenden Beispiele veranschaulichen reale Anwendungsfälle für Unity Catalog Python UDTFs, die von einfachen Datentransformationen bis hin zu komplexen externen Integrationen vorankommen.

Beispiel: Erneute Implementierung explode

Während Spark eine integrierte Funktion bereitstellt, veranschaulicht das Erstellen Einer eigenen Version das grundlegende UDTF-Muster explode , bei dem eine einzelne Eingabe ausgeführt und mehrere Ausgabezeilen erzeugt werden.

CREATE OR REPLACE FUNCTION my_explode(arr ARRAY<STRING>)
RETURNS TABLE (element STRING)
LANGUAGE PYTHON
HANDLER 'MyExplode'
DETERMINISTIC
AS $$
class MyExplode:
    def eval(self, arr):
        if arr is None:
            return
        for element in arr:
            yield (element,)
$$;

Verwenden Sie die Funktion direkt in einer SQL-Abfrage:

SELECT element FROM my_explode(array('apple', 'banana', 'cherry'));
+---------+
| element |
+---------+
| apple   |
| banana  |
| cherry  |
+---------+

Oder wenden Sie sie auf vorhandene Tabellendaten mit einer LATERAL Verknüpfung an:

SELECT s.*, e.element
FROM my_items AS s,
LATERAL my_explode(s.items) AS e;

Beispiel: IP-Adress-Geolocation über REST-API

In diesem Beispiel wird veranschaulicht, wie UDTFs externe APIs direkt in Ihren SQL-Workflow integrieren können. Analysten können Daten mit Echtzeit-API-Aufrufen mit vertrauter SQL-Syntax anreichern, ohne dass separate ETL-Prozesse erforderlich sind.

CREATE OR REPLACE FUNCTION ip_to_location(ip_address STRING)
RETURNS TABLE (city STRING, country STRING)
LANGUAGE PYTHON
HANDLER 'IPToLocationAPI'
AS $$
class IPToLocationAPI:
    def eval(self, ip_address):
        import requests
        api_url = f"https://api.ip-lookup.example.com/{ip_address}"
        try:
            response = requests.get(api_url)
            response.raise_for_status()
            data = response.json()
            yield (data.get('city'), data.get('country'))
        except requests.exceptions.RequestException as e:
            # Return nothing if the API request fails
            return
$$;

Hinweis

Python UDTFs ermöglichen TCP/UDP-Netzwerkdatenverkehr über ports 80, 443 und 53, wenn serverlose Compute- oder Compute-Konfiguration mit standardzugriffsmodus konfiguriert ist.

Verwenden Sie die Funktion, um Webprotokolldaten mit geografischen Informationen zu bereichern:

SELECT
  l.timestamp,
  l.request_path,
  geo.city,
  geo.country
FROM web_logs AS l,
LATERAL ip_to_location(l.ip_address) AS geo;

Dieser Ansatz ermöglicht die geografische Analyse in Echtzeit, ohne dass vorverarbeitete Nachschlagetabellen oder separate Datenpipelinen erforderlich sind. UdTF verarbeitet HTTP-Anforderungen, JSON-Analyse und Fehlerbehandlung, wodurch externe Datenquellen über STANDARD-SQL-Abfragen zugänglich sind.

Beispiel: Abgleichen von IP-Adressen mit CIDR-Netzwerkblöcken

In diesem Beispiel wird der Abgleich von IP-Adressen mit CIDR-Netzwerkblöcken veranschaulicht, einer gängigen Data Engineering-Aufgabe, die komplexe SQL-Logik erfordert.

Erstellen Sie zunächst Beispieldaten mit IPv4- und IPv6-Adressen:

-- An example IP logs with both IPv4 and IPv6 addresses
CREATE OR REPLACE TEMPORARY VIEW ip_logs AS
VALUES
  ('log1', '192.168.1.100'),
  ('log2', '10.0.0.5'),
  ('log3', '172.16.0.10'),
  ('log4', '8.8.8.8'),
  ('log5', '2001:db8::1'),
  ('log6', '2001:db8:85a3::8a2e:370:7334'),
  ('log7', 'fe80::1'),
  ('log8', '::1'),
  ('log9', '2001:db8:1234:5678::1')
t(log_id, ip_address);

Definieren und registrieren Sie als nächstes die UDTF. Beachten Sie die Python-Klassenstruktur:

  • Der t TABLE Parameter akzeptiert eine Eingabetabelle mit einem beliebigen Schema. Die UDTF passt sich automatisch an, um die bereitgestellten Spalten zu verarbeiten. Diese Flexibilität bedeutet, dass Sie dieselbe Funktion in verschiedenen Tabellen verwenden können, ohne die Funktionssignatur zu ändern. Sie müssen jedoch das Schema der Zeilen sorgfältig überprüfen, um die Kompatibilität sicherzustellen.
  • Die __init__ Methode wird für umfangreiche einmalige Einrichtung verwendet, z. B. das Laden der großen Netzwerkliste. Diese Arbeit erfolgt einmal pro Partition der Eingabetabelle.
  • Die eval Methode verarbeitet jede Zeile und enthält die Kernabgleichslogik. Diese Methode wird genau einmal für jede Zeile in der Eingabepartition ausgeführt, und jede Ausführung wird von der entsprechenden Instanz der IpMatcher UDTF-Klasse für diese Partition ausgeführt.
  • Die HANDLER Klausel gibt den Namen der Python-Klasse an, die die UDTF-Logik implementiert.
CREATE OR REPLACE TEMPORARY FUNCTION ip_cidr_matcher(t TABLE)
RETURNS TABLE(log_id STRING, ip_address STRING, network STRING, ip_version INT)
LANGUAGE PYTHON
HANDLER 'IpMatcher'
COMMENT 'Match IP addresses against a list of network CIDR blocks'
AS $$
class IpMatcher:
    def __init__(self):
        import ipaddress
        # Heavy initialization - load networks once per partition
        self.nets = []
        cidrs = ['192.168.0.0/16', '10.0.0.0/8', '172.16.0.0/12',
                 '2001:db8::/32', 'fe80::/10', '::1/128']
        for cidr in cidrs:
            self.nets.append(ipaddress.ip_network(cidr))

    def eval(self, row):
        import ipaddress
	    # Validate that required fields exist
        required_fields = ['log_id', 'ip_address']
        for field in required_fields:
            if field not in row:
                raise ValueError(f"Missing required field: {field}")
        try:
            ip = ipaddress.ip_address(row['ip_address'])
            for net in self.nets:
                if ip in net:
                    yield (row['log_id'], row['ip_address'], str(net), ip.version)
                    return
            yield (row['log_id'], row['ip_address'], None, ip.version)
        except ValueError:
            yield (row['log_id'], row['ip_address'], 'Invalid', None)
$$;

ip_cidr_matcher Nachdem sie im Unity-Katalog registriert ist, rufen Sie sie direkt aus SQL mithilfe der TABLE() Syntax auf:

-- Process all IP addresses
SELECT
  *
FROM
  ip_cidr_matcher(t => TABLE(ip_logs))
ORDER BY
  log_id;
+--------+-------------------------------+-----------------+-------------+
| log_id | ip_address                    | network         | ip_version  |
+--------+-------------------------------+-----------------+-------------+
| log1   | 192.168.1.100                 | 192.168.0.0/16  | 4           |
| log2   | 10.0.0.5                      | 10.0.0.0/8      | 4           |
| log3   | 172.16.0.10                   | 172.16.0.0/12   | 4           |
| log4   | 8.8.8.8                       | null            | 4           |
| log5   | 2001:db8::1                   | 2001:db8::/32   | 6           |
| log6   | 2001:db8:85a3::8a2e:370:7334  | 2001:db8::/32   | 6           |
| log7   | fe80::1                       | fe80::/10       | 6           |
| log8   | ::1                           | ::1/128         | 6           |
| log9   | 2001:db8:1234:5678::1         | 2001:db8::/32   | 6           |
+--------+-------------------------------+-----------------+-------------+

Beispiel: Batchbildbeschriftung mithilfe von Azure Databricks-Vision-Endpunkten

In diesem Beispiel wird die Batchbildbeschriftung mithilfe eines Azure Databricks-Visionsmodells veranschaulicht, das den Endpunkt bedient. Es wird gezeigt, wie terminate() für die Batchverarbeitung und partitionsbasierte Ausführung verwendet wird.

  1. Erstellen einer Tabelle mit öffentlichen Bild-URLs:

    CREATE OR REPLACE TEMPORARY VIEW sample_images AS
    VALUES
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg', 'scenery'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Camponotus_flavomarginatus_ant.jpg/1024px-Camponotus_flavomarginatus_ant.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Cat_August_2010-4.jpg/1200px-Cat_August_2010-4.jpg', 'animals'),
        ('https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/M101_hires_STScI-PRC2006-10a.jpg/1024px-M101_hires_STScI-PRC2006-10a.jpg', 'scenery')
    images(image_url, category);
    
  2. Erstellen Sie einen Unity-Katalog Python UDTF, um Bildbeschriftungen zu generieren:

    1. Initialisieren Sie die UDTF mit der Konfiguration, einschließlich Batchgröße, Azure Databricks-API-Token, Visionmodell-Endpunkt und Arbeitsbereichs-URL.
    2. Sammeln Sie in der eval Methode die Bild-URLs in einem Puffer. Wenn der Puffer die Batchgröße erreicht, lösen Sie die Batchverarbeitung aus. Dadurch wird sichergestellt, dass mehrere Bilder in einem einzelnen API-Aufruf und nicht in einzelnen Aufrufen pro Bild verarbeitet werden.
    3. Laden Sie in der Batchverarbeitungsmethode alle gepufferten Bilder herunter, codieren Sie sie als Base64, und senden Sie sie an eine einzelne API-Anforderung an Databricks VisionModel. Das Modell verarbeitet alle Bilder gleichzeitig und gibt Beschriftungen für den gesamten Batch zurück.
    4. Die terminate Methode wird genau einmal am Ende jeder Partition ausgeführt. Verarbeiten Sie in der terminate-Methode alle verbleibenden Bilder im Puffer und geben Sie alle gesammelten Beschriftungen als Ergebnisse aus.

Hinweis

Ersetzen Sie sie <workspace-url> durch Ihre tatsächliche Azure Databricks-Arbeitsbereichs-URL (https://your-workspace.cloud.databricks.com).

CREATE OR REPLACE TEMPORARY FUNCTION batch_inference_image_caption(data TABLE, api_token STRING)
RETURNS TABLE (caption STRING)
LANGUAGE PYTHON
HANDLER 'BatchInferenceImageCaption'
COMMENT 'batch image captioning by sending groups of image URLs to a Databricks vision endpoint and returning concise captions for each image.'
AS $$
class BatchInferenceImageCaption:
    def __init__(self):
        self.batch_size = 3
        self.vision_endpoint = "databricks-claude-sonnet-4-5"
        self.workspace_url = "<workspace-url>"
        self.image_buffer = []
        self.results = []

    def eval(self, row, api_token):
        self.image_buffer.append((str(row[0]), api_token))
        if len(self.image_buffer) >= self.batch_size:
            self._process_batch()

    def terminate(self):
        if self.image_buffer:
            self._process_batch()
        for caption in self.results:
            yield (caption,)

    def _process_batch(self):
        batch_data = self.image_buffer.copy()
        self.image_buffer.clear()

        import base64
        import httpx
        import requests

        # API request timeout in seconds
        api_timeout = 60
        # Maximum tokens for vision model response
        max_response_tokens = 300
        # Temperature controls randomness (lower = more deterministic)
        model_temperature = 0.3

        # create a batch for the images
        batch_images = []
        api_token = batch_data[0][1] if batch_data else None

        for image_url, _ in batch_data:
            image_response = httpx.get(image_url, timeout=15)
            image_data = base64.standard_b64encode(image_response.content).decode("utf-8")
            batch_images.append(image_data)

        content_items = [{
            "type": "text",
            "text": "Provide brief captions for these images, one per line."
        }]
        for img_data in batch_images:
            content_items.append({
                "type": "image_url",
                "image_url": {
                    "url": "data:image/jpeg;base64," + img_data
                }
            })

        payload = {
            "messages": [{
                "role": "user",
                "content": content_items
            }],
            "max_tokens": max_response_tokens,
            "temperature": model_temperature
        }

        response = requests.post(
            self.workspace_url + "/serving-endpoints/" +
            self.vision_endpoint + "/invocations",
            headers={
                'Authorization': 'Bearer ' + api_token,
                'Content-Type': 'application/json'
            },
            json=payload,
            timeout=api_timeout
        )

        result = response.json()
        batch_response = result['choices'][0]['message']['content'].strip()

        lines = batch_response.split('\n')
        captions = [line.strip() for line in lines if line.strip()]

        while len(captions) < len(batch_data):
            captions.append(batch_response)

        self.results.extend(captions[:len(batch_data)])
$$;

Um die Batchbildbeschriftung UDTF zu verwenden, rufen Sie sie mithilfe der Beispielbildtabelle auf:

Hinweis

Ersetzen Sie your_secret_scope und api_token durch den tatsächlichen Secret Scope und Schlüsselnamen für das Databricks-API-Token.

SELECT
  caption
FROM
  batch_inference_image_caption(
    data => TABLE(sample_images),
    api_token => secret('your_secret_scope', 'api_token')
  )
+---------------------------------------------------------------------------------------------------------------+
| caption                                                                                                       |
+---------------------------------------------------------------------------------------------------------------+
| Wooden boardwalk cutting through vibrant wetland grasses under blue skies                                     |
| Black ant in detailed macro photography standing on a textured surface                                        |
| Tabby cat lounging comfortably on a white ledge against a white wall                                          |
| Stunning spiral galaxy with bright central core and sweeping blue-white arms against the black void of space. |
+---------------------------------------------------------------------------------------------------------------+

Sie können auch die Kategorie "Bildbeschriftungen" nach Kategorie generieren:

SELECT
  *
FROM
  batch_inference_image_caption(
    TABLE(sample_images)
    PARTITION BY category ORDER BY (category),
    secret('your_secret_scope', 'api_token')
  )
+------------------------------------------------------------------------------------------------------+
| caption                                                                                              |
+------------------------------------------------------------------------------------------------------+
| Black ant in detailed macro photography standing on a textured surface                               |
| Stunning spiral galaxy with bright center and sweeping blue-tinged arms against the black of space.  |
| Tabby cat lounging comfortably on white ledge against white wall                                     |
| Wooden boardwalk cutting through lush wetland grasses under blue skies                               |
+------------------------------------------------------------------------------------------------------+

Beispiel: ROC-Kurve und AUC-Berechnung für die ML-Modellauswertung

In diesem Beispiel wird das Berechnen von ROC-Kurven (Receiver Operating Characteristic) und der Fläche unter der Kurve (AUC) für die Bewertung binärer Klassifikationsmodelle mithilfe von scikit-learn veranschaulicht.

In diesem Beispiel werden mehrere wichtige Muster vorgestellt:

  • Verwendung externer Bibliotheken: Integration von Scikit-Learn für die Berechnung von ROC-Kurven
  • Zustandsbehaftete Aggregation: Sammelt Vorhersagen über alle Zeilen hinweg, bevor Metriken berechnet werden.
  • terminate() Methodennutzung: Verarbeitet das vollständige Dataset und liefert nur Ergebnisse, nachdem alle Zeilen ausgewertet wurden.
  • Fehlerbehandlung: Überprüfen der erforderlichen Spalten in der Eingabetabelle

Die UDTF sammelt alle Vorhersagen im Arbeitsspeicher mithilfe der eval() Methode, berechnet und liefert dann die vollständige ROC-Kurve in der terminate() Methode. Dieses Muster ist nützlich für Metriken, die das vollständige Dataset für die Berechnung erfordern.

CREATE OR REPLACE TEMPORARY FUNCTION compute_roc_curve(t TABLE)
RETURNS TABLE (threshold DOUBLE, true_positive_rate DOUBLE, false_positive_rate DOUBLE, auc DOUBLE)
LANGUAGE PYTHON
HANDLER 'ROCCalculator'
COMMENT 'Compute ROC curve and AUC using scikit-learn'
AS $$
class ROCCalculator:
    def __init__(self):
        from sklearn import metrics
        self._roc_curve = metrics.roc_curve
        self._roc_auc_score = metrics.roc_auc_score

        self._true_labels = []
        self._predicted_scores = []

    def eval(self, row):
        if 'y_true' not in row or 'y_score' not in row:
            raise KeyError("Required columns 'y_true' and 'y_score' not found")

        true_label = row['y_true']
        predicted_score = row['y_score']

        label = float(true_label)
        self._true_labels.append(label)
        self._predicted_scores.append(float(predicted_score))

    def terminate(self):
        false_pos_rate, true_pos_rate, thresholds = self._roc_curve(
            self._true_labels,
            self._predicted_scores,
            drop_intermediate=False
        )

        auc_score = float(self._roc_auc_score(self._true_labels, self._predicted_scores))

        for threshold, tpr, fpr in zip(thresholds, true_pos_rate, false_pos_rate):
            yield float(threshold), float(tpr), float(fpr), auc_score
$$;

Erstellen Sie Beispiel-Binärklassifizierungsdaten mit Vorhersagen:

CREATE OR REPLACE TEMPORARY VIEW binary_classification_data AS
SELECT *
FROM VALUES
  ( 1, 1.0, 0.95, 'high_confidence_positive'),
  ( 2, 1.0, 0.87, 'high_confidence_positive'),
  ( 3, 1.0, 0.82, 'medium_confidence_positive'),
  ( 4, 0.0, 0.78, 'false_positive'),
  ( 5, 1.0, 0.71, 'medium_confidence_positive'),
  ( 6, 0.0, 0.65, 'false_positive'),
  ( 7, 0.0, 0.58, 'true_negative'),
  ( 8, 1.0, 0.52, 'low_confidence_positive'),
  ( 9, 0.0, 0.45, 'true_negative'),
  (10, 0.0, 0.38, 'true_negative'),
  (11, 1.0, 0.31, 'low_confidence_positive'),
  (12, 0.0, 0.15, 'true_negative'),
  (13, 0.0, 0.08, 'high_confidence_negative'),
  (14, 0.0, 0.03, 'high_confidence_negative')
AS data(sample_id, y_true, y_score, prediction_type);

Berechnen sie die ROC-Kurve und die AUC:

SELECT
    threshold,
    true_positive_rate,
    false_positive_rate,
    auc
FROM compute_roc_curve(
  TABLE(
    SELECT y_true, y_score
    FROM binary_classification_data
    WHERE y_true IS NOT NULL AND y_score IS NOT NULL
    ORDER BY sample_id
  )
)
ORDER BY threshold DESC;
+-----------+---------------------+----------------------+-------+
| threshold | true_positive_rate  | false_positive_rate  | auc   |
+-----------+---------------------+----------------------+-------+
| 1.95      | 0.0                 | 0.0                  | 0.786 |
| 0.95      | 0.167               | 0.0                  | 0.786 |
| 0.87      | 0.333               | 0.0                  | 0.786 |
| 0.82      | 0.5                 | 0.0                  | 0.786 |
| 0.78      | 0.5                 | 0.125                | 0.786 |
| 0.71      | 0.667               | 0.125                | 0.786 |
| 0.65      | 0.667               | 0.25                 | 0.786 |
| 0.58      | 0.667               | 0.375                | 0.786 |
| 0.52      | 0.833               | 0.375                | 0.786 |
| 0.45      | 0.833               | 0.5                  | 0.786 |
| 0.38      | 0.833               | 0.625                | 0.786 |
| 0.31      | 1.0                 | 0.625                | 0.786 |
| 0.15      | 1.0                 | 0.75                 | 0.786 |
| 0.08      | 1.0                 | 0.875                | 0.786 |
| 0.03      | 1.0                 | 1.0                  | 0.786 |
+-----------+---------------------+----------------------+-------+

Einschränkungen

Die folgenden Einschränkungen gelten für Unity Catalog Python UDTFs:

Nächste Schritte