Skip to content

Commit

Permalink
[RISCV] Convert AVLs with vlenb to VLMAX where possible (#97800)
Browse files Browse the repository at this point in the history
Given an AVL that's computed from vlenb, if it's equal to VLMAX then we
can replace it with the VLMAX sentinel value.

The main motiviation is to be able to express an EVL of VLMAX in VP
intrinsics whilst emitting vsetvli a0, zero, so that we can replace
llvm.riscv.masked.strided.{load,store} with their VP counterparts.

This is done in RISCVVectorPeephole (previously RISCVFoldMasks, renamed
to account for the fact that it no longer just folds masks) instead of
SelectionDAG since there are multiple places places where VP nodes are
lowered that would have need to have been handled.

This also avoids doing it in RISCVInsertVSETVLI as it's much harder to
lookup the value of the AVL, and in RISCVVectorPeephole we can take
advantage of DeadMachineInstrElim to remove any leftover
PseudoReadVLENBs.
  • Loading branch information
lukel97 authored Jul 11, 2024
1 parent 7eae9bb commit c74ba57
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 68 deletions.
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ add_llvm_target(RISCVCodeGen
RISCVMakeCompressible.cpp
RISCVExpandAtomicPseudoInsts.cpp
RISCVExpandPseudoInsts.cpp
RISCVFoldMasks.cpp
RISCVFrameLowering.cpp
RISCVGatherScatterLowering.cpp
RISCVInsertVSETVLI.cpp
Expand All @@ -55,6 +54,7 @@ add_llvm_target(RISCVCodeGen
RISCVTargetMachine.cpp
RISCVTargetObjectFile.cpp
RISCVTargetTransformInfo.cpp
RISCVVectorPeephole.cpp
GISel/RISCVCallLowering.cpp
GISel/RISCVInstructionSelector.cpp
GISel/RISCVLegalizerInfo.cpp
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/RISCV/RISCV.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void initializeRISCVMakeCompressibleOptPass(PassRegistry &);
FunctionPass *createRISCVGatherScatterLoweringPass();
void initializeRISCVGatherScatterLoweringPass(PassRegistry &);

FunctionPass *createRISCVFoldMasksPass();
void initializeRISCVFoldMasksPass(PassRegistry &);
FunctionPass *createRISCVVectorPeepholePass();
void initializeRISCVVectorPeepholePass(PassRegistry &);

FunctionPass *createRISCVOptWInstrsPass();
void initializeRISCVOptWInstrsPass(PassRegistry &);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
initializeRISCVOptWInstrsPass(*PR);
initializeRISCVPreRAExpandPseudoPass(*PR);
initializeRISCVExpandPseudoPass(*PR);
initializeRISCVFoldMasksPass(*PR);
initializeRISCVVectorPeepholePass(*PR);
initializeRISCVInsertVSETVLIPass(*PR);
initializeRISCVInsertReadWriteCSRPass(*PR);
initializeRISCVInsertWriteVXRMPass(*PR);
Expand Down Expand Up @@ -532,7 +532,7 @@ void RISCVPassConfig::addPreEmitPass2() {
}

void RISCVPassConfig::addMachineSSAOptimization() {
addPass(createRISCVFoldMasksPass());
addPass(createRISCVVectorPeepholePass());

TargetPassConfig::addMachineSSAOptimization();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
//===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
//===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
//
// This pass performs various peephole optimisations that fold masks into vector
// pseudo instructions after instruction selection.
// This pass performs various vector pseudo peephole optimisations after
// instruction selection.
//
// Currently it converts
// Currently it converts vmerge.vvm to vmv.v.v
// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
// ->
// PseudoVMV_V_V %false, %true, %vl, %sew
//
//===---------------------------------------------------------------------===//
// And masked pseudos to unmasked pseudos
// PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
// ->
// PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
//
// It also converts AVLs to VLMAX where possible
// %vl = VLENB * something
// PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
// ->
// PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
//
//===----------------------------------------------------------------------===//

#include "RISCV.h"
#include "RISCVISelDAGToDAG.h"
Expand All @@ -26,17 +37,17 @@

using namespace llvm;

#define DEBUG_TYPE "riscv-fold-masks"
#define DEBUG_TYPE "riscv-vector-peephole"

namespace {

class RISCVFoldMasks : public MachineFunctionPass {
class RISCVVectorPeephole : public MachineFunctionPass {
public:
static char ID;
const TargetInstrInfo *TII;
MachineRegisterInfo *MRI;
const TargetRegisterInfo *TRI;
RISCVFoldMasks() : MachineFunctionPass(ID) {}
RISCVVectorPeephole() : MachineFunctionPass(ID) {}

bool runOnMachineFunction(MachineFunction &MF) override;
MachineFunctionProperties getRequiredProperties() const override {
Expand All @@ -47,6 +58,7 @@ class RISCVFoldMasks : public MachineFunctionPass {
StringRef getPassName() const override { return "RISC-V Fold Masks"; }

private:
bool convertToVLMAX(MachineInstr &MI) const;
bool convertToUnmasked(MachineInstr &MI) const;
bool convertVMergeToVMv(MachineInstr &MI) const;

Expand All @@ -58,11 +70,65 @@ class RISCVFoldMasks : public MachineFunctionPass {

} // namespace

char RISCVFoldMasks::ID = 0;
char RISCVVectorPeephole::ID = 0;

INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
false)

// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
// to the VLMAX sentinel value.
bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
!RISCVII::hasSEWOp(MI.getDesc().TSFlags))
return false;
MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
if (!VL.isReg())
return false;
MachineInstr *Def = MRI->getVRegDef(VL.getReg());
if (!Def)
return false;

// Fixed-point value, denominator=8
uint64_t ScaleFixed = 8;
// Check if the VLENB was potentially scaled with slli/srli
if (Def->getOpcode() == RISCV::SLLI) {
assert(Def->getOperand(2).getImm() < 64);
ScaleFixed <<= Def->getOperand(2).getImm();
Def = MRI->getVRegDef(Def->getOperand(1).getReg());
} else if (Def->getOpcode() == RISCV::SRLI) {
assert(Def->getOperand(2).getImm() < 64);
ScaleFixed >>= Def->getOperand(2).getImm();
Def = MRI->getVRegDef(Def->getOperand(1).getReg());
}

if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
return false;

auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
// Fixed-point value, denominator=8
unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
// A Log2SEW of 0 is an operation on mask registers only
unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
assert(8 * LMULFixed / SEW > 0);

INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
// AVL = (VLENB * Scale)
//
// VLMAX = (VLENB * 8 * LMUL) / SEW
//
// AVL == VLMAX
// -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
// -> Scale == (8 * LMUL) / SEW
if (ScaleFixed != 8 * LMULFixed / SEW)
return false;

bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
VL.ChangeToImmediate(RISCV::VLMaxSentinel);

return true;
}

bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
assert(MaskDef && MaskDef->isCopy() &&
MaskDef->getOperand(0).getReg() == RISCV::V0);
Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
Expand Down Expand Up @@ -91,7 +157,7 @@ bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {

// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
bool RISCVVectorPeephole::convertVMergeToVMv(MachineInstr &MI) const {
#define CASE_VMERGE_TO_VMV(lmul) \
case RISCV::PseudoVMERGE_VVM_##lmul: \
NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
Expand Down Expand Up @@ -134,7 +200,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
return true;
}

bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
const RISCV::RISCVMaskedPseudoInfo *I =
RISCV::getMaskedPseudoInfo(MI.getOpcode());
if (!I)
Expand Down Expand Up @@ -178,7 +244,7 @@ bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
return true;
}

bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;

Expand Down Expand Up @@ -213,6 +279,7 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {

for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
Changed |= convertToVLMAX(MI);
Changed |= convertToUnmasked(MI);
Changed |= convertVMergeToVMv(MI);
}
Expand All @@ -221,4 +288,6 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
return Changed;
}

FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }
FunctionPass *llvm::createRISCVVectorPeepholePass() {
return new RISCVVectorPeephole();
}
12 changes: 6 additions & 6 deletions llvm/test/CodeGen/RISCV/rvv/insert-subvector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ define <vscale x 16 x i8> @insert_nxv16i8_nxv1i8_7(<vscale x 16 x i8> %vec, <vsc
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a1, a0, 3
; CHECK-NEXT: sub a1, a0, a1
; CHECK-NEXT: vsetvli zero, a0, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vx v8, v10, a1
; CHECK-NEXT: sub a0, a0, a1
; CHECK-NEXT: vsetvli a1, zero, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vx v8, v10, a0
; CHECK-NEXT: ret
%v = call <vscale x 16 x i8> @llvm.vector.insert.nxv1i8.nxv16i8(<vscale x 16 x i8> %vec, <vscale x 1 x i8> %subvec, i64 7)
ret <vscale x 16 x i8> %v
Expand All @@ -318,9 +318,9 @@ define <vscale x 16 x i8> @insert_nxv16i8_nxv1i8_15(<vscale x 16 x i8> %vec, <vs
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a1, a0, 3
; CHECK-NEXT: sub a1, a0, a1
; CHECK-NEXT: vsetvli zero, a0, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vx v9, v10, a1
; CHECK-NEXT: sub a0, a0, a1
; CHECK-NEXT: vsetvli a1, zero, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vx v9, v10, a0
; CHECK-NEXT: ret
%v = call <vscale x 16 x i8> @llvm.vector.insert.nxv1i8.nxv16i8(<vscale x 16 x i8> %vec, <vscale x 1 x i8> %subvec, i64 15)
ret <vscale x 16 x i8> %v
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-to-vmv.mir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 3
# RUN: llc %s -o - -mtriple=riscv64 -mattr=+v -run-pass=riscv-fold-masks \
# RUN: llc %s -o - -mtriple=riscv64 -mattr=+v -run-pass=riscv-vector-peephole \
# RUN: -verify-machineinstrs | FileCheck %s

---
Expand Down
13 changes: 5 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/vadd-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1436,20 +1436,17 @@ define <vscale x 32 x i32> @vadd_vi_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, <v
define <vscale x 32 x i32> @vadd_vi_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, <vscale x 32 x i1> %m) {
; RV32-LABEL: vadd_vi_nxv32i32_evl_nx16:
; RV32: # %bb.0:
; RV32-NEXT: csrr a0, vlenb
; RV32-NEXT: slli a0, a0, 1
; RV32-NEXT: vsetvli zero, a0, e32, m8, ta, ma
; RV32-NEXT: vsetvli a0, zero, e32, m8, ta, ma
; RV32-NEXT: vadd.vi v8, v8, -1, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vadd_vi_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a0, vlenb
; RV64-NEXT: srli a1, a0, 2
; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
; RV64-NEXT: vslidedown.vx v24, v0, a1
; RV64-NEXT: slli a0, a0, 1
; RV64-NEXT: vsetvli zero, a0, e32, m8, ta, ma
; RV64-NEXT: srli a0, a0, 2
; RV64-NEXT: vsetvli a1, zero, e8, mf2, ta, ma
; RV64-NEXT: vslidedown.vx v24, v0, a0
; RV64-NEXT: vsetvli a0, zero, e32, m8, ta, ma
; RV64-NEXT: vadd.vi v8, v8, -1, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
Expand Down
48 changes: 48 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vlmax-peephole.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=riscv64 -mattr=+v -verify-machineinstrs | FileCheck %s

define <vscale x 1 x i1> @sew1_srli(<vscale x 1 x i1> %a, <vscale x 1 x i1> %b) {
; CHECK-LABEL: sew1_srli:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e8, mf8, ta, ma
; CHECK-NEXT: vmand.mm v0, v0, v8
; CHECK-NEXT: ret
%vlmax = call i32 @llvm.vscale()
%x = call <vscale x 1 x i1> @llvm.vp.and.nxv1i1(<vscale x 1 x i1> %a, <vscale x 1 x i1> %b, <vscale x 1 x i1> splat (i1 true), i32 %vlmax)
ret <vscale x 1 x i1> %x
}

define <vscale x 1 x i64> @sew64_srli(<vscale x 1 x i64> %a, <vscale x 1 x i64> %b) {
; CHECK-LABEL: sew64_srli:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v8, v9
; CHECK-NEXT: ret
%vlmax = call i32 @llvm.vscale()
%x = call <vscale x 1 x i64> @llvm.vp.add.nxv1i64(<vscale x 1 x i64> %a, <vscale x 1 x i64> %b, <vscale x 1 x i1> splat (i1 true), i32 %vlmax)
ret <vscale x 1 x i64> %x
}

define <vscale x 8 x i64> @sew64(<vscale x 8 x i64> %a, <vscale x 8 x i64> %b) {
; CHECK-LABEL: sew64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e64, m8, ta, ma
; CHECK-NEXT: vadd.vv v8, v8, v16
; CHECK-NEXT: ret
%vscale = call i32 @llvm.vscale()
%vlmax = shl i32 %vscale, 3
%x = call <vscale x 8 x i64> @llvm.vp.add.nxv8i64(<vscale x 8 x i64> %a, <vscale x 8 x i64> %b, <vscale x 8 x i1> splat (i1 true), i32 %vlmax)
ret <vscale x 8 x i64> %x
}

define <vscale x 16 x i32> @sew32_sll(<vscale x 16 x i32> %a, <vscale x 16 x i32> %b) {
; CHECK-LABEL: sew32_sll:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a0, zero, e32, m8, ta, ma
; CHECK-NEXT: vadd.vv v8, v8, v16
; CHECK-NEXT: ret
%vscale = call i32 @llvm.vscale()
%vlmax = shl i32 %vscale, 4
%x = call <vscale x 16 x i32> @llvm.vp.add.nxv16i32(<vscale x 16 x i32> %a, <vscale x 16 x i32> %b, <vscale x 16 x i1> splat (i1 true), i32 %vlmax)
ret <vscale x 16 x i32> %x
}
13 changes: 5 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1073,20 +1073,17 @@ define <vscale x 32 x i32> @vmax_vx_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, i3
define <vscale x 32 x i32> @vmax_vx_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, i32 %b, <vscale x 32 x i1> %m) {
; RV32-LABEL: vmax_vx_nxv32i32_evl_nx16:
; RV32: # %bb.0:
; RV32-NEXT: csrr a1, vlenb
; RV32-NEXT: slli a1, a1, 1
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV32-NEXT: vmax.vx v8, v8, a0, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vmax_vx_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a1, vlenb
; RV64-NEXT: srli a2, a1, 2
; RV64-NEXT: vsetvli a3, zero, e8, mf2, ta, ma
; RV64-NEXT: vslidedown.vx v24, v0, a2
; RV64-NEXT: slli a1, a1, 1
; RV64-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; RV64-NEXT: srli a1, a1, 2
; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
; RV64-NEXT: vslidedown.vx v24, v0, a1
; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV64-NEXT: vmax.vx v8, v8, a0, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
Expand Down
13 changes: 5 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/vmaxu-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1072,20 +1072,17 @@ define <vscale x 32 x i32> @vmaxu_vx_nxv32i32_evl_nx8(<vscale x 32 x i32> %va, i
define <vscale x 32 x i32> @vmaxu_vx_nxv32i32_evl_nx16(<vscale x 32 x i32> %va, i32 %b, <vscale x 32 x i1> %m) {
; RV32-LABEL: vmaxu_vx_nxv32i32_evl_nx16:
; RV32: # %bb.0:
; RV32-NEXT: csrr a1, vlenb
; RV32-NEXT: slli a1, a1, 1
; RV32-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; RV32-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV32-NEXT: vmaxu.vx v8, v8, a0, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: vmaxu_vx_nxv32i32_evl_nx16:
; RV64: # %bb.0:
; RV64-NEXT: csrr a1, vlenb
; RV64-NEXT: srli a2, a1, 2
; RV64-NEXT: vsetvli a3, zero, e8, mf2, ta, ma
; RV64-NEXT: vslidedown.vx v24, v0, a2
; RV64-NEXT: slli a1, a1, 1
; RV64-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; RV64-NEXT: srli a1, a1, 2
; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, ma
; RV64-NEXT: vslidedown.vx v24, v0, a1
; RV64-NEXT: vsetvli a1, zero, e32, m8, ta, ma
; RV64-NEXT: vmaxu.vx v8, v8, a0, v0.t
; RV64-NEXT: vmv1r.v v0, v24
; RV64-NEXT: vsetivli zero, 0, e32, m8, ta, ma
Expand Down
Loading

0 comments on commit c74ba57

Please sign in to comment.