待定
重要
此 API 作为 DirectML 独立可再发行组件包的一部分提供(请参阅 Microsoft.AI.DirectML 版本 1.15.0 及更高版本。 另请参阅 directML 版本历史记录 。
语法
struct DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC
{
_Maybenull_ const DML_TENSOR_DESC* QueryTensor;
_Maybenull_ const DML_TENSOR_DESC* KeyTensor;
_Maybenull_ const DML_TENSOR_DESC* ValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
_Maybenull_ const DML_TENSOR_DESC* MaskTensor;
_Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
_Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
_Maybenull_ const DML_TENSOR_DESC* PastSequenceLengthsTensor;
const DML_TENSOR_DESC* OutputTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
FLOAT Scale;
FLOAT MaskFilterValue;
UINT QueryHeadCount;
UINT KeyValueHeadCount;
DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};
成员
QueryTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
使用形状 [batchSize, sequenceLength, hiddenSize]
进行查询,其中 hiddenSize = headCount * headSize
。 此张量与 StackedQueryKeyTensor 和 StackedQueryKeyValueTensor 互斥。 张量也可以有 4 或 5 个维度,只要前导维度是 1 的。
KeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
具有形状 [batchSize, keyValueSequenceLength, hiddenSize]
的键,其中 hiddenSize = headCount * headSize
。 此张量与 StackedQueryKeyTensor、 StackedKeyValueTensor 和 StackedQueryKeyValueTensor 互斥。 张量也可以有 4 或 5 个维度,只要前导维度是 1 的。
ValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
具有形状 [batchSize, keyValueSequenceLength, valueHiddenSize]
的值,其中 valueHiddenSize = headCount * valueHeadSize
。 此张量与 StackedKeyValueTensor 和 StackedQueryKeyValueTensor 互斥。 张量也可以有 4 或 5 个维度,只要前导维度是 1 的。
StackedQueryKeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
带有形状 [batchSize, sequenceLength, headCount, 2, headSize]
的堆积查询和键。 此张量与 QueryTensor、 KeyTensor、 StackedKeyValueTensor 和 StackedQueryKeyValueTensor 互斥。
StackedKeyValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
带有形状 [batchSize, keyValueSequenceLength, headCount, 2, headSize]
的堆积键和值。 此张量与 KeyTensor、 ValueTensor、 StackedQueryKeyTensor 和 StackedQueryKeyValueTensor 互斥。
StackedQueryKeyValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
带有形状 [batchSize, sequenceLength, headCount, 3, headSize]
的堆积查询、键和值。 此张量与 QueryTensor、KeyTensor、ValueTensor、StackedQueryKeyTensor 和 StackedKeyValueTensor 互斥。
BiasTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
这是形状 [hiddenSize + hiddenSize + valueHiddenSize]
的偏差,在第一个 GEMM作之前会添加到 查询/键/值 。 只要前导维度为 1,此张量也可以有 2、3、4 或 5 个维度。
MaskTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
这是一个掩码,用于确定哪些元素在 QxK GEMM作后将其值设置为 MaskFilterValue 。 此掩码的行为取决于 MaskType 的值,并在 RelativePositionBiasTensor 之后应用;如果 RelativePositionBiasTensor 为 null,则应用在第一个 GEMM作之后。 有关详细信息,请参阅 MaskType 的定义。
RelativePositionBiasTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
这是添加到第一个 GEMM作结果中的偏差。
PastKeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
上一次迭代中具有形状 [batchSize, headCount, pastSequenceLength, headSize]
的键张量。 当此张量不为 null 时,它将与键张量连接,这会导致形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
的张量。
PastValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
上一次迭代中具有形状 [batchSize, headCount, pastSequenceLength, headSize]
的值张量。 当此张量不为 null 时,它将与 ValueDesc 连接,这会导致形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
的张量。
PastSequenceLengthsTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
待定
OutputTensor
类型: const DML_TENSOR_DESC*
形状 [batchSize, sequenceLength, valueHiddenSize]
的输出。
OutputPresentKeyTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
当前状态用于交叉关注键,具有形状 [batchSize, headCount, keyValueSequenceLength, headSize]
或当前状态,以自我注意与形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
。 它包含键张量的内容或串联的 PastKey 密钥 + 张量的内容,以传递给下一次迭代。
OutputPresentValueTensor
类型:_Maybenull_ const DML_TENSOR_DESC*
交叉关注值的当前状态,具有形状 [batchSize, headCount, keyValueSequenceLength, headSize]
或当前状态,以便通过形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
引起自我注意。 它包含值张量的内容或串联 的 PastValue + 值 张量的内容,以传递给下一次迭代。
Scale
类型: FLOAT
缩放以将 QxK GEMM作的结果相乘,但在 Softmax作之前。 此值通常是 1/sqrt(headSize)
。
MaskFilterValue
类型: FLOAT
将 QxK GEMM作结果添加到定义为填充元素的掩码的位置的值。 此值应为非常大的负数(通常为 -10000.0f)。
QueryHeadCount
类型: UINT
待定
KeyValueHeadCount
类型: UINT
待定
MaskType
类型: DML_MULTIHEAD_ATTENTION_MASK_TYPE
描述 MaskTensor 的行为。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN。 当掩码包含值 0 时,将添加 MaskFilterValue ;但当它包含值 1 时,不会添加任何内容。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH。 形状 [1, batchSize]
的掩码包含每个批的未板区域的序列长度,序列长度后的所有元素都将其值设置为 MaskFilterValue。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START。 形状 [2, batchSize]
的掩码包含已解压缩区域的结束(独占)和开始(非独占)索引,区域外部的所有元素都将其值设置为 MaskFilterValue。
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END。 形状[batchSize * 3 + 2]
的掩码具有以下值: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]]
可用性
此运算符是在 DML_FEATURE_LEVEL_6_3 中引入的。
Tensor 约束
BiasTensor、KeyTensor、OutputPresentKeyTensor、OutputPresentValueTensor、OutputTensor、PastKeyTensor、PastSequenceLengthsTensor、PastValueTensor、QueryTensor、RelativePositionBiasTensor、StackedKeyValueTensor、StackedQueryKeyTensor、StackedQueryKeyValueTensor 和 ValueTensor 必须具有相同的 DataType。
Tensor 支持
张量 | 种类 | 支持的维度计数 | 支持的数据类型 |
---|---|---|---|
QueryTensor 的 | 可选输入 | 3 到 5 | FLOAT32,FLOAT16 |
KeyTensor 的 | 可选输入 | 3 到 5 | FLOAT32,FLOAT16 |
ValueTensor | 可选输入 | 3 到 5 | FLOAT32,FLOAT16 |
StackedQueryKeyTensor | 可选输入 | 5 | FLOAT32,FLOAT16 |
StackedKeyValueTensor | 可选输入 | 5 | FLOAT32,FLOAT16 |
StackedQueryKeyValueTensor | 可选输入 | 5 | FLOAT32,FLOAT16 |
BiasTensor | 可选输入 | 1 到 5 | FLOAT32,FLOAT16 |
MaskTensor 的 | 可选输入 | 1 到 5 | INT32 |
RelativePositionBiasTensor | 可选输入 | 4 到 5 | FLOAT32,FLOAT16 |
PastKeyTensor | 可选输入 | 4 到 5 | FLOAT32,FLOAT16 |
PastValueTensor | 可选输入 | 4 到 5 | FLOAT32,FLOAT16 |
PastSequenceLengthsTensor | 可选输入 | 1 到 5 | FLOAT32,FLOAT16 |
OutputTensor | 输出 | 3 到 5 | FLOAT32,FLOAT16 |
输出PresentKeyTensor | 可选输出 | 4 到 5 | FLOAT32,FLOAT16 |
输出PresentValueTensor | 可选输出 | 4 到 5 | FLOAT32,FLOAT16 |