Del via


Utforsk kunst på tvers av kultur og medier med den raske, betingede, k-nærmeste naboalgoritmen

Denne artikkelen beskriver match-finding via k-nærmeste-naboer algoritmen. Du bygger koderessurser som tillater spørringer som involverer kulturer og medier av kunst samlet fra Metropolitan Museum of Art i NYC og Amsterdam Rijksmuseum.

Forutsetning

Oversikt over BallTree

K-NN-modellen er avhengig av BallTree-datastrukturen . BallTree er et rekursivt binært tre, der hver node (eller «ball») inneholder en partisjon, eller et delsett, av datapunktene du vil spørre etter. Hvis du vil bygge en BallTree, må du bestemme «ball»-midten (basert på en bestemt angitt funksjon) nærmest hvert datapunkt. Deretter tilordner du hvert datapunkt til den tilsvarende nærmeste «ballen». Disse oppgavene oppretter en struktur som gjør det mulig for binære trelignende traverser, og gir seg til å finne k-nærmeste naboer på et BallTree-blad.

Konfigurasjon

Importer de nødvendige Python-bibliotekene, og klargjør datasettet:

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Datasettet kommer fra en tabell som inneholder kunstverkinformasjon fra både Met Museum og Rijksmuseum. Tabellen har dette skjemaet:

  • ID: En unik identifikator for hvert bestemt kunstverk
    • Eksempel på met-ID: 388395
    • Eksempel på Rijks-ID: SK-A-2344
  • Tittel: Art piece tittel, som skrevet i museets database
  • Kunstner: Kunstkunstner, som skrevet i museets database
  • Thumbnail_Url: Plassering av et JPEG-miniatyrbilde av kunstverket
  • Image_Url Nettadresseplassering for nettstedet for grafikkbildet, som driftes på Met/Rijks-nettstedet
  • Kultur: Kulturkategori for kunstverket
    • Eksempel på kulturkategorier: latinamerikansk, egyptisk og så videre.
  • Klassifisering: Middels kategori av kunstverket
    • Eksempel på mellomstore kategorier: treverk, malerier osv.
  • Museum_Page: URL-adressekobling til kunstverket, som driftes på Met/Rijks-nettstedet
  • Norm_Features: Innebygging av kunstverkbildet
  • Museum: Museet som er vert for selve kunstverket
# loads the dataset and the two trained conditional k-NN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Hvis du vil bygge spørringen, definerer du kategoriene

Bruk to k-NN-modeller: én for kultur og én for middels:

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definer og tilpass betingede k-NN-modeller

Opprett betingede k-NN-modeller for både mellom- og kulturkolonnene. Hver modell tar

  • en utdatakolonne
  • en funksjonskolonne (funksjonsvektor)
  • en verdikolonne (celleverdier under utdatakolonnen)
  • en etikettkolonne (kvaliteten som den respektive k-NN er betinget av)
medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definer samsvarende og visualiseringsmetoder

Etter det første datasettet og kategorioppsettet klargjør du metodene for å spørre og visualisere resultatene av det betingede k-NN:

addMatches() oppretter en dataramme med en håndfull treff per kategori:

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() kaller plot_img for å visualisere de beste treff for hver kategori i et rutenett:

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Sette sammen alt

Slik tar du inn

  • dataene
  • de betingede k-NN-modellene
  • grafikk-ID-verdiene som skal spørres etter
  • filbanen der utdatavisualiseringen lagres

definere en funksjon kalt test_all()

Medie- og kulturmodellene ble tidligere opplært og lastet.

# main method to test a particular dataset with two conditional k-NN models and a set of art IDs, saving the result to filename.png

def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demonstrasjon

Følgende celle utfører grupperte spørringer, gitt de ønskede bilde-ID-ene og et filnavn for å lagre visualiseringen.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")