On-demand state repartitioning for stateful streaming queries

On-demand state repartitioning allows you to resize the number of partitions for a stateful Structured Streaming query without losing checkpoint state.

Without on-demand state repartitioning, you set the number of shuffle partitions during checkpoint creation. If you change spark.sql.shuffle.partitions, queries with existing checkpoints ignore the new value. Applying a new partition count requires you to restart the query with a new checkpoint.

On-demand state repartitioning has the following benefits:

  • Tune queries by resizing the number of partitions without rebuilding the checkpoint.
  • Scale queries up or down to match workload changes.

Requirements

Change the number of partitions

Use the spark configuration spark.sql.streaming.stateStore.partitions and restart the query to change the number of shuffle and streaming state partitions:

Python

query.stop()
spark.conf.set("spark.sql.streaming.stateStore.partitions", "<numPartitions>")
query = df.writeStream.start()

Scala

query.stop()
spark.conf.set("spark.sql.streaming.stateStore.partitions", "<numPartitions>")
val query = df.writeStream.start()

For stateful queries, spark.sql.streaming.stateStore.partitions takes precedence over spark.sql.shuffle.partitions. After the query restarts and the last planned microbatch completes, the query runs a repartition operation to redistribute state data into the new number of partitions. After the repartition operation completes, the query resumes processing.

Monitor repartition state

After the next microbatch completes, StreamingQueryProgress events include the duration of the repartition operation. In an event's durationMs metrics, controlBatch.REPARTITION shows the duration value in milliseconds. Larger state sizes might increase the time to repartition. See Monitoring Structured Streaming queries on Azure Databricks.

Example

The following example scales a query down from 200, the default, to 100 shuffle partitions. Stop the query, set the new partition count, and restart:

Python

# Start the query with the default partition count (200)
query = (df
  .withWatermark("event_time", "10 minutes")
  .groupBy(
    window("event_time", "5 minutes"),
    "id")
  .count()
  .writeStream
  .format("delta")
  .option("checkpointLocation", "/checkpoint/path")
  .outputMode("append")
  .start()
)

# Stop the query and scale down to 100 partitions
query.stop()

spark.conf.set("spark.sql.streaming.stateStore.partitions", "100")

# Restart the query with the same options
query = (df
  .withWatermark("event_time", "10 minutes")
  .groupBy(
    window("event_time", "5 minutes"),
    "id")
  .count()
  .writeStream
  .format("delta")
  .option("checkpointLocation", "/checkpoint/path")
  .outputMode("append")
  .start()
)

Scala

// Start the query with the default partition count (200)
val query = df
  .withWatermark("event_time", "10 minutes")
  .groupBy(
    window($"event_time", "5 minutes"),
    $"id")
  .count()
  .writeStream
  .format("delta")
  .option("checkpointLocation", "/checkpoint/path")
  .outputMode("append")
  .start()

// Stop the query and scale down to 100 partitions
query.stop()

spark.conf.set("spark.sql.streaming.stateStore.partitions", "100")

// Restart the query with the same options
val query2 = df
  .withWatermark("event_time", "10 minutes")
  .groupBy(
    window($"event_time", "5 minutes"),
    $"id")
  .count()
  .writeStream
  .format("delta")
  .option("checkpointLocation", "/checkpoint/path")
  .outputMode("append")
  .start()