Skip to content

Commit

Permalink
JIT ARM64-SVE: Add CreateWhileLessThan*
Browse files Browse the repository at this point in the history
  • Loading branch information
a74nh committed Apr 12, 2024
1 parent d1747a7 commit e2c5d99
Show file tree
Hide file tree
Showing 9 changed files with 767 additions and 12 deletions.
38 changes: 38 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2194,6 +2194,44 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_CreateWhileLessThanMask8Bit:
case NI_Sve_CreateWhileLessThanMask16Bit:
case NI_Sve_CreateWhileLessThanMask32Bit:
case NI_Sve_CreateWhileLessThanMask64Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask8Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask16Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask32Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask64Bit:
{
// Target instruction is dependent on whether the inputs are signed or unsigned.
// This information is lost when the type is converted from CorInfoType to var_type.
// Ensure this is marked using GTF_UNSIGNED.

CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
var_types argType = TYP_UNKNOWN;
CORINFO_CLASS_HANDLE argClass = NO_CLASS_HANDLE;
CorInfoType argCoreInfoType = strip(info.compCompHnd->getArgType(sig, arg2, &argClass));

assert(sig->numArgs == 2);
argType = JITtype2varType(argCoreInfoType);
op2 = getArgForHWIntrinsic(argType, argClass);
argType = JITtype2varType(strip(info.compCompHnd->getArgType(sig, arg1, &argClass)));
op1 = impPopStack().val;

retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize);

if (argCoreInfoType == CORINFO_TYPE_ULONG || argCoreInfoType == CORINFO_TYPE_UINT)
{
retNode->gtFlags |= GTF_UNSIGNED;
}
else
{
assert(argCoreInfoType == CORINFO_TYPE_LONG || argCoreInfoType == CORINFO_TYPE_INT);
}
}
break;

default:
{
return nullptr;
Expand Down
22 changes: 22 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,28 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
GetEmitter()->emitIns_R_PATTERN(ins, emitSize, targetReg, opt, SVE_PATTERN_ALL);
break;

case NI_Sve_CreateWhileLessThanMask8Bit:
case NI_Sve_CreateWhileLessThanMask16Bit:
case NI_Sve_CreateWhileLessThanMask32Bit:
case NI_Sve_CreateWhileLessThanMask64Bit:
// Emit size is the size of the scalar operands.
emitSize = emitActualTypeSize(intrin.op1->TypeGet());
// Instruction is dependent on whether the inputs are signed or unsigned.
ins = ((node->gtFlags & GTF_UNSIGNED) != 0) ? INS_sve_whilelo : INS_sve_whilelt;
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
break;

case NI_Sve_CreateWhileLessThanOrEqualMask8Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask16Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask32Bit:
case NI_Sve_CreateWhileLessThanOrEqualMask64Bit:
// Emit size is the size of the scalar operands.
emitSize = emitActualTypeSize(intrin.op1->TypeGet());
// Instruction is dependent on whether the inputs are signed or unsigned.
ins = ((node->gtFlags & GTF_UNSIGNED) != 0) ? INS_sve_whilels : INS_sve_whilele;
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
break;

default:
unreached();
}
Expand Down
8 changes: 8 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ HARDWARE_INTRINSIC(Sve, CreateTrueMaskSingle,
HARDWARE_INTRINSIC(Sve, CreateTrueMaskUInt16, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskUInt32, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskUInt64, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask16Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask32Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilelt, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask16Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask32Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)

HARDWARE_INTRINSIC(Sve, LoadVector, -1, 2, true, {INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1h, INS_sve_ld1h, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_LowMaskedOperation)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,189 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateTrueMaskUInt64([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask16Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b16[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b16[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b16[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b16[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanMask16Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask32Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b32[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b32[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b32[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b32[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanMask32Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask64Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b64[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b64[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b64[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b64[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanMask64Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanMask8Bit : While incrementing scalar is less than

/// <summary>
/// svbool_t svwhilelt_b8[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b8[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b8[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilelt_b8[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanMask8Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask16Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b16[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b16[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b16[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b16[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<ushort> CreateWhileLessThanOrEqualMask16Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask32Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b32[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b32[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b32[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b32[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<uint> CreateWhileLessThanOrEqualMask32Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask64Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b64[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b64[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b64[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b64[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<ulong> CreateWhileLessThanOrEqualMask64Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// CreateWhileLessThanOrEqualMask8Bit : While incrementing scalar is less than or equal to

/// <summary>
/// svbool_t svwhilele_b8[_s32](int32_t op1, int32_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(int left, int right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b8[_s64](int64_t op1, int64_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(long left, long right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b8[_u32](uint32_t op1, uint32_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(uint left, uint right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svwhilele_b8[_u64](uint64_t op1, uint64_t op2)
/// </summary>
public static unsafe Vector<byte> CreateWhileLessThanOrEqualMask8Bit(ulong left, ulong right) { throw new PlatformNotSupportedException(); }


/// LoadVector : Unextended load

Expand Down
Loading

0 comments on commit e2c5d99

Please sign in to comment.