DML_GEMM_OPERATOR_DESC 结构 (directml.h)

执行形式的 Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)常规矩阵乘法函数,其中 x 表示矩阵乘法,使用 * 标量表示乘法。

此运算符需要具有布局 { BatchCount, ChannelCount, Height, Width }的 4D 张量,它将执行 BatchCount * ChannelCount 数量的独立矩阵乘法。

例如,如果 ATensor的大小为{ BatchCount, ChannelCount, M, K }而 BTensor的大小为{ BatchCount, ChannelCount, K, N }OutputTensorSize{ BatchCount, ChannelCount, M, N },则此运算符执行维度 {M,K} x {K,N} = {M,N} 的 BatchCount * ChannelCount 独立矩阵乘法。

语法

struct DML_GEMM_OPERATOR_DESC {
  const DML_TENSOR_DESC   *ATensor;
  const DML_TENSOR_DESC   *BTensor;
  const DML_TENSOR_DESC   *CTensor;
  const DML_TENSOR_DESC   *OutputTensor;
  DML_MATRIX_TRANSFORM    TransA;
  DML_MATRIX_TRANSFORM    TransB;
  FLOAT                   Alpha;
  FLOAT                   Beta;
  const DML_OPERATOR_DESC *FusedActivation;
};

成员

ATensor

类型: const DML_TENSOR_DESC*

包含 A 矩阵的张量。 如果 transA 是DML_MATRIX_TRANSFORM_NONE,则此张的大小应为{ BatchCount, ChannelCount, M, K },或者{ BatchCount, ChannelCount, K, M }如果 TransADML_MATRIX_TRANSFORM_TRANSPOSE

BTensor

类型: const DML_TENSOR_DESC*

包含 B 矩阵的张量。 如果 TransB 是DML_MATRIX_TRANSFORM_NONE,则此张的大小应为{ BatchCount, ChannelCount, K, N },或者{ BatchCount, ChannelCount, N, K }如果 TransBDML_MATRIX_TRANSFORM_TRANSPOSE

CTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

包含 C 矩阵的张量,或 nullptr。 如果未提供值,则默认值为 0。 如果提供,则此张量 的大小 应为 { BatchCount, ChannelCount, M, N }

OutputTensor

类型: const DML_TENSOR_DESC*

要向其写入结果的张量。 此张 量的大小为{ BatchCount, ChannelCount, M, N }

TransA

类型: DML_MATRIX_TRANSFORM

要应用于 ATensor 的转换;转置或无转换。

TransB

类型: DML_MATRIX_TRANSFORM

要应用于 BTensor 的转换;转置或无转换。

Alpha

类型: FLOAT

输入 ATensorBTensor 的乘积的标量乘数的值。

Beta

类型: FLOAT

可选输入 CTensor 的标量乘数值。 如果未提供 CTensor ,则忽略此值。

FusedActivation

类型:_Maybenull_ const DML_OPERATOR_DESC*

在 GEMM 之后应用的可选融合激活层。 有关详细信息,请参阅 使用融合运算符提高性能

可用性

此运算符是在 中引入的 DML_FEATURE_LEVEL_1_0

张量约束

  • ATensorBTensorCTensorOutputTensor 必须具有相同的 DataTypeDimensionCount
  • CTensorOutputTensor 必须具有相同 的大小

Tensor 支持

DML_FEATURE_LEVEL_4_0及更高版本

种类 维度 支持的维度计数 支持的数据类型
ATensor 输入 { [BatchCount], [ChannelCount], M, K } 2 到 4 FLOAT32、FLOAT16
BTensor 输入 { [BatchCount], [ChannelCount], K, N } 2 到 4 FLOAT32、FLOAT16
CTensor 可选输入 { [BatchCount], [ChannelCount], M, N } 2 到 4 FLOAT32、FLOAT16
OutputTensor 输出 { [BatchCount], [ChannelCount], M, N } 2 到 4 FLOAT32、FLOAT16

DML_FEATURE_LEVEL_1_0 及更高版本

种类 维度 支持的维度计数 支持的数据类型
ATensor 输入 { BatchCount, ChannelCount, M, K } 4 FLOAT32、FLOAT16
BTensor 输入 { BatchCount, ChannelCount, K, N } 4 FLOAT32、FLOAT16
CTensor 可选输入 { BatchCount, ChannelCount, M, N } 4 FLOAT32、FLOAT16
OutputTensor 输出 { BatchCount, ChannelCount, M, N } 4 FLOAT32、FLOAT16

要求

要求
Header directml.h

另请参阅