共用方式為


在無伺服器 GPU 運算中載入資料

本節介紹關於在無伺服器 GPU 運算中載入資料的資訊,特別是針對機器學習與深度學習應用。 請參考 教學 ,了解如何使用 Spark Python API 載入和轉換資料。

載入表格式資料

使用 Spark Connect 從 Delta 表格載入表格式機器學習資料。

針對單節點訓練,你可以使用 PySpark 方法toPandas()將 Apache Spark DataFrames 轉換成 pandas DataFrames,然後再用 PySpark 方法to_numpy()選擇性地轉換成 NumPy 格式。

備註

Spark Connect 會將分析和名稱解析延遲到執行時間,這可能會改變程式碼的行為。 請參閱 比較 Spark Connect 與 Spark Classic

Spark Connect 支援大多數 PySpark API,包括 Spark SQL、Spark 上的 Pandas API、結構化串流以及基於 DataFrame 的 MLlib。 請參閱 PySpark API 參考文件 以了解最新支援的 API。

關於其他限制,請參見 無伺服器運算限制

裝潢機內部 @distributed 的載荷資料

使用 無伺服器 GPU API 進行分散式訓練時,請將資料載入程式碼移入 @distributed 裝飾器內部。 資料集大小可能會超過 pickle 所允許的最大限制,因此建議在裝飾器中生成資料集,如下所示:

from serverless_gpu import distributed

# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
  # good practice
  dataset = get_dataset(file_path)
  ....

資料載入效能

/Workspace/Volumes 目錄都存放於遠端的 Unity Catalog 儲存空間。 如果你的資料集儲存在 Unity Catalog,資料載入速度會受到可用網路頻寬的限制。 如果你要訓練多個時代,建議先把資料本地複製,特別是複製到 /tmp 存放在超高速儲存裝置(NVMe SSD 硬碟)上的目錄。

如果您的資料集規模龐大,我們也建議以下技術來平行化訓練與資料載入:

  • 訓練多個世代時,先將資料集更新至本地目錄中的快取檔案 /tmp,再讀取每個檔案。 在後續的時代,則使用快取版本。
  • 透過在 torch DataLoader API 中啟用工作者來實現資料擷取的平行化。 設定 num_workers 至少2。 預設情況下,每位工作者會預先取得兩個工作項目。 為了提升效能,可以增加 num_workers (增加平行讀取次數)或 prefetch_factor (增加每個工作者預取的項目數量)。