Share via


Gedistribueerde training met PyTorch FSDP op serverloze GPU-rekenkracht

Dit notebook laat zien hoe u een Transformer-model traint met behulp van gedistribueerde training met de Fully Sharded Data Parallel (FSDP) van PyTorch op serverloze GPU-rekenkracht van Databricks. FSDP is een techniek voor gegevensparallellisme die modelparameters, gradienten en optimizerstatussen over meerdere GPU's versnipperd, waardoor efficiënte training mogelijk is van grote modellen die niet passen op één enkele GPU.

In dit voorbeeld leert u het volgende:

  • Gedistribueerde training instellen met de serverloze GPU gedistribueerde trainings-API
  • Een 10M parameter Transformer-model definiëren en trainen met behulp van FSDP
  • Gedistribueerde controlepunten opslaan tijdens de training
  • Experimenten bijhouden met MLflow
  • Controlepunten laden voor inferentie of voortzetting van training

In dit notebook worden synthetische gegevens gebruikt om deze zelfvoorzienend te houden, maar u kunt deze aanpassen aan uw eigen datasets.

Sleutelbegrippen:

  • FSDP (Fully Sharded Data Parallel): een gedistribueerde trainingsstrategie van PyTorch die shards-modelparameters over GPU's gebruikt om het geheugengebruik te verminderen en het trainen van grotere modellen mogelijk te maken.
  • Serverloze GPU-rekenkracht: door Databricks beheerde GPU-rekenkracht die automatisch resources voor uw workloads schaalt en in richt.

Zie gedistribueerde training voor meerdere GPU's en gedistribueerde knooppunten voor meer informatie.

Afhankelijkheden installeren

Installeer de nieuwste versie van MLflow voor het bijhouden van experimenten en modelregistratie.

%pip install -U mlflow
%restart_python

Unity Catalog-locaties configureren

Stel de Unity Catalog-locaties in waar het model en controlepunten worden opgeslagen. Werk deze waarden bij zodat deze overeenkomen met uw werkruimteconfiguratie. U hebt USE CATALOG en USE SCHEMA bevoegdheden nodig voor de opgegeven catalogus en het opgegeven schema.

# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

Helperfuncties en synthetische gegevensset definiëren

In deze sectie worden hulpprogrammafuncties gedefinieerd voor gedistribueerde trainingsinstallatie en een synthetische gegevenssetklasse voor demonstratiedoeleinden. In productie vervangt u de SyntheticDataset door uw eigen logica voor het laden van gegevens.

Belangrijkste onderdelen:

  • setup(): Initialiseert de gedistribueerde trainingsprocesgroep en configureert GPU-apparaten
  • cleanup(): De gedistribueerde procesgroep opschonen na de training
  • AppState: Een wrapper-klasse voor het controlepuntmodel en de optimalisatiestatus die compatibel is met de gedistribueerde controlepunt-API van PyTorch
  • SyntheticDataset: Genereert willekeurige gegevens voor trainingsdemonstratie
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time

# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

def setup():
    """Initialize the distributed training process group"""
    # Check if we're in a distributed environment
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
    else:
        # Fallback for single GPU
        rank = 0
        world_size = 1
        local_rank = 0

    # Initialize process group
    if world_size > 1:
        if not dist.is_initialized():
            dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

    # Set device
    if torch.cuda.is_available():
        device = torch.device(f'cuda:{local_rank}')
        torch.cuda.set_device(device)
    else:
        device = torch.device('cpu')

    return rank, world_size, device

def cleanup():
    """Clean up the distributed training process group"""
    if dist.is_initialized():
        dist.destroy_process_group()

