DML_GATHER_ND1_OPERATOR_DESC 结构 (directml.h)
从输入张量收集元素,使用索引张量将索引重新映射到输入的整个子块。 此运算符执行以下伪代码,其中“...”表示一系列坐标,其确切行为取决于批处理、输入和索引维度计数。
output[batch, ...] = input[batch, indices[batch, ...], ...]
语法
struct DML_GATHER_ND1_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *IndicesTensor;
const DML_TENSOR_DESC *OutputTensor;
UINT InputDimensionCount;
UINT IndicesDimensionCount;
UINT BatchDimensionCount;
};
成员
InputTensor
类型: const DML_TENSOR_DESC*
要从中读取的张量。
IndicesTensor
类型: const DML_TENSOR_DESC*
包含索引的张量。 此张量的 DimensionCount 必须与 InputTensor.DimensionCount 匹配。 IndexesTensor 的最后一个维度实际上是每个索引元组的坐标数,并且不能超过 InputTensor.DimensionCount。 例如,IndexesDimensionCount = 3 的索引大小{1,4,5,2}
张量表示索引为 InputTensor 的 4x5 数组,该数组由 2 坐标元组编制索引。
从 DML_FEATURE_LEVEL_3_0
开始,使用此张量使用带符号整型类型时,此运算符支持负索引值。 负索引被解释为相对于相应维度的末尾。 例如,索引 -1 引用该维度的最后一个元素。
OutputTensor
类型: const DML_TENSOR_DESC*
要向其写入结果的张量。 此张 量的 DimensionCount 和 DataType 必须与 InputTensor.DimensionCount 匹配。 预期的 OutputTensor.Sizes 是 IndicesTensor.Sizes 前导段和 InputTensor.Sizes 尾随段的串联,后者生成以下内容。
indexTupleSize = IndicesTensor.Sizes[IndicesTensor.DimensionCount - 1]
OutputTensor.Sizes = {
1...,
IndicesTensor.Sizes[(IndicesTensor.DimensionCount - IndicesDimensionCount) .. (IndicesTensor.DimensionCount - 1)],
InputTensor.Sizes[(InputTensor.DimensionCount - indexTupleSize) .. InputTensor.DimensionCount]
}
维度是右对齐的,如果需要满足 OutputTensor.DimensionCount,前面附加了前导 1 个值。
下面是一个示例。
InputTensor.Sizes = {3,4,5,6,7}
InputDimensionCount = 5
IndicesTensor.Sizes = {1,1, 1,2,3}
IndicesDimensionCount = 3 // can be thought of as a {1,2} array of 3-coordinate tuples
// The {1,2} comes from the indices tensor (ignoring last dimension which is the tuple size),
// and the {6,7} comes from input tensor, ignoring the first 3 dimensions
// since the index tuples are 3 elements (from the indices tensor last dimension).
OutputTensor.Sizes = {1, 1,2,6,7}
InputDimensionCount
类型: UINT
在忽略任何不相关的前导维度(范围为 [1, *InputTensor.DimensionCount*]
)后,InputTensor 中实际输入维度的数目。 例如,给定 InputTensor.Sizes = {1,1,4,6}
和 InputDimensionCount
= 3,实际有意义的索引为 {1,4,6}
。
IndicesDimensionCount
类型: UINT
在忽略任何不相关的前导维度(范围 [1, IndexesTensor.DimensionCount])后,IndexesTensor 中实际索引维度的数目。 例如,给定 IndexesTensor.Sizes = {1,1,4,6}
和 IndexesDimensionCount = 3,实际有意义的索引为 {1,4,6}
。
BatchDimensionCount
类型: UINT
每个张量中的维度数 (InputTensor、 IndexesTensor、 OutputTensor) ,这些维度被视为独立批处理,范围在 [0、 InputTensor.DimensionCount) 和 [0, IndexesTensor.DimensionCount) 范围内。 批计数可以为 0,表示单个批处理。 例如,给定 IndexesTensor.Sizes = {1,3,4,5,6,7}
和 IndexesDimensionCount = 5 且 BatchDimensionCount
= 2,有批处理 {3,4}
和有意义的索引 {5,6,7}
。
备注
DML_GATHER_ND1_OPERATOR_DESC添加 BatchDimensionCount,在 BatchDimensionCount = 0 时等效于 DML_GATHER_ND_OPERATOR_DESC。
示例
示例 1。 1D 重新映射
InputDimensionCount: 2
IndicesDimensionCount: 2
BatchDimensionCount: 0
InputTensor: (Sizes:{2,2}, DataType:FLOAT32)
[[0,1],[2,3]]
IndicesTensor: (Sizes:{2,1}, DataType:UINT32)
[[1],[0]]
// output[y, x] = input[indices[y], x]
OutputTensor: (Sizes:{2,2}, DataType:FLOAT32)
[[2,3],[0,1]]
示例 2。 使用批计数进行 2D 重新映射
InputDimensionCount: 3
IndicesDimensionCount: 3
BatchDimensionCount: 1
// 3 batches.
InputTensor: (Sizes:{1, 3,2,2}, DataType:FLOAT32)
[
[[[0,1],[2,3]], // batch 0
[[4,5],[6,7]], // batch 1
[[8,9],[10,11]]] // batch 2
]
// A 3x2 array of 2D tuples indexing into InputTensor.
// e.g. a tuple of <1,0> in batch 1 corresponds to input value 6.
IndicesTensor: (Sizes:{1, 3,2,2}, DataType:UINT32)
[
[[[0,0],[1,1]],
[[1,1],[0,0]],
[[0,1],[1,0]]]
]
// output[batch, x] = input[batch, indices[batch, x, 0], indices[batch, x, 1]]
OutputTensor: (Sizes:{1,1, 3,2}, DataType:FLOAT32)
[[
[[0,3],
[7,4],
[9,10]]
]]
可用性
此运算符是在 中引入的 DML_FEATURE_LEVEL_3_0
。
张量约束
- IndexesTensor、 InputTensor 和 OutputTensor 必须具有相同的 DimensionCount。
- InputTensor 和 OutputTensor 必须具有相同的 数据类型。
Tensor 支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
OutputTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
OutputTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
要求
最低受支持的客户端 | Windows 10内部版本 20348 |
最低受支持的服务器 | Windows 10内部版本 20348 |
标头 | directml.h |