Distribuované trénování s využitím PyTorch FSDP na výpočetních prostředcích GPU bez serveru

Tento notebook demonstruje, jak trénovat model Transformer pomocí distribuovaného učení s využitím PyTorch Fully Sharded Data Parallel (FSDP) na serverless GPU výpočetních prostředcích Databricks. FSDP je technika paralelizmu dat, která sharduje parametry modelu, gradienty a stavy optimalizátoru napříč několika GPU, což umožňuje efektivní trénování velkých modelů, které se nevejdou do jedné GPU.

V tomto příkladu se naučíte:

Tento poznámkový blok používá syntetická data k tomu, aby byla samostatná, ale můžete je přizpůsobit pro práci s vlastními datovými sadami.

Klíčové koncepty:

  • FSDP (plně oddělené datové paralelní):: Strategie distribuovaného trénování PyTorch, která rozděluje parametry modelu napříč GPU, aby se snížilo využití paměti a umožnilo trénování větších modelů.
  • Výpočetní prostředí GPU bez serveru: Výpočetní výkon GPU spravovaný službou Databricks, který automaticky škáluje a zřizuje prostředky pro vaše úlohy.

Další informace najdete v tématu Více GPU a distribuované trénování s více uzly.

Nainstalujte závislosti

Nainstalujte nejnovější verzi MLflow pro sledování experimentů a protokolování modelu.

%pip install -U mlflow
%restart_python

Konfigurace umístění katalogu Unity

Nastavte umístění katalogu Unity, kde se uloží model a kontrolní body. Aktualizujte tyto hodnoty tak, aby odpovídaly konfiguraci pracovního prostoru. Potřebujete zadaná oprávnění USE CATALOG k katalogu USE SCHEMA a schématu.

# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

Definování pomocných funkcí a syntetických datových sad

Tato část definuje pomocné funkce pro distribuované trénování a syntetickou třídu datových sad pro demonstrační účely. V produkční fázi byste SyntheticDataset nahradili vlastní logikou načítání dat.

Klíčové komponenty:

  • setup(): Inicializuje distribuovanou skupinu procesů trénování a nakonfiguruje zařízení GPU.
  • cleanup(): Po trénování vyčistí distribuovanou skupinu procesů.
  • AppState: Třída obálky pro model kontrolních bodů a stav optimalizátoru, který je kompatibilní s rozhraním API distribuovaného kontrolního bodu PyTorch
  • SyntheticDataset: Generuje náhodná data pro ukázku trénování.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time

# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

def setup():
    """Initialize the distributed training process group"""
    # Check if we're in a distributed environment
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
    else:
        # Fallback for single GPU
        rank = 0
        world_size = 1
        local_rank = 0

    # Initialize process group
    if world_size > 1:
        if not dist.is_initialized():
            dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

    # Set device
    if torch.cuda.is_available():
        device = torch.device(f'cuda:{local_rank}')
        torch.cuda.set_device(device)
    else:
        device = torch.device('cpu')

    return rank, world_size, device

def cleanup():
    """Clean up the distributed training process group"""
    if dist.is_initialized():
        dist.destroy_process_group()

