Nota
El acceso a esta página requiere autorización. Puede intentar iniciar sesión o cambiar directorios.
El acceso a esta página requiere autorización. Puede intentar cambiar los directorios.
Realiza una operación de atención multi head (para obtener más información, consulta Atención es todo lo que necesitas). Exactamente una consulta, la clave y el tensor de valor deben estar presentes, independientemente de si están apiladas o no. Por ejemplo, si se proporciona StackedQueryKey , los tensores Query y Key deben ser NULL, ya que ya se proporcionan en un diseño apilado. Lo mismo sucede con StackedKeyValue y StackedQueryKeyValue. Los tensores apilados siempre tienen cinco dimensiones y siempre se apilan en la cuarta dimensión.
Lógicamente, el algoritmo se puede descomponer en las siguientes operaciones (las operaciones entre corchetes son opcionales):
[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);
Importante
Esta API está disponible como parte del paquete redistribuible independiente de DirectML (consulte Microsoft.AI.DirectML versión 1.12 y posteriores). Consulte también historial de versiones de DirectML.
Sintaxis
struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
_Maybenull_ const DML_TENSOR_DESC* QueryTensor;
_Maybenull_ const DML_TENSOR_DESC* KeyTensor;
_Maybenull_ const DML_TENSOR_DESC* ValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
_Maybenull_ const DML_TENSOR_DESC* MaskTensor;
_Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
_Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
const DML_TENSOR_DESC* OutputTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
FLOAT Scale;
FLOAT MaskFilterValue;
UINT HeadCount;
DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};
Miembros
QueryTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Consulta con forma [batchSize, sequenceLength, hiddenSize]
, donde hiddenSize = headCount * headSize
. Este tensor se excluye mutuamente con StackedQueryKeyTensor y StackedQueryKeyValueTensor. El tensor también puede tener 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.
KeyTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Clave con forma [batchSize, keyValueSequenceLength, hiddenSize]
, donde hiddenSize = headCount * headSize
. Este tensor se excluye mutuamente con StackedQueryKeyTensor, StackedKeyValueTensor y StackedQueryKeyValueTensor. El tensor también puede tener 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.
ValueTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Valor con forma [batchSize, keyValueSequenceLength, valueHiddenSize]
, donde valueHiddenSize = headCount * valueHeadSize
. Este tensor se excluye mutuamente con StackedKeyValueTensor y StackedQueryKeyValueTensor. El tensor también puede tener 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.
StackedQueryKeyTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Consulta apilada y clave con forma [batchSize, sequenceLength, headCount, 2, headSize]
. Este tensor se excluye mutuamente con QueryTensor, KeyTensor, StackedKeyValueTensor y StackedQueryKeyValueTensor.
StackedKeyValueTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Clave apilada y valor con forma [batchSize, keyValueSequenceLength, headCount, 2, headSize]
. Este tensor se excluye mutuamente con KeyTensor, ValueTensor, StackedQueryKeyTensor y StackedQueryKeyValueTensor.
StackedQueryKeyValueTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Consulta apilada, clave y valor con forma [batchSize, sequenceLength, headCount, 3, headSize]
. Este tensor se excluye mutuamente con QueryTensor, KeyTensor, ValueTensor, StackedQueryKeyTensor y StackedKeyValueTensor.
BiasTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Este es el sesgo, de la forma [hiddenSize + hiddenSize + valueHiddenSize]
, que se agrega alvalor declave/de consulta/ antes de la primera operación gemm. Este tensor también puede tener 2, 3, 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.
MaskTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Esta es la máscara que determina qué elementos obtienen su valor establecido en MaskFilterValue después de la operación GEMM de QxK. El comportamiento de esta máscara depende del valor de MaskType y se aplica después de RelativePositionBiasTensor, o después de la primera operación GEMM si RelativePositionBiasTensor es null. Consulte la definición de MaskType para obtener más información.
RelativePositionBiasTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Este es el sesgo que se agrega al resultado de la primera operación GEMM.
PastKeyTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Tensor de clave de la iteración anterior con la forma [batchSize, headCount, pastSequenceLength, headSize]
. Cuando este tensor no es null, se concatena con el tensor de clave, lo que da como resultado un tensor de forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
.
PastValueTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Tensor de valor de la iteración anterior con forma [batchSize, headCount, pastSequenceLength, headSize]
. Cuando este tensor no es NULL, se concatena con ValueDesc , lo que da como resultado un tensor de forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
.
OutputTensor
Tipo: const DML_TENSOR_DESC*
Salida, de la forma [batchSize, sequenceLength, valueHiddenSize]
.
OutputPresentKeyTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Estado presente para la clave de atención cruzada, con forma [batchSize, headCount, keyValueSequenceLength, headSize]
o estado presente para la auto-atención con la forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
. Contiene el contenido del tensor de clave o el contenido del tensor declavepastkey + concatenado para pasar a la siguiente iteración.
OutputPresentValueTensor
Tipo: _Maybenull_ const DML_TENSOR_DESC*
Estado presente para el valor de atención cruzada, con forma [batchSize, headCount, keyValueSequenceLength, headSize]
o estado presente para la auto atención con la forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]
. Contiene el contenido del tensor de valor o el contenidodel tensorde valor PastValue + concatenado para pasar a la siguiente iteración.
Scale
Tipo: FLOAT
Escale para multiplicar el resultado de la operación GEMM de QxK, pero antes de la operación Softmax. Este valor suele ser 1/sqrt(headSize)
.
MaskFilterValue
Tipo: FLOAT
Valor que se agrega al resultado de la operación GEMM de QxK a las posiciones que la máscara definió como elementos de relleno. Este valor debe ser un número negativo muy grande (normalmente -10000,0f).
HeadCount
Tipo: UINT
Número de cabezas de atención.
MaskType
Tipo: DML_MULTIHEAD_ATTENTION_MASK_TYPE
Describe el comportamiento de MaskTensor.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN. Cuando la máscara contiene un valor de 0, MaskFilterValue se agrega; pero cuando contiene un valor de 1, no se agrega nada.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH. La máscara, de la forma [1, batchSize]
, contiene las longitudes de secuencia del área no rellenada para cada lote y todos los elementos después de que la longitud de la secuencia obtengan su valor establecido en MaskFilterValue.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START. La máscara, de la forma [2, batchSize]
, contiene los índices de fin (exclusivo) e inicio (inclusivo) del área no rellenada, y todos los elementos fuera del área obtienen su valor establecido en MaskFilterValue.
DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END. La máscara, de la forma [batchSize * 3 + 2]
, tiene los siguientes valores: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]]
.
Disponibilidad
Este operador se introdujo en DML_FEATURE_LEVEL_6_1.
Restricciones tensor
BiasTensor, KeyTensor, OutputPresentKeyTensor, OutputPresentValueTensor, OutputTensor, PastKeyTensor, PastValueTensor, QueryTensor, RelativePositionBiasTensor, StackedKeyValueTensor, StackedQueryKeyTensor, StackedQueryKeyValueTensor y ValueTensor deben tener el mismo DataType.
Compatibilidad con Tensor
Tensor | Variante | Recuentos de dimensiones admitidos | Supported data types (Tipos de datos admitidos) |
---|---|---|---|
QueryTensor | Entrada opcional | De 3 a 5 | FLOAT32, FLOAT16 |
KeyTensor | Entrada opcional | De 3 a 5 | FLOAT32, FLOAT16 |
ValueTensor | Entrada opcional | De 3 a 5 | FLOAT32, FLOAT16 |
StackedQueryKeyTensor | Entrada opcional | 5 | FLOAT32, FLOAT16 |
StackedKeyValueTensor | Entrada opcional | 5 | FLOAT32, FLOAT16 |
StackedQueryKeyValueTensor | Entrada opcional | 5 | FLOAT32, FLOAT16 |
BiasTensor | Entrada opcional | De 1 a 5 | FLOAT32, FLOAT16 |
MaskTensor | Entrada opcional | De 1 a 5 | INT32 |
RelativePositionBiasTensor | Entrada opcional | De 4 a 5 | FLOAT32, FLOAT16 |
PastKeyTensor | Entrada opcional | De 4 a 5 | FLOAT32, FLOAT16 |
PastValueTensor | Entrada opcional | De 4 a 5 | FLOAT32, FLOAT16 |
OutputTensor | Salida | De 3 a 5 | FLOAT32, FLOAT16 |
OutputPresentKeyTensor | Salida opcional | De 4 a 5 | FLOAT32, FLOAT16 |
OutputPresentValueTensor | Salida opcional | De 4 a 5 | FLOAT32, FLOAT16 |
Requisitos
Encabezado | directml.h |