Распределенное обучение с использованием PyTorch FSDP на вычислениях на GPU без сервера

В этой записной книжке показано, как обучить модель преобразователя с помощью распределенного обучения с помощью полного сегментированного параллелизма данных (FSDP) в Databricks бессерверных вычислений GPU. FSDP — это метод параллелизма данных, который сегментирует параметры модели, градиенты и состояния оптимизатора в нескольких GPU, что позволяет эффективно обучать большие модели, которые не подходят для одного GPU.

В этом примере вы узнаете, как:

  • Настройка распределенного обучения с помощью БЕССерверного API распределенного обучения GPU
  • Определение и обучение модели преобразователя 10M параметров с помощью FSDP
  • Сохранение распределенных контрольных точек во время обучения
  • Отслеживание экспериментов с помощью MLflow
  • Загрузка контрольных точек для вывода или продолжения обучения

Эта записная книжка использует искусственные данные, чтобы она была независимой, но вы можете адаптировать её для работы с собственными наборами данных.

Основные понятия:

  • FSDP (полностью сегментированная параллельная передача данных) — стратегия распределенного обучения PyTorch, которая сегментирует параметры модели в gpu, чтобы уменьшить использование памяти и включить обучение более крупных моделей.
  • Бессерверные вычислительные ресурсы GPU: управляемые вычислительные ресурсы Databricks, которые автоматически масштабирует и подготавливает ресурсы для рабочих нагрузок.

Для получения дополнительной информации см. Обучение с использованием нескольких GPU и распределённого обучения на множестве узлов.

Установка зависимостей

Установите последнюю версию MLflow для отслеживания экспериментов и ведения журнала моделей.

%pip install -U mlflow
%restart_python

Настройка расположений каталога Unity

Настройте расположения каталога Unity, в которых будут храниться модель и контрольные точки. Обновите эти значения, чтобы соответствовать конфигурации рабочей области. Требуются привилегии USE CATALOG и USE SCHEMA для указанного каталога и схемы.

# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

Определение вспомогательных функций и синтетических наборов данных

В этом разделе определяются служебные функции для распределенной настройки обучения и класса искусственного набора данных для демонстрационных целей. В рабочей среде вы замените SyntheticDataset на вашу собственную логику загрузки данных.

Ключевые компоненты:

  • setup(): инициализирует группу распределенных процессов обучения и настраивает устройства GPU
  • cleanup(): очищает распределенную группу процессов после обучения
  • AppState: класс-оболочка для модели контрольной точки и состояния оптимизатора, совместимого с API распределенной контрольной точки PyTorch.
  • SyntheticDataset: создает случайные данные для демонстрации обучения
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time

# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

def setup():
    """Initialize the distributed training process group"""
    # Check if we're in a distributed environment
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
    else:
        # Fallback for single GPU
        rank = 0
        world_size = 1
        local_rank = 0

    # Initialize process group
    if world_size > 1:
        if not dist.is_initialized():
            dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

    # Set device
    if torch.cuda.is_available():
        device = torch.device(f'cuda:{local_rank}')
        torch.cuda.set_device(device)
    else:
        device = torch.device('cpu')

    return rank, world_size, device

def cleanup():
    """Clean up the distributed training process group"""
    if dist.is_initialized():
        dist.destroy_process_group()

class SyntheticDataset(Dataset):
    """Simple synthetic dataset for demo purposes"""
    def __init__(self, size=10000, input_dim=512, num_classes=10):
        self.size = size
        self.input_dim = input_dim
        self.num_classes = num_classes

        # Generate synthetic data
        np.random.seed(42)  # For reproducible results
        self.data = torch.randn(size, input_dim)
        # Create labels with some pattern
        self.labels = torch.randint(0, num_classes, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

Определение модели преобразователя с помощью FSDP

В этом разделе определяется простая модель преобразователя для классификации и логики применения сегментирования FSDP. Хотя FSDP обычно используется для больших языковых моделей с параметрами 7B+, в этом примере демонстрируется метод с меньшей моделью 10M параметров, сегментированной по нескольким GPU H100.

Архитектура модели:

  • TransformerBlock: один слой трансформера с многоголовым вниманием и MLP
  • SimpleTransformer: стек блоков преобразователя с входной проекцией и головой классификации
  • apply_fsdp(): оборачивает слои модели с использованием FSDP для распределенного обучения

FSDP разделяет параметры модели, градиенты и состояния оптимизатора между GPU, уменьшая требования к памяти на каждом GPU и что позволяет обучать более крупные модели.

class TransformerBlock(nn.Module):
    """Simple transformer block for testing FSDP"""
    def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim),
        )

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.norm2(x + mlp_out)

        return x

