Image Data Support in Apache Spark
This post is co-authored by the Microsoft Azure Machine Learning team, in collaboration with Databricks Machine Learning team.
Introduction
Apache Spark is being increasingly used for deep learning applications for image processing and computer vision at scale. Problems such as image classification or object detection are being solved using deep learning frameworks such as Cognitive Toolkit (CNTK), TensorFlow, BigDL and DeepLearning4J, and integrated into Spark through libraries such as MMLSpark or TensorFlowOnSpark. However, until now, there hasn't been a common interface for importing images, or representing images in Spark DataFrames. Consequently, the different frameworks cannot easily communicate with each other or with core Spark components such as SparkML pipelines or Deep Learning pipelines. To overcome this problem, the Microsoft Azure Machine Learning Team collaborated with Databricks and the Apache Spark community to make images a first-class citizen in core Spark, based on existing industrial standards.
Importing and Representing Images in Spark DataFrames
An image processing and computer vision pipeline typically consists of the image import, preprocessing, model training and inferencing stages, depicted below.
To accurately represent an image throughout this pipeline, you need certain pieces of data:
- The pixel values that represent the image itself.
- Image resolution or bit depth, e.g. 8-bit, 16-bit, 32-bit and so on.
- Number and order of color channels, e.g. grayscale, RGB, CYMK, etc.
- Height and width of the image.
- Metadata about the origin of image, such as file system path.
Having all these bits of data is important. For example, pre-trained deep neural networks assume specific image size, normalization and order of color channels that matches what the model was originally trained against. A mistake here, and the accuracy of the model can suffer catastrophically. Therefore, it is important to have a consistent representation of image metadata throughout the machine learning pipeline.
The structure spark.ml.image.imageSchema is used to capture this information in standardized way. This makes it easy to build re-usable image pipelines that feed into different deep learning libraries, as well as efficiently use libraries such as OpenCV to pre-process images. The pixels are stored as uncompressed binary data, ensuring close-to-metal performance and low conversion overheads. The OpenCV convention is used to describe the bit depth and color channels of the image.
Formally, the schema is defined as:
StructType(,
StructField("origin", StringType, true) ::
StructField("height", IntegerType, false) ::
StructField("width", IntegerType, false) ::
StructField("nChannels", IntegerType, false) ::
StructField("mode", IntegerType, false) ::
StructField("data", BinaryType, false) :: Nil)
where "nChannels" is OpenCV-compatible type, and "mode" is OpenCV-compatible byte order.
The method spark.readImages lets you read images in common formats (jpg, png, etc.) from HDFS storage into DataFrame. Each image is stored as a row in the imageSchema format. The API is defined as:
readImages(,
path: String,
sparkSession: SparkSession,
recursive: Boolean,
numPartitions: Int,
dropImageFailures: Boolean,
sampleRatio: Double,
seed: Long)
The recursive option allows you to read images from subfolders, for example for positive and negative labeled samples. The sampleRatio parameter allows you to experiment with a smaller sample of images before training a model with full data.
The readImages API is available through MMLSpark library, together with additional methods for pre-processing images, or as a stand-alone reference implementation at spark-packages.org.
Example of Image Transformation Pipeline
This Jupyter notebook demonstrates how the image data can be read in, and processed within a SparkML pipeline. The following lines show how you can read in a collection of images as Spark DataFrames. Note how the readImages function appears as a member of Spark context, similar to spark.read.csv or spark.read.json. You can then inspect the schema and analyze the properties of your image dataset:
images = spark.readImages(IMAGE_PATH, recursive = True, sampleRatio = 0.1).cache()
images.printSchema()
print(images.count())
You can then extract the pixel data, and pass it to deep learning models for image classification and computer vision through Apache Spark Deep Learning Pipelines.
You can also apply MMLSpark's image transformations to resize and crop the images as pipeline stages. The transformed data can then be fed into, for example, a deep learning model to classify the images.
from mmlspark import ImageTransformer
tr = (ImageTransformer() # images are resized and then cropped
.setOutputCol("transformed")
.resize(height = 200, width = 200)
.crop(0, 0, height = 180, width = 180) )
smallImages = tr.transform(images).select("transformed")
You can use deep neural networks such as CNTK or TensorFlow to extract high-order features from the images and then pass them to SparkML machine learning algorithms, using the transfer learning approach.
featurizedImages = cntkModel.transform(smallImages).select(["features","labels"])
from mmlspark import TrainClassifier
from pyspark.ml.classification import RandomForestClassifier
model = TrainClassifier(model=RandomForestClassifier(),labelCol="labels").fit(featurizedImages)
The result is an end-to-end pipeline that you can use to read, preprocess and classify images in scalable fashion.
Next Steps
The image APIs have been recently merged to Apache Spark core and are included in Spark release 2.3. Try it out and send us your feedback. Also try the image preprocessing functionality in the MMLSpark library.
ML Blog Team