Share via


Load data on Serverless GPU compute

This section covers information about loading data on serverless GPU compute for specifically for ML and DL applications. Check the tutorial to learn more about how to load and transform data using the Spark Python API.

Load tabular data

Use Spark Connect to load tabular machine learning data from Delta tables.

For single node training, You can convert Apache Spark DataFrames into pandas DataFrames using the PySpark method toPandas(), and then optionally convert to NumPy format using the PySpark method to_numpy().

Note

Spark Connect defers analysis and name resolution to execution time, which may change the behavior of your code. See Compare Spark Connect to Spark Classic.

Spark Connect supports most PySpark APIs, including Spark SQL, Pandas API on Spark, Structured Streaming, and MLlib (DataFrame-based). See the PySpark API reference documentation for the latest supported APIs.

For other limitations, see Serverless compute limitations.

Load data inside the @distributed decorator

When using the Serverless GPU API for distributed training, move data loading code inside the @distributed decorator. The dataset size can exceed the maximum size allowed by pickle, so it is recommended to generate the dataset inside the decorator, as shown below:

from serverless_gpu import distributed

# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
  # good practice
  dataset = get_dataset(file_path)
  ....

Data loading performance

/Workspace and /Volumes directories are hosted on remote Unity Catalog storage. If your dataset is stored in Unity Catalog, the data loading speed is limited by the available network bandwidth. If you are training multiple epochs, its recommended to first copy the data locally, specifically to the /tmp directory which is hosted on super fast storage(NVMe SSD disks).

If your dataset is large we also recommend the following techniques to parallelize training and data loading:

  • When training multiple epochs, update your dataset to cache files locally in the /tmp directory before reading each file. On subsequent epochs, use the cached version.
  • Parallelize data fetching by enabling workers in the torch DataLoader API. Set num_workers to at least 2. By default, each worker prefetches two work items. To improve performance, increase num_workers (which will increase the parallel reads) or prefetch_factor (which will increase the number of items each worker prefetches).