Skip to content

Commit

Permalink
Automerge: [AArch64][SME] Spill p-regs as z-regs when streaming hazar…
Browse files Browse the repository at this point in the history
…ds are possible (#123752)

This patch adds a new option `-aarch64-enable-zpr-predicate-spills`
(which is disabled by default), this option replaces predicate spills
with vector spills in streaming[-compatible] functions.

For example:

```
str	p8, [sp, #7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, #7, mul vl]            // 2-byte Folded Reload
```

Becomes:

```
mov	z0.b, p8/z, #1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #0
```

This is done to avoid streaming memory hazards between FPR/vector and
predicate spills, which currently occupy the same stack area even when
the `-aarch64-stack-hazard-size` flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO
and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos
handles scavenging the required registers (z0 in the above example) and,
in the worst case spilling a register to an emergency stack slot in the
expansion. The condition flags are also preserved around the `cmpne` in
case they are live at the expansion point.
  • Loading branch information
MacDue authored and github-actions[bot] committed Feb 3, 2025
2 parents 717246e + 82c6b8f commit f9490a9
Show file tree
Hide file tree
Showing 11 changed files with 1,474 additions and 12 deletions.
313 changes: 308 additions & 5 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
case AArch64::STR_PXI:
case AArch64::LDR_ZXI:
case AArch64::LDR_PXI:
case AArch64::PTRUE_B:
case AArch64::CPY_ZPzI_B:
case AArch64::CMPNE_PPzZI_B:
return I->getFlag(MachineInstr::FrameSetup) ||
I->getFlag(MachineInstr::FrameDestroy);
}
Expand Down Expand Up @@ -3265,7 +3268,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
break;
case RegPairInfo::PPR:
StrOpc = AArch64::STR_PXI;
StrOpc =
Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI;
break;
case RegPairInfo::VG:
StrOpc = AArch64::STRXui;
Expand Down Expand Up @@ -3494,7 +3498,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
break;
case RegPairInfo::PPR:
LdrOpc = AArch64::LDR_PXI;
LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO
: AArch64::LDR_PXI;
break;
case RegPairInfo::VG:
continue;
Expand Down Expand Up @@ -3720,6 +3725,14 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
continue;
}

// Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is
// spilled. If all of p0-p3 are used as return values p4 is must be free
// to reload p8-p15.
if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 &&
AArch64::PPR_p8to15RegClass.contains(Reg)) {
SavedRegs.set(AArch64::P4);
}

// MachO's compact unwind format relies on all registers being stored in
// pairs.
// FIXME: the usual format is actually better if unwinding isn't needed.
Expand Down Expand Up @@ -4159,8 +4172,295 @@ int64_t AArch64FrameLowering::assignSVEStackObjectOffsets(
true);
}

/// Attempts to scavenge a register from \p ScavengeableRegs given the used
/// registers in \p UsedRegs.
static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
BitVector const &ScavengeableRegs) {
for (auto Reg : ScavengeableRegs.set_bits()) {
if (UsedRegs.available(Reg))
return Reg;
}
return AArch64::NoRegister;
}

/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
/// \p MachineInstrs.
static void propagateFrameFlags(MachineInstr &SourceMI,
ArrayRef<MachineInstr *> MachineInstrs) {
for (MachineInstr *MI : MachineInstrs) {
if (SourceMI.getFlag(MachineInstr::FrameSetup))
MI->setFlag(MachineInstr::FrameSetup);
if (SourceMI.getFlag(MachineInstr::FrameDestroy))
MI->setFlag(MachineInstr::FrameDestroy);
}
}

/// RAII helper class for scavenging or spilling a register. On construction
/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
/// MaybeSpillFI to free a register. The free'd register is returned via the \p
/// FreeReg output parameter. On destruction, if there is a spill, its previous
/// value is reloaded. The spilling and scavenging is only valid at the
/// insertion point \p MBBI, this class should _not_ be used in places that
/// create or manipulate basic blocks, moving the expected insertion point.
struct ScopedScavengeOrSpill {
ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;

ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
Register SpillCandidate, const TargetRegisterClass &RC,
LiveRegUnits const &UsedRegs,
BitVector const &AllocatableRegs,
std::optional<int> *MaybeSpillFI)
: MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
*MF.getSubtarget().getInstrInfo())),
TRI(*MF.getSubtarget().getRegisterInfo()) {
FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs);
if (FreeReg != AArch64::NoRegister)
return;
assert(MaybeSpillFI && "Expected emergency spill slot FI information "
"(attempted to spill in prologue/epilogue?)");
if (!MaybeSpillFI->has_value()) {
MachineFrameInfo &MFI = MF.getFrameInfo();
*MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
TRI.getSpillAlign(RC));
}
FreeReg = SpillCandidate;
SpillFI = MaybeSpillFI->value();
TII.storeRegToStackSlot(MBB, MBBI, FreeReg, false, *SpillFI, &RC, &TRI,
Register());
}

