Esplorazione dell'arte attraverso la cultura e media con vicini veloci, condizionali, k-nearest

Questo articolo funge da linea guida per la ricerca di corrispondenze tramite k-near-neighbors. Si configura il codice che consente di eseguire query che coinvolgono culture e mezzi d'arte accumulati dal Metropolitan Museum of Art di Nyc e dal Rijks museum di Amsterdam.

Prerequisiti

  • Collegare il notebook a una lakehouse. Sul lato sinistro selezionare Aggiungi per aggiungere una lakehouse esistente o creare una lakehouse.

Panoramica di BallTree

La struttura che funziona dietro il modello KNN è un BallTree, che è un albero binario ricorsivo in cui ogni nodo (o "palla") contiene una partizione dei punti di dati su cui eseguire una query. La creazione di un BallTree comporta l'assegnazione di punti dati alla "palla" il cui centro è più vicino (rispetto a una determinata caratteristica specificata), con conseguente struttura che consente l'attraversamento binario simile all'albero binario e si presta a trovare i vicini k più vicini a una foglia di BallTree.

Attrezzaggio

Importare le librerie Python necessarie e preparare il set di dati.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Il set di dati proviene da una tabella contenente informazioni sull'opera d'arte provenienti dai musei Met e Rijks. Lo schema è il seguente:

  • id: identificatore univoco per un pezzo d'arte
    • Id met di esempio: 388395
    • Id Rijks di esempio: SK-A-2344
  • Titolo: Titolo del pezzo d'arte, come scritto nel database del museo
  • Artista: arte artista, come scritto nel database del museo
  • Thumbnail_Url: posizione di un'anteprima JPEG del pezzo d'arte
  • Image_Url Posizione di un'immagine del pezzo d'arte ospitato sul sito Web Met/Rijks
  • Cultura: Categoria di cultura che il pezzo d'arte rientra
    • Categorie cultura di esempio: latino americano, egiziano e così via.
  • Classificazione: categoria di media sotto cui rientra il pezzo d'arte
    • Categorie medie di esempio: legno, dipinti e così via.
  • Museum_Page: Collegamento all'opera d'arte sul sito Web Met/Rijks
  • Norm_Features: Incorporamento dell'immagine dell'arte
  • Museo: specifica il museo da cui proviene il pezzo
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definire le categorie su cui eseguire query

Vengono usati due modelli KNN: uno per le impostazioni cultura e uno per il supporto.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definire e adattare i modelli ConditionalKNN

Creare modelli ConditionalKNN per le colonne medie e cultura; ogni modello accetta una colonna di output, una colonna feature (vettore di funzionalità), una colonna di valori (valori di cella nella colonna di output) e una colonna etichetta (la qualità su cui è condizionata la rispettiva chiave KNN).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definire metodi di corrispondenza e visualizzazione

Dopo la configurazione iniziale del set di dati e della categoria, preparare i metodi che eseguiranno query e visualizzeranno i risultati del knn condizionale.

addMatches() crea un dataframe con una manciata di corrispondenze per categoria.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() chiama plot_img per visualizzare le corrispondenze principali per ogni categoria in una griglia.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Combinazione delle funzionalità

Definire test_all() per accettare i dati, i modelli CKNN, i valori di ID immagine su cui eseguire query e il percorso del file in cui salvare la visualizzazione di output. I modelli di media e cultura sono stati precedentemente sottoposti a training e caricati.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

La cella seguente esegue query in batch in base agli ID immagine desiderati e un nome file per salvare la visualizzazione.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")