Sdílet prostřednictvím


Prozkoumejte umění napříč kulturou a médium pomocí rychlého, podmíněného algoritmu k-nejbližšího souseda

Tento článek popisuje hledání shody pomocí algoritmu k-nejbližších sousedů. Vytváříte prostředky kódu, které umožňují dotazy zahrnující kultury a média umění maskované z Metropolitního muzea umění v NYC a Amsterdam Rijksmuseum.

Požadavky

Přehled BallTree

Model k-NN spoléhá na datovou strukturu BallTree . BallTree je rekurzivní binární strom, kde každý uzel (nebo "míč") obsahuje oddíl nebo podmnožinu datových bodů, které chcete dotazovat. Chcete-li vytvořit BallTree, určete střed "míč" (na základě určité zadané funkce) nejblíže ke každému datovému bodu. Pak každému datovému bodu přiřaďte odpovídající nejbližší "míč". Tato přiřazení vytvářejí strukturu, která umožňuje procházení jako binární strom, a umožňuje najít k-nejbližší sousedy v BallTree list.

Nastavení

Naimportujte potřebné knihovny Pythonu a připravte datovou sadu:

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Datová sada pochází z tabulky, která obsahuje informace o uměleckých dělech z Met Museum i Rijksmuseum. Tabulka má toto schéma:

  • ID: Jedinečný identifikátor každého konkrétního kusu umění
    • Ukázkové ID met: 388395
    • Ukázkové ID Rijks: SK-A-2344
  • Název: Název umělecké části, jak je napsané v databázi muzea
  • Umělec: umělecký umělec, jak je napsané v databázi muzea
  • Thumbnail_Url: Umístění miniatury obrázku ve formátu JPEG
  • Image_Url Umístění adresy URL webu obrázku umělecké části hostované na webu Met/Rijks
  • Kultura: Kultura kategorie umělecké části
    • Ukázkové jazykové kategorie: latinamerická, egyptská atd.
  • Klasifikace: Střední kategorie umělecké části
    • Ukázkové střední kategorie: dřevo,obrazy atd.
  • Museum_Page: Odkaz na uměleckou část hostovaný na webu Met/Rijks
  • Norm_Features: Vložení obrázku umělecké části
  • Muzeum: Muzeum hostující skutečnou uměleckou část
# loads the dataset and the two trained conditional k-NN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Pokud chcete vytvořit dotaz, definujte kategorie.

Použijte dva modely k-NN: jeden pro jazykovou verzi a jeden pro střední:

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Definování a přizpůsobení podmíněných modelů k-NN

Vytvořte podmíněné modely k-NN pro sloupce se střední i jazykovou verzí. Každý model přebírá

  • výstupní sloupec
  • sloupec funkcí (vektor funkcí)
  • sloupec hodnot (hodnoty buněk ve výstupním sloupci)
  • sloupec popisku (kvalita, na které je příslušný k-NN podmíněný)
medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Definování odpovídajících a vizualizačních metod

Po počátečním nastavení datové sady a kategorie připravte metody pro dotazování a vizualizaci výsledků podmíněné sítě k-NN:

addMatches() vytvoří datový rámec s několika shodami pro každou kategorii:

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() volání plot_img k vizualizaci nejlepších shod pro každou kategorii do mřížky:

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Spojte všechno dohromady

Jak se dostat do

  • data
  • podmíněné modely k-NN
  • hodnoty ID umění, na které se mají dotazovat
  • cesta k souboru, kam se uloží výstupní vizualizace

define a function called test_all()

Modely střední a jazykové verze byly dříve natrénovány a načteny.

# main method to test a particular dataset with two conditional k-NN models and a set of art IDs, saving the result to filename.png

def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

Následující buňka provádí dávkové dotazy s ohledem na ID požadovaných obrázků a název souboru pro uložení vizualizace.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")