Udostępnij za pośrednictwem


Eksploruj sztukę w różnych kulturach i mediach za pomocą szybkiego, warunkowego algorytmu k najbliższych sąsiadów

W tym artykule opisano znajdowanie dopasowań za pomocą algorytmu k najbliższych sąsiadów. Tworzysz zasoby kodu, które umożliwiają wykonywanie zapytań obejmujących kultury i medium sztuki zgromadzone z Metropolitan Museum of Art w Nowym Jorku i Amsterdamie Rijks.

Wymagania wstępne

Omówienie obiektu BallTree

Model k-NN opiera się na strukturze danych BallTree . BallTree to cykliczne drzewo binarne, w którym każdy węzeł (lub "piłka") zawiera partycję lub podzbiór punktów danych, które chcesz wykonać. Aby utworzyć obiekt BallTree, określ środek "piłki" (na podstawie określonej funkcji) najbliżej każdego punktu danych. Następnie przypisz każdy punkt danych do odpowiadającego mu najbliższej "kuli". Przypisania te tworzą strukturę, która umożliwia przechodzenie podobne do drzewa binarnego i nadaje się do znajdowania najbliższych sąsiadów w liściu BallTree.

Konfiguracja

Zaimportuj niezbędne biblioteki języka Python i przygotuj zestaw danych:

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Zestaw danych pochodzi z tabeli zawierającej informacje o sztuce zarówno z Muzeum Met, jak i Rijks pst. Tabela ma następujący schemat:

  • IDENTYFIKATOR: unikatowy identyfikator dla każdego konkretnego dzieła sztuki
    • Przykładowy identyfikator met: 388395
    • Przykładowy identyfikator Rijks: SK-A-2344
  • Tytuł: Tytuł sztuki, napisany w bazie danych muzeum
  • Artysta: Artysta sztuki, napisany w bazie danych muzeum
  • Thumbnail_Url: Lokalizacja miniatury JPEG dzieła sztuki
  • Image_Url Lokalizacja adresu URL witryny internetowej obrazu fragmentu sztuki hostowana w witrynie internetowej Met/Rijks
  • Kultura: Kategoria kultury utworu artystycznego
    • Kategorie kultury przykładowej: ameryka łacińska, egipska itp.
  • Klasyfikacja: Średnia kategoria dzieła sztuki
    • Przykładowe kategorie średnie: stolarka, obrazy itp.
  • Museum_Page: link url do dzieła sztuki hostowany w witrynie internetowej Met/Rijks
  • Norm_Features: Osadzanie obrazu utworu sztuki
  • Muzeum: Muzeum hostuje rzeczywisty kawałek sztuki
# loads the dataset and the two trained conditional k-NN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Aby utworzyć zapytanie, zdefiniuj kategorie

Użyj dwóch modeli k-NN: jeden dla kultury i jeden dla średniej:

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definiowanie i dopasowywanie warunkowych modeli k-NN

Utwórz warunkowe modele k-NN zarówno dla kolumn średnich, jak i kolumn kultury. Każdy model przyjmuje

  • kolumna wyjściowa
  • kolumna funkcji (wektor funkcji)
  • kolumna wartości (wartości komórek w kolumnie wyjściowej)
  • kolumna etykiety (jakość, na podstawie którego jest warunek k-NN)
medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definiowanie pasujących i wizualizowania metod

Po początkowej konfiguracji zestawu danych i kategorii przygotuj metody do wykonywania zapytań i wizualizowania wyników warunkowego k-NN:

addMatches() tworzy ramkę danych z kilkoma dopasowaniami na kategorię:

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() wywołania plot_img w celu wizualizacji pierwszych dopasowań dla każdej kategorii w siatce:

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Ułóż wszystko razem

Aby wziąć udział

  • dane
  • warunkowe modele k-NN
  • wartości identyfikatora sztuki do zapytania
  • ścieżka pliku, w której jest zapisywana wizualizacja wyjściowa

definiowanie funkcji o nazwie test_all()

Modele średnie i kulturowe zostały wcześniej wytrenowane i załadowane.

# main method to test a particular dataset with two conditional k-NN models and a set of art IDs, saving the result to filename.png

def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

Poniższa komórka wykonuje zapytania wsadowe, biorąc pod uwagę żądane identyfikatory obrazów i nazwę pliku w celu zapisania wizualizacji.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")