Edit

Explore art across cultures and mediums with conditional k-nearest neighbors

In this article, you use the conditional k-nearest neighbors (k-NN) algorithm from SynapseML to find visually similar artwork. You query a dataset of art from the Metropolitan Museum of Art in NYC, filtering by culture and medium categories.

Prerequisites

  • Create a new notebook.
  • Attach your notebook to a lakehouse. On the left side of your notebook, select Add to add an existing lakehouse or create a new one.

Import libraries

In the first notebook cell, import the required Python libraries:

from pyspark.sql.types import BooleanType
from pyspark.sql.functions import lit, array, udf
from synapse.ml.nn import ConditionalKNN
from PIL import Image
from io import BytesIO

import requests
import numpy as np
import matplotlib.pyplot as plt

All imports should complete without errors. If you see ModuleNotFoundError, confirm you're using Fabric runtime 1.2 or later.

Load the dataset

The dataset is a parquet file containing artwork metadata from the Metropolitan Museum of Art. Load it into a Spark DataFrame:

df = spark.read.parquet(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/met_and_rijks.parquet"
)
display(df.drop("Norm_Features"))

The dataset contains approximately 51,000 rows.

Dataset schema

The table contains these columns:

  • id: A unique identifier for each piece of art (for example, 388395)
  • Title: Art piece title as stored in the museum's database
  • Artist: Art piece artist as stored in the museum's database
  • Thumbnail_Url: URL of a JPEG thumbnail of the art piece
  • Image_Url: Website URL of the full art piece image
  • Culture: Culture category (for example, japanese, american, italian)
  • Classification: Medium category (for example, paintings, ceramics, glass)
  • Museum_Page: URL link to the art piece page on the museum website
  • Norm_Features: Pre-computed image embedding vector (used for similarity search)
  • Museum: The museum that hosts the art piece

Define categories and filter the data

Define the culture and medium categories you want to query. Then filter the dataset to include only artwork that matches your selected categories:

mediums = ["paintings", "glass", "ceramics"]
cultures = ["japanese", "american", "african (general)"]

# For more categories, uncomment the extended lists:
# mediums = ['prints', 'drawings', 'ceramics', 'textiles', 'paintings',
#            'musical instruments', 'glass', 'accessories', 'photographs',
#            'metalwork', 'sculptures', 'weapons', 'stone', 'precious',
#            'paper', 'woodwork', 'leatherwork', 'uncategorized']
# cultures = ['african (general)', 'american', 'ancient american',
#             'ancient asian', 'ancient european', 'ancient middle-eastern',
#             'asian (general)', 'austrian', 'belgian', 'british', 'chinese',
#             'czech', 'dutch', 'egyptian', 'european (general)', 'french',
#             'german', 'greek', 'iranian', 'italian', 'japanese',
#             'latin american', 'middle eastern', 'roman', 'russian',
#             'south asian', 'southeast asian', 'spanish', 'swiss', 'various']

classes = cultures + mediums
medium_set = set(mediums)
culture_set = set(cultures)

small_df = df.where(
    udf(
        lambda medium, culture: (medium in medium_set) or (culture in culture_set),
        BooleanType(),
    )("Classification", "Culture")
)

small_df.cache()
print(f"Filtered dataset row count: {small_df.count()}")

The output shows a count of several thousand rows, depending on the selected categories.

Fit conditional k-NN models

Create two conditional k-NN models - one conditioned on the medium (Classification) and one conditioned on culture. Each model accepts:

  • An output column for storing matches
  • A features column containing the image embedding vector
  • A values column specifying what to return for each match (thumbnail URL)
  • A label column indicating the conditioning category
medium_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Classification")
    .fit(small_df)
)
culture_cknn = (
    ConditionalKNN()
    .setOutputCol("Matches")
    .setFeaturesCol("Norm_Features")
    .setValuesCol("Thumbnail_Url")
    .setLabelCol("Culture")
    .fit(small_df)
)

Define matching and visualization methods

Define helper functions to query the models and display results.

The add_matches() function applies a conditional k-NN model across all specified categories, adding a matches column for each:

def add_matches(classes, cknn, df):
    """Apply conditional k-NN for each category label, adding match columns."""
    results = df
    for label in classes:
        results = cknn.transform(
            results.withColumn("conditioner", array(lit(label)))
        ).withColumnRenamed("Matches", "Matches_{}".format(label))
    return results

The plot_img() and plot_urls() functions render query results as an image grid:

