structure DML_GEMM_OPERATOR_DESC (directml.h)
Effectue une fonction de multiplication de matrice générale de la forme Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
, où x
désigne la multiplication de matrice et *
désigne la multiplication avec un scalaire.
Cet opérateur nécessite des tenseurs 4D avec disposition { BatchCount, ChannelCount, Height, Width }
, et il effectuera BatchCount * ChannelCount nombre de multiplications de matrice indépendantes.
Par exemple, si ATensor a Sizes de { BatchCount, ChannelCount, M, K }
, et BTensor a Sizes de { BatchCount, ChannelCount, K, N }
, et OutputTensor a Sizes de { BatchCount, ChannelCount, M, N }
, alors cet opérateur effectue BatchCount * ChannelCount des multiplications de matrice indépendantes de dimensions {M,K} x {K,N} = {M,N}.
Syntaxe
struct DML_GEMM_OPERATOR_DESC {
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *CTensor;
const DML_TENSOR_DESC *OutputTensor;
DML_MATRIX_TRANSFORM TransA;
DML_MATRIX_TRANSFORM TransB;
FLOAT Alpha;
FLOAT Beta;
const DML_OPERATOR_DESC *FusedActivation;
};
Membres
ATensor
Type : const DML_TENSOR_DESC*
Tenseur contenant la matrice A. Les tailles de ce tenseur doivent être { BatchCount, ChannelCount, M, K }
si TransA est DML_MATRIX_TRANSFORM_NONE ou { BatchCount, ChannelCount, K, M }
si TransA est DML_MATRIX_TRANSFORM_TRANSPOSE.
BTensor
Type : const DML_TENSOR_DESC*
Tenseur contenant la matrice B. Les tailles de ce tenseur doivent être { BatchCount, ChannelCount, K, N }
si TransB est DML_MATRIX_TRANSFORM_NONE ou { BatchCount, ChannelCount, N, K }
si TransB est DML_MATRIX_TRANSFORM_TRANSPOSE.
CTensor
Type : _Maybenull_ const DML_TENSOR_DESC*
Tenseur contenant la matrice C, ou nullptr
. Les valeurs par défaut sont 0 lorsqu’elles ne sont pas fournies. Si elle est fournie, les tailles de ce tenseur doivent être { BatchCount, ChannelCount, M, N }
.
OutputTensor
Type : const DML_TENSOR_DESC*
Tenseur dans lequel écrire les résultats. Les tailles de ce tenseur sont { BatchCount, ChannelCount, M, N }
.
TransA
Type : DML_MATRIX_TRANSFORM
Transformation à appliquer à ATensor ; soit une transposition, soit aucune transformation.
TransB
Type : DML_MATRIX_TRANSFORM
Transformation à appliquer à BTensor ; soit une transposition, soit aucune transformation.
Alpha
Type : FLOAT
Valeur du multiplicateur scalaire pour le produit des entrées ATensor et BTensor.
Beta
Type : FLOAT
Valeur du multiplicateur scalaire pour le CTensor d’entrée facultatif. Si CTensor n’est pas fourni, cette valeur est ignorée.
FusedActivation
Type : _Maybenull_ const DML_OPERATOR_DESC*
Couche d’activation fusionnée facultative à appliquer après le GEMM. Pour plus d’informations, consultez Utilisation d’opérateurs fusionnés pour améliorer les performances.
Disponibilité
Cet opérateur a été introduit dans DML_FEATURE_LEVEL_1_0
.
Contraintes tensoriels
- ATensor, BTensor, CTensor et OutputTensor doivent avoir les mêmes DataType et DimensionCount.
- CTensor et OutputTensor doivent avoir les mêmes tailles.
Prise en charge de Tensor
DML_FEATURE_LEVEL_4_0 et versions ultérieures
Tenseur | Genre | Dimensions | Nombre de dimensions pris en charge | Types de données pris en charge |
---|---|---|---|---|
ATensor | Entrée | { [BatchCount], [ChannelCount], M, K } | 2 à 4 | FLOAT32, FLOAT16 |
BTensor | Entrée | { [BatchCount], [ChannelCount], K, N } | 2 à 4 | FLOAT32, FLOAT16 |
CTensor | Entrée facultative | { [BatchCount], [ChannelCount], M, N } | 2 à 4 | FLOAT32, FLOAT16 |
OutputTensor | Sortie | { [BatchCount], [ChannelCount], M, N } | 2 à 4 | FLOAT32, FLOAT16 |
DML_FEATURE_LEVEL_1_0 et versions ultérieures
Tenseur | Genre | Dimensions | Nombre de dimensions pris en charge | Types de données pris en charge |
---|---|---|---|---|
ATensor | Entrée | { BatchCount, ChannelCount, M, K } | 4 | FLOAT32, FLOAT16 |
BTensor | Entrée | { BatchCount, ChannelCount, K, N } | 4 | FLOAT32, FLOAT16 |
CTensor | Entrée facultative | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
OutputTensor | Sortie | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32, FLOAT16 |
Configuration requise
Condition requise | Valeur |
---|---|
En-tête | directml.h |