Events
Mar 31, 11 PM - Apr 2, 11 PM
The biggest Fabric, Power BI, and SQL learning event. March 31 – April 2. Use code FABINSIDER to save $400.
Register todayThis browser is no longer supported.
Upgrade to Microsoft Edge to take advantage of the latest features, security updates, and technical support.
This article serves as a guideline for match-finding via k-nearest-neighbors. You set up code that allows queries involving cultures and mediums of art amassed from the Metropolitan Museum of Art in NYC and the Rijksmuseum in Amsterdam.
The structure functioning behind the KNN model is a BallTree, which is a recursive binary tree where each node (or "ball") contains a partition of the points of data to be queried. Building a BallTree involves assigning data points to the "ball" whose center they're closest to (with respect to a certain specified feature), resulting in a structure that allows binary-tree-like traversal and lends itself to finding k-nearest neighbors at a BallTree leaf.
Import necessary Python libraries and prepare dataset.
from synapse.ml.core.platform import *
if running_on_binder():
from IPython import get_ipython
from pyspark.sql.types import BooleanType
from pyspark.sql.types import *
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import lit, array, array_contains, udf, col, struct
from synapse.ml.nn import ConditionalKNN, ConditionalKNNModel
from PIL import Image
from io import BytesIO
import requests
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
# Bootstrap Spark Session
spark = SparkSession.builder.getOrCreate()
Our dataset comes from a table containing artwork information from both the Met and Rijks museums. The schema is as follows:
# loads the dataset and the two trained CKNN models for querying by medium and culture
df = spark.read.parquet(
"wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))
Two KNN models are used: one for culture, and one for medium.
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings', "musical instruments","glass", 'accessories', 'photographs', "metalwork",
# "sculptures", "weapons", "stone", "precious", "paper", "woodwork", "leatherwork", "uncategorized"]
mediums = ["paintings", "glass", "ceramics"]
# cultures = ['african (general)', 'american', 'ancient american', 'ancient asian', 'ancient european', 'ancient middle-eastern', 'asian (general)',
# 'austrian', 'belgian', 'british', 'chinese', 'czech', 'dutch', 'egyptian']#, 'european (general)', 'french', 'german', 'greek',
# 'iranian', 'italian', 'japanese', 'latin american', 'middle eastern', 'roman', 'russian', 'south asian', 'southeast asian',
# 'spanish', 'swiss', 'various']
cultures = ["japanese", "american", "african (general)"]
# Uncomment the above for more robust and large scale searches!
classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)
selected_ids = {"AK-RBK-17525-2", "AK-MAK-1204", "AK-RAK-2015-2-9"}
small_df = df.where(
udf(
lambda medium, culture, id_val: (medium in medium_set)
or (culture in culture_set)
or (id_val in selected_ids),
BooleanType(),
)("Classification", "Culture", "id")
)
small_df.count()
Create ConditionalKNN models for both the medium and culture columns; each model takes in an output column, features column (feature vector), values column (cell values under the output column), and label column (the quality that the respective KNN is conditioned on).
medium_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Classification")
.fit(small_df)
)
culture_cknn = (
ConditionalKNN()
.setOutputCol("Matches")
.setFeaturesCol("Norm_Features")
.setValuesCol("Thumbnail_Url")
.setLabelCol("Culture")
.fit(small_df)
)
After the initial dataset and category setup, prepare methods that will query and visualize the conditional KNN's results.
addMatches()
creates a Dataframe with a handful of matches per category.
def add_matches(classes, cknn, df):
results = df
for label in classes:
results = cknn.transform(
results.withColumn("conditioner", array(lit(label)))
).withColumnRenamed("Matches", "Matches_{}".format(label))
return results
plot_urls()
calls plot_img
to visualize top matches for each category into a grid.
def plot_img(axis, url, title):
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
axis.imshow(img, aspect="equal")
except:
pass
if title is not None:
axis.set_title(title, fontsize=4)
axis.axis("off")
def plot_urls(url_arr, titles, filename):
nx, ny = url_arr.shape
plt.figure(figsize=(nx * 5, ny * 5), dpi=1600)
fig, axes = plt.subplots(ny, nx)
# reshape required in the case of 1 image query
if len(axes.shape) == 1:
axes = axes.reshape(1, -1)
for i in range(nx):
for j in range(ny):
if j == 0:
plot_img(axes[j, i], url_arr[i, j], titles[i])
else:
plot_img(axes[j, i], url_arr[i, j], None)
plt.savefig(filename, dpi=1600) # saves the results as a PNG
display(plt.show())
Define test_all()
to take in the data, CKNN models, the art id values to query on, and the file path to save the output visualization to. The medium and culture models were previously trained and loaded.
# main method to test a particular dataset with two CKNN models and a set of art IDs, saving the result to filename.png
def test_all(data, cknn_medium, cknn_culture, test_ids, root):
is_nice_obj = udf(lambda obj: obj in test_ids, BooleanType())
test_df = data.where(is_nice_obj("id"))
results_df_medium = add_matches(mediums, cknn_medium, test_df)
results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)
results = results_df_culture.collect()
original_urls = [row["Thumbnail_Url"] for row in results]
culture_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in cultures
]
culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")
medium_urls = [
[row["Matches_{}".format(label)][0]["value"] for row in results]
for label in mediums
]
medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")
return results_df_culture
The following cell performs batched queries given desired image IDs and a filename to save the visualization.
# sample query
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root=".")
Events
Mar 31, 11 PM - Apr 2, 11 PM
The biggest Fabric, Power BI, and SQL learning event. March 31 – April 2. Use code FABINSIDER to save $400.
Register todayTraining
Module
Explore the art world by using RESTful APIs - Training
Learn about how to build a basic API, how to query APIs, and authentication strategies in Python and JavaScript while discovering unique art by using museum APIs.