Skip to content

Commit

Permalink
Add support for Sve.DuplicateSelectedScalarToVector() (#103228)
Browse files Browse the repository at this point in the history
* Add support for Sve.DuplicateScalarToVector()

* Fix build errors from API mismatch

* Use ConstantExpected instead of ConstantExpectedAttribute

* Fix API incompatibility in Sve.PlatformNotSupported.cs

* Fix emitting larger than 9-bit immediate with Sve LDR/STR instructions

* Add HW_Category_SIMDByIndexedElement flag

* Fix issue for getting incorrect register type for op2
  • Loading branch information
SwapnilGaikwad authored Jun 27, 2024
1 parent 742bae8 commit 5b962c3
Show file tree
Hide file tree
Showing 16 changed files with 644 additions and 22 deletions.
7 changes: 7 additions & 0 deletions src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,13 @@ class CodeGen final : public CodeGenInterface

void instGen_Set_Reg_To_Zero(emitAttr size, regNumber reg, insFlags flags = INS_FLAGS_DONT_CARE);

void instGen_Set_Reg_To_Base_Plus_Imm(emitAttr size,
regNumber dstReg,
regNumber baseReg,
ssize_t imm,
insFlags flags = INS_FLAGS_DONT_CARE DEBUGARG(size_t targetHandle = 0)
DEBUGARG(GenTreeFlags gtFlags = GTF_EMPTY));

void instGen_Set_Reg_To_Imm(emitAttr size,
regNumber reg,
ssize_t imm,
Expand Down
13 changes: 12 additions & 1 deletion src/coreclr/jit/codegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,8 +2215,19 @@ void CodeGen::genEHCatchRet(BasicBlock* block)
GetEmitter()->emitIns_R_L(INS_adr, EA_PTRSIZE, block->GetTarget(), REG_INTRET);
}

// move an immediate value into an integer register
// move an immediate value + base address into an integer register
void CodeGen::instGen_Set_Reg_To_Base_Plus_Imm(emitAttr size,
regNumber dstReg,
regNumber baseReg,
ssize_t imm,
insFlags flags DEBUGARG(size_t targetHandle)
DEBUGARG(GenTreeFlags gtFlags))
{
instGen_Set_Reg_To_Imm(size, dstReg, imm);
GetEmitter()->emitIns_R_R_R(INS_add, size, dstReg, dstReg, baseReg);
}

// move an immediate value into an integer register
void CodeGen::instGen_Set_Reg_To_Imm(emitAttr size,
regNumber reg,
ssize_t imm,
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4679,6 +4679,7 @@ class Compiler
unsigned immNumber,
var_types simdBaseType,
CorInfoType simdBaseJitType,
CORINFO_CLASS_HANDLE op1ClsHnd,
CORINFO_CLASS_HANDLE op2ClsHnd,
CORINFO_CLASS_HANDLE op3ClsHnd,
unsigned* immSimdSize,
Expand Down
13 changes: 9 additions & 4 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7899,10 +7899,13 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
{
useRegForImm = true;
regNumber rsvdReg = codeGen->rsGetRsvdReg();
codeGen->instGen_Set_Reg_To_Imm(EA_PTRSIZE, rsvdReg, imm);
// For larger imm values (> 9 bits), calculate base + imm in a reserved register first.
codeGen->instGen_Set_Reg_To_Base_Plus_Imm(EA_PTRSIZE, rsvdReg, reg2, imm);
reg2 = rsvdReg;
imm = 0;
}
break;
}
break;

default:
NYI("emitIns_R_S"); // FP locals?
Expand Down Expand Up @@ -8150,9 +8153,11 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
{
useRegForImm = true;
regNumber rsvdReg = codeGen->rsGetRsvdReg();
codeGen->instGen_Set_Reg_To_Imm(EA_PTRSIZE, rsvdReg, imm);
// For larger imm values (> 9 bits), calculate base + imm in a reserved register first.
codeGen->instGen_Set_Reg_To_Base_Plus_Imm(EA_PTRSIZE, rsvdReg, reg2, imm);
reg2 = rsvdReg;
imm = 0;
}
break;
}
break;

