本節介紹關於在無伺服器 GPU 運算中載入資料的資訊,特別是針對機器學習與深度學習應用。 請參考 教學 ,了解如何使用 Spark Python API 載入和轉換資料。
載入表格式資料
使用 Spark Connect 從 Delta 表格載入表格式機器學習資料。
針對單節點訓練,你可以使用 PySpark 方法toPandas()將 Apache Spark DataFrames 轉換成 pandas DataFrames,然後再用 PySpark 方法to_numpy()選擇性地轉換成 NumPy 格式。
備註
Spark Connect 會將分析和名稱解析延遲到執行時間,這可能會改變程式碼的行為。 請參閱 比較 Spark Connect 與 Spark Classic。
Spark Connect 支援大多數 PySpark API,包括 Spark SQL、Spark 上的 Pandas API、結構化串流以及基於 DataFrame 的 MLlib。 請參閱 PySpark API 參考文件 以了解最新支援的 API。
關於其他限制,請參見 無伺服器運算限制。
裝潢機內部 @distributed 的載荷資料
使用 無伺服器 GPU API 進行分散式訓練時,請將資料載入程式碼移入 @distributed 裝飾器內部。 資料集大小可能會超過 pickle 所允許的最大限制,因此建議在裝飾器中生成資料集,如下所示:
from serverless_gpu import distributed
# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
# good practice
dataset = get_dataset(file_path)
....
資料載入效能
/Workspace 和 /Volumes 目錄都存放於遠端的 Unity Catalog 儲存空間。 如果你的資料集儲存在 Unity Catalog,資料載入速度會受到可用網路頻寬的限制。 如果你要訓練多個時代,建議先把資料本地複製,特別是複製到 /tmp 存放在超高速儲存裝置(NVMe SSD 硬碟)上的目錄。
如果您的資料集規模龐大,我們也建議以下技術來平行化訓練與資料載入:
- 訓練多個世代時,先將資料集更新至本地目錄中的快取檔案
/tmp,再讀取每個檔案。 在後續的時代,則使用快取版本。 - 透過在 torch DataLoader API 中啟用工作者來實現資料擷取的平行化。 設定
num_workers至少2。 預設情況下,每位工作者會預先取得兩個工作項目。 為了提升效能,可以增加num_workers(增加平行讀取次數)或prefetch_factor(增加每個工作者預取的項目數量)。