DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC结构 (directml.h)

待定

重要

此 API 作为 DirectML 独立可再发行组件包的一部分提供(请参阅 Microsoft.AI.DirectML 版本 1.15.0 及更高版本。 另请参阅 directML 版本历史记录

语法

struct DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastSequenceLengthsTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT QueryHeadCount;
    UINT KeyValueHeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

成员

QueryTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

使用形状 [batchSize, sequenceLength, hiddenSize]进行查询,其中 hiddenSize = headCount * headSize。 此张量与 StackedQueryKeyTensorStackedQueryKeyValueTensor 互斥。 张量也可以有 4 或 5 个维度,只要前导维度是 1 的。

KeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

具有形状 [batchSize, keyValueSequenceLength, hiddenSize]的键,其中 hiddenSize = headCount * headSize。 此张量与 StackedQueryKeyTensorStackedKeyValueTensorStackedQueryKeyValueTensor 互斥。 张量也可以有 4 或 5 个维度,只要前导维度是 1 的。

ValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

具有形状 [batchSize, keyValueSequenceLength, valueHiddenSize]的值,其中 valueHiddenSize = headCount * valueHeadSize。 此张量与 StackedKeyValueTensorStackedQueryKeyValueTensor 互斥。 张量也可以有 4 或 5 个维度,只要前导维度是 1 的。

StackedQueryKeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

带有形状 [batchSize, sequenceLength, headCount, 2, headSize]的堆积查询和键。 此张量与 QueryTensorKeyTensorStackedKeyValueTensorStackedQueryKeyValueTensor 互斥。

StackedQueryKeyTensor 布局

StackedKeyValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

带有形状 [batchSize, keyValueSequenceLength, headCount, 2, headSize]的堆积键和值。 此张量与 KeyTensorValueTensorStackedQueryKeyTensorStackedQueryKeyValueTensor 互斥。

StackedKeyValueTensor 布局

StackedQueryKeyValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

带有形状 [batchSize, sequenceLength, headCount, 3, headSize]的堆积查询、键和值。 此张量与 QueryTensor、KeyTensorValueTensorStackedQueryKeyTensor 和 StackedKeyValueTensor 互斥。

StackedQueryKeyValueTensor 布局

BiasTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

这是形状 [hiddenSize + hiddenSize + valueHiddenSize]的偏差,在第一个 GEMM作之前会添加到 查询// 。 只要前导维度为 1,此张量也可以有 2、3、4 或 5 个维度。

MaskTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

这是一个掩码,用于确定哪些元素在 QxK GEMM作后将其值设置为 MaskFilterValue 。 此掩码的行为取决于 MaskType 的值,并在 RelativePositionBiasTensor 之后应用;如果 RelativePositionBiasTensor 为 null,则应用在第一个 GEMM作之后。 有关详细信息,请参阅 MaskType 的定义。

RelativePositionBiasTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

这是添加到第一个 GEMM作结果中的偏差。

PastKeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

上一次迭代中具有形状 [batchSize, headCount, pastSequenceLength, headSize]的键张量。 当此张量不为 null 时,它将与键张量连接,这会导致形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]的张量。

PastValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

上一次迭代中具有形状 [batchSize, headCount, pastSequenceLength, headSize]的值张量。 当此张量不为 null 时,它将与 ValueDesc 连接,这会导致形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]的张量。

PastSequenceLengthsTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

待定

OutputTensor

类型: const DML_TENSOR_DESC*

形状 [batchSize, sequenceLength, valueHiddenSize]的输出。

OutputPresentKeyTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

当前状态用于交叉关注键,具有形状 [batchSize, headCount, keyValueSequenceLength, headSize] 或当前状态,以自我注意与形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]。 它包含键张量的内容或串联的 PastKey 密钥 + 张量的内容,以传递给下一次迭代。

OutputPresentValueTensor

类型:_Maybenull_ const DML_TENSOR_DESC*

交叉关注值的当前状态,具有形状 [batchSize, headCount, keyValueSequenceLength, headSize] 或当前状态,以便通过形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]引起自我注意。 它包含值张量的内容或串联 的 PastValue + 张量的内容,以传递给下一次迭代。

Scale

类型: FLOAT

缩放以将 QxK GEMM作的结果相乘,但在 Softmax作之前。 此值通常是 1/sqrt(headSize)

MaskFilterValue

类型: FLOAT

将 QxK GEMM作结果添加到定义为填充元素的掩码的位置的值。 此值应为非常大的负数(通常为 -10000.0f)。

QueryHeadCount

类型: UINT

待定

KeyValueHeadCount

类型: UINT

待定

MaskType

类型: DML_MULTIHEAD_ATTENTION_MASK_TYPE

描述 MaskTensor 的行为。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN。 当掩码包含值 0 时,将添加 MaskFilterValue ;但当它包含值 1 时,不会添加任何内容。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH。 形状 [1, batchSize]的掩码包含每个批的未板区域的序列长度,序列长度后的所有元素都将其值设置为 MaskFilterValue

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START。 形状 [2, batchSize]的掩码包含已解压缩区域的结束(独占)和开始(非独占)索引,区域外部的所有元素都将其值设置为 MaskFilterValue

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END。 形状[batchSize * 3 + 2]的掩码具有以下值: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]]

可用性

此运算符是在 DML_FEATURE_LEVEL_6_3 中引入的。

Tensor 约束

BiasTensorKeyTensorOutputPresentKeyTensorOutputPresentValueTensorOutputTensorPastKeyTensorPastSequenceLengthsTensor、PastValueTensorQueryTensorRelativePositionBiasTensorStackedKeyValueTensorStackedQueryKeyTensor、StackedQueryKeyValueTensorValueTensor 必须具有相同的 DataType

Tensor 支持

张量 种类 支持的维度计数 支持的数据类型
QueryTensor 的 可选输入 3 到 5 FLOAT32,FLOAT16
KeyTensor 的 可选输入 3 到 5 FLOAT32,FLOAT16
ValueTensor 可选输入 3 到 5 FLOAT32,FLOAT16
StackedQueryKeyTensor 可选输入 5 FLOAT32,FLOAT16
StackedKeyValueTensor 可选输入 5 FLOAT32,FLOAT16
StackedQueryKeyValueTensor 可选输入 5 FLOAT32,FLOAT16
BiasTensor 可选输入 1 到 5 FLOAT32,FLOAT16
MaskTensor 的 可选输入 1 到 5 INT32
RelativePositionBiasTensor 可选输入 4 到 5 FLOAT32,FLOAT16
PastKeyTensor 可选输入 4 到 5 FLOAT32,FLOAT16
PastValueTensor 可选输入 4 到 5 FLOAT32,FLOAT16
PastSequenceLengthsTensor 可选输入 1 到 5 FLOAT32,FLOAT16
OutputTensor 输出 3 到 5 FLOAT32,FLOAT16
输出PresentKeyTensor 可选输出 4 到 5 FLOAT32,FLOAT16
输出PresentValueTensor 可选输出 4 到 5 FLOAT32,FLOAT16