DML_GATHER_ELEMENTS_OPERATOR_DESC 结构 (directml.h)
使用索引张量从输入张量沿给定轴收集元素,以重新映射到输入。 此运算符执行以下伪代码,其确切行为取决于轴、输入维度计数和索引维度计数。
output[i, j, k, ...] = input[index[i, j, k, ...], j, k, ...] // if axis == 0
output[i, j, k, ...] = input[i, index[i, j, k, ...], k, ...] // if axis == 1
output[i, j, k, ...] = input[i, j, index[i, j, k, ...], ...] // if axis == 2
...
语法
struct DML_GATHER_ELEMENTS_OPERATOR_DESC {
const DML_TENSOR_DESC *InputTensor;
const DML_TENSOR_DESC *IndicesTensor;
const DML_TENSOR_DESC *OutputTensor;
UINT Axis;
};
成员
InputTensor
类型: const DML_TENSOR_DESC*
要从中读取的张量。
IndicesTensor
类型: const DML_TENSOR_DESC*
索引沿活动轴进入输入张量。 对于除轴以外的每个维度,Size 必须匹配 InputTensor.Size。
从 DML_FEATURE_LEVEL_3_0
开始,使用此张量使用带符号整型类型时,此运算符支持负索引值。 负索引被解释为相对于轴维度的末尾。 例如,索引 -1 引用该维度的最后一个元素。
OutputTensor
类型: const DML_TENSOR_DESC*
要向其写入结果的张量。 大小必须与 IndicesTensor.Sizes 匹配,DataType 必须与 InputTensor.DataType 匹配。
Axis
类型: UINT
要沿收集的 InputTensor 的轴维度,范围为 [0, *InputTensor.DimensionCount*)
。
示例
Axis = 0
InputTensor: (Sizes:{3,3}, DataType:FLOAT32)
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
IndicesTensor: (Sizes:{2,3}, DataType:UINT32)
[[1, 2, 0],
[2, 0, 0]]
// output[y, x] = input[indices[y, x], x]
OutputTensor: (Sizes:{2,3}, DataType:UINT32)
[[4, 8, 3], // select elements vertically from data
[7, 2, 3]]
可用性
此运算符是在 中引入的 DML_FEATURE_LEVEL_2_1
。
张量约束
IndicesTensor
、 InputTensor 和 OutputTensor 必须具有相同的 DimensionCount。- InputTensor 和 OutputTensor 必须具有相同的 数据类型。
Tensor 支持
DML_FEATURE_LEVEL_4_1 及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
OutputTensor | 输出 | 1 到 8 | FLOAT64、FLOAT32、FLOAT16、INT64、INT32、INT16、INT8、UINT64、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_3_0及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 1 到 8 | INT64、INT32、UINT64、UINT32 |
OutputTensor | 输出 | 1 到 8 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
DML_FEATURE_LEVEL_2_1及更高版本
张 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
InputTensor | 输入 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
IndicesTensor | 输入 | 4 | UINT32 |
OutputTensor | 输出 | 4 | FLOAT32、FLOAT16、INT32、INT16、INT8、UINT32、UINT16、UINT8 |
要求
要求 | 值 |
---|---|
最低受支持的客户端 | Windows 10内部版本 20348 |
最低受支持的服务器 | Windows 10内部版本 20348 |
标头 | directml.h |