Share via


Get started: Serverless GPU compute with A10 GPUs

This notebook demonstrates how to use serverless GPU compute to run GPU workloads on A10 GPUs directly from Databricks notebooks. You'll learn how to use the serverless_gpu Python library to execute functions on single or multiple GPUs for distributed training.

Serverless GPU compute provides on-demand access to GPU resources without managing clusters. The serverless_gpu library enables seamless execution of GPU workloads with automatic resource provisioning. To learn more, see the Serverless GPU API documentation.

Requirements

Before running this notebook, connect it to Serverless GPU compute:

  1. From the compute selector, select Serverless GPU.
  2. In the Environment tab on the right side, select A10 as the Accelerator.

Verify GPU connection

Run the nvidia-smi command to confirm that your notebook is connected to an A10 GPU and view GPU specifications.

%sh nvidia-smi
Wed Jan 14 19:31:33 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A10G                    On  |   00000000:00:1E.0 Off |                    0 |
|  0%   22C    P8             23W /  300W |       1MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Import the serverless GPU library

Import the serverless_gpu library to access the API for running functions on GPU resources.

import serverless_gpu
Warning: serverless_gpu is in Beta. The API is subject to change.

Create a distributed function

Use the @distributed decorator to create a DistributedFunction that runs on GPU resources. This decorator accepts the following parameters:

  • gpus (int): Number of GPUs to use
  • gpu_type (Optional[Union[GPUType, str]]): The GPU type to use (required if remote=True). Available types: 'a10', 'h100'
  • remote (bool): Whether to run the function on remote GPUs (defaults to False)
  • run_async (bool): Whether to run the function asynchronously (defaults to False)
from serverless_gpu import distributed

@distributed(gpus=1, gpu_type='a10', remote=True)
def foo(x):
  print('hello_world', x)
  return x


foo
DistributedFunction(gpus=1, gpu_type=GPUType.A10, remote=True, run_async=False, func=foo)

Run the distributed function

Launch the DistributedFunction using the .distributed() method. Pass any required arguments as keyword parameters.

foo.distributed(x=5)
[5]

Distributed training with multiple GPUs

You can launch multiple A10 GPUs in parallel for distributed training workloads. The serverless_gpu.runtime module provides helper functions to manage distributed execution:

  • get_local_rank(): Get the local rank of the current GPU
  • get_global_rank(): Get the global rank across all GPUs
  • get_world_size(): Get the total number of GPUs

Note: Multi-node runs of up to 70 nodes may take as long as 20 minutes to start, with each subsequent node taking longer. For larger runs, you might experience longer wait times or occasional failures.

# The runtime module includes helpers to be used during the GPU runtime (i.e. in the function body).
# These helpers include get_local_rank, get_global_rank, get_world_size
from serverless_gpu import runtime as rt

@distributed(gpus=3, gpu_type='a10', remote=True)
def multi_a10():
  return rt.get_global_rank(), rt.get_world_size()


multi_a10.distributed() # returns a list, one element per GPU

[(0, 3), (1, 3), (2, 3)]

Next steps

Learn more about serverless GPU compute:

Example notebook

Get started: Serverless GPU compute with A10 GPUs

Get notebook