DML_GATHER_ELEMENTS_OPERATOR_DESC 结构 (directml.h)

使用索引张量从输入张量沿给定轴收集元素,以重新映射到输入。 此运算符执行以下伪代码,其确切行为取决于轴、输入维度计数和索引维度计数。

output[i, j, k, ...] = input[index[i, j, k, ...], j, k, ...] // if axis == 0
output[i, j, k, ...] = input[i, index[i, j, k, ...], k, ...] // if axis == 1
output[i, j, k, ...] = input[i, j, index[i, j, k, ...], ...] // if axis == 2
...

语法

struct DML_GATHER_ELEMENTS_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *IndicesTensor;
  const DML_TENSOR_DESC *OutputTensor;
  UINT                  Axis;
};

成员

InputTensor

类型: const DML_TENSOR_DESC*

要从中读取的张量。

IndicesTensor

类型: const DML_TENSOR_DESC*

索引沿活动轴进入输入张量。 对于除轴以外的每个维度,Size 必须匹配 InputTensor.Size

DML_FEATURE_LEVEL_3_0开始,使用此张量使用带符号整型类型时,此运算符支持负索引值。 负索引被解释为相对于轴维度的末尾。 例如,索引 -1 引用该维度的最后一个元素。

OutputTensor

类型: const DML_TENSOR_DESC*

要向其写入结果的张量。 大小必须与 IndicesTensor.Sizes 匹配,DataType 必须与 InputTensor.DataType 匹配。

Axis

类型: UINT

要沿收集的 InputTensor 的轴维度,范围为 [0, *InputTensor.DimensionCount*)

示例

Axis = 0

InputTensor: (Sizes:{3,3}, DataType:FLOAT32)
    [[1, 2, 3],
     [4, 5, 6],
     [7, 8, 9]]

IndicesTensor: (Sizes:{2,3}, DataType:UINT32)
    [[1, 2, 0],
     [2, 0, 0]]

// output[y, x] = input[indices[y, x], x]
OutputTensor: (Sizes:{2,3}, DataType:UINT32)
    [[4, 8, 3], // select elements vertically from data
     [7, 2, 3]]

可用性

此运算符是在 中引入的 DML_FEATURE_LEVEL_2_1

张量约束

  • IndicesTensorInputTensorOutputTensor 必须具有相同的 DimensionCount
  • InputTensorOutputTensor 必须具有相同的 数据类型

Tensor 支持

DML_FEATURE_LEVEL_4_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
IndicesTensor 输入 1 到 8 INT64、INT32、UINT64、UINT32
OutputTensor 输出 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_3_0及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
IndicesTensor 输入 1 到 8 INT64、INT32、UINT64、UINT32
OutputTensor 输出 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_1及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
IndicesTensor 输入 4 UINT32
OutputTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

要求

要求
最低受支持的客户端 Windows 10内部版本 20348
最低受支持的服务器 Windows 10内部版本 20348
标头 directml.h