在 WSL 中启用 TensorFlow with DirectML

Windows 11 和 Windows 10 版本 21H2 通过使用 TensorFlow 1.15 的 TensorFlow with DirectML 包,为学生、初学者和专业人士提供了在其现有硬件上运行机器学习 (ML) 训练的方法。 要在 TensorFlow 2 上使用 DirectML,请查看 TensorFlow-DirectML-Plugin。 设置后,可以使用现有的示例模型脚本,或查看 DirectML 存储库中的一些示例

安装 Windows 11 或 Windows 10,版本 21H2

要使用这些功能,可以下载并安装 Windows 11Windows 10 版本 21H2

安装最新的 GPU 驱动程序

在 WSL 中安装 TensorFlow with DirectML 包之前,需要安装来自于 GPU 硬件供应商的最新驱动程序。 这些驱动程序使 Windows GPU 能够使用 WSL。

在“设置”应用的“Windows 更新”部分中,选择“检查是否有更新”,或者查看 GPU 硬件供应商网站。

AMD

从其网站下载并安装 AMD 的驱动程序。 以下硬件支持此功能:

  • AMD Radeon™ RX 系列和 Radeon™ VI 显卡。
  • AMD Radeon™ Pro 系列显卡。
  • 搭载 Radeon™ Vega 显卡的 AMD Ryzen 和 Ryzen™ PRO 处理器。
  • 搭载 Radeon™ Vega 显卡的 AMD Ryzen 和 Ryzen™ PRO 移动处理器。

有关兼容 AMD 产品的完整列表,请参阅 AMD 发行说明。

Intel

从其网站下载并安装可用于 DirectML 的 Intel 驱动程序

NVIDIA

从其网站下载并安装可用于 DirectML 的 NVIDIA 驱动程序。 有关详细信息,请参阅适用于 Linux 的 Windows 子系统 (WSL) 中的 NVIDIA GPU 页面。

设置 TensorFlow with DirectML

安装 WSL

安装上述驱动程序后,请确保启用 WSL安装基于 glibc 的分发版(例如 Ubuntu 或 Debian)。 在我们的测试中,使用的是 Ubuntu。 通过在设置应用的 Windows 更新部分中选择“检查更新”,确保你拥有最新的内核。

注意

确保启用了“更新 Windows 时接收其他 Microsoft 产品的更新”。 可以在“设置”应用“Windows 更新”部分的“高级”选项中找到该项。

对于这些功能,需要 5.10.43.3 或更高版本的内核版本。 可以通过在 PowerShell 中运行以下命令来检查版本号:

wsl cat /proc/version

设置 Python 环境

建议在 WSL 实例中设置虚拟 Python 环境。 可以使用许多工具来设置虚拟 Python 环境 - 对于这些说明,我们将使用 Anaconda 的 Miniconda。 此设置的其余部分假定你使用 miniconda 环境。

按照 Anaconda 网站上的指南安装 Miniconda,或在 WSL 中运行以下命令来安装 Miniconda。

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 
bash Miniconda3-latest-Linux-x86_64.sh 

在 WSL 中安装 Miniconda 后,使用名为 directml 的 Python 创建环境并通过以下命令激活它:

注意

在下面的命令中,我们使用 Python 3.6。 但是,tensorflow-directml 包适用于 Python 3.5、3.6 或 3.7 环境。

conda create --name directml python=3.6 

conda activate directml 

安装 TensorFlow with DirectML 包

运行以下命令,通过 pip 安装 TensorFlow with DirectML 包。

注意

tensorflow-directml 包仅支持 TensorFlow 1.15。

pip install tensorflow-directml

安装 tensorflow-directml 包后,可以通过添加两个张量来验证其是否正确运行。 将以下行复制到交互式 Python 会话中。

import tensorflow.compat.v1 as tf 

tf.enable_eager_execution(tf.ConfigProto(log_device_placement=True)) 

print(tf.add([1.0, 2.0], [3.0, 4.0])) 

应会看到类似于以下内容的输出,其中 add 运算符位于 DML 设备上。

2020-06-15 11:27:18.235973: I tensorflow/core/common_runtime/dml/dml_device_factory.cc:45] DirectML device enumeration: found 1 compatible adapters. 

2020-06-15 11:27:18.240065: I tensorflow/core/common_runtime/dml/dml_device_factory.cc:32] DirectML: creating device on adapter 0 (AMD Radeon VII) 

2020-06-15 11:27:18.323949: I tensorflow/stream_executor/platform/default/dso_loader.cc:60] Successfully opened dynamic library libdirectml.so.ba106a7c621ea741d21598708ee581c11918380 

2020-06-15 11:27:18.337830: I tensorflow/core/common_runtime/eager/execute.cc:571] Executing op Add in device /job:localhost/replica:0/task:0/device:DML:0 

tf.Tensor([4. 6.], shape=(2,), dtype=float32) 

TensorFlow with DirectML 示例和反馈

查看我们的示例,或使用你现有的模型脚本。 如果遇到问题或有关于 TensorFlow with DirectML 包的反馈,请与我们的团队联系