Exploration de l’art à travers la culture et les médias avec des k-plus proches voisins, rapides et conditionnels

Cet article sert de guide pour la recherche de correspondances au moyen des k-plus proches voisins. Vous avez mis en place un code qui permet des requêtes impliquant des cultures et des médias d’art provenant du Metropolitan Museum of Art de New York et du Rijksmuseum à Amsterdam.

Prérequis

  • Attachez votre cahier à une cabane au bord du lac. Sur le côté gauche, sélectionnez Ajouter pour ajouter un lakehouse existant ou en créer un.

Vue d’ensemble de BallTree

La structure fonctionnant derrière le modèle KNN est un BallTree. Il s’agit d’une arborescence binaire récursive dans laquelle chaque nœud (ou « boule ») contient une partition des points de données à interroger. La création d’un BallTree implique l’affectation des points de données à la « boule » dont ils sont le plus proches (par rapport à une certaine caractéristique spécifiée), ce qui aboutit à une structure autorisant une traversée du type arbre binaire et se prêtant à la recherche de k-plus proches voisins à une feuille BallTree.

Programme d’installation

Importez les bibliothèques Python indispensables et préparez le jeu de données.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Notre jeu de données provient d’une table contenant des informations sur les œuvres d’art des musées Met et Rijks. Le schéma se présente comme suit :

  • id : un identificateur unique d’une œuvre d’art
    • Exemple d’ID Met : 388395
    • Exemple d’ID Rijks : SK-A-2344
  • Titre : Titre d’une œuvre d’art, tel qu’il est écrit dans la base de données du musée
  • Artiste : Artiste de l’œuvre d’art, tel qu’il est écrit dans la base de données du musée
  • Thumbnail_Url : Emplacement d’une miniature JPEG de l’œuvre d’art
  • Image_Url Emplacement d’une image de l’œuvre d’art hébergée sur le site Web Met/Rijks
  • Culture : catégorie culturelle dont dépend l’œuvre d’art
    • Exemples de catégories culturelles : Amérique latine, Égyptienne, etc.
  • Classification : catégorie de média dont dépend l’œuvre d’art
    • Exemples de catégories de média : boiseries, peintures, etc.
  • Museum_Page: Lien vers l’œuvre d’art sur le site web Met/Rijks
  • Norm_Features : Incorporation de l’image de l’œuvre d’art
  • Musée : Indique le musée d’origine de la pièce
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Définir des catégories à interroger

Deux modèles KNN sont utilisés : l’un pour la culture et l’autre pour le média.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Définir et ajuster des modèles ConditionalKNN

Créez des modèles ConditionalKNN pour les colonnes média et culture. Chaque modèle accepte une colonne de sortie, une colonne de caractéristiques (vecteur de caractéristique), une colonne de valeurs (valeurs de cellule sous la colonne de sortie) et une colonne d’étiquette (la qualité sur laquelle le KNN respectif est conditionné).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Définir des méthodes de correspondance et de visualisation

Après la configuration initiale du jeu de données et de la catégorie, préparez les méthodes qui vont interroger et visualiser les résultats conditionnels du KNN.

addMatches() crée un Dataframe avec quelques correspondances par catégorie.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() appelle plot_img pour visualiser les premières correspondances pour chaque catégorie dans une grille.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Mise en application de toutes les fonctionnalités combinées

Définissez test_all() pour récupérer les données, les modèles CKNN, les valeurs de l’ID de l’œuvre d’art à interroger et le chemin d’accès au fichier dans lequel enregistrer la visualisation de sortie. Les modèles de média et de culture ont été précédemment entraînés et chargés.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Démo

La cellule suivante effectue des requêtes par lot, en fonction des ID d’image souhaités et d’un nom de fichier, pour enregistrer la visualisation.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")