Udforskning af kunst på tværs af kultur og medium med hurtige, betingede, k-nærmeste naboer
Denne notesbog fungerer som en retningslinje for match-finding via k-nearest-naboer. Vi har oprettet kode, der tillader forespørgsler, der involverer kulturer og kunstmedier, der er samlet fra Metropolitan Museum of Art i NYC og Rijksmuseum i Amsterdam.
Forudsætninger
- Vedhæft din notesbog til et lakehouse. I venstre side skal du vælge Tilføj for at tilføje et eksisterende lakehouse eller oprette et lakehouse.
Oversigt over BallTree
Den struktur, der fungerer bag kNN-modellen, er en BallTree, som er et rekursivt binært træ, hvor hver node (eller "bold") indeholder en partition af de datapunkter, der skal forespørges. Opbygning af en BallTree omfatter tildeling af datapunkter til "kuglen", hvis centrum de er tættest på (med hensyn til en bestemt bestemt funktion), hvilket resulterer i en struktur, der tillader binær-træ-lignende traversal og egner sig til at finde k-nærmeste naboer på en BallTree blad.
Konfiguration
Importér nødvendige Python-biblioteker, og forbered datasæt.
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
Vores datasæt kommer fra en tabel, der indeholder illustrationsoplysninger fra både Met- og Rijks-museerne. Skemaet er som følger:
- id: Et entydigt id for et kunstværk
- Eksempel på met-id: 388395
- Eksempel på Rijks-id: SK-A-2344
- Titel: Titel på kunststykke, som skrevet i museets database
- Kunstner: Kunststykkekunstner, som skrevet i museets database
- Thumbnail_Url: Placering af et JPEG-miniaturebillede af kunstværket
- Image_Url Placering af et billede af kunststykket, der hostes på Met/Rijks-webstedet
- Kultur: Kulturkategori, som kunststykket falder ind under
- Eksempler på kulturkategorier: latinamerika, egyptisk osv.
- Klassificering: Mellemkategori, som kunstværket falder ind under
- Eksempel på mellemstore kategorier: træværk, malerier osv.
- Museum_Page: Link til kunstværket på Met/Rijks hjemmeside
- Norm_Features: Integrering af billedet af kunstværket
- Museum: Angiver, hvilket museum stykket stammer fra
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
Definer kategorier, der skal forespørges på
Vi bruger to kNN-modeller: en til kultur og en til mellem.
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
Definer og tilpas betingedeKNN-modeller
Vi opretter betingede KNN-modeller til både mellem- og kulturkolonner. hver model bruger en outputkolonne, funktionskolonne (funktionsvektor), værdikolonne (celleværdier under outputkolonnen) og etiketkolonne (den kvalitet, som den respektive KNN er betinget af).
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
Definer matchende og visualiseringsmetoder
Efter den indledende konfiguration af datasæt og kategori forbereder vi metoder, der forespørger og visualiserer resultaterne af det betingede kNN.
addMatches()
opretter en Dataframe med en håndfuld forekomster pr. kategori.
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
kalder plot_img
for at visualisere de mest populære forekomster for hver kategori i et gitter.
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
Samling af det hele
Vi definerer test_all()
, hvordan data, CKNN-modeller, de kunst-id-værdier, der skal forespørgs om, skal hentes, og den filsti, outputvisualiseringen skal gemmes i. Mellem- og kulturmodellerne blev tidligere oplært og indlæst.
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
Demo
Følgende celle udfører batchforespørgsler med de ønskede billed-id'er og et filnavn for at gemme visualiseringen.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")