Compartir a través de


estructura DML_MULTIHEAD_ATTENTION_OPERATOR_DESC (directml.h)

Realiza una operación de atención multi head (para obtener más información, consulta Atención es todo lo que necesitas). Exactamente una consulta, la clave y el tensor de valor deben estar presentes, independientemente de si están apiladas o no. Por ejemplo, si se proporciona StackedQueryKey , los tensores Query y Key deben ser NULL, ya que ya se proporcionan en un diseño apilado. Lo mismo sucede con StackedKeyValue y StackedQueryKeyValue. Los tensores apilados siempre tienen cinco dimensiones y siempre se apilan en la cuarta dimensión.

Lógicamente, el algoritmo se puede descomponer en las siguientes operaciones (las operaciones entre corchetes son opcionales):

[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);

Importante

Esta API está disponible como parte del paquete redistribuible independiente de DirectML (consulte Microsoft.AI.DirectML versión 1.12 y posteriores). Consulte también historial de versiones de DirectML.

Sintaxis

struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT HeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

Miembros

QueryTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Consulta con forma [batchSize, sequenceLength, hiddenSize], donde hiddenSize = headCount * headSize. Este tensor se excluye mutuamente con StackedQueryKeyTensor y StackedQueryKeyValueTensor. El tensor también puede tener 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.

KeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Clave con forma [batchSize, keyValueSequenceLength, hiddenSize], donde hiddenSize = headCount * headSize. Este tensor se excluye mutuamente con StackedQueryKeyTensor, StackedKeyValueTensor y StackedQueryKeyValueTensor. El tensor también puede tener 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.

ValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Valor con forma [batchSize, keyValueSequenceLength, valueHiddenSize], donde valueHiddenSize = headCount * valueHeadSize. Este tensor se excluye mutuamente con StackedKeyValueTensor y StackedQueryKeyValueTensor. El tensor también puede tener 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.

StackedQueryKeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Consulta apilada y clave con forma [batchSize, sequenceLength, headCount, 2, headSize]. Este tensor se excluye mutuamente con QueryTensor, KeyTensor, StackedKeyValueTensor y StackedQueryKeyValueTensor.

Diseño stackedQueryKeyTensor

StackedKeyValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Clave apilada y valor con forma [batchSize, keyValueSequenceLength, headCount, 2, headSize]. Este tensor se excluye mutuamente con KeyTensor, ValueTensor, StackedQueryKeyTensor y StackedQueryKeyValueTensor.

Diseño stackedKeyValueTensor

StackedQueryKeyValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Consulta apilada, clave y valor con forma [batchSize, sequenceLength, headCount, 3, headSize]. Este tensor se excluye mutuamente con QueryTensor, KeyTensor, ValueTensor, StackedQueryKeyTensor y StackedKeyValueTensor.

Diseño stackedQueryKeyValueTensor

BiasTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Este es el sesgo, de la forma [hiddenSize + hiddenSize + valueHiddenSize], que se agrega alvalor declave/de consulta/ antes de la primera operación gemm. Este tensor también puede tener 2, 3, 4 o 5 dimensiones, siempre que las dimensiones iniciales sean 1.

MaskTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Esta es la máscara que determina qué elementos obtienen su valor establecido en MaskFilterValue después de la operación GEMM de QxK. El comportamiento de esta máscara depende del valor de MaskType y se aplica después de RelativePositionBiasTensor, o después de la primera operación GEMM si RelativePositionBiasTensor es null. Consulte la definición de MaskType para obtener más información.

RelativePositionBiasTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Este es el sesgo que se agrega al resultado de la primera operación GEMM.

PastKeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Tensor de clave de la iteración anterior con la forma [batchSize, headCount, pastSequenceLength, headSize]. Cuando este tensor no es null, se concatena con el tensor de clave, lo que da como resultado un tensor de forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

PastValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Tensor de valor de la iteración anterior con forma [batchSize, headCount, pastSequenceLength, headSize]. Cuando este tensor no es NULL, se concatena con ValueDesc , lo que da como resultado un tensor de forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize].

