Hinweis
Für den Zugriff auf diese Seite ist eine Autorisierung erforderlich. Sie können versuchen, sich anzumelden oder das Verzeichnis zu wechseln.
Für den Zugriff auf diese Seite ist eine Autorisierung erforderlich. Sie können versuchen, das Verzeichnis zu wechseln.
In diesem Artikel wird die Übereinstimmungssuche über den Algorithmus "k-nearest-neighbors" beschrieben. Sie erstellen Coderessourcen, die Abfragen mit Kulturen und Medien von Kunst aus dem Metropolitan Museum of Art in NYC und dem Amsterdam Rijksmuseum ermöglichen.
Voraussetzungen
- Ein Notizbuch, das an ein Seehaus angeschlossen ist. Besuchen Sie die Daten in Ihrem Seehaus mit einem Notizbuch , um weitere Informationen zu erhalten.
Übersicht über den BallTree
Das k-NN-Modell basiert auf der BallTree-Datenstruktur . BallTree ist eine rekursive binäre Struktur, in der jeder Knoten (oder "Ball") eine Partition oder Teilmenge der Datenpunkte enthält, die Sie abfragen möchten. Um eine BallTree zu erstellen, bestimmen Sie die "Ball"-Mitte (basierend auf einem bestimmten angegebenen Feature), das jedem Datenpunkt am nächsten kommt. Weisen Sie dann jedem Datenpunkt den entsprechenden nächstgelegenen "Ball" zu. Diese Zuordnungen erstellen eine Struktur, die binäre baumähnliche Traversale ermöglicht, und bietet sich an, k-nächste Nachbarn auf einem BallTree-Blatt zu finden.
Konfiguration
Importieren Sie die erforderlichen Python-Bibliotheken, und bereiten Sie das Dataset vor:
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
Das Dataset stammt aus einer Tabelle, die Kunstinformationen sowohl aus dem Met Museum als auch aus dem Rijksmuseum enthält. Die Tabelle weist dieses Schema auf:
-
ID: Ein eindeutiger Bezeichner für jedes bestimmte Kunstwerk
- Beispiel-Met-ID: 388395
- Beispiel-Rijks-ID: SK-A-2344
- Titel: Art piece title, wie in der Datenbank des Museums geschrieben
- Künstler: Kunststückkünstler, wie in der Datenbank des Museums geschrieben
- Thumbnail_Url: Speicherort einer JPEG-Miniaturansicht des Kunstwerks
- Image_Url Website-URL-Speicherort des Kunststückbilds, gehostet auf der Met/Rijks-Website
-
Kultur: Kulturkategorie des Kunststücks
- Beispielkulturkategorien: Lateinamerikanisch, Ägyptisch usw.
-
Klassifikation: Mittlere Kategorie des Kunststücks
- Mustermittelkategorien: Holzarbeiten, Gemälde usw.
- Museum_Page: URL-Link zum Kunstwerk, gehostet auf der Met/Rijks-Website
- Norm_Features: Einbetten des Grafikbilds
- Museum: Das Museum, in dem das eigentliche Kunststück gehostet wird
# loads the dataset and the two trained conditional k-NN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
Um die Abfrage zu erstellen, definieren Sie die Kategorien
Verwenden Sie zwei k-NN-Modelle: eine für Kultur und eine für mittel:
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
Definieren und Anpassen von bedingten k-NN-Modellen
Erstellen Sie bedingte k-NN-Modelle sowohl für mittlere als auch für Kulturspalten. Jedes Modell benötigt
- eine Ausgabespalte
- Eine Featurespalte (Featurevektor)
- eine Wertespalte (Zellwerte unter der Ausgabespalte)
- eine Etikettenspalte (die Qualität, auf der die jeweilige k-NN bedingt ist)
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
Definieren von Abgleichs- und Visualisierungsmethoden
Bereiten Sie nach dem anfänglichen Dataset und der Kategorieeinrichtung die Methoden zum Abfragen und Visualisieren der Ergebnisse der bedingten k-NN vor:
addMatches()
erstellt einen Dataframe mit einer Handvoll Übereinstimmungen pro Kategorie:
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
Aufrufe plot_img
zum Visualisieren der wichtigsten Übereinstimmungen für jede Kategorie in einem Raster:
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
Alles zusammensetzen
So nehmen Sie sich ein
- die Daten
- die bedingten k-NN-Modelle
- die zu abfragenden Grafik-ID-Werte
- Dateipfad, in dem die Ausgabevisualisierung gespeichert wird
definieren einer Funktion, die aufgerufen wird test_all()
Die Mittel- und Kulturmodelle wurden zuvor trainiert und geladen.
# main method to test a particular dataset with two conditional k-NN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
Demo
In der folgenden Zelle werden Batchabfragen ausgeführt, wobei die gewünschten Bild-IDs und ein Dateiname zum Speichern der Visualisierung verwendet werden.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")