DML_MULTIHEAD_ATTENTION_OPERATOR_DESC 構造体 (directml.h)

マルチヘッド アテンション操作を実行します (詳細については、「必要なものはアテンション」を参照してください)。 スタックされているかどうかに関係なく、1 つのクエリキーのテンソルのみが存在する必要があります。 たとえば、StackedQueryKey が指定されている場合、クエリテンソルとキー テンソルの両方が null である必要があります。これらは既にスタック レイアウトで提供されているためです。 StackedKeyValueStackedQueryKeyValue も同様です。 スタックされたテンソルは常に 5 つの次元を持ち、常に 4 番目の次元に積み上げされます。

論理的には、アルゴリズムを次の操作に分解できます (角かっこ内の操作は省略可能)。

[Add Bias to query/key/value] -> GEMM(Query, Transposed(Key)) * Scale -> [Add RelativePositionBias] -> [Add Mask] -> Softmax -> GEMM(SoftmaxResult, Value);

重要

この API は、DirectML スタンドアロン再頒布可能パッケージの一部として使用できます (Microsoft.AI.DirectML バージョン 1.12 以降を参照してください)。 DirectML バージョン履歴も参照してください。

構文

struct DML_MULTIHEAD_ATTENTION_OPERATOR_DESC
{
    _Maybenull_ const DML_TENSOR_DESC* QueryTensor;
    _Maybenull_ const DML_TENSOR_DESC* KeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* ValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* StackedQueryKeyValueTensor;
    _Maybenull_ const DML_TENSOR_DESC* BiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* MaskTensor;
    _Maybenull_ const DML_TENSOR_DESC* RelativePositionBiasTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* PastValueTensor;
    const DML_TENSOR_DESC* OutputTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentKeyTensor;
    _Maybenull_ const DML_TENSOR_DESC* OutputPresentValueTensor;
    FLOAT Scale;
    FLOAT MaskFilterValue;
    UINT HeadCount;
    DML_MULTIHEAD_ATTENTION_MASK_TYPE MaskType;
};

メンバー

QueryTensor

型: _Maybenull_ const DML_TENSOR_DESC*

hiddenSize = headCount * headSize の場合、図形 [batchSize, sequenceLength, hiddenSize] を使用してクエリを実行します。 このテンソルは、StackedQueryKeyTensor および StackedQueryKeyValueTensor と相互に排他的です。 テンソルは、先頭の次元が 1 である限り、4 つまたは 5 つの次元を持つこともできます。

KeyTensor

型: _Maybenull_ const DML_TENSOR_DESC*

hiddenSize = headCount * headSize の場合の、形状 [batchSize, keyValueSequenceLength, hiddenSize] のキー。 このテンソルは、StackedQueryKeyTensorStackedKeyValueTensor、および StackedQueryKeyValueTensor と相互に排他的です。 テンソルは、先頭の次元が 1 である限り、4 つまたは 5 つの次元を持つこともできます。

ValueTensor

型: _Maybenull_ const DML_TENSOR_DESC*

valueHiddenSize = headCount * valueHeadSize の場合に形状 [batchSize, keyValueSequenceLength, valueHiddenSize] を持つ値。 このテンソルは、StackedKeyValueTensor および StackedQueryKeyValueTensor と相互に排他的です。 テンソルは、先頭の次元が 1 である限り、4 つまたは 5 つの次元を持つこともできます。

StackedQueryKeyTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, sequenceLength, headCount, 2, headSize] のスタック クエリとキー。 このテンソルは、QueryTensorKeyTensorStackedKeyValueTensorStackedQueryKeyValueTensor と相互に排他的です。

StackedQueryKeyTensor layout

StackedKeyValueTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, keyValueSequenceLength, headCount, 2, headSize] のスタック キーと値。 このテンソルは、KeyTensorValueTensorStackedQueryKeyTensorStackedQueryKeyValueTensor と相互に排他的です。

StackedKeyValueTensor layout

StackedQueryKeyValueTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, sequenceLength, headCount, 3, headSize] のスタック クエリ、キー、値。 このテンソルは、QueryTensoKeyTensorValueTensorStackedQueryKeyTensorStackedKeyValueTensor と相互に排他的です。

StackedQueryKeyValueTensor layout

BiasTensor

型: _Maybenull_ const DML_TENSOR_DESC*

これは、最初の GEMM 操作の前にクエリ/キー/に追加される形状 [hiddenSize + hiddenSize + valueHiddenSize] のバイアスです。 このテンソルは、先頭の寸法が 1 である限り、2、3、4、または 5 次元を持つこともできます。

MaskTensor

型: _Maybenull_ const DML_TENSOR_DESC*

これは、QxK GEMM 操作の後に MaskFilterValue に設定された値を取得する要素を決定するマスクです。 このマスクの動作は MaskType の値に依存し、RelativePositionBiasTensor の後、または RelativePositionBiasTensor が null の場合は最初の GEMM 操作の後に適用されます。 詳細については、MaskType の定義を参照してください。

