Skip to content

Commit

Permalink
revise overload resolution for splats/truncations
Browse files Browse the repository at this point in the history
Allow truncations when matching arguments for intrinsic overloads. This eliminates the need for explicit scalar extractions from vectors for arguments that are scalar by nature. This encompasses any vectors passed for scalars, allowing the truncation, but emitting a warning the same as is done for other assignments of vectors to scalars.

This maintains splats as the preferred transformations and promotes perfect matches to be preferred over that. This has the effect of removing the need to carefully order intrinsics to ensure that the right variant gets matched first before another one incorrectly takes its place with a faulty cast.

Allowing truncations causes a problems with a small subset of intrinsics that have explicit overloads for various matrix,vector, scalar combinations. Namely the mul overloads. These could be simplified to accept a new range of template types except the dimensions need to be matched in unconventional ways.

For these, the notion of uncastable or "ONLY" variants of the template/layout types are introduced. These are indicated with a trailing "!" after the parameter typename in gen_intrin_main, which directs them to an array that contains a NOCAST enum that, when encountered, will skip the attempts to splat or truncate.

Fixes microsoft#7079
  • Loading branch information
pow2clk committed Feb 3, 2025
1 parent c65a179 commit 819545b
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 93 deletions.
9 changes: 6 additions & 3 deletions include/dxc/dxcapi.internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ enum LEGAL_INTRINSIC_TEMPLATES {
LITEMPLATE_MATRIX = 3, // Matrix types (eg. float3x3).
LITEMPLATE_ANY =
4, // Any one of scalar, vector or matrix types (but not object).
LITEMPLATE_OBJECT = 5, // Object types.
LITEMPLATE_ARRAY = 6, // Scalar array.
LITEMPLATE_OBJECT = 5, // Object types.
LITEMPLATE_ARRAY = 6, // Scalar array.
LITEMPLATE_SCALAR_ONLY = 7, // Uncastable scalar types.
LITEMPLATE_VECTOR_ONLY = 8, // Uncastable vector types (eg. float3).
LITEMPLATE_MATRIX_ONLY = 9, // Uncastable matrix types (eg. float3x3).

LITEMPLATE_COUNT = 7
LITEMPLATE_COUNT = 10
};

// INTRIN_COMPTYPE_FROM_TYPE_ELT0 is for object method intrinsics to indicate
Expand Down
35 changes: 5 additions & 30 deletions lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4196,12 +4196,10 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
loadArgs.emplace_back(opArg); // opcode
loadArgs.emplace_back(helper.handle); // resource handle

// offsets
if (opcode == OP::OpCode::TextureLoad) {
// set mip level
loadArgs.emplace_back(helper.mipLevel);
}

