Menjelajahi Seni Lintas Budaya dan Sedang dengan Tetangga Yang Cepat, Bersyukur, k-Terdekat

Artikel ini berfungsi sebagai pedoman untuk pencarian kecocokan melalui k-tetangga terdekat. Anda menyiapkan kode yang memungkinkan kueri yang melibatkan budaya dan media seni yang dikumpulkan dari Museum Seni Metropolitan di NYC dan Rijksmuseum di Amsterdam.

Prasyarat

  • Lampirkan buku catatan Anda ke lakehouse. Di sisi kiri, pilih Tambahkan untuk menambahkan lakehouse yang ada atau buat lakehouse.

Gambaran umum BallTree

Struktur yang berfungsi di belakang model KNN adalah BallTree, yang merupakan pohon biner rekursif di mana setiap simpul (atau "bola") berisi partisi titik data yang akan dikueri. Membangun BallTree melibatkan penetapan poin data ke "bola" yang pusatnya paling dekat dengan mereka (sehubungan dengan fitur tertentu), menghasilkan struktur yang memungkinkan traversal seperti pohon biner dan meminjamkan dirinya untuk menemukan tetangga terdekat di daun BallTree.

Penyiapan

Impor pustaka Python yang diperlukan dan siapkan himpunan data.

from synapse.ml.core.platform import *

if running_on_binder():
    from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()

Himpunan data kami berasal dari tabel yang berisi informasi karya seni dari museum Met dan Rijks. Skema adalah sebagai berikut:

  • id: Pengidentifikasi unik untuk sepotong seni
    • Id Met Sampel: 388395
    • Contoh id Rijks: SK-A-2344
  • Judul: Judul karya seni, seperti yang ditulis dalam database museum
  • Seniman: Seniman karya seni, seperti yang ditulis dalam database museum
  • Thumbnail_Url: Lokasi gambar mini JPEG dari karya seni
  • Image_Url Lokasi gambar karya seni yang dihosting di situs web Met/Rijks
  • Budaya: Kategori budaya yang dijatuhkan oleh karya seni
    • Contoh kategori budaya: amerika latin, Mesir, dll.
  • Klasifikasi: Kategori medium yang dijatuhkan oleh karya seni
    • Contoh kategori sedang: kayu, lukisan, dll.
  • Museum_Page: Menautkan ke karya seni di situs web Met/Rijks
  • Norm_Features: Penyematan gambar karya seni
  • Museum: Menentukan museum mana yang berasal dari potongan tersebut
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

Tentukan kategori yang akan dikueri pada

Dua model KNN digunakan: satu untuk budaya, dan satu untuk menengah.

# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs',  "metalwork",
#           "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]

mediums = ["paintings", "glass", "ceramics"]

# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
#            'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
#            'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
#            'spanish', 'swiss', 'various']

cultures = ["japanese", "american", "african (general)"]

# Uncomment the above for more robust and large scale searches!

classes = cultures + mediums

medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}

small_df = df.where(
    udf(
        lambda medium, culture, id_val: (medium in medium_set)
        or (culture in culture_set)
        or (id_val in selected_ids),
        BooleanType(),
    )("Classification", "Culture", "id")
)

small_df.count()

Menentukan dan menyesuaikan model ConditionalKNN

Membuat model ConditionalKNN untuk kolom menengah dan budaya; setiap model mengambil kolom output, kolom fitur (vektor fitur), kolom nilai (nilai sel di bawah kolom output), dan kolom label (kualitas tempat KNN masing-masing dikondrasi).

medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Menentukan metode pencocokan dan visualisasi

Setelah himpunan data awal dan penyiapan kategori, siapkan metode yang akan mengkueri dan memvisualisasikan hasil KNN bersyarat.

addMatches() membuat Dataframe dengan beberapa kecocokan per kategori.

def add_matches(classes, cknn, df):
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

plot_urls() panggilan untuk memvisualisasikan plot_img kecocokan teratas untuk setiap kategori ke dalam kisi.

def plot_img(axis, url, title):
    try:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except:
        pass
    if title is not None:
        axis.set_title(title, fontsize=4)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    nx, ny = url_arr.shape

    plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
    fig, axes = plt.subplots(ny, nx)

    # reshape required in the case of 1 image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.savefig(filename, dpi=1600)  # saves the results as a PNG

    display(plt.show())

Merangkum semuanya

Tentukan untuk mengambil data, model CKNN, nilai id seni untuk dikueri test_all() , dan jalur file untuk menyimpan visualisasi output. Model menengah dan budaya sebelumnya dilatih dan dimuat.

# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png


def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_nice_obj("id"))

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Demo

Sel berikut melakukan kueri batch yang diberikan ID gambar yang diinginkan dan nama file untuk menyimpan visualisasi.

# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")