DML_TOP_K1_OPERATOR_DESC 结构 (directml.h)

沿 InputTensor 轴从每个序列中选择最大或最小的 K 元素,并分别返回 OutputValueTensorOutputIndexTensor 中这些元素的值和索引。 序列是指沿 InputTensor维度存在的元素集之一。

可以使用 AxisDirection 控制是选择最大的 K 元素还是选择最小的 K 元素。

语法

struct DML_TOP_K1_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *OutputValueTensor;
  const DML_TENSOR_DESC *OutputIndexTensor;
  UINT                  Axis;
  UINT                  K;
  DML_AXIS_DIRECTION    AxisDirection;
};

成员

InputTensor

类型: const DML_TENSOR_DESC*

包含要选择的元素的输入张量。

OutputValueTensor

类型: const DML_TENSOR_DESC*

要向其写入前 K 元素的值的输出张量。 根据 AxisDirection 的值,选择前 K 个元素是最大元素还是最小元素。 此张量的大小必须等于 InputTensor参数指定的维度除外,轴参数的大小必须等于 K

如果 axisDirection 是 (DML_AXIS_DIRECTION_DECREASING,则保证从每个输入序列中选择的 K 值按最大到最小) 降序排序。 否则,情况正好相反,并且所选值保证按从小到大) (升序排序。

OutputIndexTensor

类型: const DML_TENSOR_DESC*

要向其写入前 K 元素的索引的输出张量。 此张量的大小必须等于 InputTensor参数指定的维度除外,轴参数的大小必须等于 K

此张量中返回的索引相对于其序列 (的开头而不是张量) 的开头进行测量。 例如,索引 0 始终引用轴中所有序列的第一个元素。

如果 top-K 中的两个或更多个元素具有相同的值 (即,当存在平) 时,将包含这两个元素的索引,并保证按升序元素索引进行排序。 请注意,这与 AxisDirection 的值无关。

Axis

类型: UINT

要选择其上的元素的维度的索引。 此值必须小于 InputTensorDimensionCount

K

类型: UINT

要选择的元素数。 K 必须大于 0,但小于 InputTensor 中沿 Axis 指定的维度的元素数。

AxisDirection

类型: DML_AXIS_DIRECTION

DML_AXIS_DIRECTION枚举中的值。 如果设置为 DML_AXIS_DIRECTION_INCREASING,则此运算符将按值递增的顺序返回 最小的K 元素。 否则,它将按降序返回 最大的K 元素。

示例

示例 1

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[ 0,  1, 10, 11],
   [ 3,  2,  9,  8],
   [ 4,  5,  6,  7]]]]

Axis: 3
K:    2
AxisDirection: DML_AXIS_DIRECTION_DECREASING
   
OutputValueTensor: (Sizes:{1,1,3,2}, DataType:FLOAT32)
[[[[11, 10],
   [ 9,  8],
   [ 7,  6]]]]

OutputIndexTensor: (Sizes:{1,1,3,2}, DataType:UINT32)
[[[[3, 2],
   [2, 3],
   [3, 2]]]]

示例 2。 使用不同的轴

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[ 0,  1, 10, 11],
   [ 3,  2,  9,  8],
   [ 4,  5,  6,  7]]]]

Axis: 2
K:    2
AxisDirection: DML_AXIS_DIRECTION_DECREASING
   
OutputValueTensor: (Sizes:{1,1,2,4}, DataType:FLOAT32)
[[[[ 4,  5, 10, 11],
   [ 3,  2,  9,  8]]]]

OutputIndexTensor: (Sizes:{1,1,2,4}, DataType:UINT32)
[[[[2, 2, 0, 0],
   [1, 1, 1, 1]]]]

示例 3。 绑定值

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[1, 2, 2, 3],
   [3, 4, 5, 5],
   [6, 6, 6, 6]]]]

Axis: 3
K:    3
AxisDirection: DML_AXIS_DIRECTION_DECREASING
   
OutputValueTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[3, 2, 2],
   [5, 5, 4],
   [6, 6, 6]]]]

OutputIndexTensor: (Sizes:{1,1,3,3}, DataType:UINT32)
[[[[3, 1, 2],
   [2, 3, 1],
   [0, 1, 2]]]]

示例 4. 增加轴方向

InputTensor: (Sizes:{1,1,3,4}, DataType:FLOAT32)
[[[[1, 2, 2, 3],
   [3, 4, 5, 5],
   [6, 6, 6, 6]]]]

Axis: 3
K:    3
AxisDirection: DML_AXIS_DIRECTION_INCREASING
   
OutputValueTensor: (Sizes:{1,1,3,3}, DataType:FLOAT32)
[[[[1, 2, 2],
   [3, 4, 5],
   [6, 6, 6]]]]

OutputIndexTensor: (Sizes:{1,1,3,3}, DataType:UINT32)
[[[[0, 1, 2],
   [0, 1, 2],
   [0, 1, 2]]]]

备注

AxisDirection 设置为 DML_AXIS_DIRECTION_DECREASING 时,此运算符等效于 DML_TOP_K_OPERATOR_DESC

可用性

此运算符是在 中引入的 DML_FEATURE_LEVEL_2_1

张量约束

  • InputTensorOutputIndexTensorOutputValueTensor 必须具有相同的 DimensionCount
  • InputTensorOutputValueTensor 必须具有相同的 数据类型

Tensor 支持

DML_FEATURE_LEVEL_5_0及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputValueTensor 输出 1 到 8 FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputIndexTensor 输出 1 到 8 UINT64、UINT32

DML_FEATURE_LEVEL_3_1及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputValueTensor 输出 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputIndexTensor 输出 1 到 8 UINT32

DML_FEATURE_LEVEL_2_1及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputValueTensor 输出 4 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputIndexTensor 输出 4 UINT32

要求

   
最低受支持的客户端 Windows 10内部版本 20348
最低受支持的服务器 Windows 10内部版本 20348
标头 directml.h