class SyntheticDataset(Dataset):
    """Simple synthetic dataset for demo purposes"""
    def __init__(self, size=10000, input_dim=512, num_classes=10):
        self.size = size
        self.input_dim = input_dim
        self.num_classes = num_classes

        # Generate synthetic data
        np.random.seed(42)  # For reproducible results
        self.data = torch.randn(size, input_dim)
        # Create labels with some pattern
        self.labels = torch.randint(0, num_classes, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

Het transformatiemodel definiëren met FSDP

In deze sectie wordt een eenvoudig transformermodel gedefinieerd voor classificatie en de logica voor het toepassen van FSDP-sharding. Hoewel FSDP doorgaans wordt gebruikt voor grote taalmodellen met 7B+ parameters, demonstreert dit voorbeeld de techniek met een kleiner model met 10 miljoen parameters dat wordt verdeeld over meerdere H100 GPU's.

Modelarchitectuur:

  • TransformerBlock: Eén transformatorlaag met meerdere hoofden aandacht en MLP
  • SimpleTransformer: Een stapel transformatorblokken met invoerprojectie en classificatiekop
  • apply_fsdp(): Omhult modellagen met FSDP voor gedistribueerde training

FSDP splitst de modelparameters, gradienten en toestanden van de optimizer over GPU's, waardoor de geheugenvereisten per GPU worden verminderd en grotere modellen kunnen worden getraind.

class TransformerBlock(nn.Module):
    """Simple transformer block for testing FSDP"""
    def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim),
        )

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.norm2(x + mlp_out)

        return x

