DML_GATHER_ND1_OPERATOR_DESC 结构 (directml.h)

从输入张量收集元素,使用索引张量将索引重新映射到输入的整个子块。 此运算符执行以下伪代码,其中“...”表示一系列坐标,其确切行为取决于批处理、输入和索引维度计数。

output[batch, ...] = input[batch, indices[batch, ...], ...]

语法

struct DML_GATHER_ND1_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *IndicesTensor;
  const DML_TENSOR_DESC *OutputTensor;
  UINT                  InputDimensionCount;
  UINT                  IndicesDimensionCount;
  UINT                  BatchDimensionCount;
};

成员

InputTensor

类型: const DML_TENSOR_DESC*

要从中读取的张量。

IndicesTensor

类型: const DML_TENSOR_DESC*

包含索引的张量。 此张量的 DimensionCount 必须与 InputTensor.DimensionCount 匹配。 IndexesTensor 的最后一个维度实际上是每个索引元组的坐标数,并且不能超过 InputTensor.DimensionCount。 例如,IndexesDimensionCount = 3 的索引大小{1,4,5,2}张量表示索引为 InputTensor 的 4x5 数组,该数组由 2 坐标元组编制索引。

DML_FEATURE_LEVEL_3_0开始,使用此张量使用带符号整型类型时,此运算符支持负索引值。 负索引被解释为相对于相应维度的末尾。 例如,索引 -1 引用该维度的最后一个元素。

OutputTensor

类型: const DML_TENSOR_DESC*

要向其写入结果的张量。 此张 量的 DimensionCountDataType 必须与 InputTensor.DimensionCount 匹配。 预期的 OutputTensor.SizesIndicesTensor.Sizes 前导段和 InputTensor.Sizes 尾随段的串联,后者生成以下内容。

indexTupleSize = IndicesTensor.Sizes[IndicesTensor.DimensionCount - 1]
OutputTensor.Sizes = {
    1...,
    IndicesTensor.Sizes[(IndicesTensor.DimensionCount - IndicesDimensionCount) .. (IndicesTensor.DimensionCount - 1)],
    InputTensor.Sizes[(InputTensor.DimensionCount - indexTupleSize) .. InputTensor.DimensionCount]
}

维度是右对齐的,如果需要满足 OutputTensor.DimensionCount,前面附加了前导 1 个值。

下面是一个示例。

InputTensor.Sizes = {3,4,5,6,7}
InputDimensionCount = 5
IndicesTensor.Sizes = {1,1, 1,2,3}
IndicesDimensionCount = 3 // can be thought of as a {1,2} array of 3-coordinate tuples

// The {1,2} comes from the indices tensor (ignoring last dimension which is the tuple size),
// and the {6,7} comes from input tensor, ignoring the first 3 dimensions
// since the index tuples are 3 elements (from the indices tensor last dimension).
OutputTensor.Sizes = {1, 1,2,6,7}

InputDimensionCount

类型: UINT

在忽略任何不相关的前导维度(范围为 [1, *InputTensor.DimensionCount*])后,InputTensor 中实际输入维度的数目。 例如,给定 InputTensor.Sizes = {1,1,4,6}InputDimensionCount = 3,实际有意义的索引为 {1,4,6}

IndicesDimensionCount

类型: UINT

在忽略任何不相关的前导维度(范围 [1, IndexesTensor.DimensionCount])后,IndexesTensor 中实际索引维度的数目。 例如,给定 IndexesTensor.Sizes = {1,1,4,6}IndexesDimensionCount = 3,实际有意义的索引为 {1,4,6}

BatchDimensionCount

类型: UINT

每个张量中的维度数 (InputTensorIndexesTensorOutputTensor) ,这些维度被视为独立批处理,范围在 [0、 InputTensor.DimensionCount) 和 [0, IndexesTensor.DimensionCount) 范围内。 批计数可以为 0,表示单个批处理。 例如,给定 IndexesTensor.Sizes = {1,3,4,5,6,7}IndexesDimensionCount = 5 且 BatchDimensionCount = 2,有批处理 {3,4} 和有意义的索引 {5,6,7}

备注

DML_GATHER_ND1_OPERATOR_DESC添加 BatchDimensionCount,在 BatchDimensionCount = 0 时等效于 DML_GATHER_ND_OPERATOR_DESC

示例

示例 1。 1D 重新映射

InputDimensionCount: 2
IndicesDimensionCount: 2
BatchDimensionCount: 0

InputTensor: (Sizes:{2,2}, DataType:FLOAT32)
    [[0,1],[2,3]]

IndicesTensor: (Sizes:{2,1}, DataType:UINT32)
    [[1],[0]]

// output[y, x] = input[indices[y], x]
OutputTensor: (Sizes:{2,2}, DataType:FLOAT32)
    [[2,3],[0,1]]

示例 2。 使用批计数进行 2D 重新映射

InputDimensionCount: 3
IndicesDimensionCount: 3
BatchDimensionCount: 1

// 3 batches.
InputTensor: (Sizes:{1, 3,2,2}, DataType:FLOAT32)
    [
        [[[0,1],[2,3]],   // batch 0
         [[4,5],[6,7]],   // batch 1
         [[8,9],[10,11]]] // batch 2
    ]

// A 3x2 array of 2D tuples indexing into InputTensor.
// e.g. a tuple of <1,0> in batch 1 corresponds to input value 6.
IndicesTensor: (Sizes:{1, 3,2,2}, DataType:UINT32)
    [
        [[[0,0],[1,1]],
         [[1,1],[0,0]],
         [[0,1],[1,0]]]
    ]

// output[batch, x] = input[batch, indices[batch, x, 0], indices[batch, x, 1]]
OutputTensor: (Sizes:{1,1, 3,2}, DataType:FLOAT32)
    [[
        [[0,3],
         [7,4],
         [9,10]]
    ]]

可用性

此运算符是在 中引入的 DML_FEATURE_LEVEL_3_0

张量约束

  • IndexesTensorInputTensorOutputTensor 必须具有相同的 DimensionCount
  • InputTensorOutputTensor 必须具有相同的 数据类型

Tensor 支持

DML_FEATURE_LEVEL_4_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
IndicesTensor 输入 1 到 8 INT64、INT32、UINT64、UINT32
OutputTensor 输出 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_3_0及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
IndicesTensor 输入 1 到 8 INT64、INT32、UINT64、UINT32
OutputTensor 输出 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

要求

   
最低受支持的客户端 Windows 10内部版本 20348
最低受支持的服务器 Windows 10内部版本 20348
标头 directml.h