Skip to content

Commit

Permalink
Add separate enums for the different types of interpretation
Browse files Browse the repository at this point in the history
  • Loading branch information
damyanp committed Feb 27, 2025
1 parent 3102cc5 commit 1cf7cfc
Showing 1 changed file with 95 additions and 44 deletions.
139 changes: 95 additions & 44 deletions proposals/0029-cooperative-vector.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,13 @@ can be found in [Minimum Support Set].

### Type Interpretations

The various "interpretation" arguments specify a value from the following enum:
#### From-Register Interpretations

Input vectors stored in registers (eg `vector<float, 16>`) are interpreted
according to values from the following enum:

```c++
enum class DXILTypeInterpretation :uint {
enum class DXILRegisterInterpretation : uint {
Float16 = 0,
Float32 = 1,
UnsignedInt8 = 2,
Expand All @@ -351,16 +354,42 @@ enum class DXILTypeInterpretation :uint {
};
```
For matrices and vectors that are specified by resource handles and stored in
raw-buffers, the interpretation value directly specifies the element type. It
is invalid to specify a packed interpretation in these cases.
For these vectors there is a distinction between the physical type and the
logical type. The **input interpretation** argument for these vectors describes
how to convert from the physical to logical type. This allows elements to be
interpreted as types not natively supported by HLSL, e.g. uint8/sint8. For
packed interpretations, a single physical element can expand into multiple
logical elements.
#### Memory Interpretations
For input vectors that come from variables there is a distinction between the
physical type and the logical type. The **input interpretation** argument for
these vectors describes how to convert from the physical to logical type. This
allows elements to be interpreted as types not natively supported by HLSL, e.g.
uint8/sint8. For packed interpretations, a single physical element can expand
into multiple logical elements.
Matrices and Vectors that are stored in raw-buffers and specified by resource
handles (eg the matrix and bias-vector arguments to dx.op.matvecmul) are
interpreted according to values from the following enum:
The various "interpretation" arguments specify a value from the following enum:
```c++
enum class DXILMemoryInterpretation :uint {
Float16 = 0,
Float32 = 1,
UnsignedInt8 = 2,
UnsignedInt16 = 3,
UnsignedInt32 = 4,
SignedInt8 = 5,
SignedInt16 = 6,
SignedInt32 = 7,
FloatE4M3 = 8,
FloatE5M2 = 9,
Unsupported = 32
};
```

The interpretation value directly specifies the element type. Note that there
are no packed interpretation types for memory interpretations.

#### CheckFeatureSupport

