共用方式為


DirectMLX

DirectMLX 是 DirectML 的C++標頭專用協助程式連結庫,旨在讓您更輕鬆地將個別運算符撰寫成圖表。

DirectMLX 為所有 DirectML (DML) 運算符類型提供方便的包裝函式,以及直覺式運算元多載,讓您更輕鬆地具現化 DML 運算符,並將其鏈結至複雜的圖形。

在哪裡可以找到 DirectMLX.h

DirectMLX.h 在 MIT 授權下以開放原始碼軟體的形式散發。 您可以在 DirectML GitHub 上找到最新版本。

版本需求

DirectMLX 需要 DirectML 1.4.0 版或更新版本(請參閱 DirectML 版本歷程記錄)。 不支援舊版的 DirectML。

DirectMLX.h 需要支援 C++11 的編譯程式,包括 (但不限於):

  • Visual Studio 2017
  • Visual Studio 2019
  • Clang 10 (叮噹 10)

請注意,C++17 (或更新版本) 編譯程式是我們建議的選項。 可以編譯C++11,但需要使用第三方連結庫(例如 GSLAbseil)來取代遺漏的標準連結庫功能。

如果您有無法編譯 DirectMLX.h的組態,請在 我們的 GitHub 上提出問題

基本用法

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

dml::Graph graph(device);

// Input tensor of type FLOAT32 and sizes { 1, 2, 3, 4 }
auto x = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, {1, 2, 3, 4}));

// Create an operator to compute the square root of x
auto y = dml::Sqrt(x);

// Compile a DirectML operator from the graph. When executed, this compiled operator will compute
// the square root of its input.
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { y });

// Now initialize and dispatch the DML operator as usual

以下是另一個範例,其會建立能夠計算 二次方公式的 DirectML 圖形。

#include <DirectML.h>
#include <DirectMLX.h>

IDMLDevice* device;

/* ... */

std::pair<dml::Expression, dml::Expression>
    QuadraticFormula(dml::Expression a, dml::Expression b, dml::Expression c)
{
    // Quadratic formula: given an equation of the form ax^2 + bx + c = 0, x can be found by:
    //   x = -b +/- sqrt(b^2 - 4ac) / (2a)
    // https://en.wikipedia.org/wiki/Quadratic_formula

    // Note: DirectMLX provides operator overloads for common mathematical expressions. So for 
    // example a*c is equivalent to dml::Multiply(a, c).
    auto x1 = -b + dml::Sqrt(b*b - 4*a*c) / (2*a);
    auto x2 = -b - dml::Sqrt(b*b - 4*a*c) / (2*a);

    return { x1, x2 };
}

/* ... */

dml::Graph graph(device);

dml::TensorDimensions inputSizes = {1, 2, 3, 4};
auto a = dml::InputTensor(graph, 0, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto b = dml::InputTensor(graph, 1, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));
auto c = dml::InputTensor(graph, 2, dml::TensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, inputSizes));

auto [x1, x2] = QuadraticFormula(a, b, c);

// When executed with input tensors a, b, and c, this compiled operator computes the two outputs
// of the quadratic formula, and returns them as two output tensors x1 and x2
DML_EXECUTION_FLAGS flags = DML_EXECUTION_FLAG_NONE;
ComPtr<IDMLCompiledOperator> op = graph.Compile(flags, { x1, x2 });

// Now initialize and dispatch the DML operator as usual

其他範例

您可以在 DirectML GitHub 存放庫中找到使用 DirectMLX 的完整範例。

編譯時間選項

DirectMLX 支援編譯時期 #define 來客製化標頭的各個部分。

選項 說明
DMLX_NO_EXCEPTIONS 如果使用#define,則錯誤會導致呼叫std::abort ,而非擲回例外狀況。 如果例外狀況無法使用,則預設會定義這個值(例如,如果編譯程式選項中已停用例外狀況)。
DMLX_USE_WIL 如果定義了 #define,則會拋出 Windows 實作庫 的例外狀況類型。 否則,會改用標準例外狀況類型(例如 std::runtime_error)。 如果定義 DMLX_NO_EXCEPTIONS ,這個選項就沒有作用。
DMLX_USE_ABSEIL 如果 #define 被設定,請使用 Abseil 來替代 C++11 中無法使用的標準庫類型。 這些類型包括 absl::optional (取代 std::optional)、 absl::Span (取代 std::span) 和 absl::InlinedVector
DMLX_USE_GSL 控制是否要使用 GSL 來取代 std::span。 如果 #define'd,在沒有原生std::span實作的編譯器上,gsl::span 的用法會被替換為 std::span。 否則,會改為提供內嵌置入實作。 請注意,此選項僅在使用不支援 C++20 的編譯器進行編譯時使用,且未使用其他如 Abseil 的替代標準函式庫時才適用。

控制張量佈局

對於大多數運算符,DirectMLX 會代表您計算運算符輸出張量的屬性。 例如,當在dml::Reduce軸上執行大小為{ 0, 2, 3 }的輸入張量{ 3, 4, 5, 6 }時,DirectMLX 會自動計算輸出張量的屬性,包括{ 1, 4, 1, 1 }的正確形狀。

不過,輸出張量的其他屬性包括 StridesTotalTensorSizeInBytesGuaranteedBaseOffsetAlignment。 在預設情況下,DirectMLX 會設定這些屬性,使得張量沒有步幅、沒有保證的基底偏移對齊,以及 DMLCalcBufferTensorSize 所計算的位元組總大小。

DirectMLX 支援透過使用被稱為張量策略的對象來自定義這些輸出張量屬性的能力。 TensorPolicy 是由 DirectMLX 調用的可自訂回調,根據張量計算出的數據類型、標誌和尺寸返回輸出張量的屬性。

Tensor 原則可以在 dml::Graph 物件上設定,並且將用於該圖形上的所有後續運算符。 當建構 TensorDesc 時,也可以直接設定 Tensor 原則。

因此,DirectMLX 所產生的張量配置可以透過設定 TensorPolicy 來控制,以設定其張量的適當步幅。

範例 1

// Define a policy, which is a function that returns a TensorProperties given a data type,
// flags, and sizes.
dml::TensorProperties MyCustomPolicy(
    DML_TENSOR_DATA_TYPE dataType,
    DML_TENSOR_FLAGS flags,
    Span<const uint32_t> sizes)
{
    // Compute your custom strides, total tensor size in bytes, and guaranteed base
    // offset alignment
    dml::TensorProperties props;
    props.strides = /* ... */;
    props.totalTensorSizeInBytes = /* ... */;
    props.guaranteedBaseOffsetAlignment = /* ... */;
    return props;
};

// Set the policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy(&MyCustomPolicy));

範例 2

DirectMLX 也提供一些內建的替代張量處理策略。 例如,InterleavedChannel 原則為提高便利性而提供,可以用來產生具有特定跨度的張量,使其以 NHWC 順序編寫。

// Set the InterleavedChannel policy on the dml::Graph
dml::Graph graph(/* ... */);
graph.SetTensorPolicy(dml::TensorPolicy::InterleavedChannel());

// When executed, the tensor `result` will be in NHWC layout (rather than the default NCHW)
auto result = dml::Convolution(/* ... */);

另請參閱