Aracılığıyla paylaş


Sunucusuz GPU işlemi üzerinde PyTorch FSDP kullanarak dağıtılmış eğitim

Bu not defteri, Databricks'in sunucusuz GPU hesaplaması üzerinde PyTorch'un Tam Parçalanmış Veri Paralelliği (FSDP) ile dağıtılmış eğitimi kullanarak bir Transformer modelini eğitmeyi gösterir. FSDP, birden çok GPU'da model parametrelerini, gradyanları ve iyileştirici durumlarını parçalayan ve tek bir GPU'ya sığmayan büyük modellerin verimli bir şekilde eğitilmesini sağlayan bir veri paralelliği tekniğidir.

Bu örnekte şunları nasıl yapacağınızı öğreneceksiniz:

Bu not defteri kendi içinde tutmak için yapay verileri kullanır, ancak kendi veri kümelerinizle çalışacak şekilde uyarlayabilirsiniz.

Temel kavramlar:

  • FSDP (Tam Parçalı Veri Paralel): Bellek kullanımını azaltmak ve daha büyük modellerin eğitilmesine olanak tanımak için GPU'lar arasında model parametrelerini parçalayan bir PyTorch dağıtılmış eğitim stratejisi.
  • Sunucusuz GPU işlemi: İş yükleriniz için kaynakları otomatik olarak ölçeklendire ve sağlayan Databricks tarafından yönetilen GPU işlemi.

Daha fazla bilgi için bkz . Çoklu GPU ve çok düğümlü dağıtılmış eğitim.

Bağımlılıkları yükleme

Deneme izleme ve model günlüğü için en son MLflow sürümünü yükleyin.

%pip install -U mlflow
%restart_python

Unity Kataloğu konumlarını yapılandırma

Modelin ve denetim noktalarının depolanacağı Unity Kataloğu konumlarını ayarlayın. Bu değerleri çalışma alanı yapılandırmanızla eşleşecek şekilde güncelleştirin. Belirtilen USE CATALOG kataloğunda ve USE SCHEMA şemada ayrıcalıklara ihtiyacınız vardır.

# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")

Yardımcı işlevleri ve yapay veri kümesini tanımlama

Bu bölüm, dağıtılmış eğitim kurulumu için yardımcı program işlevlerini ve gösterim amacıyla yapay veri kümesi sınıfını tanımlar. Üretimde, SyntheticDataset öğesini kendi veri yükleme mantığınızla değiştirirsiniz.

Önemli bileşenler:

  • setup(): Dağıtılmış eğitim süreci grubunu başlatır ve GPU cihazlarını yapılandırr
  • cleanup(): Eğitimden sonra dağıtılmış işlem grubunu temizler
  • AppState: PyTorch'un dağıtılmış denetim noktası API'si ile uyumlu denetim noktası modeli ve iyileştirici durumu için sarmalayıcı sınıfı
  • SyntheticDataset: Eğitim tanıtımı için rastgele veriler oluşturur
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time

# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

def setup():
    """Initialize the distributed training process group"""
    # Check if we're in a distributed environment
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ.get('LOCAL_RANK', 0))
    else:
        # Fallback for single GPU
        rank = 0
        world_size = 1
        local_rank = 0

    # Initialize process group
    if world_size > 1:
        if not dist.is_initialized():
            dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

    # Set device
    if torch.cuda.is_available():
        device = torch.device(f'cuda:{local_rank}')
        torch.cuda.set_device(device)
    else:
        device = torch.device('cpu')

    return rank, world_size, device

def cleanup():
    """Clean up the distributed training process group"""
    if dist.is_initialized():
        dist.destroy_process_group()

class SyntheticDataset(Dataset):
    """Simple synthetic dataset for demo purposes"""
    def __init__(self, size=10000, input_dim=512, num_classes=10):
        self.size = size
        self.input_dim = input_dim
        self.num_classes = num_classes

        # Generate synthetic data
        np.random.seed(42)  # For reproducible results
        self.data = torch.randn(size, input_dim)
        # Create labels with some pattern
        self.labels = torch.randint(0, num_classes, (size,))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

FSDP ile Transformer modelini tanımlama

