Erkunden von Kunst in Kultur und Medium mit schnellen, bedingten, k-nächstgelegenen Nachbarn

Dieses Notebook dient als Leitfaden für die Übereinstimmungssuche über k-nearest-neighbors. Wir haben Code eingerichtet, der Abfragen von Kulturen und Medien der Kunst ermöglicht, die aus dem Metropolitan Museum of Art in NYC und dem Rijksmuseum in Amsterdam stammen.

Voraussetzungen

  • Schließen Sie Ihr Notizbuch an ein Lakehouse an. Wählen Sie auf der linken Seite Hinzufügen aus, um ein vorhandenes Lakehouse hinzuzufügen oder ein Lakehouse zu erstellen.

Übersicht über den BallTree

Die Struktur, die hinter dem kNN-Modell funktioniert, ist ein BallTree, bei dem es sich um eine rekursive binäre Struktur handelt, in der jeder Knoten (oder "Ball") eine Partition der abzufragenden Datenpunkte enthält. Das Erstellen eines BallTree umfasst das Zuweisen von Datenpunkten zu dem "Ball", dessen Mittelpunkt sie am nächsten sind (in Bezug auf ein bestimmtes angegebenes Feature), was zu einer Struktur führt, die binärbaumähnlichen Durchlauf ermöglicht und sich eignet, um k-nächste Nachbarn an einem BallTree-Blatt zu finden.

Einrichten

Importieren Sie die erforderlichen Python-Bibliotheken, und bereiten Sie das Dataset vor.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Unser Dataset stammt aus einer Tabelle mit Kunstwerken aus den Museen Met und Rijks. Das Schema lautet wie folgt:

  • id: Ein eindeutiger Bezeichner für ein Kunstwerk
    • Beispiel für die Met-ID: 388395
    • Beispiel für Rijks-ID: SK-A-2344
  • Titel: Titel des Kunstwerks, wie in der Datenbank des Museums geschrieben
  • Künstler: Art-Piece-Künstler, wie in der Datenbank des Museums geschrieben
  • Thumbnail_Url: Speicherort einer JPEG-Miniaturansicht des Kunstwerks
  • Image_Url Ort eines Bilds des Kunstwerks, das auf der Met/Rijks-Website gehostet wird
  • Kultur: Kategorie der Kultur, unter die das Kunstwerk fällt
    • Beispielkulturkategorien: lateinamerikanisch, ägyptisch usw.
  • Klassifizierung: Kategorie des Mediums, unter das das Kunstwerk fällt
    • Beispielmedienkategorien: Holzarbeiten, Gemälde usw.
  • Museum_Page: Link zum Kunstwerk auf der Website von Met/Rijks
  • Norm_Features: Einbettung des Kunstwerkbilds
  • Museum: Gibt an, aus welchem Museum das Stück stammt
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definieren von Kategorien, für die abgefragt werden soll

Wir verwenden zwei kNN-Modelle: eines für Kultur und eines für medium.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definieren und Anpassen von bedingtenKNN-Modellen

Wir erstellen bedingteKNN-Modelle sowohl für die Medium- als auch für die Kulturspalte. Jedes Modell akzeptiert eine Ausgabespalte, bietet Spalten (Featurevektor), Wertespalte (Zellenwerte unter der Ausgabespalte) und Bezeichnungsspalte (die Qualität, auf die der jeweilige KNN unterliegt).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definieren von Abgleichs- und Visualisierungsmethoden

Nach der einrichtung des anfänglichen Datasets und der Kategorie bereiten wir Methoden vor, die die Ergebnisse des bedingten kNN abfragen und visualisieren.

addMatches() erstellt einen Dataframe mit einer Handvoll Übereinstimmungen pro Kategorie.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() aufruft plot_img , um die wichtigsten Übereinstimmungen für jede Kategorie in einem Raster zu visualisieren.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Zusammenfügen des Gesamtbilds

Wir definieren test_all() , um die Daten, CKNN-Modelle, die Kunst-ID-Werte für die Abfrage und den Dateipfad zu übernehmen, in dem die Ausgabevisualisierung gespeichert werden soll. Die Medien- und Kulturmodelle wurden zuvor trainiert und geladen.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

Die folgende Zelle führt Batchabfragen mit gewünschten Bild-IDs und einem Dateinamen aus, um die Visualisierung zu speichern.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")

Nächste Schritte