bool hasSpilled() const { return SpillFI.has_value(); }

/// Returns the free register (found from scavenging or spilling a register).
Register freeRegister() const { return FreeReg; }

Register operator*() const { return freeRegister(); }

~ScopedScavengeOrSpill() {
if (hasSpilled())
TII.loadRegFromStackSlot(MBB, MBBI, FreeReg, *SpillFI, &RC, &TRI,
Register());
}

private:
MachineBasicBlock &MBB;
MachineBasicBlock::iterator MBBI;
const TargetRegisterClass &RC;
const AArch64InstrInfo &TII;
const TargetRegisterInfo &TRI;
Register FreeReg = AArch64::NoRegister;
std::optional<int> SpillFI;
};

/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
struct EmergencyStackSlots {
std::optional<int> ZPRSpillFI;
std::optional<int> PPRSpillFI;
std::optional<int> GPRSpillFI;
};

/// Registers available for scavenging (ZPR, PPR3b, GPR).
struct ScavengeableRegs {
BitVector ZPRRegs;
BitVector PPR3bRegs;
BitVector GPRRegs;
};

static bool isInPrologueOrEpilogue(const MachineInstr &MI) {
return MI.getFlag(MachineInstr::FrameSetup) ||
MI.getFlag(MachineInstr::FrameDestroy);
}

/// Expands:
/// ```
/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
/// ```
/// To:
/// ```
/// $z0 = CPY_ZPzI_B $p0, 1, 0
/// STR_ZXI $z0, $stack.0, 0
/// ```
/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
/// spilling if necessary).
static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
MachineInstr &MI,
const TargetRegisterInfo &TRI,
LiveRegUnits const &UsedRegs,
ScavengeableRegs const &SR,
EmergencyStackSlots &SpillSlots) {
MachineFunction &MF = *MBB.getParent();
auto *TII =
static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());

ScopedScavengeOrSpill ZPredReg(
MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);

SmallVector<MachineInstr *, 2> MachineInstrs;
const DebugLoc &DL = MI.getDebugLoc();
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
.addReg(*ZPredReg, RegState::Define)
.add(MI.getOperand(0))
.addImm(1)
.addImm(0)
.getInstr());
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
.addReg(*ZPredReg)
.add(MI.getOperand(1))
.addImm(MI.getOperand(2).getImm())
.setMemRefs(MI.memoperands())
.getInstr());
propagateFrameFlags(MI, MachineInstrs);
}

/// Expands:
/// ```
/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
/// ```
/// To:
/// ```
/// $z0 = LDR_ZXI %stack.0, 0
/// $p0 = PTRUE_B 31, implicit $vg
/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
/// ```
/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
/// spilling if necessary). If the status flags are in use at the point of
/// expansion they are preserved (by moving them to/from a GPR). This may cause
/// an additional spill if no GPR is free at the expansion point.
static bool expandFillPPRFromZPRSlotPseudo(MachineBasicBlock &MBB,
MachineInstr &MI,
const TargetRegisterInfo &TRI,
LiveRegUnits const &UsedRegs,
ScavengeableRegs const &SR,
EmergencyStackSlots &SpillSlots) {
MachineFunction &MF = *MBB.getParent();
auto *TII =
static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());

ScopedScavengeOrSpill ZPredReg(
MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);

ScopedScavengeOrSpill PredReg(
MF, MBB, MI, AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs, SR.PPR3bRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.PPRSpillFI);