class SyntheticDataset(Dataset):
    """Simple synthetic dataset for demo purposes"""
    def __init__(self, size=10000, input_dim=512, num_classes=10):
        self.size = size
        self.input_dim = input_dim
        self.num_classes = num_classes

        # Generate synthetic data
        np.random.seed(42)  # For reproducible results
        self.data = torch.randn(size, input_dim)
        # Create labels with some pattern
        self.labels = torch.randint(0, num_classes, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

Definování modelu Transformer pomocí FSDP

Tato část definuje jednoduchý model Transformer pro klasifikaci a logiku pro použití horizontálního dělení FSDP. I když se FSDP obvykle používá pro velké jazykové modely s více než 7 miliardami parametrů, tento příklad ukazuje techniku s modelem s 10 miliony parametry, rozděleným mezi více grafických procesorů H100.

Architektura modelu:

  • TransformerBlock: Jedna vrstva transformátoru s více hlavami pozornosti a MLP
  • SimpleTransformer: Zásobník transformátorových bloků se vstupní projekcí a klasifikační hlavou
  • apply_fsdp(): Obaluje vrstvy modelu použitím FSDP pro distribuované trénování.

FSDP horizontálně dělí parametry modelu, gradienty a stavy optimalizátoru napříč jednotlivými GPU, snižuje požadavky na paměť na jednotlivé GPU a umožňuje trénování větších modelů.

class TransformerBlock(nn.Module):
    """Simple transformer block for testing FSDP"""
    def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim),
        )

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.norm2(x + mlp_out)

        return x

