Share via


DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC-Struktur (directml.h)

Berechnet Backpropagationsverläufe für die Batchnormalisierung. DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC führt mehrere Berechnungen durch, die in den separaten Ausgabebeschreibungen beschrieben sind.

OutputScaleGradientTensor und OutputBiasGradientTensor werden mithilfe von Summen für den Dimensionssatz berechnet, für die die Größen MeanTensor, ScaleTensor und VarianceTensor gleich einem sind.

Syntax

struct DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC {
  const DML_TENSOR_DESC *InputTensor;
  const DML_TENSOR_DESC *InputGradientTensor;
  const DML_TENSOR_DESC *MeanTensor;
  const DML_TENSOR_DESC *VarianceTensor;
  const DML_TENSOR_DESC *ScaleTensor;
  const DML_TENSOR_DESC *OutputGradientTensor;
  const DML_TENSOR_DESC *OutputScaleGradientTensor;
  const DML_TENSOR_DESC *OutputBiasGradientTensor;
  FLOAT                 Epsilon;
};

Member

InputTensor

Typ: const DML_TENSOR_DESC*

Ein Tensor, der die Eingabedaten enthält. Dies ist in der Regel derselbe Tensor, der als InputTensor bereitgestellt wurde, um im Vorwärtsdurchlauf DML_BATCH_NORMALIZATION_OPERATOR_DESC .

InputGradientTensor

Typ: const DML_TENSOR_DESC*

Der eingehende Gradienten-Tensor. Dies wird in der Regel aus der Ausgabe der Backpropagation einer vorherigen Ebene abgerufen.

MeanTensor

Typ: const DML_TENSOR_DESC*

Ein Tensor, der die Mitteldaten enthält. Dies ist in der Regel derselbe Tensor, der als MeanTensor bereitgestellt wurde, um im Vorwärtsdurchlauf DML_BATCH_NORMALIZATION_OPERATOR_DESC .

VarianceTensor

Typ: const DML_TENSOR_DESC*

Ein Tensor, der die Varianzdaten enthält. Dies ist in der Regel derselbe Tensor, der als VarianceTensor bereitgestellt wurde, um im Vorwärtsdurchlauf DML_OPERATOR_BATCH_NORMALIZATION .

ScaleTensor

Typ: const DML_TENSOR_DESC*

Ein Tensor, der die Skalierungsdaten enthält. Dies ist in der Regel derselbe Tensor, der als ScaleTensor bereitgestellt wurde, um im Vorwärtsdurchlauf DML_BATCH_NORMALIZATION_OPERATOR_DESC .

OutputGradientTensor

Typ: const DML_TENSOR_DESC*

Für jeden entsprechenden Wert in den Eingaben ist OutputGradient = InputGradient * (Scale / sqrt(Variance + Epsilon)).

OutputScaleGradientTensor

Typ: const DML_TENSOR_DESC*

Die folgende Berechnung erfolgt oder jeder entsprechende Wert in den Eingaben.

OutputScaleGradient = sum(InputGradient * (Input - Mean) / sqrt(Variance + Epsilon))

OutputBiasGradientTensor

Typ: const DML_TENSOR_DESC*

Die folgende Berechnung erfolgt oder jeder entsprechende Wert in den Eingaben.

OutputBiasGradient = sum(InputGradient)

Epsilon

Typ: FLOAT

Ein kleiner Wert, der der Varianz hinzugefügt wird, um null zu vermeiden.

Hinweise

Verfügbarkeit

Dieser Operator wurde in DML_FEATURE_LEVEL_3_1eingeführt.

Tensoreinschränkungen

  • InputGradientTensor, InputTensor, MeanTensor, OutputBiasGradientTensor, OutputGradientTensor, OutputScaleGradientTensor, ScaleTensor und VarianceTensor müssen denselben DataType und DimensionCount aufweisen.
  • MeanTensor, OutputBiasGradientTensor, OutputScaleGradientTensor, ScaleTensor und VarianceTensor müssen die gleichen Größen aufweisen.
  • InputGradientTensor, InputTensor und OutputGradientTensor müssen die gleichen Größen aufweisen.

Tensorunterstützung

Tensor Variante Dimensionen Unterstützte Dimensionsanzahl Unterstützte Datentypen
InputTensor Eingabe { InputDimensions[] } 1 bis 8 FLOAT32, FLOAT16
InputGradientTensor Eingabe { InputDimensions[] } 1 bis 8 FLOAT32, FLOAT16
MeanTensor Eingabe { MeanDimensions[] } 1 bis 8 FLOAT32, FLOAT16
VarianceTensor Eingabe { MeanDimensions[] } 1 bis 8 FLOAT32, FLOAT16
ScaleTensor Eingabe { MeanDimensions[] } 1 bis 8 FLOAT32, FLOAT16
OutputGradientTensor Ausgabe { InputDimensions[] } 1 bis 8 FLOAT32, FLOAT16
OutputScaleGradientTensor Ausgabe { MeanDimensions[] } 1 bis 8 FLOAT32, FLOAT16
OutputBiasGradientTensor Ausgabe { MeanDimensions[] } 1 bis 8 FLOAT32, FLOAT16

Anforderungen

Anforderung Wert
Unterstützte Mindestversion (Client) Windows Build 22000
Unterstützte Mindestversion (Server) Windows Build 22000
Kopfzeile directml.h