def plot_img(axis, url, title):
    """Download and display an image from a URL on a matplotlib axis."""
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content)).convert("RGB")
        axis.imshow(img, aspect="equal")
    except Exception as e:
        axis.text(0.5, 0.5, "Image\nunavailable", ha="center", va="center", fontsize=6)
    if title is not None:
        axis.set_title(title, fontsize=10)
    axis.axis("off")


def plot_urls(url_arr, titles, filename):
    """Create a grid visualization of artwork thumbnails and save to file."""
    nx, ny = url_arr.shape

    fig, axes = plt.subplots(ny, nx, figsize=(nx * 5, ny * 5), dpi=150)

    # Reshape required for a single-image query
    if len(axes.shape) == 1:
        axes = axes.reshape(1, -1)

    for i in range(nx):
        for j in range(ny):
            if j == 0:
                plot_img(axes[j, i], url_arr[i, j], titles[i])
            else:
                plot_img(axes[j, i], url_arr[i, j], None)

    plt.tight_layout()
    plt.savefig(filename, dpi=150)
    plt.show()

Run the query and visualize results

Define the test_all() function to orchestrate querying both models and generating visualizations:

def test_all(data, cknn_medium, cknn_culture, test_ids, root):
    """Query both k-NN models for given art IDs and save visualizations."""
    is_match = udf(lambda obj: obj in test_ids, BooleanType())
    test_df = data.where(is_match("id"))

    test_count = test_df.count()
    if test_count == 0:
        print("Warning: No matching art IDs found. Verify IDs exist in the filtered dataset.")
        return None

    print(f"Querying {test_count} artwork(s)...")

    results_df_medium = add_matches(mediums, cknn_medium, test_df)
    results_df_culture = add_matches(cultures, cknn_culture, results_df_medium)

    results = results_df_culture.collect()

    original_urls = [row["Thumbnail_Url"] for row in results]

    culture_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in cultures
    ]
    culture_url_arr = np.array([original_urls] + culture_urls)[:, :]
    plot_urls(culture_url_arr, ["Original"] + cultures, root + "matches_by_culture.png")

    medium_urls = [
        [row["Matches_{}".format(label)][0]["value"] for row in results]
        for label in mediums
    ]
    medium_url_arr = np.array([original_urls] + medium_urls)[:, :]
    plot_urls(medium_url_arr, ["Original"] + mediums, root + "matches_by_medium.png")

    return results_df_culture

Now, select sample art IDs from the filtered dataset and run the query:

# Select 3 sample artwork IDs from the filtered dataset
sample_rows = small_df.select("id").take(3)
selected_ids = {row["id"] for row in sample_rows}
print(f"Selected art IDs: {selected_ids}")

# Run the query and generate visualizations
result_df = test_all(small_df, medium_cknn, culture_cknn, selected_ids, root="./")

Two image grids appear inline. The first grid shows the original artwork with nearest neighbors across cultures. The second grid shows nearest neighbors across mediums.

Cleanup

Remove cached data and saved files when you finish exploring:

small_df.unpersist()
import os
for f in ["./matches_by_culture.png", "./matches_by_medium.png"]:
    if os.path.exists(f):
        os.remove(f)
        print(f"Removed {f}")
print("OK Cleanup complete")

Troubleshooting

Issue Cause Resolution
ModuleNotFoundError: No module named 'synapse.ml' Notebook not using Fabric runtime Verify your notebook is attached to a Fabric lakehouse with runtime 1.2+
Py4JJavaError during spark.read.parquet(...) Network connectivity issue Confirm your workspace can reach mmlspark.blob.core.windows.net on port 443
Empty result from test_all() (0 rows) Selected IDs aren't in the filtered dataset Use small_df.select("id").show(5) to pick valid IDs from the filtered data
HTTPError or blank images in visualization Thumbnail URL no longer accessible Some thumbnails may become unavailable over time. The plot_img function displays "Image unavailable" for failed downloads.
OutOfMemoryError during model fitting Dataset too large for available memory Reduce the number of categories in mediums and cultures lists
Slow model fitting (>10 minutes) Large dataset with many categories Start with fewer categories (3 each), then expand once the pipeline works

How conditional k-NN works

The conditional k-NN model relies on the BallTree data structure. A BallTree is a recursive binary tree where each node (or "ball") contains a partition of the data points you want to query.

To build a BallTree:

  1. Determine the "ball" center closest to each data point, based on a specified feature.
  2. Assign each data point to the nearest ball.
  3. Repeat recursively, creating a structure that supports binary-tree traversals.

This structure enables efficient k-nearest neighbor lookups at each leaf node.