Explorar el arte a través de la cultura y el medio con vecinos más próximos de k (KNN), rápidos y condicionales

Este artículo sirve como guía para buscar coincidencias a través de los vecinos más próximos de k. Se configura un código que permite realizar consultas sobre las culturas y los medios de arte acumulados en el Museo Metropolitano de Arte de Nueva York y el Rijksmuseum de Amsterdam.

Requisitos previos

  • Adjunte el bloc de notas a una casa de lago. En el lado izquierdo, seleccione Añadir para añadir un almacén de lago existente o crear uno.

Información general del BallTree

La estructura que funciona detrás del modelo KNN es un BallTree, que es un árbol binario recursivo donde cada nodo (o "bola") contiene una partición de los puntos de datos a consultar. Para crear un BallTree hay que asignar los puntos de datos a la "bola" cuyo centro esté más próximo (con respecto a una determinada característica especificada), lo que da lugar a una estructura que permite recorrerla como un árbol binario y se presta a la búsqueda de los vecinos más próximos de k en una hoja del BallTree.

Configuración

Importar las bibliotecas Python necesarias y preparar el conjunto de datos.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Nuestro conjunto de datos procede de una tabla que contiene información sobre obras de arte de los museos Met y Rijks. El esquema es el siguiente

  • id: un identificador único para una obra de arte
    • Muestra Met id: 388395
    • Muestra Rijks id: SK-A-2344
  • Título: título de la obra de arte, tal y como figura en la base de datos del museo
  • Artista: artista de la obra, tal como figura en la base de datos del museo
  • Url_miniatura: ubicación de una miniatura JPEG de la obra de arte
  • Url_de_imagen ubicación de una imagen de la obra de arte hospedada en el sitio Web del Met/Rijks
  • Cultura: categoría cultural a la que pertenece la obra de arte
    • Ejemplos de categorías culturales: latinoamericana, egipcia, etc.
  • Clasificación: categoría del medio al que pertenece la obra de arte
    • Ejemplo de categorías de soporte: obra de madera, pintura, etc.
  • Museum_Page: vínculo a la obra de arte en el sitio web de Met/Rijks
  • Norm_Features: inserción de la imagen de la obra de arte
  • Museum: especifica de qué museo procede la obra
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Definir las categorías que se consultarán

Se utilizan dos modelos KNN: uno para la cultura y otro para el medio.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definir y ajustar modelos ConditionalKNN

Cree modelos ConditionalKNN para las columnas de medio y cultura; cada modelo toma una columna de salida, una columna de características (vector de características), una columna de valores (valores de celda bajo la columna de salida) y una columna de etiqueta (la calidad a la que está condicionado el KNN respectivo).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definir los métodos de comparación y visualización

Después de la configuración inicial del conjunto de datos y las categorías, prepare los métodos que consultarán y visualizarán los resultados del KNN condicional.

addMatches() crea un Dataframe con un puñado de coincidencias por categoría.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() llama a plot_img para visualizar las mejores coincidencias de cada categoría en una cuadrícula.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Poner en práctica lo aprendido

Defina test_all() para tomar los datos, los modelos CKNN, los valores de id de arte a consultar, y la ruta del archivo para guardar la visualización de salida. Los modelos de cultura y media se entrenaron y cargaron previamente.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demostración

La siguiente celda realiza consultas por lotes dados los IDs de imagen deseados y un nombre de archivo para guardar la visualización.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")