Entdecken Sie Kunst kultur- und medienübergreifend mit dem konditionalen k-Nächste-Nachbarn-Verfahren

In diesem Artikel verwenden Sie den bedingte k-nächstgelegenen Nachbaralgorithmus (k-NN) von SynapseML, um visuell ähnliche Grafiken zu finden. Sie fragen ein Dataset von Kunst aus dem Metropolitan Museum of Art in NYC ab, filtern nach Kultur und mittleren Kategorien.

Voraussetzungen

  • Erstellen Sie ein neues Notizbuch.
  • Fügen Sie Ihr Notebook an ein Lakehouse an. Wählen Sie auf der linken Seite Ihres Notebooks Hinzufügen aus, um ein vorhandenes Lakehouse hinzuzufügen oder ein neues zu erstellen.

Importieren von Bibliotheken

Importieren Sie in der ersten Notizbuchzelle die erforderlichen Python-Bibliotheken:

from pyspark.sql.types import BooleanType
from pyspark.sql.functions import lit, array, udf
from synapse.ml.nn import ConditionalKNN
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt

Alle Importe sollten ohne Fehler abgeschlossen werden. Wenn ModuleNotFoundError angezeigt wird, bestätigen Sie, dass Sie Fabric Laufzeit 1.2 oder höher verwenden.

Laden des Datasets

Das Dataset ist eine Parkettdatei mit Grafikmetadaten aus dem Metropolitan Museum of Art. Laden Sie es in einen Spark DataFrame:

df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Das Dataset enthält ungefähr 51.000 Zeilen.

Datensatzschema

Die Tabelle enthält die folgenden Spalten:

  • id: Ein eindeutiger Bezeichner für jedes Kunstwerk (z. B 388395. )
  • Titel: Titel des Kunstwerks, wie in der Datenbank des Museums gespeichert
  • Künstler: Künstler des Kunstwerks gemäß den Angaben in der Datenbank des Museums
  • Thumbnail_Url: URL einer JPEG-Miniaturansicht des Kunststücks
  • Image_Url: Website-URL des vollständigen Grafikbilds
  • Kultur: Kulturkategorie (z. B. Japanisch, Amerikanisch, Italienisch)
  • Klassifizierung: Mittelkategorie (z. B. Gemälde, Keramik, Glas)
  • Museum_Page: URL-Link zur Seite "Kunstwerke" auf der Museumswebsite
  • Norm_Features: Vorab berechneter Bildeinbettungsvektor (wird für die Ähnlichkeitssuche verwendet)
  • Museum: Das Museum, das das Kunstwerk beherbergt

Definieren von Kategorien und Filtern der Daten

Definieren Sie die Kultur- und Mittelkategorien, die Sie abfragen möchten. Filtern Sie dann das Dataset so, dass nur Grafiken enthalten sind, die Ihren ausgewählten Kategorien entsprechen:

mediums = ["paintings", "glass", "ceramics"]
cultures = ["japanese", "american", "african (general)"]

# For more categories, uncomment the extended lists:
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings',
#            'musical instruments', 'glass', 'accessories', 'photographs',
#            'metalwork', 'sculptures', 'weapons', 'stone', 'precious',
#            'paper', 'woodwork', 'leatherwork', 'uncategorized']
# cultures = ['african (general)', 'american', 'ancient american',
#             'ancient asian', 'ancient european', 'ancient middle-eastern',
#             'asian (general)', 'austrian', 'belgian', 'british', 'chinese',
#             'czech', 'dutch', 'egyptian', 'european (general)', 'french',
#             'german', 'greek', 'iranian', 'italian', 'japanese',
#             'latin american', 'middle eastern', 'roman', 'russian',
#             'south asian', 'southeast asian', 'spanish', 'swiss', 'various']

classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)

small_df = df.where(
    udf(
        lambda medium, culture: (medium in medium_set) or (culture in culture_set),
        BooleanType(),
    )("Classification", "Culture")
)

small_df.cache()
print(f"Filtered dataset row count: {small_df.count()}")

Die Ausgabe zeigt je nach ausgewählten Kategorien eine Anzahl von mehreren tausend Zeilen an.

Anpassen von bedingten k-NN-Modellen

Erstellen Sie zwei konditionierte k-NN-Modelle – eines, das auf dem Medium (Klassifizierung) basiert, und eines, das auf der Kultur basiert. Jedes Modell akzeptiert:

  • Eine Ausgabespalte zum Speichern von Übereinstimmungen
  • Eine Featurespalte mit dem Bildeinbettungsvektor
  • Eine Wertespalte, die angibt, was für jede Übereinstimmung zurückgegeben werden soll (URL der Miniaturansicht)
  • Eine Bezeichnungsspalte , die die Konditionierungskategorie angibt
medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definieren von Abgleichs- und Visualisierungsmethoden

Definieren Sie Hilfsfunktionen, um die Modelle abzufragen und Ergebnisse anzuzeigen.