[CheckFeatureSupport] can be used to determine what combinations of **TYi**,
**input interpretation**, **matrix interpretation**, **matrix transpose**,
Expand Down Expand Up @@ -530,16 +559,38 @@ typedef enum D3D12_FEATURE {
typedef enum D3D12_COOPERATIVE_VECTOR_DATATYPE {
D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT16 = 0,
D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT32 = 1,
D3D12_COOPERATIVE_VECTOR_DATATYPE_UINT8 = 2,
D3D12_COOPERATIVE_VECTOR_DATATYPE_UINT16 = 3,
D3D12_COOPERATIVE_VECTOR_DATATYPE_UINT32 = 4,
D3D12_COOPERATIVE_VECTOR_DATATYPE_SINT8 = 5,
D3D12_COOPERATIVE_VECTOR_DATATYPE_SINT16 = 6,
D3D12_COOPERATIVE_VECTOR_DATATYPE_SINT32 = 7,
D3D12_COOPERATIVE_VECTOR_DATATYPE_SINT8_PACKED = 8,
D3D12_COOPERATIVE_VECTOR_DATATYPE_UINT8_PACKED = 9,
D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT_E4M3 = 10, // FP8: 1 sign bit, 4 exp bits, 3 mantissa bits
D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT_E5M2 = 11 // FP8: 1 sign bit, 5 exp bits, 2 mantissa bits
D3D12_COOPERATIVE_VECTOR_DATATYPE_UINT16 = 2,
D3D12_COOPERATIVE_VECTOR_DATATYPE_UINT32 = 3,
D3D12_COOPERATIVE_VECTOR_DATATYPE_SINT16 = 4,
D3D12_COOPERATIVE_VECTOR_DATATYPE_SINT32 = 5,
};

typedef enum D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE {
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_FLOAT16 = 0,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_FLOAT32 = 1,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_UINT8 = 2,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_UINT16 = 3,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_UINT32 = 4,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_SINT8 = 5,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_SINT16 = 6,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_SINT32 = 7,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_SINT8_PACKED = 8,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_UINT8_PACKED = 9,
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_FLOAT_E4M3 = 10, // FP8: 1 sign bit, 4 exp bits, 3 mantissa bits
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE_FLOAT_E5M2 = 11 // FP8: 1 sign bit, 5 exp bits, 2 mantissa bits
};

typedef enum D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE {
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT16 = 0,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT32 = 1,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_UINT8 = 2,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_UINT16 = 3,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_UINT32 = 4,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_SINT8 = 5,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_SINT16 = 6,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_SINT32 = 7,
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT_E4M3 = 8, // FP8: 1 sign bit, 4 exp bits, 3 mantissa bits
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT_E5M2 = 9 // FP8: 1 sign bit, 5 exp bits, 2 mantissa bits
};

typedef enum D3D12_COOPERATIVE_VECTOR_TIER
Expand All @@ -558,12 +609,12 @@ typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONSNN // NN tbd when implemented
// Used for VectorMatrixMulAdd intinsic
typedef struct D3D12_COOPERATIVE_VECTOR_PROPERTIES_INFERENCE
{
D3D12_COOPERATIVE_VECTOR_DATATYPE InputType;
D3D12_COOPERATIVE_VECTOR_DATATYPE InputInterpretation;
D3D12_COOPERATIVE_VECTOR_DATATYPE MatrixInterpretation;
D3D12_COOPERATIVE_VECTOR_DATATYPE BiasInterpretation;
D3D12_COOPERATIVE_VECTOR_DATATYPE OutputType;
BOOL TransposeSupported;
D3D12_COOPERATIVE_VECTOR_DATATYPE InputType;
D3D12_COOPERATIVE_VECTOR_REGISTER_DATATYPE InputInterpretation;
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE MatrixInterpretation;
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE BiasInterpretation;
D3D12_COOPERATIVE_VECTOR_DATATYPE OutputType;
BOOL TransposeSupported;
};

// Used for OuterProductAccumulate and ReduceSumAccumulate intrinsics
Expand Down Expand Up @@ -702,16 +753,16 @@ API.

// Descriptor to query the destination buffer size
typedef struct D3D12_COOPERATIVE_VECTOR_MATRIX_CONVERSION_DEST_INFO {
UINT DestSize; // !< [out]Destination buffer size in bytes
// required for conversion
D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT DestLayout; // !< [in] Is the layout the matrix is converted to
UINT DestStride; // !< [in] Is the number of bytes between a consecutive
// row or column (depending on DestLayout) of the
// destination matrix if it is row-major or
// column-major.
UINT NumRows; // !< [in] Is the number of rows in the matrix.
UINT NumColumns; // !< [in] Is the number of columns in the matrix.
D3D12_COOPERATIVE_VECTOR_DATATYPE DestDataType; // !< [in] the type of a destination matrix element.
UINT DestSize; // !< [out]Destination buffer size in bytes
// required for conversion
D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT DestLayout; // !< [in] Is the layout the matrix is converted to
UINT DestStride; // !< [in] Is the number of bytes between a consecutive
// row or column (depending on DestLayout) of the
// destination matrix if it is row-major or
// column-major.
UINT NumRows; // !< [in] Is the number of rows in the matrix.
UINT NumColumns; // !< [in] Is the number of columns in the matrix.
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE DestDataType; // !< [in] the type of a destination matrix element.
};

// An API to return the number of bytes required in the destination buffer to
Expand Down Expand Up @@ -749,7 +800,7 @@ typedef struct D3D12_COOPERATIVE_VECTOR_MATRIX_CONVERSION_DATA {
typedef struct D3D12_COOPERATIVE_VECTOR_MATRIX_CONVERSION_SRC_INFO {
UINT SrcSize; // !< [in] Is the length in bytes of
// srcData
D3D12_COOPERATIVE_VECTOR_DATATYPE SrcDataType; // !< [in] Is the type of a
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE SrcDataType; // !< [in] Is the type of a
// source matrix
// element
D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT SrcLayout; // !< [in] Is the layout of the
Expand Down Expand Up @@ -791,11 +842,11 @@ void ID3D12CommandList::CooperativeVectorConvertMatrix(D3D12_COOPERATIVE_VECTOR_
* If DestLayout is row-major or column-major, then DestStride should be greater than the length of a row/column, and a
multiple of the element size.
* If SrcComponentType is not a supported MatrixInterpretation value as reported by CheckFeatureSupport() then
SrcComponentType should be `D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT32`.
SrcComponentType should be `D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT32`.
* If DestComponentType is not a supported MatrixInterpretation value as reported by CheckFeatureSupport() then
DestComponentType should be `D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT32`.
* If SrcComponentType and DestComponentType are not equal, then one should be `D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT32` or `D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT16` and the other should be a lower-precision floating-point type.
* If DestComponentType is `D3D12_COOPERATIVE_VECTOR_DATATYPE_E4M3` or `D3D12_COOPERATIVE_VECTOR_DATATYPE_E5M2`, then DestLayout should be `D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL` or `D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL`.
DestComponentType should be `D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT32`.
* If SrcComponentType and DestComponentType are not equal, then one should be `D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT32` or `D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT16` and the other should be a lower-precision floating-point type.
* If DestComponentType is `D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_E4M3` or `D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_E5M2`, then DestLayout should be `D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL` or `D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL`.
*Usage Example:*
Expand All @@ -815,14 +866,14 @@ D3D12_COOPERATIVE_VECTOR_MATRIX_CONVERSION_INFO infoDesc =
// converted
numColumns, // number of columns in weight matrix to
// be converted
D3D12_COOPERATIVE_VECTOR_DATATYPE_E4M3 // convert to FP8 datatype
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_E4M3 // convert to FP8 datatype
},
//SrcInfo
{
srcSize, // number of bytes of matrix in source
// layout and datatype
D3D12_COOPERATIVE_VECTOR_DATATYPE_FLOAT32, // convert from float
D3D12_COOPERATIVE_VECTOR_MEMORY_DATATYPE_FLOAT32, // convert from float
D3D12_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR, // convert from row major layout
(numColumns * sizeof(float)) // row major stride without padding
},
Expand Down

0 comments on commit 1cf7cfc

Please sign in to comment.