DML_GEMM_OPERATOR_DESC 結構 (directml.h)
執行形式的 Output = FusedActivation(Alpha * TransA(A) x TransB(B) + Beta * C)
一般矩陣乘法函式,其中 x
表示矩陣乘法,並以 *
純量表示乘法。
此運算子需要具有版面配置的 { BatchCount, ChannelCount, Height, Width }
4D 張量,而且會執行 BatchCount * ChannelCount 獨立矩陣乘法的數目。
例如,如果 ATensor 的大小為{ BatchCount, ChannelCount, M, K }
,而 BTensor{ BatchCount, ChannelCount, K, N }
的大小為 ,而 OutputTensor的大小{ BatchCount, ChannelCount, M, N }
為 ,則此運算符會執行維度 {M,K} x {K} = {M,K} = {M,N} = {M,N} 的 BatchCount * ChannelCount 獨立矩陣乘法。
語法
struct DML_GEMM_OPERATOR_DESC {
const DML_TENSOR_DESC *ATensor;
const DML_TENSOR_DESC *BTensor;
const DML_TENSOR_DESC *CTensor;
const DML_TENSOR_DESC *OutputTensor;
DML_MATRIX_TRANSFORM TransA;
DML_MATRIX_TRANSFORM TransB;
FLOAT Alpha;
FLOAT Beta;
const DML_OPERATOR_DESC *FusedActivation;
};
成員
ATensor
類型: const DML_TENSOR_DESC*
包含 A 矩陣的張量。 如果 TransA是DML_MATRIX_TRANSFORM_NONE,或 { BatchCount, ChannelCount, K, M }
TransA是DML_MATRIX_TRANSFORM_TRANSPOSE,則此張量的大小應該是 { BatchCount, ChannelCount, M, K }
。
BTensor
類型: const DML_TENSOR_DESC*
包含 B 矩陣的張量。 如果 TransB是DML_MATRIX_TRANSFORM_NONE,或 { BatchCount, ChannelCount, N, K }
TransB是DML_MATRIX_TRANSFORM_TRANSPOSE,則此張量的大小應該是 { BatchCount, ChannelCount, K, N }
。
CTensor
類型: _Maybenull_ const DML_TENSOR_DESC*
包含 C 矩陣或 nullptr
的張量。 未提供時,值預設為 0。 如果提供,則這個 Tensor 的大小 應該是 { BatchCount, ChannelCount, M, N }
。
OutputTensor
類型: const DML_TENSOR_DESC*
要寫入結果的張量。 這個 Tensor 的大小 為 { BatchCount, ChannelCount, M, N }
。
TransA
要套用至 ATensor 的轉換;轉置或無轉換。
TransB
要套用至 BTensor 的轉換;轉置或無轉換。
Alpha
類型: FLOAT
輸入 ATensor 和 BTensor 乘積的純量乘數值。
Beta
類型: FLOAT
選擇性輸入 CTensor 的純量乘數值。 如果未提供 CTensor ,則會忽略此值。
FusedActivation
類型:_Maybenull_ const DML_OPERATOR_DESC*
要在 GEMM 之後套用的選擇性融合啟用層。 如需詳細資訊,請參閱 使用 fused 運算子改善效能。
可用性
這個運算子是在 中 DML_FEATURE_LEVEL_1_0
引進的。
Tensor 條件約束
- ATensor、BTensor、CTensor 和 OutputTensor 必須具有相同的 DataType 和 DimensionCount。
- CTensor 和 OutputTensor 必須具有相同 的大小。
Tensor 支援
DML_FEATURE_LEVEL_4_0和更新版本
張 | 種類 | 維度 | 支援的維度計數 | 支援的資料類型 |
---|---|---|---|---|
ATensor | 輸入 | { [BatchCount], [ChannelCount], M, K } | 2 到 4 | FLOAT32、FLOAT16 |
BTensor | 輸入 | { [BatchCount], [ChannelCount], K, N } | 2 到 4 | FLOAT32、FLOAT16 |
CTensor | 選擇性輸入 | { [BatchCount], [ChannelCount], M, N } | 2 到 4 | FLOAT32、FLOAT16 |
OutputTensor | 輸出 | { [BatchCount], [ChannelCount], M, N } | 2 到 4 | FLOAT32、FLOAT16 |
DML_FEATURE_LEVEL_1_0和更新版本
張 | 種類 | 維度 | 支援的維度計數 | 支援的資料類型 |
---|---|---|---|---|
ATensor | 輸入 | { BatchCount, ChannelCount, M, K } | 4 | FLOAT32、FLOAT16 |
BTensor | 輸入 | { BatchCount, ChannelCount, K, N } | 4 | FLOAT32、FLOAT16 |
CTensor | 選擇性輸入 | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32、FLOAT16 |
OutputTensor | 輸出 | { BatchCount, ChannelCount, M, N } | 4 | FLOAT32、FLOAT16 |
規格需求
需求 | 值 |
---|---|
標頭 | directml.h |