Die add_matches() Funktion wendet ein bedingtes k-NN-Modell für alle angegebenen Kategorien an und fügt jeweils eine Übereinstimmungsspalte hinzu:

def add_matches(classes, cknn, df):
    """Apply conditional k-NN for each category label, adding match columns."""
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

Die Funktionen plot_img() und plot_urls() stellen Abfrageergebnisse als Bildraster dar:

def plot_img(axis, url, title):
    """Download and display an image from a URL on a matplotlib axis."""
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except Exception as e:
        axis.text(0.5, 0.5, "Image\nunavailable", ha="center", va="center", fontsize=6)
    if title is not None:
        axis.set_title(title, fontsize=10)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    """Create a grid visualization of artwork thumbnails and save to file."""
    nx, ny = url_arr.shape

    fig, axes = plt.subplots(ny, nx, figsize=(nx * 5, ny * 5), dpi=150)

    # Reshape required for a single-image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.show()

Ausführen der Abfrage und Visualisieren von Ergebnissen

Definieren Sie die test_all() Funktion, um die Abfrage beider Modelle zu orchestrieren und Visualisierungen zu generieren:

def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    """Query both k-NN models for given art IDs and save visualizations."""
    is_match = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_match("id"))

    test_count = test_df.count()
    if test_count == 0:
        print("Warning: No matching art IDs found. Verify IDs exist in the filtered dataset.")
        return None

    print(f"Querying {test_count} artwork(s)...")

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Wählen Sie nun Beispielart-IDs aus dem gefilterten Dataset aus, und führen Sie die Abfrage aus:

# Select 3 sample artwork IDs from the filtered dataset
sample_rows = small_df.select("id").take(3)
selected_ids = {row["id"] for row in sample_rows}
print(f"Selected art IDs: {selected_ids}")

# Run the query and generate visualizations
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root="./")

Zwei Bildraster werden inline angezeigt. Das erste Raster zeigt das originale Kunstwerk mit nächsten Nachbarn in allen Kulturen. Das zweite Gitter zeigt die nächstgelegenen Nachbarn über verschiedene Medien hinweg.

Cleanup

Entfernen Sie zwischengespeicherte Daten und gespeicherte Dateien, wenn Sie mit der Erkundung fertig sind:

small_df.unpersist()
import os
for f in ["./matches_by_culture.png", "./matches_by_medium.png"]:
    if os.path.exists(f):
        os.remove(f)
        print(f"Removed {f}")
print("OK Cleanup complete")

Troubleshooting

Thema Ursache Resolution
ModuleNotFoundError: No module named 'synapse.ml' Notebook, das die Fabric-Runtime nicht verwendet Überprüfen Sie, ob Ihr Notebook mit einem Fabric-Lakehouse mit Runtime 1.2 oder höher verbunden ist.
Py4JJavaError während spark.read.parquet(...) Netzwerkkonnektivitätsproblem Stellen Sie sicher, dass Ihr Arbeitsbereich mmlspark.blob.core.windows.net über Port 443 erreichen kann.
Leeres Ergebnis aus test_all() (0 Zeilen) Ausgewählte IDs befinden sich nicht im gefilterten Dataset Verwenden Sie small_df.select("id").show(5), um gültige IDs aus den gefilterten Daten auszuwählen.
HTTPError oder leere Bilder in der Visualisierung Die Miniaturbild-URL ist nicht mehr zugänglich. Einige Miniaturansichten sind im Laufe der Zeit möglicherweise nicht mehr verfügbar. Die plot_img Funktion zeigt "Bild nicht verfügbar" für fehlgeschlagene Downloads an.
OutOfMemoryError während der Modellanpassung Dataset zu groß für verfügbaren Arbeitsspeicher Verringern der Anzahl von Kategorien in mediums und cultures Listen
Langsame Modellanpassung (>10 Minuten) Großes Dataset mit vielen Kategorien Beginnen Sie mit weniger Kategorien (jeweils 3), und erweitern Sie dann, sobald die Pipeline arbeitet.

Wie bedingtes k-NN funktioniert

Das bedingte k-NN-Modell basiert auf der BallTree-Datenstruktur . Ein BallTree ist eine rekursive binäre Struktur, in der jeder Knoten (oder "Ball") eine Partition der Datenpunkte enthält, die Sie abfragen möchten.

So erstellen Sie ein BallTree:

  1. Bestimmen Sie die "Ball"-Mitte, die jedem Datenpunkt am nächsten kommt, basierend auf einem angegebenen Feature.
  2. Weisen Sie jedem Datenpunkt den nächstgelegenen Ball zu.
  3. Wiederholen Sie dies rekursiv, indem Sie eine Struktur erstellen, die Traversierungen von Binärbäumen ermöglicht.

Diese Struktur ermöglicht effiziente k-Nächste-Nachbarn-Abfragen in jedem Blattknoten.