class SimpleTransformer(nn.Module):
    """Simple transformer model for classification with FSDP"""
    def __init__(self, input_dim=512, num_layers=64, num_classes=10):
        super().__init__()
        self.input_projection = nn.Linear(input_dim, input_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(dim=input_dim) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(input_dim)
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # Add sequence dimension for transformer
        x = x.unsqueeze(1)  # [batch, 1, input_dim]

        x = self.input_projection(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        # Global average pooling
        x = x.mean(dim=1)  # [batch, input_dim]

        return self.classifier(x)

def apply_fsdp(model, world_size):
    """Apply FSDP to the model"""
    if world_size > 1:
        print("Applying FSDP to model layers...")
        # Apply fsdp to each transformer layer
        for i, layer in enumerate(model.layers):
            fully_shard(layer)
            print(f"Applied FSDP to layer {i}")

        # Apply FSDP to the entire model
        fully_shard(model)
        print("Applied FSDP to entire model")
    else:
        print("Single GPU detected, skipping FSDP setup")

    return model

Определение распределенной функции обучения

Функция обучения оборачивается декоратором @distributed, предоставленным бессерверным API GPU. Этот декоратор обрабатывает:

  • Подготовка указанного количества графических процессоров (8 графических процессоров H100 GPU в этом примере)
  • Настройка распределенной среды обучения
  • Управление жизненным циклом удаленных вычислительных ресурсов

Функция обучения включает:

  • Инициализация модели и оболочка FSDP
  • Загрузка данных с помощью DistributedSampler для параллельной обработки данных
  • Цикл обучения с обновлениями градиента
  • Периодическое сохранение чекпойнтов с помощью API для распределенных чекпойнтов PyTorch
  • Ведение журнала MLflow для отслеживания экспериментов

Контрольные точки сохраняются в томе каталога Unity и регистрируются как объекты MLflow для версионирования и воспроизводимости.

from serverless_gpu import distributed
from serverless_gpu.compute import GPUType

NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
    """
    Self-contained FSDP training demo using PyTorch 2.0+
    Trains a simple neural network on synthetic data using FSDP
    """
    import mlflow
    mlflow.start_run(run_name='fsdp_example')
    def main_training():
        """Main training function"""
        print("Starting FSDP Training Demo...")

        # Setup distributed training
        rank, world_size, device = setup()

        print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA device count: {torch.cuda.device_count()}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")

        # Create dataset and data loader
        dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)

        # Use DistributedSampler if we have multiple processes
        if world_size > 1:
            sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
            shuffle = False
        else:
            sampler = None
            shuffle = True

        dataloader = DataLoader(
            dataset,
            batch_size=32,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True
        )

        # Create model
        model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)

        # Apply FSDP
        model = apply_fsdp(model, world_size)

        print(f"Model created and moved to device: {device}")
        if rank == 0:
            print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

        # Training loop
        num_epochs = 5
        loss_history = []

        print(f"Training for {num_epochs} epochs...")
        writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
        for epoch in range(num_epochs):
            if sampler:
                sampler.set_epoch(epoch)

            model.train()
            total_loss = 0.0
            num_batches = 0

            epoch_start_time = time.time()

            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(device), target.to(device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                output = model(data)
                loss = criterion(output, target)

                # Backward pass
                loss.backward()
                mlflow.log_metric(
                    key='loss',
                    value=loss.item(),
                    step=batch_idx,
                )
                # Update weights
                optimizer.step()

                total_loss += loss.item()

                num_batches += 1

                if batch_idx % 10 == 0:
                    print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
                    state_dict = { 'app': AppState(model, optimizer) }
                    ckpt_start_time = time.time()
                    dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
                    ckpt_time = time.time() - ckpt_start_time
                    print(f'Checkpointing took {ckpt_time:.2f}s')
                    mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
                    if rank == 0:
                        print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
            # Calculate average loss for this epoch
            avg_loss = total_loss / num_batches
            mlflow.log_metric(key='avg_loss', value=avg_loss)

            loss_history.append(avg_loss)

            epoch_time = time.time() - epoch_start_time

            if rank == 0:
                print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')

        # Verify loss is decreasing
        if rank == 0:
            print("\n=== FSDP Training Results ===")
            print("Loss history:")
            for i, loss in enumerate(loss_history):
                print(f"Epoch {i+1}: {loss:.6f}")

            # Check if loss is generally decreasing
            initial_loss = loss_history[0]
            final_loss = loss_history[-1]
            loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100

            print(f"\nInitial Loss: {initial_loss:.6f}")
            print(f"Final Loss: {final_loss:.6f}")
            print(f"Loss Reduction: {loss_reduction:.2f}%")

            if final_loss < initial_loss:
                print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
            else:
                print("❌ WARNING: Loss did not decrease. Check training configuration.")

            print(f"\nFSDP training completed successfully on {world_size} GPU(s)")

        # Cleanup
        cleanup()
        mlflow.end_run()

        return {
            'initial_loss': loss_history[0] if loss_history else None,
            'final_loss': loss_history[-1] if loss_history else None,
            'loss_history': loss_history,
            'world_size': world_size,
            'device': str(device),
            'fsdp_enabled': world_size > 1
        }

    # Run the training
    return main_training()

Запуск распределенного обучения

Выполните функцию обучения, чтобы начать распределенное обучение по 8 GPU H100. Метод активирует удаленное .distributed() выполнение на бессерверных вычислениях GPU. Ход обучения, метрики потери и контрольные точки записываются в MLflow.

Эта ячейка может занять несколько минут, так как она подготавливает ресурсы GPU, обучает модель для 5 эпох и сохраняет контрольные точки.

print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")

Загрузка контрольной точки модели

В этом разделе показано, как загрузить сохраненную контрольную точку для вывода или продолжения обучения. Контрольная точка содержит вес модели и состояние оптимизатора, сохраненное во время обучения.

Обратите внимание, что при загрузке контрольных точек за пределами распределенного контекста обучения (без инициализации группы процессов) API распределенной контрольной точки PyTorch автоматически отключает коллективные операции и загружает контрольную точку на одном устройстве.

def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    state_dict = { 'app': AppState(model, optimizer)}

    # print(state_dict)
    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=f'{CHECKPOINT_DIR}/step0',
    )
    model.load_state_dict(state_dict['app'].state_dict()['model'])

run_checkpoint_load_example()

Дальнейшие действия

Теперь, когда вы узнали, как использовать PyTorch FSDP для распределенного обучения для бессерверных вычислений GPU, изучите эти ресурсы, чтобы узнать больше:

Пример записной книжки

Распределенное обучение с использованием PyTorch FSDP на вычислениях на GPU без сервера

Получите ноутбук