DML_ARGMAX_OPERATOR_DESC 结构 (directml.h)

输出输入张量一个或多个维度内最大值元素的索引。

每个输出元素都是对输入张量子集应用 argmax 缩减的结果。 argmax 函数输出一组输入元素中最大值元素的索引。 每个缩减所涉及的输入元素由提供的输入轴确定。 同样,每个输出索引都与提供的输入轴相关。 如果指定了所有输入轴,运算符将应用单个 argmax 缩减,并生成单个输出元素。

语法

struct DML_ARGMAX_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *OutputTensor;
  UINT                  AxisCount;
  const UINT            *Axes;
  DML_AXIS_DIRECTION    AxisDirection;
};

成员

InputTensor

类型: const DML_TENSOR_DESC*

要从中读取的张量。

OutputTensor

类型: const DML_TENSOR_DESC*

要向其写入结果的张量。 每个输出元素都是 InputTensor 中元素子集的 argmax 缩减的结果。

  • DimensionCount 必须与 InputTensor.DimensionCount 匹配, (输入张量排名保留) 。
  • 大小 必须与 InputTensor.Size 匹配,但缩小 的轴中包含的维度除外,其大小必须为 1。

AxisCount

类型: UINT

要减少的轴数。 此字段确定 Axes 数组的大小。

Axes

类型:_Field_size_ (AxisCount) const UINT*

要沿其减小的轴。 值必须位于 范围 [0, InputTensor.DimensionCount - 1]中。

AxisDirection

类型: DML_AXIS_DIRECTION

确定当多个输入元素具有相同值时要选择的索引。

  • DML_AXIS_DIRECTION_INCREASING 返回第一个最大值元素 (的索引, argmax({3,2,1,2,3}) = 0 例如,)
  • DML_AXIS_DIRECTION_DECREASING 返回最后一个最大值元素 (的索引, argmax({3,2,1,2,3}) = 4 例如,)

示例

本部分中的示例都使用相同的二维输入张量。

InputTensor: (Sizes:{3, 3}, DataType:FLOAT32)
[[1, 2, 3],
 [3, 0, 4],
 [2, 5, 2]]

示例 1。 将 argmax 应用于列

AxisCount: 1
Axes: {0}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{1, 3}, DataType:UINT32)
[[1,  // argmax({1, 3, 2})
  2,  // argmax({2, 0, 5})
  1]] // argmax({3, 4, 2})

示例 2。 将 argmax 应用于行

AxisCount: 1
Axes: {1}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{3, 1}, DataType:UINT32)
[[2], // argmax({1, 2, 3})
 [2], // argmax({3, 0, 4})
 [1]] // argmax({2, 5, 2})

示例 3。 将 argmax 应用于整个张量) (所有轴

AxisCount: 2
Axes: {0, 1}
AxisDirection: DML_AXIS_DIRECTION_INCREASING
OutputTensor: (Sizes:{1, 1}, DataType:UINT32)
[[7]]  // argmax({1, 2, 3, 3, 0, 4, 2, 5, 2})

备注

输出张量大小必须与输入张量大小相同,但缩小的轴必须为 1。

DML_AXIS_DIRECTION_INCREASINGAxisDirection 时,此 API 等效于使用 DML_REDUCE_FUNCTION_ARGMAXDML_REDUCE_OPERATOR_DESC

此功能的子集通过 DML_REDUCE_OPERATOR_DESC 运算符公开,在早期的 DirectML 功能级别上受支持。

可用性

此运算符是在 中引入的 DML_FEATURE_LEVEL_3_0

张量约束

InputTensorOutputTensor 必须具有相同的 DimensionCount

Tensor 支持

DML_FEATURE_LEVEL_4_1 及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8
OutputTensor 输出 1 到 8 INT64、INT32、UINT64、UINT32

DML_FEATURE_LEVEL_3_0及更高版本

种类 支持的维度计数 支持的数据类型
InputTensor 输入 1 到 8 FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8
OutputTensor 输出 1 到 8 INT64、INT32、UINT64、UINT32

要求

   
最低受支持的客户端 Windows 10内部版本 20348
最低受支持的服务器 Windows 10内部版本 20348
标头 directml.h