Compartir a través de


DirectMLX

DirectMLX es una biblioteca auxiliar solo de encabezado de C++ para DirectML, diseñada para facilitar la redacción de operadores individuales en grafos.

DirectMLX proporciona encapsuladores cómodos para todos los tipos de operadores de DirectML (DML), así como sobrecargas de operador intuitivas, lo que facilita la creación de instancias de operadores DML y encadenarlos a gráficos complejos.

Dónde encontrar DirectMLX.h

DirectMLX.h se distribuye como software de código abierto bajo la licencia MIT. La última versión se puede encontrar en el repositorio de GitHub de DirectML.

Requisitos de versión

DirectMLX requiere la versión 1.4.0 o posterior de DirectML (consulte el historial de versiones de DirectML). No se admiten versiones anteriores de DirectML.

DirectMLX.h requiere un compilador compatible con C++11, entre los que se incluyen los siguientes:

  • Visual Studio 2017
  • Visual Studio 2019
  • Clang 10

Tenga en cuenta que un compilador de C++17 (o posterior) es la opción recomendada. Es posible compilar para C++11, pero requiere el uso de bibliotecas de terceros (como GSL y Abseil) para reemplazar la funcionalidad de biblioteca estándar que falta.

Si tiene una configuración que no puede compilar DirectMLX.h, envíe un problema en nuestro GitHub.

Uso básico

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

dml::Graph graph(device);

// Input tensor of type FLOAT32 and sizes { 1, 2, 3, 4 }
auto x = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, {1, 2, 3, 4}));

// Create an operator to compute the square root of x
auto y = dml::Sqrt(x);

// Compile a DirectML operator from the graph. When executed, this compiled operator will compute
// the square root of its input.
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { y });

// Now initialize and dispatch the DML operator as usual

Este es otro ejemplo, que crea un gráfico de DirectML capaz de calcular la fórmula cuadrática.

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

std::pair<dml::Expression, dml::Expression>
    QuadraticFormula(dml::Expression a, dml::Expression b, dml::Expression c)
{
    // Quadratic formula: given an equation of the form ax^2 + bx + c = 0, x can be found by:
    //   x = -b +/- sqrt(b^2 - 4ac) / (2a)
    // https://en.wikipedia.org/wiki/Quadratic_formula

    // Note: DirectMLX provides operator overloads for common mathematical expressions. So for 
    // example a*c is equivalent to dml::Multiply(a, c).
    auto x1 = -b + dml::Sqrt(b*b - 4*a*c) / (2*a);
    auto x2 = -b - dml::Sqrt(b*b - 4*a*c) / (2*a);

    return { x1, x2 };
}

/* ... */

dml::Graph graph(device);

dml::TensorDimensions inputSizes = {1, 2, 3, 4};
auto a = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto b = dml::InputTensor(graph, 1, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto c = dml::InputTensor(graph, 2, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));

auto [x1, x2] = QuadraticFormula(a, b, c);

// When executed with input tensors a, b, and c, this compiled operator computes the two outputs
// of the quadratic formula, and returns them as two output tensors x1 and x2
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { x1, x2 });

// Now initialize and dispatch the DML operator as usual

Más ejemplos

Puede encontrar ejemplos completos con DirectMLX en el repositorio de GitHub de DirectML.

Opciones de tiempo de compilación

DirectMLX admite el valor #define's del tiempo de compilación para personalizar varias partes del encabezado.

Opción Descripción
DMLX_NO_EXCEPTIONS Si el valor es #define'd, los errores generan una llamada a std::abort en lugar de iniciar una excepción. Esto se define de forma predeterminada si no hay excepciones disponibles (por ejemplo, si las excepciones se han deshabilitado en las opciones del compilador).
DMLX_USE_WIL Si el valor es #define'd, las excepciones se inician mediante los tipos de excepciones de la Biblioteca de implementación de Windows. De lo contrario, se usan los tipos de excepción estándar (como std::runtime_error). Esta opción no tiene ningún efecto si se define DMLX_NO_EXCEPTIONS.
DMLX_USE_ABSEIL Si el valor es #define' d, se usa Abseil como reemplazo para los tipos de biblioteca estándar no disponibles en C++11. Estos tipos incluyen absl::optional (en lugar de std::optional), absl::Span (en lugar de std::span) y absl::InlinedVector.
DMLX_USE_GSL Controla si se va a usar GSL como reemplazo de std::span. Si el valor es #define'd, los usos de std::span se reemplazan por gsl::span en los compiladores sin implementaciones nativas std::span. De lo contrario, se proporciona una implementación provisional insertada. Tenga en cuenta que esta opción solo se usa al compilar en un compilador anterior a C++20 sin compatibilidad con std::span, y cuando no se usa ningún otro reemplazo de biblioteca estándar (como Abseil).

Control del diseño de tensores

Para la mayoría de los operadores, DirectMLX calcula las propiedades de los tensores de salida del operador en su nombre. Por ejemplo, al realizar un dml::Reduce entre los ejes { 0, 2, 3 } con un tensor de entrada de tamaños { 3, 4, 5, 6 }, DirectMLX calculará automáticamente las propiedades del tensor de salida, incluida la forma correcta de { 1, 4, 1, 1 }.

Sin embargo, entre las otras propiedades de un tensor de salida se incluyen Strides, TotalTensorSizeInBytes y GuaranteedBaseOffsetAlignment. De forma predeterminada, DirectMLX establece estas propiedades de forma que el tensor no tenga ningún intervalo, ninguna alineación de desplazamiento base garantizada y un tamaño total de tensor en bytes calculado mediante DMLCalcBufferTensorSize.

DirectMLX admite la capacidad de personalizar estas propiedades de tensor de salida mediante objetos conocidos como directivas de tensor. TensorPolicy es una devolución de llamada personalizable invocada por DirectMLX que devuelve propiedades de tensor de salida según el tipo de datos calculado, las marcas y los tamaños de un tensor.

Las directivas del tensor se pueden establecer en el objeto dml::Graph y se usarán para todos los operadores subsiguientes de ese gráfico. Las directivas del tensor también se pueden establecer directamente al construir un TensorDesc.

Por lo tanto, el diseño de tensores que produce DirectMLX se puede controlar estableciendo un valor para TensorPolicy, que define los intervalos adecuados en sus tensores.

Ejemplo 1

// Define a policy, which is a function that returns a TensorProperties given a data type,
// flags, and sizes.
dml::TensorProperties MyCustomPolicy(
    DML_TENSOR_DATA_TYPE dataType,
    DML_TENSOR_FLAGS flags,
    Span<const uint32_t> sizes)
{
    // Compute your custom strides, total tensor size in bytes, and guaranteed base
    // offset alignment
    dml::TensorProperties props;
    props.strides = /* ... */;
    props.totalTensorSizeInBytes = /* ... */;
    props.guaranteedBaseOffsetAlignment = /* ... */;
    return props;
};

// Set the policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy(&MyCustomPolicy));

Ejemplo 2

DirectMLX también proporciona algunas directivas de tensor alternativas integradas. La directiva InterleavedChannel, por ejemplo, se proporciona por practicidad, y se puede usar para generar tensores con intervalos para que se escriban en orden NHWC.

// Set the InterleavedChannel policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy::InterleavedChannel());

// When executed, the tensor `result` will be in NHWC layout (rather than the default NCHW)
auto result = dml::Convolution(/* ... */);

Consulte también