if (opcode == OP::OpCode::TextureLoad) {
// texture coord
unsigned coordSize = DxilResource::GetNumCoords(RK);
bool isVectorAddr = helper.addr->getType()->isVectorTy();
Expand All @@ -4213,22 +4211,6 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
} else
loadArgs.emplace_back(undefI);
}
} else {
if (helper.addr->getType()->isVectorTy()) {
Value *scalarOffset =
Builder.CreateExtractElement(helper.addr, (uint64_t)0);

// TODO: calculate the real address based on opcode

loadArgs.emplace_back(scalarOffset); // offset
} else {
// TODO: calculate the real address based on opcode

loadArgs.emplace_back(helper.addr); // offset
}
}
// offset 0
if (opcode == OP::OpCode::TextureLoad) {
if (helper.offset && !isa<llvm::UndefValue>(helper.offset)) {
unsigned offsetSize = DxilResource::GetNumOffsets(RK);
for (unsigned i = 0; i < 3; i++) {
Expand All @@ -4242,11 +4224,9 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
loadArgs.emplace_back(undefI);
loadArgs.emplace_back(undefI);
}
}

// Offset 1
if (RK == DxilResource::Kind::TypedBuffer) {
loadArgs.emplace_back(undefI);
} else {
loadArgs.emplace_back(helper.addr); // c0
loadArgs.emplace_back(undefI); // c1
}

Value *ResRet = Builder.CreateCall(F, loadArgs, OP->GetOpCodeName(opcode));
Expand Down Expand Up @@ -4420,12 +4400,7 @@ void TranslateStore(DxilResource::Kind RK, Value *handle, Value *val,
if (RK == DxilResource::Kind::RawBuffer ||
RK == DxilResource::Kind::TypedBuffer) {
// Offset 0
if (offset->getType()->isVectorTy()) {
Value *scalarOffset = Builder.CreateExtractElement(offset, (uint64_t)0);
storeArgs.emplace_back(scalarOffset); // offset
} else {
storeArgs.emplace_back(offset); // offset
}
storeArgs.emplace_back(offset); // offset

// Store offset0 for later use
offset0Idx = storeArgs.size() - 1;
Expand Down
61 changes: 50 additions & 11 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ enum ArTypeObjectKind {
// indexer object used to implement .mips[1].
AR_TOBJ_STRING, // Represents a string
AR_TOBJ_DEPENDENT, // Dependent type for template.
AR_TOBJ_NOCAST, // Parameter should not have layout casts (splat,trunc)
};

enum TYPE_CONVERSION_FLAGS {
Expand Down Expand Up @@ -989,9 +990,18 @@ static const ArTypeObjectKind g_NullTT[] = {AR_TOBJ_VOID, AR_TOBJ_UNKNOWN};

static const ArTypeObjectKind g_ArrayTT[] = {AR_TOBJ_ARRAY, AR_TOBJ_UNKNOWN};

static const ArTypeObjectKind g_ScalarOnlyTT[] = {
AR_TOBJ_SCALAR, AR_TOBJ_NOCAST, AR_TOBJ_UNKNOWN};

static const ArTypeObjectKind g_VectorOnlyTT[] = {
AR_TOBJ_VECTOR, AR_TOBJ_NOCAST, AR_TOBJ_UNKNOWN};

static const ArTypeObjectKind g_MatrixOnlyTT[] = {
AR_TOBJ_MATRIX, AR_TOBJ_NOCAST, AR_TOBJ_UNKNOWN};

const ArTypeObjectKind *g_LegalIntrinsicTemplates[] = {
g_NullTT, g_ScalarTT, g_VectorTT, g_MatrixTT,
g_AnyTT, g_ObjectTT, g_ArrayTT,
g_NullTT, g_ScalarTT, g_VectorTT, g_MatrixTT, g_AnyTT,
g_ObjectTT, g_ArrayTT, g_ScalarOnlyTT, g_VectorOnlyTT, g_MatrixOnlyTT,
};
C_ASSERT(ARRAYSIZE(g_LegalIntrinsicTemplates) == LITEMPLATE_COUNT);

Expand Down Expand Up @@ -6113,7 +6123,7 @@ bool HLSLExternalSource::MatchArguments(
ArBasicKind
ComponentType[MaxIntrinsicArgs]; // Component type for each argument,
// AR_BASIC_UNKNOWN if unspecified.
UINT uSpecialSize[IA_SPECIAL_SLOTS]; // row/col matching types, UNUSED_INDEX32
UINT uSpecialSize[IA_SPECIAL_SLOTS]; // row/col matching types, UnusedSize
// if unspecified.
badArgIdx = MaxIntrinsicArgs;

Expand Down Expand Up @@ -6249,12 +6259,15 @@ bool HLSLExternalSource::MatchArguments(
"otherwise intrinsic table was modified and g_MaxIntrinsicParamCount "
"was not updated (or uTemplateId is out of bounds)");

// Compare template
// Compare template to any type matching params requirements.
if ((AR_TOBJ_UNKNOWN == Template[pIntrinsicArg->uTemplateId]) ||
((AR_TOBJ_SCALAR == Template[pIntrinsicArg->uTemplateId]) &&
(AR_TOBJ_VECTOR == TypeInfoShapeKind ||
AR_TOBJ_MATRIX == TypeInfoShapeKind))) {
// Unrestricted or truncation of tuples to scalars are allowed
// Previous params gave no type restrictions
// or truncation of tuples to scalars are allowed
// Later steps harmonize common typed params and will always convert the
// earlier arg into a splat instead.
Template[pIntrinsicArg->uTemplateId] = TypeInfoShapeKind;
} else if (AR_TOBJ_SCALAR == TypeInfoShapeKind) {
if (AR_TOBJ_SCALAR != Template[pIntrinsicArg->uTemplateId] &&
Expand Down Expand Up @@ -6292,6 +6305,11 @@ bool HLSLExternalSource::MatchArguments(
}
}

// If the intrinsic parameter has variable rows or columns but must match
// other argument dimensions, it will be specified in pIntrinsicArg with
// a special value indicating that the dimension depends on passed values.
// uSpecialSize stores the dimensions of the actual passed type.

// Rows
if (AR_TOBJ_SCALAR != TypeInfoShapeKind) {
if (pIntrinsicArg->uRows >= IA_SPECIAL_BASE) {
Expand Down Expand Up @@ -6398,18 +6416,39 @@ bool HLSLExternalSource::MatchArguments(
const ArTypeObjectKind *pTT =
g_LegalIntrinsicTemplates[pArgument->uLegalTemplates];
if (AR_TOBJ_UNKNOWN != Template[i]) {
if ((AR_TOBJ_SCALAR == Template[i]) &&
(AR_TOBJ_VECTOR == *pTT || AR_TOBJ_MATRIX == *pTT)) {
Template[i] = *pTT;
} else {
// See if a perfect match overload is available
while (AR_TOBJ_UNKNOWN != *pTT && AR_TOBJ_NOCAST != *pTT) {
if (Template[i] == *pTT)
break;
pTT++;
}

if (AR_TOBJ_UNKNOWN == *pTT) {
// Perfect match failed and casts are allowed.
// Try splats and truncations to get a match.
pTT = g_LegalIntrinsicTemplates[pArgument->uLegalTemplates];
while (AR_TOBJ_UNKNOWN != *pTT) {
if (Template[i] == *pTT)
if (AR_TOBJ_SCALAR == Template[i] &&
(AR_TOBJ_VECTOR == *pTT || AR_TOBJ_MATRIX == *pTT)) {
// If a scalar was passed in and the expected value was
// matrix/vector convert to the template type for a splat.
// Only applicable to VectorTT and MatrixTT,
// since the vec/mtx has to be first in the list.
Template[i] = *pTT;
break;
} else if (AR_TOBJ_VECTOR == Template[i] && AR_TOBJ_SCALAR == *pTT) {
// If a vector was passed in and the expected value was scalar
// convert to the template type for a truncation.
// Only applicable to ScalarTT,
// since the scalar has to be first in the list.
Template[i] = AR_TOBJ_SCALAR;
break;
}
pTT++;
}
}

if (AR_TOBJ_UNKNOWN == *pTT) {
if (AR_TOBJ_UNKNOWN == *pTT || AR_TOBJ_NOCAST == *pTT) {
Template[i] = g_LegalIntrinsicTemplates[pArgument->uLegalTemplates][0];
badArgIdx = std::min(badArgIdx, i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
// RUN: %dxc -E main -T vs_6_2 -DTY1=float3 -DTY2=bool -enable-16bit-types %s | FileCheck %s
// RUN: %dxc -E main -T vs_6_2 -DTY1=float4 -DTY2=uint16_t -enable-16bit-types %s | FileCheck %s

// RUN: %dxc -E main -T vs_6_2 -DTY1=float4 -DTY2=uint16_t4 -enable-16bit-types %s | FileCheck %s -check-prefix=CHECK
// RUN: %dxc -E main -T vs_6_2 -DTY1=float4 -DTY2=float16_t2 -enable-16bit-types %s | FileCheck %s -check-prefix=CHECK

// RUN: %dxc -E main -T vs_6_2 -DTY1=float4 -DTY2=uint16_t4x4 -enable-16bit-types %s | FileCheck %s -check-prefix=CHECK_ERROR
// RUN: %dxc -E main -T vs_6_2 -DTY1=float4 -DTY2=uint16_t4 -enable-16bit-types %s | FileCheck %s -check-prefix=CHECK_ERROR
// RUN: %dxc -E main -T vs_6_2 -DTY1=float4 -DTY2=float16_t2 -enable-16bit-types %s | FileCheck %s -check-prefix=CHECK_ERROR
// RUN: %dxc -E main -T vs_6_2 -DTY1=uint16_t4x4 -DTY2=float16_t -enable-16bit-types %s | FileCheck %s -check-prefix=CHECK_ERROR

// CHECK: define void @main()
// CHECK_ERROR: note: candidate function not viable: no known conversion from

TY1 main (TY1 a: IN0, TY1 b : IN1, TY2 c : IN2) : OUT {
return refract(a, b, c);
}
}
6 changes: 3 additions & 3 deletions tools/clang/test/SemaHLSL/intrinsic-examples.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ float4 RWByteAddressBufferMain(uint2 a : A, uint2 b : B) : SV_Target
uint status;
// TODO - fix the following error - the subscript exist, but the indexer type is incorrect - message is misleading
r += uav1[b]; // expected-error {{type 'RWByteAddressBuffer' does not provide a subscript operator}} fxc-error {{X3121: array, matrix, vector, or indexable object type expected in index expression}}
r += uav1.Load(a); // expected-error {{no matching member function for call to 'Load'}} expected-note {{candidate function template not viable: requires 2 arguments, but 1 was provided}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint)}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint, out uint status)}} fxc-error {{X3013: 'Load': no matching 1 parameter intrinsic method}} fxc-error {{X3013: Possible intrinsic methods are:}}
uav1.Load(a, status); // expected-error {{no matching member function for call to 'Load'}} expected-note {{candidate function template not viable: requires single argument 'byteOffset', but 2 arguments were provided}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint)}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint, out uint status)}} fxc-error {{X3013: 'Load': no matching 2 parameter intrinsic method}} fxc-error {{X3013: Possible intrinsic methods are:}}
r += uav1.Load(a); // expected-warning {{implicit truncation of vector type}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint)}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint, out uint status)}} fxc-error {{X3013: 'Load': no matching 1 parameter intrinsic method}} fxc-error {{X3013: Possible intrinsic methods are:}}
uav1.Load(a, status); // expected-warning {{implicit truncation of vector type}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint)}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint, out uint status)}} fxc-error {{X3013: 'Load': no matching 2 parameter intrinsic method}} fxc-error {{X3013: Possible intrinsic methods are:}}
r += status;
uav1.Load(a, status); // expected-error {{no matching member function for call to 'Load'}} expected-note {{requires single argument 'byteOffset', but 2 arguments were provided}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint)}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint, out uint status)}} fxc-error {{X3013: 'Load': no matching 2 parameter intrinsic method}} fxc-error {{X3013: Possible intrinsic methods are:}}
uav1.Load(a, status); // expected-warning {{implicit truncation of vector type}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint)}} fxc-error {{X3013: RWByteAddressBuffer<uint>.Load(uint, out uint status)}} fxc-error {{X3013: 'Load': no matching 2 parameter intrinsic method}} fxc-error {{X3013: Possible intrinsic methods are:}}
r += status;
uav1[b] = r; // expected-error {{type 'RWByteAddressBuffer' does not provide a subscript operator}} fxc-error {{X3121: array, matrix, vector, or indexable object type expected in index expression}}
uav1.Load(a.x, status);
Expand Down
22 changes: 11 additions & 11 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ $type1 [[rn,unsigned_op=umax]] max(in numeric<> a, in $type1 b);
$type1 [[rn,unsigned_op=umin]] min(in numeric<> a, in $type1 b);
$type1 [[]] modf(in float_like<> x, out $type1 ip);
uint<4> [[rn]] msad4(in uint reference, in uint<2> source, in uint<4> accum);
numeric [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric b) : mul_ss;
numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<c2> b) : mul_sv;
numeric<r2, c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<r2, c2> b) : mul_sm;
numeric<c> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric b) : mul_vs;
numeric [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric<c> b) : mul_vv;
numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_vm;
numeric<r, c> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric b) : mul_ms;
numeric<r> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric<c> b) : mul_mv;
numeric<r, c2> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_mm;
numeric [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric! a, in $match<2, 0> numeric! b) : mul_ss;
numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric! a, in $match<2, 0> numeric<c2>! b) : mul_sv;
numeric<r2, c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric! a, in $match<2, 0> numeric<r2, c2>! b) : mul_sm;
numeric<c> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c>! a, in $match<2, 0> numeric! b) : mul_vs;
numeric [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c>! a, in $match<2, 0> numeric<c>! b) : mul_vv;
numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c>! a, in col_major $match<2, 0> numeric<c, c2>! b) : mul_vm;
numeric<r, c> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<r, c>! a, in $match<2, 0> numeric! b) : mul_ms;
numeric<r> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c>! a, in $match<2, 0> numeric<c>! b) : mul_mv;
numeric<r, c2> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c>! a, in col_major $match<2, 0> numeric<c, c2>! b) : mul_mm;
$type1 [[rn]] normalize(in float_like<c> x);
$type1 [[rn]] pow(in float_like<> x, in $type1 y);
void [[]] printf(in string Format, ...);
Expand Down Expand Up @@ -849,8 +849,8 @@ $match<0, -1> void<4> [[]] GatherCmpAlpha(in sampler_cmp s, in float<4> x, in fl
namespace BufferMethods {

void [[]] GetDimensions(out uint_only width) : bufinfo;
$classT [[ro]] Load(in int<1> x) : buffer_load;
$classT [[]] Load(in int<1> x, out uint_only status) : buffer_load_s;
$classT [[ro]] Load(in int x) : buffer_load;
$classT [[]] Load(in int x, out uint_only status) : buffer_load_s;

} namespace

Expand Down
50 changes: 18 additions & 32 deletions utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8330,13 +8330,13 @@ def load_intrinsics(self, intrinsic_defs):
params_split_re = re.compile(r"\s*,\s*")
ws_split_re = re.compile(r"\s+")
typeref_re = re.compile(r"\$type(\d+)$")
type_matrix_re = re.compile(r"(\S+)<(\S+)@(\S+)>$")
type_vector_re = re.compile(r"(\S+)<(\S+)>$")
type_matrix_re = re.compile(r"(\S+)<(\S+)@(\S+)>(\!?)$")
type_vector_re = re.compile(r"(\S+)<(\S+)>(\!?)$")
type_any_re = re.compile(r"(\S+)<>$")
type_array_re = re.compile(r"(\S+)\[\]$")
type_object_re = re.compile(
r"""(
sampler\w* | string |
sampler\w* | any_sampler\w* | string |
(?:RW)?(?:Texture\w*|ByteAddressBuffer) |
acceleration_struct | ray_desc |
Node\w* | RWNode\w* | EmptyNode\w* |
Expand Down Expand Up @@ -8381,6 +8381,7 @@ def process_arg(desc, idx, done_args, intrinsic_name):
component_list = "LICOMPTYPE_ANY"
rows = "1"
cols = "1"
only = ""
if type_name == "$classT":
assert idx == 0, "'$classT' can only be used as the return type"
# template_id may be -1 in other places other than return type, for example in Stream.Append().
Expand Down Expand Up @@ -8415,13 +8416,19 @@ def process_arg(desc, idx, done_args, intrinsic_name):
base_type = type_name

def do_matrix(m):
base_type, rows, cols = m.groups()
template_list = "LITEMPLATE_MATRIX"
base_type, rows, cols, only = m.groups()
if only == "!":
template_list = "LITEMPLATE_MATRIX_ONLY"
else:
template_list = "LITEMPLATE_MATRIX"
return base_type, rows, cols, template_list

def do_vector(m):
base_type, cols = m.groups()
template_list = "LITEMPLATE_VECTOR"
base_type, cols, only = m.groups()
if only == "!":
template_list = "LITEMPLATE_VECTOR_ONLY"
else:
template_list = "LITEMPLATE_VECTOR"
return base_type, rows, cols, template_list

def do_any(m):
Expand Down Expand Up @@ -8454,32 +8461,11 @@ def do_object(m):
base_type, rows, cols, template_list = do(m)
break
else:
type_vector_match = type_vector_re.match(type_name)
if type_vector_match:
base_type = type_vector_match.group(1)
cols = type_vector_match.group(2)
template_list = "LITEMPLATE_VECTOR"
if type_name[-1] == "!":
template_list = "LITEMPLATE_SCALAR_ONLY"
base_type = type_name[:-1]
else:
type_any_match = type_any_re.match(type_name)
if type_any_match:
base_type = type_any_match.group(1)
rows = "r"
cols = "c"
template_list = "LITEMPLATE_ANY"
else:
base_type = type_name
if (
base_type.startswith("sampler")
or base_type.startswith("string")
or base_type.startswith("Texture")
or base_type.startswith("wave")
or base_type.startswith("acceleration_struct")
or base_type.startswith("ray_desc")
or base_type.startswith("any_sampler")
):
template_list = "LITEMPLATE_OBJECT"
else:
template_list = "LITEMPLATE_SCALAR"
template_list = "LITEMPLATE_SCALAR"
assert base_type in self.base_types, "Unknown base type '%s' in '%s'" % (
base_type,
desc,
Expand Down

0 comments on commit 819545b

Please sign in to comment.