Bu bölümde sınıflandırma için basit bir Transformer modeli ve FSDP parçalama uygulama mantığı tanımlanmıştır. FSDP genellikle 7B' den fazla parametreye sahip büyük dil modelleri için kullanılırken, bu örnekte birden çok H100 GPU arasında parçalanmış daha küçük bir 10M parametre modeline sahip teknik gösterilmektedir.

Model mimarisi:

  • TransformerBlock: Çok başlı dikkat mekanizması ve MLP içeren tek bir transformatör katmanı
  • SimpleTransformer: Giriş projeksiyonu ve sınıflandırma başlığına sahip transformatör blokları yığını
  • apply_fsdp(): Dağıtılmış eğitim için FSDP ile model katmanlarını sarar

FSDP GPU'lar genelinde model parametrelerini, gradyanları ve iyileştirici durumlarını parçalayarak GPU başına bellek gereksinimlerini azaltır ve daha büyük modellerin eğitilmesini sağlar.

class TransformerBlock(nn.Module):
    """Simple transformer block for testing FSDP"""
    def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim),
        )

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # MLP
        mlp_out = self.mlp(x)
        x = self.norm2(x + mlp_out)

        return x

class SimpleTransformer(nn.Module):
    """Simple transformer model for classification with FSDP"""
    def __init__(self, input_dim=512, num_layers=64, num_classes=10):
        super().__init__()
        self.input_projection = nn.Linear(input_dim, input_dim)
        self.layers = nn.ModuleList([
            TransformerBlock(dim=input_dim) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(input_dim)
        self.classifier = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # Add sequence dimension for transformer
        x = x.unsqueeze(1)  # [batch, 1, input_dim]

        x = self.input_projection(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        # Global average pooling
        x = x.mean(dim=1)  # [batch, input_dim]

        return self.classifier(x)

def apply_fsdp(model, world_size):
    """Apply FSDP to the model"""
    if world_size > 1:
        print("Applying FSDP to model layers...")
        # Apply fsdp to each transformer layer
        for i, layer in enumerate(model.layers):
            fully_shard(layer)
            print(f"Applied FSDP to layer {i}")

        # Apply FSDP to the entire model
        fully_shard(model)
        print("Applied FSDP to entire model")
    else:
        print("Single GPU detected, skipping FSDP setup")

    return model

Dağıtılmış eğitim işlevini tanımlama

Eğitim işlevi, sunucusuz GPU API'sinden gelen @distributed dekoratörü ile sarılır. Bu dekoratör şunu işler:

  • Belirtilen sayıda GPU sağlama (bu örnekte 8 H100 GPU)
  • Dağıtılmış eğitim ortamını ayarlama
  • Uzak işlem kaynaklarının yaşam döngüsünü yönetme

Eğitim işlevi şunları içerir:

  • Model başlatma ve FSDP kapsülleme
  • Paralel veri işleme için ile DistributedSampler veri yükleme
  • Gradyan güncelleştirmeleri içeren eğitim döngüsü
  • PyTorch'un dağıtılmış denetim noktası API'sini kullanarak düzenli denetim noktası kaydetme
  • Deneme izleme için MLflow günlüğü

Kontrol noktaları, Unity Catalog hacmine kaydedilir ve sürüm oluşturma ve yeniden üretilebilirlik için MLflow nesneleri olarak günlüğe kaydedilir.

from serverless_gpu import distributed
from serverless_gpu.compute import GPUType

NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
    """
    Self-contained FSDP training demo using PyTorch 2.0+
    Trains a simple neural network on synthetic data using FSDP
    """
    import mlflow
    mlflow.start_run(run_name='fsdp_example')
    def main_training():
        """Main training function"""
        print("Starting FSDP Training Demo...")

        # Setup distributed training
        rank, world_size, device = setup()

        print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA device count: {torch.cuda.device_count()}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")

        # Create dataset and data loader
        dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)

        # Use DistributedSampler if we have multiple processes
        if world_size > 1:
            sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
            shuffle = False
        else:
            sampler = None
            shuffle = True

        dataloader = DataLoader(
            dataset,
            batch_size=32,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True
        )

        # Create model
        model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)

        # Apply FSDP
        model = apply_fsdp(model, world_size)

        print(f"Model created and moved to device: {device}")
        if rank == 0:
            print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

        # Loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

        # Training loop
        num_epochs = 5
        loss_history = []

        print(f"Training for {num_epochs} epochs...")
        writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
        for epoch in range(num_epochs):
            if sampler:
                sampler.set_epoch(epoch)

            model.train()
            total_loss = 0.0
            num_batches = 0

            epoch_start_time = time.time()

            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(device), target.to(device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                output = model(data)
                loss = criterion(output, target)

                # Backward pass
                loss.backward()
                mlflow.log_metric(
                    key='loss',
                    value=loss.item(),
                    step=batch_idx,
                )
                # Update weights
                optimizer.step()

                total_loss += loss.item()

                num_batches += 1

                if batch_idx % 10 == 0:
                    print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
                    state_dict = { 'app': AppState(model, optimizer) }
                    ckpt_start_time = time.time()
                    dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
                    ckpt_time = time.time() - ckpt_start_time
                    print(f'Checkpointing took {ckpt_time:.2f}s')
                    mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
                    if rank == 0:
                        print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
            # Calculate average loss for this epoch
            avg_loss = total_loss / num_batches
            mlflow.log_metric(key='avg_loss', value=avg_loss)

            loss_history.append(avg_loss)

            epoch_time = time.time() - epoch_start_time

            if rank == 0:
                print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')

        # Verify loss is decreasing
        if rank == 0:
            print("\n=== FSDP Training Results ===")
            print("Loss history:")
            for i, loss in enumerate(loss_history):
                print(f"Epoch {i+1}: {loss:.6f}")

            # Check if loss is generally decreasing
            initial_loss = loss_history[0]
            final_loss = loss_history[-1]
            loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100

            print(f"\nInitial Loss: {initial_loss:.6f}")
            print(f"Final Loss: {final_loss:.6f}")
            print(f"Loss Reduction: {loss_reduction:.2f}%")

            if final_loss < initial_loss:
                print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
            else:
                print("❌ WARNING: Loss did not decrease. Check training configuration.")

            print(f"\nFSDP training completed successfully on {world_size} GPU(s)")

        # Cleanup
        cleanup()
        mlflow.end_run()

        return {
            'initial_loss': loss_history[0] if loss_history else None,
            'final_loss': loss_history[-1] if loss_history else None,
            'loss_history': loss_history,
            'world_size': world_size,
            'device': str(device),
            'fsdp_enabled': world_size > 1
        }

    # Run the training
    return main_training()

Dağıtılmış eğitimi başlatın

8 H100 GPU arasında dağıtılmış eğitim başlatmak için eğitim işlevini yürütür. .distributed() yöntemi sunucusuz GPU hesaplamalarında uzaktan yürütmeyi tetikler. Eğitim ilerleme durumu, kayıp ölçümleri ve denetim noktaları MLflow'a kaydedilir.

GPU kaynaklarını sağladığı, modeli 5 epoch eğittiği ve kontrol noktalarını kaydettiği için bu hücrenin tamamlanması birkaç dakika sürebilir.

print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")

Model denetim noktası yükleme

Bu bölümde, çıkarım veya devam eden eğitim için kaydedilmiş bir denetim noktasının nasıl yüklenecekleri gösterilmektedir. Denetim noktası, eğitim sırasında kaydedilmiş model ağırlıklarını ve optimizatör durumunu içerir.

Denetim noktalarını dağıtılmış eğitim bağlamı dışında yüklerken (işlem grubu başlatılmadı), PyTorch'un dağıtılmış denetim noktası API'sinin toplu işlemleri otomatik olarak devre dışı bırakdığını ve denetim noktasını tek bir cihaza yüklediğini unutmayın.

def run_checkpoint_load_example():
    # create the non FSDP-wrapped toy model
    model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    state_dict = { 'app': AppState(model, optimizer)}

    # print(state_dict)
    # since no progress group is initialized, DCP will disable any collectives.
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=f'{CHECKPOINT_DIR}/step0',
    )
    model.load_state_dict(state_dict['app'].state_dict()['model'])

run_checkpoint_load_example()

Sonraki Adımlar

Sunucusuz GPU işlemi üzerinde dağıtılmış eğitim için PyTorch FSDP'yi kullanmayı öğrendiğinize göre daha fazla bilgi edinmek için şu kaynakları keşfedin:

Örnek defter

Sunucusuz GPU işlemi üzerinde PyTorch FSDP kullanarak dağıtılmış eğitim

Dizüstü bilgisayar al