Expand Down
8 changes: 4 additions & 4 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,8 +1526,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
{
unsigned immSimdSize = simdSize;
var_types immSimdBaseType = simdBaseType;
getHWIntrinsicImmTypes(intrinsic, sig, 2, simdBaseType, simdBaseJitType, sigReader.op2ClsHnd,
sigReader.op3ClsHnd, &immSimdSize, &immSimdBaseType);
getHWIntrinsicImmTypes(intrinsic, sig, 2, simdBaseType, simdBaseJitType, sigReader.op1ClsHnd,
sigReader.op2ClsHnd, sigReader.op3ClsHnd, &immSimdSize, &immSimdBaseType);
HWIntrinsicInfo::lookupImmBounds(intrinsic, immSimdSize, immSimdBaseType, 2, &immLowerBound, &immUpperBound);

if (!CheckHWIntrinsicImmRange(intrinsic, simdBaseJitType, immOp2, mustExpand, immLowerBound, immUpperBound,
Expand Down Expand Up @@ -1559,8 +1559,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
#ifdef TARGET_ARM64
unsigned immSimdSize = simdSize;
var_types immSimdBaseType = simdBaseType;
getHWIntrinsicImmTypes(intrinsic, sig, 1, simdBaseType, simdBaseJitType, sigReader.op2ClsHnd,
sigReader.op3ClsHnd, &immSimdSize, &immSimdBaseType);
getHWIntrinsicImmTypes(intrinsic, sig, 1, simdBaseType, simdBaseJitType, sigReader.op1ClsHnd,
sigReader.op2ClsHnd, sigReader.op3ClsHnd, &immSimdSize, &immSimdBaseType);
HWIntrinsicInfo::lookupImmBounds(intrinsic, immSimdSize, immSimdBaseType, 1, &immLowerBound, &immUpperBound);
#else
immUpperBound = HWIntrinsicInfo::lookupImmUpperBound(intrinsic);
Expand Down
19 changes: 17 additions & 2 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ void Compiler::getHWIntrinsicImmOps(NamedIntrinsic intrinsic,
// immNumber -- Which immediate to use (1 for most intrinsics)
// simdBaseType -- base type of the intrinsic
// simdType -- vector size of the intrinsic
// op1ClsHnd -- cls handler for op1
// op2ClsHnd -- cls handler for op2
// op2ClsHnd -- cls handler for op3
// immSimdSize [IN/OUT] -- Size of the immediate to override
Expand All @@ -302,6 +303,7 @@ void Compiler::getHWIntrinsicImmTypes(NamedIntrinsic intrinsic,
unsigned immNumber,
var_types simdBaseType,
CorInfoType simdBaseJitType,
CORINFO_CLASS_HANDLE op1ClsHnd,
CORINFO_CLASS_HANDLE op2ClsHnd,
CORINFO_CLASS_HANDLE op3ClsHnd,
unsigned* immSimdSize,
Expand All @@ -317,7 +319,12 @@ void Compiler::getHWIntrinsicImmTypes(NamedIntrinsic intrinsic,
var_types indexedElementBaseType;
*immSimdSize = 0;

if (sig->numArgs == 3)
if (sig->numArgs == 2)
{
indexedElementBaseJitType = getBaseJitTypeAndSizeOfSIMDType(op1ClsHnd, immSimdSize);
indexedElementBaseType = JitType2PreciseVarType(indexedElementBaseJitType);
}
else if (sig->numArgs == 3)
{
indexedElementBaseJitType = getBaseJitTypeAndSizeOfSIMDType(op2ClsHnd, immSimdSize);
indexedElementBaseType = JitType2PreciseVarType(indexedElementBaseJitType);
Expand Down Expand Up @@ -402,7 +409,15 @@ void HWIntrinsicInfo::lookupImmBounds(
}
else if (category == HW_Category_SIMDByIndexedElement)
{
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
if (intrinsic == NI_Sve_DuplicateSelectedScalarToVector)
{
// For SVE_DUP, the upper bound on index does not depend on the vector length.
immUpperBound = (512 / (BITS_PER_BYTE * genTypeSize(baseType))) - 1;
}
else
{
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
}
}
else
{
Expand Down
29 changes: 24 additions & 5 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTre
const HWIntrinsic intrinInfo(intrin);
var_types indexedElementOpType;

if (intrinInfo.numOperands == 3)
if (intrinInfo.numOperands == 2)
{
indexedElementOpType = intrinInfo.op1->TypeGet();
}
else if (intrinInfo.numOperands == 3)
{
indexedElementOpType = intrinInfo.op2->TypeGet();
}
Expand Down Expand Up @@ -357,13 +361,28 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
}
else
{
HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
if (intrin.numOperands == 2)
{
HWIntrinsicImmOpHelper helper(this, intrin.op2, node);

for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const int elementIndex = helper.ImmValue();

GetEmitter()->emitIns_R_R_I(ins, emitSize, targetReg, op1Reg, elementIndex, opt);
}
}
else
{
const int elementIndex = helper.ImmValue();
assert(intrin.numOperands == 3);
HWIntrinsicImmOpHelper helper(this, intrin.op3, node);

GetEmitter()->emitIns_R_R_R_I(ins, emitSize, targetReg, op1Reg, op2Reg, elementIndex, opt);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const int elementIndex = helper.ImmValue();

GetEmitter()->emitIns_R_R_R_I(ins, emitSize, targetReg, op1Reg, op2Reg, elementIndex, opt);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask8Bit,
HARDWARE_INTRINSIC(Sve, Divide, -1, 2, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdiv, INS_sve_udiv, INS_sve_sdiv, INS_sve_udiv, INS_sve_fdiv, INS_sve_fdiv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, DotProduct, -1, 3, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdot, INS_sve_udot, INS_sve_sdot, INS_sve_udot, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, DotProductBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdot, INS_sve_udot, INS_sve_sdot, INS_sve_udot, INS_invalid, INS_invalid}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, DuplicateSelectedScalarToVector, -1, 2, true, {INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup, INS_sve_dup}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand)
HARDWARE_INTRINSIC(Sve, ExtractVector, -1, 3, true, {INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext, INS_sve_ext}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAdd, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAddBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3197,6 +3197,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x3:
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x4:
case NI_AdvSimd_Arm64_DuplicateSelectedScalarToVector128:
case NI_Sve_DuplicateSelectedScalarToVector:
assert(hasImmediateOperand);
assert(varTypeIsIntegral(intrin.op2));
if (intrin.op2->IsCnsIntOrI())
Expand Down
15 changes: 13 additions & 2 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,11 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
var_types indexedElementOpType;

if (intrin.numOperands == 3)
if (intrin.numOperands == 2)
{
indexedElementOpType = intrin.op1->TypeGet();
}
else if (intrin.numOperands == 3)
{
indexedElementOpType = intrin.op2->TypeGet();
}
Expand Down Expand Up @@ -1678,7 +1682,14 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
assert(!isRMW);

srcCount += BuildOperandUses(intrin.op2, RBM_ASIMD_INDEXED_H_ELEMENT_ALLOWED_REGS.GetFloatRegSet());
if (intrin.id == NI_Sve_DuplicateSelectedScalarToVector)
{
srcCount += BuildOperandUses(intrin.op2);
}
else
{
srcCount += BuildOperandUses(intrin.op2, RBM_ASIMD_INDEXED_H_ELEMENT_ALLOWED_REGS.GetFloatRegSet());
}

if (intrin.op3 != nullptr)
{
Expand Down
8 changes: 4 additions & 4 deletions src/coreclr/jit/rationalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,8 @@ void Rationalizer::RewriteHWIntrinsicAsUserCall(GenTree** use, ArrayStack<GenTre

if (immOp2 != nullptr)
{
comp->getHWIntrinsicImmTypes(intrinsicId, &sigInfo, 2, simdBaseType, simdBaseJitType, op2ClsHnd,
op3ClsHnd, &immSimdSize, &immSimdBaseType);
comp->getHWIntrinsicImmTypes(intrinsicId, &sigInfo, 2, simdBaseType, simdBaseJitType, op1ClsHnd,
op2ClsHnd, op3ClsHnd, &immSimdSize, &immSimdBaseType);
HWIntrinsicInfo::lookupImmBounds(intrinsicId, immSimdSize, immSimdBaseType, 2, &immLowerBound,
&immUpperBound);

Expand All @@ -473,8 +473,8 @@ void Rationalizer::RewriteHWIntrinsicAsUserCall(GenTree** use, ArrayStack<GenTre
immSimdBaseType = simdBaseType;
}

comp->getHWIntrinsicImmTypes(intrinsicId, &sigInfo, 1, simdBaseType, simdBaseJitType, op2ClsHnd, op3ClsHnd,
&immSimdSize, &immSimdBaseType);
comp->getHWIntrinsicImmTypes(intrinsicId, &sigInfo, 1, simdBaseType, simdBaseJitType, op1ClsHnd, op2ClsHnd,
op3ClsHnd, &immSimdSize, &immSimdBaseType);
HWIntrinsicInfo::lookupImmBounds(intrinsicId, immSimdSize, immSimdBaseType, 1, &immLowerBound,
&immUpperBound);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,69 @@ internal Arm64() { }
public static unsafe Vector<ulong> DotProductBySelectedScalar(Vector<ulong> addend, Vector<ushort> left, Vector<ushort> right, [ConstantExpected] byte rightIndex) { throw new PlatformNotSupportedException(); }


/// Broadcast a scalar value

/// <summary>
/// svuint8_t svdup_lane[_u8](svuint8_t data, uint8_t index)
/// DUP Zresult.B, Zdata.B[index]
/// </summary>
public static unsafe Vector<byte> DuplicateSelectedScalarToVector(Vector<byte> data, [ConstantExpected(Min = 0, Max = (byte)(63))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat64_t svdup_lane[_f64](svfloat64_t data, uint64_t index)
/// DUP Zresult.D, Zdata.D[index]
/// </summary>
public static unsafe Vector<double> DuplicateSelectedScalarToVector(Vector<double> data, [ConstantExpected(Min = 0, Max = (byte)(7))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint16_t svdup_lane[_s16](svint16_t data, uint16_t index)
/// DUP Zresult.H, Zdata.H[index]
/// </summary>
public static unsafe Vector<short> DuplicateSelectedScalarToVector(Vector<short> data, [ConstantExpected(Min = 0, Max = (byte)(31))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint32_t svdup_lane[_s32](svint32_t data, uint32_t index)
/// DUP Zresult.S, Zdata.S[index]
/// </summary>
public static unsafe Vector<int> DuplicateSelectedScalarToVector(Vector<int> data, [ConstantExpected(Min = 0, Max = (byte)(15))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svdup_lane[_s64](svint64_t data, uint64_t index)
/// DUP Zresult.D, Zdata.D[index]
/// </summary>
public static unsafe Vector<long> DuplicateSelectedScalarToVector(Vector<long> data, [ConstantExpected(Min = 0, Max = (byte)(7))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint8_t svdup_lane[_s8](svint8_t data, uint8_t index)
/// DUP Zresult.B, Zdata.B[index]
/// </summary>
public static unsafe Vector<sbyte> DuplicateSelectedScalarToVector(Vector<sbyte> data, [ConstantExpected(Min = 0, Max = (byte)(63))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svfloat32_t svdup_lane[_f32](svfloat32_t data, uint32_t index)
/// DUP Zresult.S, Zdata.S[index]
/// </summary>
public static unsafe Vector<float> DuplicateSelectedScalarToVector(Vector<float> data, [ConstantExpected(Min = 0, Max = (byte)(15))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint16_t svdup_lane[_u16](svuint16_t data, uint16_t index)
/// DUP Zresult.H, Zdata.H[index]
/// </summary>
public static unsafe Vector<ushort> DuplicateSelectedScalarToVector(Vector<ushort> data, [ConstantExpected(Min = 0, Max = (byte)(31))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint32_t svdup_lane[_u32](svuint32_t data, uint32_t index)
/// DUP Zresult.S, Zdata.S[index]
/// </summary>
public static unsafe Vector<uint> DuplicateSelectedScalarToVector(Vector<uint> data, [ConstantExpected(Min = 0, Max = (byte)(15))] byte index) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svdup_lane[_u64](svuint64_t data, uint64_t index)
/// DUP Zresult.D, Zdata.D[index]
/// </summary>
public static unsafe Vector<ulong> DuplicateSelectedScalarToVector(Vector<ulong> data, [ConstantExpected(Min = 0, Max = (byte)(7))] byte index) { throw new PlatformNotSupportedException(); }


/// <summary>
/// svuint8_t svext[_u8](svuint8_t op1, svuint8_t op2, uint64_t imm3)
/// EXT Ztied1.B, Ztied1.B, Zop2.B, #imm3
Expand Down
Loading

0 comments on commit 5b962c3

Please sign in to comment.