// Elide NZCV spills if we know it is not used.
bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
std::optional<ScopedScavengeOrSpill> NZCVSaveReg;
if (IsNZCVUsed)
NZCVSaveReg.emplace(
MF, MBB, MI, AArch64::X0, AArch64::GPR64RegClass, UsedRegs, SR.GPRRegs,
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.GPRSpillFI);
SmallVector<MachineInstr *, 4> MachineInstrs;
const DebugLoc &DL = MI.getDebugLoc();
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
.addReg(*ZPredReg, RegState::Define)
.add(MI.getOperand(1))
.addImm(MI.getOperand(2).getImm())
.setMemRefs(MI.memoperands())
.getInstr());
if (IsNZCVUsed)
MachineInstrs.push_back(
BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
.addReg(NZCVSaveReg->freeRegister(), RegState::Define)
.addImm(AArch64SysReg::NZCV)
.addReg(AArch64::NZCV, RegState::Implicit)
.getInstr());
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
.addReg(*PredReg, RegState::Define)
.addImm(31));
MachineInstrs.push_back(
BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
.addReg(MI.getOperand(0).getReg(), RegState::Define)
.addReg(*PredReg)
.addReg(*ZPredReg)
.addImm(0)
.addReg(AArch64::NZCV, RegState::ImplicitDefine)
.getInstr());
if (IsNZCVUsed)
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
.addImm(AArch64SysReg::NZCV)
.addReg(NZCVSaveReg->freeRegister())
.addReg(AArch64::NZCV, RegState::ImplicitDefine)
.getInstr());

propagateFrameFlags(MI, MachineInstrs);
return PredReg.hasSpilled();
}

/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
/// operations within the MachineBasicBlock \p MBB.
static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
const TargetRegisterInfo &TRI,
ScavengeableRegs const &SR,
EmergencyStackSlots &SpillSlots) {
LiveRegUnits UsedRegs(TRI);
UsedRegs.addLiveOuts(MBB);
bool HasPPRSpills = false;
for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
UsedRegs.stepBackward(MI);
switch (MI.getOpcode()) {
case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR,
SpillSlots);
MI.eraseFromParent();
break;
case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, SpillSlots);
MI.eraseFromParent();
break;
default:
break;
}
}

return HasPPRSpills;
}

void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
MachineFunction &MF, RegScavenger *RS) const {

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
const TargetSubtargetInfo &TSI = MF.getSubtarget();
const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();

// If predicates spills are 16-bytes we may need to expand
// SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
BitVector Regs = TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
assert(Regs.count() > 0 && "Expected scavengeable registers");
return Regs;
};

ScavengeableRegs SR{};
SR.ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID);
// Only p0-7 are possible as the second operand of cmpne (needed for fills).
SR.PPR3bRegs = ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID);
SR.GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID);

EmergencyStackSlots SpillSlots;
for (MachineBasicBlock &MBB : MF) {
// In the case we had to spill a predicate (in the range p0-p7) to reload
// a predicate (>= p8), additional spill/fill pseudos will be created.
// These need an additional expansion pass. Note: There will only be at
// most two expansion passes, as spilling/filling a predicate in the range
// p0-p7 never requires spilling another predicate.
for (int Pass = 0; Pass < 2; Pass++) {
bool HasPPRSpills =
expandSMEPPRToZPRSpillPseudos(MBB, TRI, SR, SpillSlots);
assert((Pass == 0 || !HasPPRSpills) && "Did not expect PPR spills");
if (!HasPPRSpills)
break;
}
}
}

MachineFrameInfo &MFI = MF.getFrameInfo();

assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown &&
Expand All @@ -4170,7 +4470,6 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
int64_t SVEStackSize =
assignSVEStackObjectOffsets(MFI, MinCSFrameIndex, MaxCSFrameIndex);

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
AFI->setStackSizeSVE(alignTo(SVEStackSize, 16U));
AFI->setMinMaxSVECSFrameIndex(MinCSFrameIndex, MaxCSFrameIndex);

Expand Down Expand Up @@ -5204,9 +5503,13 @@ void AArch64FrameLowering::emitRemarks(

unsigned RegTy = StackAccess::AccessType::GPR;
if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector) {
if (AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()))
// SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO
// spill/fill the predicate as a data vector (so are an FPR acess).
if (MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
AArch64::PPRRegClass.contains(MI.getOperand(0).getReg())) {
RegTy = StackAccess::PPR;
else
} else
RegTy = StackAccess::FPR;
} else if (AArch64InstrInfo::isFpOrNEON(MI)) {
RegTy = StackAccess::FPR;
Expand Down
Loading

0 comments on commit f9490a9

Please sign in to comment.