使用融合运算符提高性能

某些 DirectML 运算符支持称为融合的概念。 运算符融合是提高性能的一种方法,通过将一个运算符(通常是激活函数)合并到另一个运算符以便其一起执行,而无需往返内存。

何时融合激活函数

融合激活函数是实现性能优化的一种手段。 在许多机器学习 (ML) 模型中,极其常见的一种方案是将非线性(激活函数)应用于模型中每一层的输出。

通常,这需要往返图形内存。 例如,如果卷积后跟非融合 Relu 激活函数,则 GPU 必须等待卷积的结果写入 GPU 内存,然后才能开始计算 Relu 激活层。 由于大多数激活函数的计算工作负载往往较小,因此往返图形内存可能是主要的性能瓶颈。

运算符融合允许在前面的运算符(例如卷积)中执行激活函数(如上例中的 Relu)。 这样,GPU 就可以计算激活函数,而无需等待前面的运算符的结果写入内存,从而提高了性能。

由于融合激活函数后会产生相同的结果,但在很多情况下计算速度更快,因此,建议尽可能将激活层融合到其前面的运算符中,从而消除激活层。

如何融合激活函数

支持融合激活函数的运算符在其运算符结构 const DML_OPERATOR_DESC* FusedActivation 中具有其他可选参数。 以卷积为例,卷积支持融合激活函数,其运算符说明中具有相应的 FusedActivation(请参阅 DML_CONVOLUTION_OPERATOR_DESC)。

struct DML_CONVOLUTION_OPERATOR_DESC
{
    const DML_TENSOR_DESC* InputTensor;
    const DML_TENSOR_DESC* FilterTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    const DML_TENSOR_DESC* OutputTensor;
    DML_CONVOLUTION_MODE Mode;
    DML_CONVOLUTION_DIRECTION Direction;
    UINT DimensionCount;
    _Field_size_(DimensionCount) const UINT* Strides;
    _Field_size_(DimensionCount) const UINT* Dilations;
    _Field_size_(DimensionCount) const UINT* StartPadding;
    _Field_size_(DimensionCount) const UINT* EndPadding;
    _Field_size_(DimensionCount) const UINT* OutputPadding;
    UINT GroupCount;
    _Maybenull_ const DML_OPERATOR_DESC* FusedActivation;
};

要融合激活函数,请构造 DML_OPERATOR_DESC,描述要融合的激活函数类型。 例如,要融合 Relu 函数,则正确的运算符类型为 DML_OPERATOR_ACTIVATION_RELU

注意

构造激活函数的运算符说明时,必须将激活函数的 InputTensorOutputTensor 参数设置为 NULL

示例

DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyReluDesc;
leakyReluDesc.InputTensor = nullptr;
leakyReluDesc.OutputTensor = nullptr;
leakyReluDesc.Alpha = 0.01f;

DML_OPERATOR_DESC activationDesc = { DML_OPERATOR_ACTIVATION_LEAKY_RELU, &leakyReluDesc };

DML_CONVOLUTION_OPERATOR_DESC convDesc;
// ...
convDesc.FusedActivation = &activationDesc;

对于完整示例,DirectMLSuperResolution 示例利用融合后的激活函数来提高性能。

支持融合激活函数的运算符

以下列出了 DML_OPERATOR_TYPE 枚举中的常量。 该主题中的每个常量均已链接到要使用的相应说明结构。

  • DML_OPERATOR_BATCH_NORMALIZATION
  • DML_OPERATOR_BATCH_NORMALIZATION_TRAINING
  • DML_OPERATOR_CONVOLUTION
  • DML_OPERATOR_ELEMENT_WISE_ADD1
  • DML_OPERATOR_GEMM
  • DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION
  • DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1

支持融合的激活函数

以下列出了 DML_OPERATOR_TYPE 枚举中的常量。 该主题中的每个常量均已链接到要使用的相应说明结构。

  • DML_OPERATOR_ELEMENT_WISE_CLIP
  • DML_OPERATOR_ACTIVATION_LINEAR
  • DML_OPERATOR_ACTIVATION_SIGMOID
  • DML_OPERATOR_ACTIVATION_HARD_SIGMOID
  • DML_OPERATOR_ACTIVATION_TANH
  • DML_OPERATOR_ACTIVATION_SCALED_TANH
  • DML_OPERATOR_ACTIVATION_RELU
  • DML_OPERATOR_ACTIVATION_LEAKY_RELU
  • DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU
  • DML_OPERATOR_ACTIVATION_ELU
  • DML_OPERATOR_ACTIVATION_CELU
  • DML_OPERATOR_ACTIVATION_SCALED_ELU
  • DML_OPERATOR_ACTIVATION_SOFTPLUS
  • DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS
  • DML_OPERATOR_ACTIVATION_SOFTSIGN
  • DML_OPERATOR_ACTIVATION_IDENTITY
  • DML_OPERATOR_ACTIVATION_SHRINK
  • DML_OPERATOR_ACTIVATION_GELU

未列出的任何运算符都不支持融合激活函数。

另请参阅