Poznámka:
Přístup k této stránce vyžaduje autorizaci. Můžete se zkusit přihlásit nebo změnit adresáře.
Přístup k této stránce vyžaduje autorizaci. Můžete zkusit změnit adresáře.
Tento notebook demonstruje, jak trénovat model Transformer pomocí distribuovaného učení s využitím PyTorch Fully Sharded Data Parallel (FSDP) na serverless GPU výpočetních prostředcích Databricks. FSDP je technika paralelizmu dat, která sharduje parametry modelu, gradienty a stavy optimalizátoru napříč několika GPU, což umožňuje efektivní trénování velkých modelů, které se nevejdou do jedné GPU.
V tomto příkladu se naučíte:
- Nastavení distribuovaného trénování pomocí bezserverového rozhraní API pro distribuované trénování GPU
- Definujte a natrénujte Transformer model s 10M parametry pomocí FSDP
- Ukládání distribuovaných kontrolních bodů během trénování
- Sledování experimentů pomocí MLflow
- Nahrát kontrolní body pro inference nebo pokračující trénink
Tento poznámkový blok používá syntetická data k tomu, aby byla samostatná, ale můžete je přizpůsobit pro práci s vlastními datovými sadami.
Klíčové koncepty:
- FSDP (plně oddělené datové paralelní):: Strategie distribuovaného trénování PyTorch, která rozděluje parametry modelu napříč GPU, aby se snížilo využití paměti a umožnilo trénování větších modelů.
- Výpočetní prostředí GPU bez serveru: Výpočetní výkon GPU spravovaný službou Databricks, který automaticky škáluje a zřizuje prostředky pro vaše úlohy.
Další informace najdete v tématu Více GPU a distribuované trénování s více uzly.
Nainstalujte závislosti
Nainstalujte nejnovější verzi MLflow pro sledování experimentů a protokolování modelu.
%pip install -U mlflow
%restart_python
Konfigurace umístění katalogu Unity
Nastavte umístění katalogu Unity, kde se uloží model a kontrolní body. Aktualizujte tyto hodnoty tak, aby odpovídaly konfiguraci pracovního prostoru. Potřebujete zadaná oprávnění USE CATALOG k katalogu USE SCHEMA a schématu.
# You must have `USE CATALOG` privileges on the catalog, and you must have `USE SCHEMA` privileges on the schema.
# If necessary, change the catalog and schema name here.
dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("model_name", "transformer_fsdp")
dbutils.widgets.text("uc_volume", "checkpoints")
UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_VOLUME = dbutils.widgets.get("uc_volume")
MODEL_NAME = dbutils.widgets.get("model_name")
UC_MODEL_NAME = f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}"
print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
Definování pomocných funkcí a syntetických datových sad
Tato část definuje pomocné funkce pro distribuované trénování a syntetickou třídu datových sad pro demonstrační účely. V produkční fázi byste SyntheticDataset nahradili vlastní logikou načítání dat.
Klíčové komponenty:
-
setup(): Inicializuje distribuovanou skupinu procesů trénování a nakonfiguruje zařízení GPU. -
cleanup(): Po trénování vyčistí distribuovanou skupinu procesů. -
AppState: Třída obálky pro model kontrolních bodů a stav optimalizátoru, který je kompatibilní s rozhraním API distribuovaného kontrolního bodu PyTorch -
SyntheticDataset: Generuje náhodná data pro ukázku trénování.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint import FileSystemWriter as StorageWriter
import torch.multiprocessing as mp
from torch.distributed.fsdp import fully_shard
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import numpy as np
import os
import time
# Below is an example of distributed checkpoint based on
# https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
def setup():
"""Initialize the distributed training process group"""
# Check if we're in a distributed environment
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ.get('LOCAL_RANK', 0))
else:
# Fallback for single GPU
rank = 0
world_size = 1
local_rank = 0
# Initialize process group
if world_size > 1:
if not dist.is_initialized():
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
# Set device
if torch.cuda.is_available():
device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
return rank, world_size, device
def cleanup():
"""Clean up the distributed training process group"""
if dist.is_initialized():
dist.destroy_process_group()
class SyntheticDataset(Dataset):
"""Simple synthetic dataset for demo purposes"""
def __init__(self, size=10000, input_dim=512, num_classes=10):
self.size = size
self.input_dim = input_dim
self.num_classes = num_classes
# Generate synthetic data
np.random.seed(42) # For reproducible results
self.data = torch.randn(size, input_dim)
# Create labels with some pattern
self.labels = torch.randint(0, num_classes, (size,))
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
Definování modelu Transformer pomocí FSDP
Tato část definuje jednoduchý model Transformer pro klasifikaci a logiku pro použití horizontálního dělení FSDP. I když se FSDP obvykle používá pro velké jazykové modely s více než 7 miliardami parametrů, tento příklad ukazuje techniku s modelem s 10 miliony parametry, rozděleným mezi více grafických procesorů H100.
Architektura modelu:
-
TransformerBlock: Jedna vrstva transformátoru s více hlavami pozornosti a MLP -
SimpleTransformer: Zásobník transformátorových bloků se vstupní projekcí a klasifikační hlavou -
apply_fsdp(): Obaluje vrstvy modelu použitím FSDP pro distribuované trénování.
FSDP horizontálně dělí parametry modelu, gradienty a stavy optimalizátoru napříč jednotlivými GPU, snižuje požadavky na paměť na jednotlivé GPU a umožňuje trénování větších modelů.
class TransformerBlock(nn.Module):
"""Simple transformer block for testing FSDP"""
def __init__(self, dim=512, num_heads=8, mlp_ratio=4):
super().__init__()
self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, dim),
)
def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# MLP
mlp_out = self.mlp(x)
x = self.norm2(x + mlp_out)
return x
class SimpleTransformer(nn.Module):
"""Simple transformer model for classification with FSDP"""
def __init__(self, input_dim=512, num_layers=64, num_classes=10):
super().__init__()
self.input_projection = nn.Linear(input_dim, input_dim)
self.layers = nn.ModuleList([
TransformerBlock(dim=input_dim) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(input_dim)
self.classifier = nn.Linear(input_dim, num_classes)
def forward(self, x):
# Add sequence dimension for transformer
x = x.unsqueeze(1) # [batch, 1, input_dim]
x = self.input_projection(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# Global average pooling
x = x.mean(dim=1) # [batch, input_dim]
return self.classifier(x)
def apply_fsdp(model, world_size):
"""Apply FSDP to the model"""
if world_size > 1:
print("Applying FSDP to model layers...")
# Apply fsdp to each transformer layer
for i, layer in enumerate(model.layers):
fully_shard(layer)
print(f"Applied FSDP to layer {i}")
# Apply FSDP to the entire model
fully_shard(model)
print("Applied FSDP to entire model")
else:
print("Single GPU detected, skipping FSDP setup")
return model
Definování distribuované trénovací funkce
Trénovací funkce je obalena dekorátorem @distributed z API bezserverového GPU. Tento dekorátor zpracovává následující:
- Zřízení zadaného počtu GPU (v tomto příkladu 8 H100 GPU)
- Nastavení distribuovaného trénovacího prostředí
- Správa životního cyklu vzdálených výpočetních prostředků
Trénovací funkce zahrnuje:
- Inicializace modelů a zabalení FSDP
- Načítání dat s
DistributedSamplerpro paralelní zpracování dat - Trénovací smyčka s aktualizacemi gradientu
- Pravidelné ukládání kontrolních bodů pomocí distribuovaného rozhraní API kontrolního bodu PyTorch
- Zaznamenávání pomocí MLflow pro monitorování experimentů
Kontrolní body se ukládají do svazku katalogu Unity a protokolují se jako artefakty MLflow pro správu verzí a reprodukovatelnost.
from serverless_gpu import distributed
from serverless_gpu.compute import GPUType
NUM_WORKERS = 8
CHECKPOINT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{MODEL_NAME}"
@distributed(gpus=NUM_WORKERS, gpu_type=GPUType.H100)
def run_fsdp_training(num_workers=NUM_WORKERS):
"""
Self-contained FSDP training demo using PyTorch 2.0+
Trains a simple neural network on synthetic data using FSDP
"""
import mlflow
mlflow.start_run(run_name='fsdp_example')
def main_training():
"""Main training function"""
print("Starting FSDP Training Demo...")
# Setup distributed training
rank, world_size, device = setup()
print(f"Rank: {rank}, World Size: {world_size}, Device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"Current CUDA device: {torch.cuda.current_device()}")
# Create dataset and data loader
dataset = SyntheticDataset(size=10000, input_dim=512, num_classes=10)
# Use DistributedSampler if we have multiple processes
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
shuffle = False
else:
sampler = None
shuffle = True
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=True
)
# Create model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10).to(device)
# Apply FSDP
model = apply_fsdp(model, world_size)
print(f"Model created and moved to device: {device}")
if rank == 0:
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
# Training loop
num_epochs = 5
loss_history = []
print(f"Training for {num_epochs} epochs...")
writer = StorageWriter(cache_staged_state_dict=False, path=CHECKPOINT_DIR)
for epoch in range(num_epochs):
if sampler:
sampler.set_epoch(epoch)
model.train()
total_loss = 0.0
num_batches = 0
epoch_start_time = time.time()
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
mlflow.log_metric(
key='loss',
value=loss.item(),
step=batch_idx,
)
# Update weights
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 10 == 0:
print(f'Saving checkpoint to {CHECKPOINT_DIR}/step{batch_idx}')
state_dict = { 'app': AppState(model, optimizer) }
ckpt_start_time = time.time()
dcp.save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}/step{batch_idx}")
ckpt_time = time.time() - ckpt_start_time
print(f'Checkpointing took {ckpt_time:.2f}s')
mlflow.log_artifacts(f'{CHECKPOINT_DIR}/step{batch_idx}', artifact_path=f'checkpoints/step{batch_idx}')
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# Calculate average loss for this epoch
avg_loss = total_loss / num_batches
mlflow.log_metric(key='avg_loss', value=avg_loss)
loss_history.append(avg_loss)
epoch_time = time.time() - epoch_start_time
if rank == 0:
print(f'Epoch {epoch+1}/{num_epochs} with {num_batches} completed in {epoch_time:.2f}s. Average Loss: {avg_loss:.6f}')
# Verify loss is decreasing
if rank == 0:
print("\n=== FSDP Training Results ===")
print("Loss history:")
for i, loss in enumerate(loss_history):
print(f"Epoch {i+1}: {loss:.6f}")
# Check if loss is generally decreasing
initial_loss = loss_history[0]
final_loss = loss_history[-1]
loss_reduction = ((initial_loss - final_loss) / initial_loss) * 100
print(f"\nInitial Loss: {initial_loss:.6f}")
print(f"Final Loss: {final_loss:.6f}")
print(f"Loss Reduction: {loss_reduction:.2f}%")
if final_loss < initial_loss:
print("✅ SUCCESS: FSDP training is working! Loss is decreasing.")
else:
print("❌ WARNING: Loss did not decrease. Check training configuration.")
print(f"\nFSDP training completed successfully on {world_size} GPU(s)")
# Cleanup
cleanup()
mlflow.end_run()
return {
'initial_loss': loss_history[0] if loss_history else None,
'final_loss': loss_history[-1] if loss_history else None,
'loss_history': loss_history,
'world_size': world_size,
'device': str(device),
'fsdp_enabled': world_size > 1
}
# Run the training
return main_training()
Spusťte distribuované školení
Spuštěním trénovací funkce zahajte distribuované trénování napříč grafickými procesory 8 H100. Tato .distributed() metoda aktivuje vzdálené spouštění na výpočetních prostředcích GPU bez serveru. Průběh trénování, metriky ztráty a kontrolní body se budou protokolovat do MLflow.
Dokončení této buňky může trvat několik minut, protože připravuje prostředky GPU, trénuje model po dobu 5 epoch a ukládá kontrolní body modelu.
print("Starting FSDP Demo on Databricks Serverless GPU...")
result = run_fsdp_training.distributed()
print("FSDP Demo completed!")
print(f"Training Results: {result}")
Načtení kontrolního bodu modelu
Tato sekce ukazuje, jak načíst uložený záchytný bod pro predikci nebo pokračování trénování. Kontrolní bod obsahuje váhy modelu a stav optimalizátoru uložený během trénování.
Všimněte si, že při načítání kontrolních bodů mimo distribuovaný kontext trénování (není inicializována žádná skupina procesů), rozhraní API distribuovaného kontrolního bodu PyTorch automaticky zakáže kolektivní operace a načte kontrolní bod do jednoho zařízení.
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = SimpleTransformer(input_dim=512, num_layers=4, num_classes=10)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
state_dict = { 'app': AppState(model, optimizer)}
# print(state_dict)
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=f'{CHECKPOINT_DIR}/step0',
)
model.load_state_dict(state_dict['app'].state_dict()['model'])
run_checkpoint_load_example()
Další kroky
Teď, když jste se naučili používat PyTorch FSDP k distribuovanému trénování na bezserverových výpočetních prostředcích GPU, projděte si tyto zdroje informací:
- Distribuované trénování s více GPU a více uzly – Informace o různých strategiích distribuovaného trénování
- Osvědčené postupy pro bezserverové výpočetní prostředí GPU – Optimalizace úloh GPU
- Řešení potíží s bezserverovým výpočetním prostředím GPU – Běžné problémy a řešení
- Dokumentace k PyTorch FSDP – podrobné informace o funkcích a konfiguraci FSDP