Udforskning af kunst på tværs af kultur og medium med hurtige, betingede, k-nærmeste naboer

Denne artikel fungerer som en retningslinje for match-finding via k-nærmeste-naboer. Du opretter kode, der tillader forespørgsler, der involverer kulturer og kunstmedier samlet fra Metropolitan Museum of Art i NYC og Rijksmuseum i Amsterdam.

Forudsætninger

  • Vedhæft din notesbog til et lakehouse. I venstre side skal du vælge Tilføj for at tilføje et eksisterende lakehouse eller oprette et lakehouse.

Oversigt over BallTree

Strukturen, der fungerer bag KNN-modellen, er en BallTree, som er et rekursivt binært træ, hvor hver node (eller "kugle") indeholder en partition af de datapunkter, der skal forespørges på. Oprettelse af en BallTree omfatter tildeling af datapunkter til den "kugle", hvis centrum de er tættest på (med hensyn til en bestemt bestemt funktion), hvilket resulterer i en struktur, der tillader binær-træ-lignende traversal og egner sig til at finde k-nærmeste naboer på en BallTree blad.

Opsætte

Importér nødvendige Python-biblioteker, og forbered datasæt.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Vores datasæt kommer fra en tabel, der indeholder illustrationsoplysninger fra både Met- og Rijks-museerne. Skemaet er som følger:

  • id: Et entydigt id for et kunstværk
    • Eksempel på met-id: 388395
    • Eksempel på Rijks-id: SK-A-2344
  • Titel: Kunststykketitel, som skrevet i museets database
  • Kunstner: Kunststykkekunstner, som skrevet i museets database
  • Thumbnail_Url: Placering af et JPEG-miniaturebillede af kunststykket
  • Image_Url Placering af et billede af kunststykket, der hostes på Met/Rijks-webstedet
  • Kultur: Kulturkategori, som kunststykket falder ind under
    • Eksempel på kulturkategorier: latinamerikansk, egyptisk osv.
  • Klassificering: Kategori af medium, at stykket falder ind under
    • Eksempel på mellemstore kategorier: træarbejde, malerier osv.
  • Museum_Page: Link til kunstværket på Met/Rijks hjemmeside
  • Norm_Features: Integrering af billedstykkebilledet
  • Museum: Angiver, hvilket museum stykket stammer fra
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definer kategorier, der skal forespørges på

Der bruges to KNN-modeller: én til kultur og en til medium.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definer og tilpas betingedeKNN-modeller

Opret betingedeKNN-modeller for både mellem- og kulturkolonnerne. hver model bruger en outputkolonne, en funktionskolonne (funktionsvektor), en værdikolonne (celleværdier under outputkolonnen) og en etiketkolonne (den kvalitet, som den respektive KNN er betinget af).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definer matchende og visualiseringsmetoder

Efter den indledende konfiguration af datasæt og kategori skal du forberede metoder, der forespørger og visualiserer det betingede KNN's resultater.

addMatches() opretter en Dataframe med en håndfuld match pr. kategori.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() kalder plot_img for at visualisere de mest populære match for hver kategori i et gitter.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Sætte det hele sammen

Definer test_all() for at hente data, CKNN-modeller, de kunst-id-værdier, der skal forespørgs på, og den filsti, outputvisualiseringen skal gemmes i. Medium- og kulturmodellerne blev tidligere oplært og indlæst.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

Følgende celle udfører batchforespørgsler med de ønskede billed-id'er og et filnavn for at gemme visualiseringen.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")