RelativePositionBiasTensor

型: _Maybenull_ const DML_TENSOR_DESC*

これは、最初の GEMM 操作の結果に追加されるバイアスです。

PastKeyTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, headCount, pastSequenceLength, headSize] を持つ、前のイテレーションのキー テンソル。 このテンソルが null でない場合は、鍵 テンソルと連結され、形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] のテンソルになります。

PastValueTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, headCount, pastSequenceLength, headSize] の、前のイテレーションの値テンソル。 このテンソルが null でない場合は、ValueDesc と連結され、形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] のテンソルになります。

OutputTensor

型: const DML_TENSOR_DESC*

形状 [batchSize, sequenceLength, valueHiddenSize] の出力。

OutputPresentKeyTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, headCount, keyValueSequenceLength, headSize] の場合はクロス アテンション キーの、形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] の場合はセルフ アテンション キーの現在の状態。 キー テンソルの内容、または次のイテレーションに渡す連結された PastKey + キー テンソルの内容が含まれます。

OutputPresentValueTensor

型: _Maybenull_ const DML_TENSOR_DESC*

形状 [batchSize, headCount, keyValueSequenceLength, headSize] の場合はクロス アテンション値の、形状 [batchSize, headCount, pastSequenceLength + keyValueSequenceLength, headSize] の場合はセルフ アテンション 値の現在の状態。 値テンソルの内容、または次のイテレーションに渡すまたは連結された PastValue + Value テンソルの内容が含まれます。

Scale

型: FLOAT

QxK GEMM 操作の結果を、Softmax 操作の前にスケーリングし、乗算します。 通常、その値は 1/sqrt(headSize) です。

MaskFilterValue

型: FLOAT

マスクが埋め込み要素として定義した位置で、 QxK GEMM 操作の結果に追加される値。 この値は、非常に大きな負の数 (通常は -10000.0f) にする必要があります。

HeadCount

型: UINT

アテンション ヘッドの数。

MaskType

型: DML_MULTIHEAD_ATTENTION_MASK_TYPE

MaskTensor の動作について説明します。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN。 マスクに 0 の値が含まれている場合、MaskFilterValue が追加されますが、値 1 が含まれている場合、何も追加されません。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH。 形状 [1, batchSize] のマスクには、各バッチの埋め込みなし領域のシーケンス長が含まれており、シーケンスの長さより後のすべての要素の値が MaskFilterValue に設定されます。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START。 形状 [2, batchSize] のマスクには、埋め込みなしの領域の終了 (排他) インデックスと開始 (包括) インデックスが含まれており、領域外のすべての要素の値が MaskFilterValue に設定されます。

DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END。 形状 [batchSize * 3 + 2]のマスクには、次の値があります: [keyLength[0], ..., keyLength[batchSize - 1], queryStart[0], ..., queryStart[batchSize - 1], queryEnd[batchSize - 1], keyStart[0], ..., keyStart[batchSize - 1], keyEnd[batchSize - 1]]

可用性

この演算子は、DML_FEATURE_LEVEL_6_1で導入されました。

Tensor 制約

BiasTensorKeyTensorOutputPresentKeyTensorOutputPresentValueTensorOutputTensorPastKeyTensorPastValueTensorQueryTensorRelativePositionBiasTensorStackedKeyValueTensorStackedQueryKeyTensorStackedQueryKeyValueTensor、および ValueTensor には、同じ DataType が必要です。

Tensor のサポート

Tensor 種類 サポートされているディメンション数 サポートされるデータ型
QueryTensor 省略可能な入力 3 から 5 まで FLOAT32、FLOAT16
KeyTensor 省略可能な入力 3 から 5 まで FLOAT32、FLOAT16
ValueTensor 省略可能な入力 3 から 5 まで FLOAT32、FLOAT16
StackedQueryKeyTensor 省略可能な入力 5 FLOAT32、FLOAT16
StackedKeyValueTensor 省略可能な入力 5 FLOAT32、FLOAT16
StackedQueryKeyValueTensor 省略可能な入力 5 FLOAT32、FLOAT16
BiasTensor 省略可能な入力 1 から 5 FLOAT32、FLOAT16
MaskTensor 省略可能な入力 1 から 5 INT32
RelativePositionBiasTensor 省略可能な入力 4 から 5 FLOAT32、FLOAT16
PastKeyTensor 省略可能な入力 4 から 5 FLOAT32、FLOAT16
PastValueTensor 省略可能な入力 4 から 5 FLOAT32、FLOAT16
OutputTensor 出力 3 から 5 まで FLOAT32、FLOAT16
OutputPresentKeyTensor 省略可能な出力 4 から 5 FLOAT32、FLOAT16
OutputPresentValueTensor 省略可能な出力 4 から 5 FLOAT32、FLOAT16

要件

   
ヘッダー directml.h