directml.h) (DML_GATHER_OPERATOR_DESC 结构

使用 IndicesTensor 重新映射索引,沿从输入张量中收集元素。 此运算符执行以下伪代码,其中“...”表示一系列坐标,其确切行为由轴和索引维度计数确定:

output[...] = input[..., indices[...], ...]

语法

struct DML_GATHER_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *IndicesTensor;
  const DML_TENSOR_DESC *OutputTensor;
  UINT                  Axis;
  UINT                  IndexDimensions;
};

成员

InputTensor

类型: const DML_TENSOR_DESC*

要从中读取的张量。

IndicesTensor

类型: const DML_TENSOR_DESC*

包含索引的张量。 此张量维度的 DimensionCount 必须与 InputTensor.DimensionCount 匹配

DML_FEATURE_LEVEL_3_0开始,将此张量一起使用带符号整数类型时,此运算符支持负索引值。 负索引被解释为相对于轴维度的末尾。 例如,索引为 -1 表示沿该维度的最后一个元素。

无效索引将产生不正确的输出,但不会失败,并且所有读取都将安全地固定在输入张量内存中。

OutputTensor

类型: const DML_TENSOR_DESC*

将结果写入到的张量。 此张量维度的 DimensionCountDataType 必须与 InputTensor.DimensionCount 匹配。 预期的 OutputTensor.Sizes 是在当前上拆分的 InputTensor.Sizes 前导段和尾部段的串联,其中插入了 IndicesTensor.Sizes

OutputTensor.Sizes = {
    InputTensor.Sizes[0..Axis],
    IndicesTensor.Sizes[(IndicesTensor.DimensionCount - IndexDimensions) .. IndicesTensor.DimensionCount],
    InputTensor.Sizes[(Axis+1) .. InputTensor.DimensionCount]
}

维度是右对齐的,以便裁剪输入大小中的任何前导 1 值,否则会溢出输出 DimensionCount

此张量中相关维度的数目取决于 IndexDimensionsInputTensor的原始秩。 原始秩是使用前导维度填充之前维度的数目。 输出中相关维度的数目可以通过 InputTensor + IndexDimensions - 1 的原始排名来计算。 此值必须小于或等于 OutputTensorDimensionCount

Axis

类型: UINT

要收集的 InputTensor 的轴维度,范围为 [0, *InputTensor.DimensionCount*)

IndexDimensions

类型: UINT

忽略任何不相关的前导维度后的实际索引维度 IndicesTensor 数,范围 [0, IndicesTensor.DimensionCount) 。 例如,给定 IndicesTensor.Sizes = { 1, 1, 4, 6 }IndexDimensions = 3,实际有意义的索引为 。{ 1, 4, 6 }

示例

示例 1。 1D 重新映射

Axis: 0
IndexDimensions: 1

InputTensor: (Sizes:{4}, DataType:FLOAT32)
    [11,12,13,14]

IndicesTensor: (Sizes:{5}, DataType:UINT32)
    [3,1,3,0,2]

// output[x] = input[indices[x]]
OutputTensor: (Sizes:{5}, DataType:FLOAT32)
    [14,12,14,11,13]

示例 2。 2D 输出、1D 索引、轴 0、串联行

Axis: 0
IndexDimensions: 1

InputTensor: (Sizes:{3,2}, DataType:FLOAT32)
    [[1,2], // row 0
     [3,4], // row 1
     [5,6]] // row 2

IndicesTensor: (Sizes:{1, 4}, DataType:UINT32)
    [[0,
      1,
      1,
      2]]

// output[y, x] = input[indices[y], x]
OutputTensor: (Sizes:{4,2}, DataType:FLOAT32)
    [[1,2], // input row 0
     [3,4], // input row 1
     [3,4], // input row 1
     [5,6]] // input row 2

示例 3。 2D,轴 1,交换列

Axis: 1
IndexDimensions: 2

InputTensor: (Sizes:{3,2}, DataType:FLOAT32)
    [[1,2],
     [3,4],
     [5,6]]

IndicesTensor: (Sizes:{1, 2}, DataType:UINT32)
    [[1,0]]

// output[y, x] = input[y, indices[x]]
OutputTensor: (Sizes:{3,2}, DataType:FLOAT32)
    [[2,1],
     [4,3],
     [6,5]]

示例 4. 2D、轴 1、嵌套索引

Axis: 2
IndexDimensions: 2

InputTensor: (Sizes:{1, 3,3}, DataType:FLOAT32)
    [ [[1,2,3],
       [4,5,6],
       [7,8,9]] ]

IndicesTensor: (Sizes:{1, 1,2}, DataType:UINT32)
    [ [[0,2]] ]

// output[z, y, x] = input[z, indices[y, x]]
OutputTensor: (Sizes:{3,1,2}, DataType:FLOAT32)
    [[[1,3]],
     [[4,6]],
     [[7,9]]]

示例 5. 2D、轴 0、嵌套索引

Axis: 1
IndexDimensions: 2

InputTensor: (Sizes:{1, 3,2}, DataType:FLOAT32)
    [ [[1,2],
       [3,4],
       [5,6]] ]

IndicesTensor: (Sizes:{1, 2,2}, DataType:UINT32)
    [ [[0,1],
       [1,2]] ]

// output[z, y, x] = input[indices[z, y], x]
OutputTensor: (Sizes:{2,2,2}, DataType:FLOAT32)
    [[[1,2], [3,4]],
     [[3,4], [5,6]]]

可用性

此运算符是在 中 DML_FEATURE_LEVEL_1_0引入的。

张量约束

  • IndicesTensorInputTensorOutputTensor 必须具有相同的 DimensionCount
  • InputTensorOutputTensor 必须具有相同的 数据类型

张量支持

DML_FEATURE_LEVEL_4_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
IndicesTensor 输入 1 到 8 INT64、INT32、UINT64、UINT32
OutputTensor 输出 1 到 8 FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_3_0 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
IndicesTensor 输入 1 到 8 INT64、INT32、UINT64、UINT32
OutputTensor 输出 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_2_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
IndicesTensor 输入 4 UINT32
OutputTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8

DML_FEATURE_LEVEL_1_0 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16
IndicesTensor 输入 4 UINT32
OutputTensor 输出 4 FLOAT32、FLOAT16

要求

   
标头 directml.h