class SimpleTransformer(nn.Module):
    """Simple transformer model for classification with FSDP"""
    def __init__(self, input_dim=512, num_layers=64, num_classes=10):
        super().__init__()
        self.input_projection = nn.Linear(input_dim, input_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(dim=input_dim) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(input_dim)
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # Add sequence dimension for transformer
        x = x.unsqueeze(1)  # [batch, 1, input_dim]

        x = self.input_projection(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        # Global average pooling
        x = x.mean(dim=1)  # [batch, input_dim]

        return self.classifier(x)

def apply_fsdp(model, world_size):
    """Apply FSDP to the model"""
    if world_size > 1:
        print("Applying FSDP to model layers...")
        # Apply fsdp to each transformer layer
        for i, layer in enumerate(model.layers):
            fully_shard(layer)
            print(f"Applied FSDP to layer {i}")

        # Apply FSDP to the entire model
        fully_shard(model)
        print("Applied FSDP to entire model")
    else:
        print("Single GPU detected, skipping FSDP setup")

    return model

De gedistribueerde trainingsfunctie definiëren

De trainingsfunctie wordt verpakt met de @distributed decorator van de serverloze GPU-API. Deze decorator verwerkt:

  • Het opgegeven aantal GPU's inrichten (8 H100 GPU's in dit voorbeeld)
  • De gedistribueerde trainingsomgeving instellen
  • De levenscyclus van externe rekenresources beheren

De trainingsfunctie bevat:

  • Modelinitialisatie en FSDP-omwikkeling
  • Gegevens laden met DistributedSampler voor parallelle gegevensverwerking
  • Trainingslus met gradiëntupdates
  • Periodiek controlepunt opslaan met behulp van de gedistribueerde controlepunt-API van PyTorch
  • MLflow-logboekregistratie voor het bijhouden van experimenten

Controlepunten worden opgeslagen op een Unity Catalog-volume en geregistreerd als MLflow-artefacten voor versiebeheer en reproduceerbaarheid.

from serverless_gpu import distributed
from serverless_gpu.compute import GPUType

NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
    """
    Self-contained FSDP training demo using PyTorch 2.0+
    Trains a simple neural network on synthetic data using FSDP
    """
    import mlflow
    mlflow.start_run(run_name='fsdp_example')
    def main_training():
        """Main training function"""
        print("Starting FSDP Training Demo...")

        # Setup distributed training
        rank, world_size, device = setup()

        print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA device count: {torch.cuda.device_count()}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")

        # Create dataset and data loader
        dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)

        # Use DistributedSampler if we have multiple processes
        if world_size > 1:
            sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
            shuffle = False
        else:
            sampler = None
            shuffle = True

        dataloader = DataLoader(
            dataset,
            batch_size=32,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True
        )

        # Create model
        model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)

        # Apply FSDP
        model = apply_fsdp(model, world_size)

        print(f"Model created and moved to device: {device}")
        if rank == 0:
            print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

        # Training loop
        num_epochs = 5
        loss_history = []

        print(f"Training for {num_epochs} epochs...")
        writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
        for epoch in range(num_epochs):
            if sampler:
                sampler.set_epoch(epoch)

            model.train()
            total_loss = 0.0
            num_batches = 0

            epoch_start_time = time.time()

            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(device), target.to(device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                output = model(data)
                loss = criterion(output, target)

                # Backward pass
                loss.backward()
                mlflow.log_metric(
                    key='loss',
                    value=loss.item(),
                    step=batch_idx,
                )
                # Update weights
                optimizer.step()

                total_loss += loss.item()

                num_batches += 1

                if batch_idx % 10 == 0:
                    print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
                    state_dict = { 'app': AppState(model, optimizer) }
                    ckpt_start_time = time.time()
                    dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
                    ckpt_time = time.time() - ckpt_start_time
                    print(f'Checkpointing took {ckpt_time:.2f}s')
                    mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
                    if rank == 0:
                        print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
            # Calculate average loss for this epoch
            avg_loss = total_loss / num_batches
            mlflow.log_metric(key='avg_loss', value=avg_loss)

            loss_history.append(avg_loss)

            epoch_time = time.time() - epoch_start_time

            if rank == 0:
                print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')

        # Verify loss is decreasing
        if rank == 0:
            print("\n=== FSDP Training Results ===")
            print("Loss history:")
            for i, loss in enumerate(loss_history):
                print(f"Epoch {i+1}: {loss:.6f}")

            # Check if loss is generally decreasing
            initial_loss = loss_history[0]
            final_loss = loss_history[-1]
            loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100

            print(f"\nInitial Loss: {initial_loss:.6f}")
            print(f"Final Loss: {final_loss:.6f}")
            print(f"Loss Reduction: {loss_reduction:.2f}%")

            if final_loss < initial_loss:
                print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
            else:
                print("❌ WARNING: Loss did not decrease. Check training configuration.")

            print(f"\nFSDP training completed successfully on {world_size} GPU(s)")

        # Cleanup
        cleanup()
        mlflow.end_run()

        return {
            'initial_loss': loss_history[0] if loss_history else None,
            'final_loss': loss_history[-1] if loss_history else None,
            'loss_history': loss_history,
            'world_size': world_size,
            'device': str(device),
            'fsdp_enabled': world_size > 1
        }

    # Run the training
    return main_training()

De gedistribueerde training uitvoeren

Voer de trainingsfunctie uit om gedistribueerde training te starten over 8 H100 GPU's. De .distributed() methode activeert externe uitvoering op serverloze GPU-rekenkracht. Trainingsvoortgang, metrische gegevens over verlies en controlepunten worden vastgelegd in MLflow.

Het kan enkele minuten duren voordat deze cel GPU-resources inricht, het model traint voor 5 tijdvakken en controlepunten opslaat.

print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")

Een modelcontrolepunt laden

In deze sectie wordt gedemonstreerd hoe u een opgeslagen controlepunt laadt voor deductie of voortgezette training. Het controlepunt bevat de modelgewichten en de optimalisatiestatus die tijdens de training zijn opgeslagen.

Houd er rekening mee dat bij het laden van controlepunten buiten een gedistribueerde trainingscontext (geen procesgroep geïnitialiseerd), de gedistribueerde controlepunt-API van PyTorch collectieve bewerkingen automatisch uitschakelt en het controlepunt op één apparaat laadt.

def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    state_dict = { 'app': AppState(model, optimizer)}

    # print(state_dict)
    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=f'{CHECKPOINT_DIR}/step0',
    )
    model.load_state_dict(state_dict['app'].state_dict()['model'])

run_checkpoint_load_example()

Volgende stappen 

Nu u hebt geleerd hoe u PyTorch FSDP kunt gebruiken voor gedistribueerde training op serverloze GPU-compute, kunt u deze resources verkennen voor meer informatie:

Voorbeeld van notebook

Gedistribueerde training met PyTorch FSDP op serverloze GPU-rekenkracht

Notebook krijgen