OutputTensor

Tipo: const DML_TENSOR_DESC*

Salida, de la forma [batchSize, sequenceLength, valueHiddenSize].

OutputPresentKeyTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Estado presente para la clave de atención cruzada, con forma [batchSize, headCount, keyValueSequenceLength, headSize] o estado presente para la auto-atención con la forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Contiene el contenido del tensor de clave o el contenido del tensor declavepastkey + concatenado para pasar a la siguiente iteración.

OutputPresentValueTensor

Tipo: _Maybenull_ const DML_TENSOR_DESC*

Estado presente para el valor de atención cruzada, con forma [batchSize, headCount, keyValueSequenceLength, headSize] o estado presente para la auto atención con la forma [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize]. Contiene el contenido del tensor de valor o el contenidodel tensorde valor PastValue + concatenado para pasar a la siguiente iteración.

Scale

Tipo: FLOAT

Escale para multiplicar el resultado de la operación GEMM de QxK, pero antes de la operación Softmax. Este valor suele ser 1/sqrt(headSize).

MaskFilterValue

Tipo: FLOAT

Valor que se agrega al resultado de la operación GEMM de QxK a las posiciones que la máscara definió como elementos de relleno. Este valor debe ser un número negativo muy grande (normalmente -10000,0f).

HeadCount

Tipo: UINT

Número de cabezas de atención.

MaskType

Tipo: DML_MULTIHEAD_ATTENTION_MASK_TYPE

Describe el comportamiento de MaskTensor.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN. Cuando la máscara contiene un valor de 0, MaskFilterValue se agrega; pero cuando contiene un valor de 1, no se agrega nada.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH. La máscara, de la forma [1, batchSize], contiene las longitudes de secuencia del área no rellenada para cada lote y todos los elementos después de que la longitud de la secuencia obtengan su valor establecido en MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START. La máscara, de la forma [2, batchSize], contiene los índices de fin (exclusivo) e inicio (inclusivo) del área no rellenada, y todos los elementos fuera del área obtienen su valor establecido en MaskFilterValue.

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END. La máscara, de la forma [batchSize * 3 + 2], tiene los siguientes valores: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]].

Disponibilidad

Este operador se introdujo en DML_FEATURE_LEVEL_6_1.

Restricciones tensor

BiasTensor, KeyTensor, OutputPresentKeyTensor, OutputPresentValueTensor, OutputTensor, PastKeyTensor, PastValueTensor, QueryTensor, RelativePositionBiasTensor, StackedKeyValueTensor, StackedQueryKeyTensor, StackedQueryKeyValueTensor y ValueTensor deben tener el mismo DataType.

Compatibilidad con Tensor

Tensor Variante Recuentos de dimensiones admitidos Supported data types (Tipos de datos admitidos)
QueryTensor Entrada opcional De 3 a 5 FLOAT32, FLOAT16
KeyTensor Entrada opcional De 3 a 5 FLOAT32, FLOAT16
ValueTensor Entrada opcional De 3 a 5 FLOAT32, FLOAT16
StackedQueryKeyTensor Entrada opcional 5 FLOAT32, FLOAT16
StackedKeyValueTensor Entrada opcional 5 FLOAT32, FLOAT16
StackedQueryKeyValueTensor Entrada opcional 5 FLOAT32, FLOAT16
BiasTensor Entrada opcional De 1 a 5 FLOAT32, FLOAT16
MaskTensor Entrada opcional De 1 a 5 INT32
RelativePositionBiasTensor Entrada opcional De 4 a 5 FLOAT32, FLOAT16
PastKeyTensor Entrada opcional De 4 a 5 FLOAT32, FLOAT16
PastValueTensor Entrada opcional De 4 a 5 FLOAT32, FLOAT16
OutputTensor Salida De 3 a 5 FLOAT32, FLOAT16
OutputPresentKeyTensor Salida opcional De 4 a 5 FLOAT32, FLOAT16
OutputPresentValueTensor Salida opcional De 4 a 5 FLOAT32, FLOAT16

Requisitos

   
Encabezado directml.h