class SimpleTransformer(nn.Module):
    """Simple transformer model for classification with FSDP"""
    def __init__(self, input_dim=512, num_layers=64, num_classes=10):
        super().__init__()
        self.input_projection = nn.Linear(input_dim, input_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(dim=input_dim) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(input_dim)
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # Add sequence dimension for transformer
        x = x.unsqueeze(1)  # [batch, 1, input_dim]

        x = self.input_projection(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        # Global average pooling
        x = x.mean(dim=1)  # [batch, input_dim]

        return self.classifier(x)

def apply_fsdp(model, world_size):
    """Apply FSDP to the model"""
    if world_size > 1:
        print("Applying FSDP to model layers...")
        # Apply fsdp to each transformer layer
        for i, layer in enumerate(model.layers):
            fully_shard(layer)
            print(f"Applied FSDP to layer {i}")

        # Apply FSDP to the entire model
        fully_shard(model)
        print("Applied FSDP to entire model")
    else:
        print("Single GPU detected, skipping FSDP setup")

    return model

Definování distribuované trénovací funkce

Trénovací funkce je obalena dekorátorem @distributed z API bezserverového GPU. Tento dekorátor zpracovává následující:

  • Zřízení zadaného počtu GPU (v tomto příkladu 8 H100 GPU)
  • Nastavení distribuovaného trénovacího prostředí
  • Správa životního cyklu vzdálených výpočetních prostředků

Trénovací funkce zahrnuje:

  • Inicializace modelů a zabalení FSDP
  • Načítání dat s DistributedSampler pro paralelní zpracování dat
  • Trénovací smyčka s aktualizacemi gradientu
  • Pravidelné ukládání kontrolních bodů pomocí distribuovaného rozhraní API kontrolního bodu PyTorch
  • Zaznamenávání pomocí MLflow pro monitorování experimentů

Kontrolní body se ukládají do svazku katalogu Unity a protokolují se jako artefakty MLflow pro správu verzí a reprodukovatelnost.

from serverless_gpu import distributed
from serverless_gpu.compute import GPUType

NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
    """
    Self-contained FSDP training demo using PyTorch 2.0+
    Trains a simple neural network on synthetic data using FSDP
    """
    import mlflow
    mlflow.start_run(run_name='fsdp_example')
    def main_training():
        """Main training function"""
        print("Starting FSDP Training Demo...")

        # Setup distributed training
        rank, world_size, device = setup()

        print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA device count: {torch.cuda.device_count()}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")

        # Create dataset and data loader
        dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)

        # Use DistributedSampler if we have multiple processes
        if world_size > 1:
            sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
            shuffle = False
        else:
            sampler = None
            shuffle = True

        dataloader = DataLoader(
            dataset,
            batch_size=32,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True
        )

        # Create model
        model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)

        # Apply FSDP
        model = apply_fsdp(model, world_size)

        print(f"Model created and moved to device: {device}")
        if rank == 0:
            print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

        # Training loop
        num_epochs = 5
        loss_history = []

        print(f"Training for {num_epochs} epochs...")
        writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
        for epoch in range(num_epochs):
            if sampler:
                sampler.set_epoch(epoch)

            model.train()
            total_loss = 0.0
            num_batches = 0

            epoch_start_time = time.time()

            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(device), target.to(device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                output = model(data)
                loss = criterion(output, target)

                # Backward pass
                loss.backward()
                mlflow.log_metric(
                    key='loss',
                    value=loss.item(),
                    step=batch_idx,
                )
                # Update weights
                optimizer.step()

                total_loss += loss.item()

                num_batches += 1

                if batch_idx % 10 == 0:
                    print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
                    state_dict = { 'app': AppState(model, optimizer) }
                    ckpt_start_time = time.time()
                    dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
                    ckpt_time = time.time() - ckpt_start_time
                    print(f'Checkpointing took {ckpt_time:.2f}s')
                    mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
                    if rank == 0:
                        print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
            # Calculate average loss for this epoch
            avg_loss = total_loss / num_batches
            mlflow.log_metric(key='avg_loss', value=avg_loss)

            loss_history.append(avg_loss)

            epoch_time = time.time() - epoch_start_time

            if rank == 0:
                print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')

        # Verify loss is decreasing
        if rank == 0:
            print("\n=== FSDP Training Results ===")
            print("Loss history:")
            for i, loss in enumerate(loss_history):
                print(f"Epoch {i+1}: {loss:.6f}")

            # Check if loss is generally decreasing
            initial_loss = loss_history[0]
            final_loss = loss_history[-1]
            loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100

            print(f"\nInitial Loss: {initial_loss:.6f}")
            print(f"Final Loss: {final_loss:.6f}")
            print(f"Loss Reduction: {loss_reduction:.2f}%")

            if final_loss < initial_loss:
                print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
            else:
                print("❌ WARNING: Loss did not decrease. Check training configuration.")

            print(f"\nFSDP training completed successfully on {world_size} GPU(s)")

        # Cleanup
        cleanup()
        mlflow.end_run()

        return {
            'initial_loss': loss_history[0] if loss_history else None,
            'final_loss': loss_history[-1] if loss_history else None,
            'loss_history': loss_history,
            'world_size': world_size,
            'device': str(device),
            'fsdp_enabled': world_size > 1
        }

    # Run the training
    return main_training()

Spusťte distribuované školení

Spuštěním trénovací funkce zahajte distribuované trénování napříč grafickými procesory 8 H100. Tato .distributed() metoda aktivuje vzdálené spouštění na výpočetních prostředcích GPU bez serveru. Průběh trénování, metriky ztráty a kontrolní body se budou protokolovat do MLflow.

Dokončení této buňky může trvat několik minut, protože připravuje prostředky GPU, trénuje model po dobu 5 epoch a ukládá kontrolní body modelu.

print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")

Načtení kontrolního bodu modelu

Tato sekce ukazuje, jak načíst uložený záchytný bod pro predikci nebo pokračování trénování. Kontrolní bod obsahuje váhy modelu a stav optimalizátoru uložený během trénování.

Všimněte si, že při načítání kontrolních bodů mimo distribuovaný kontext trénování (není inicializována žádná skupina procesů), rozhraní API distribuovaného kontrolního bodu PyTorch automaticky zakáže kolektivní operace a načte kontrolní bod do jednoho zařízení.

def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    state_dict = { 'app': AppState(model, optimizer)}

    # print(state_dict)
    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=f'{CHECKPOINT_DIR}/step0',
    )
    model.load_state_dict(state_dict['app'].state_dict()['model'])

run_checkpoint_load_example()

Další kroky

Teď, když jste se naučili používat PyTorch FSDP k distribuovanému trénování na bezserverových výpočetních prostředcích GPU, projděte si tyto zdroje informací:

Ukázkový poznámkový blok

Distribuované trénování s využitím PyTorch FSDP na výpočetních prostředcích GPU bez serveru

Pořiďte si notebook