Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86] combineX86ShuffleChain - provide list of combined shuffle nodes, replace HasVariableMask bool arg. NFC. #127826

Merged
merged 2 commits into from
Feb 19, 2025

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Feb 19, 2025

Minor NFC refactor before making better variable mask combining decisions - isTargetShuffleVariableMask doesn't discriminate between fast (AND, PSHUFB etc.) and slow (VPERMV3) variable shuffles, so an opaque HasVariableMask is only of limited use.

…, replace HasVariableMask bool arg. NFC.

Minor NFC refactor before making better variable mask combining decisions - isTargetShuffleVariableMask doesn't discriminate between fast (AND, PSHUFB etc.) and slow (VPERMV3) variable shuffles, so an opaque HasVariableMask is only of limited use.
@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

Minor NFC refactor before making better variable mask combining decisions - isTargetShuffleVariableMask doesn't discriminate between fast (AND, PSHUFB etc.) and slow (VPERMV3) variable shuffles, so an opaque HasVariableMask is only of limited use.


Full diff: https://github.com/llvm/llvm-project/pull/127826.diff

1 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+24-23)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7485fc48f4132..e7b102b51f488 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -39580,7 +39580,7 @@ static bool matchBinaryPermuteShuffle(
 
 static SDValue combineX86ShuffleChainWithExtract(
     ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth,
-    bool HasVariableMask, bool AllowVariableCrossLaneMask,
+    ArrayRef<const SDNode *> SrcNodes, bool AllowVariableCrossLaneMask,
     bool AllowVariablePerLaneMask, SelectionDAG &DAG,
     const X86Subtarget &Subtarget);
 
@@ -39595,7 +39595,7 @@ static SDValue combineX86ShuffleChainWithExtract(
 /// instruction but should only be used to replace chains over a certain depth.
 static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
                                       ArrayRef<int> BaseMask, int Depth,
-                                      bool HasVariableMask,
+                                      ArrayRef<const SDNode*> SrcNodes,
                                       bool AllowVariableCrossLaneMask,
                                       bool AllowVariablePerLaneMask,
                                       SelectionDAG &DAG,
@@ -40064,6 +40064,10 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
   if (Depth < 1)
     return SDValue();
 
+  bool HasVariableMask = llvm::any_of(SrcNodes, [](const SDNode *N) {
+    return isTargetShuffleVariableMask(N->getOpcode());
+  });
+
   // Depth threshold above which we can efficiently use variable mask shuffles.
   int VariableCrossLaneShuffleDepth =
       Subtarget.hasFastVariableCrossLaneShuffle() ? 1 : 2;
@@ -40134,9 +40138,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
     // If that failed and either input is extracted then try to combine as a
     // shuffle with the larger type.
     if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
-            Inputs, Root, BaseMask, Depth, HasVariableMask,
-            AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG,
-            Subtarget))
+            Inputs, Root, BaseMask, Depth, SrcNodes, AllowVariableCrossLaneMask,
+            AllowVariablePerLaneMask, DAG, Subtarget))
       return WideShuffle;
 
     // If we have a dual input lane-crossing shuffle then lower to VPERMV3,
@@ -40307,8 +40310,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
   // If that failed and either input is extracted then try to combine as a
   // shuffle with the larger type.
   if (SDValue WideShuffle = combineX86ShuffleChainWithExtract(
-          Inputs, Root, BaseMask, Depth, HasVariableMask,
-          AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, Subtarget))
+          Inputs, Root, BaseMask, Depth, SrcNodes, AllowVariableCrossLaneMask,
+          AllowVariablePerLaneMask, DAG, Subtarget))
     return WideShuffle;
 
   // If we have a dual input shuffle then lower to VPERMV3,
@@ -40346,7 +40349,7 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
 // extract_subvector(shuffle(x,y,m2),0)
 static SDValue combineX86ShuffleChainWithExtract(
     ArrayRef<SDValue> Inputs, SDValue Root, ArrayRef<int> BaseMask, int Depth,
-    bool HasVariableMask, bool AllowVariableCrossLaneMask,
+    ArrayRef<const SDNode *> SrcNodes, bool AllowVariableCrossLaneMask,
     bool AllowVariablePerLaneMask, SelectionDAG &DAG,
     const X86Subtarget &Subtarget) {
   unsigned NumMaskElts = BaseMask.size();
@@ -40475,7 +40478,7 @@ static SDValue combineX86ShuffleChainWithExtract(
 
   if (SDValue WideShuffle =
           combineX86ShuffleChain(WideInputs, WideRoot, WideMask, Depth,
-                                 HasVariableMask, AllowVariableCrossLaneMask,
+                                 SrcNodes, AllowVariableCrossLaneMask,
                                  AllowVariablePerLaneMask, DAG, Subtarget)) {
     WideShuffle =
         extractSubVector(WideShuffle, 0, DAG, SDLoc(Root), RootSizeInBits);
@@ -40698,7 +40701,7 @@ static SDValue canonicalizeShuffleMaskWithHorizOp(
 // TODO: Extend this to merge multiple constant Ops and update the mask.
 static SDValue combineX86ShufflesConstants(MVT VT, ArrayRef<SDValue> Ops,
                                            ArrayRef<int> Mask,
-                                           bool HasVariableMask,
+                                           ArrayRef<const SDNode *> SrcNodes,
                                            SelectionDAG &DAG, const SDLoc &DL,
                                            const X86Subtarget &Subtarget) {
   unsigned SizeInBits = VT.getSizeInBits();
@@ -40720,6 +40723,9 @@ static SDValue combineX86ShufflesConstants(MVT VT, ArrayRef<SDValue> Ops,
   // only used once or the combined shuffle has included a variable mask
   // shuffle, this is to avoid constant pool bloat.
   bool IsOptimizingSize = DAG.shouldOptForSize();
+  bool HasVariableMask = llvm::any_of(SrcNodes, [](const SDNode *N) {
+    return isTargetShuffleVariableMask(N->getOpcode());
+  });
   if (IsOptimizingSize && !HasVariableMask &&
       llvm::none_of(Ops, [](SDValue SrcOp) { return SrcOp->hasOneUse(); }))
     return SDValue();
@@ -40821,7 +40827,7 @@ namespace llvm {
 static SDValue combineX86ShufflesRecursively(
     ArrayRef<SDValue> SrcOps, int SrcOpIndex, SDValue Root,
     ArrayRef<int> RootMask, ArrayRef<const SDNode *> SrcNodes, unsigned Depth,
-    unsigned MaxDepth, bool HasVariableMask, bool AllowVariableCrossLaneMask,
+    unsigned MaxDepth, bool AllowVariableCrossLaneMask,
     bool AllowVariablePerLaneMask, SelectionDAG &DAG,
     const X86Subtarget &Subtarget) {
   assert(!RootMask.empty() &&
@@ -40877,7 +40883,6 @@ static SDValue combineX86ShufflesRecursively(
   SmallVector<int, 64> OpMask;
   SmallVector<SDValue, 2> OpInputs;
   APInt OpUndef, OpZero;
-  bool IsOpVariableMask = isTargetShuffleVariableMask(Op.getOpcode());
   if (getTargetShuffleInputs(Op, OpDemandedElts, OpInputs, OpMask, OpUndef,
                              OpZero, DAG, Depth, false)) {
     // Shuffle inputs must not be larger than the shuffle result.
@@ -41092,7 +41097,6 @@ static SDValue combineX86ShufflesRecursively(
     return getOnesVector(RootVT, DAG, DL);
 
   assert(!Ops.empty() && "Shuffle with no inputs detected");
-  HasVariableMask |= IsOpVariableMask;
 
   // Update the list of shuffle nodes that have been combined so far.
   SmallVector<const SDNode *, 16> CombinedNodes(SrcNodes);
@@ -41121,15 +41125,14 @@ static SDValue combineX86ShufflesRecursively(
       }
       if (SDValue Res = combineX86ShufflesRecursively(
               Ops, i, Root, ResolvedMask, CombinedNodes, Depth + 1, MaxDepth,
-              HasVariableMask, AllowCrossLaneVar, AllowPerLaneVar, DAG,
-              Subtarget))
+              AllowCrossLaneVar, AllowPerLaneVar, DAG, Subtarget))
         return Res;
     }
   }
 
   // Attempt to constant fold all of the constant source ops.
   if (SDValue Cst = combineX86ShufflesConstants(
-          RootVT, Ops, Mask, HasVariableMask, DAG, DL, Subtarget))
+          RootVT, Ops, Mask, CombinedNodes, DAG, DL, Subtarget))
     return Cst;
 
   // If constant fold failed and we only have constants - then we have
@@ -41231,7 +41234,7 @@ static SDValue combineX86ShufflesRecursively(
 
     // Try to combine into a single shuffle instruction.
     if (SDValue Shuffle = combineX86ShuffleChain(
-            Ops, Root, Mask, Depth, HasVariableMask, AllowVariableCrossLaneMask,
+            Ops, Root, Mask, Depth, CombinedNodes, AllowVariableCrossLaneMask,
             AllowVariablePerLaneMask, DAG, Subtarget))
       return Shuffle;
 
@@ -41250,7 +41253,7 @@ static SDValue combineX86ShufflesRecursively(
   // If that failed and any input is extracted then try to combine as a
   // shuffle with the larger type.
   return combineX86ShuffleChainWithExtract(
-      Ops, Root, Mask, Depth, HasVariableMask, AllowVariableCrossLaneMask,
+      Ops, Root, Mask, Depth, CombinedNodes, AllowVariableCrossLaneMask,
       AllowVariablePerLaneMask, DAG, Subtarget);
 }
 
@@ -41259,7 +41262,6 @@ static SDValue combineX86ShufflesRecursively(SDValue Op, SelectionDAG &DAG,
                                              const X86Subtarget &Subtarget) {
   return combineX86ShufflesRecursively(
       {Op}, 0, Op, {0}, {}, /*Depth*/ 0, X86::MaxShuffleCombineDepth,
-      /*HasVarMask*/ false,
       /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, DAG,
       Subtarget);
 }
@@ -41897,7 +41899,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
       if (SDValue Res = combineX86ShufflesRecursively(
               {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 0,
               X86::MaxShuffleCombineDepth,
-              /*HasVarMask*/ false, /*AllowCrossLaneVarMask*/ true,
+              /*AllowCrossLaneVarMask*/ true,
               /*AllowPerLaneVarMask*/ true, DAG, Subtarget))
         return DAG.getNode(X86ISD::VBROADCAST, DL, VT,
                            DAG.getBitcast(SrcVT, Res));
@@ -42236,7 +42238,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
           llvm::narrowShuffleMaskElts(EltBits / 8, Mask, ByteMask);
           if (SDValue NewMask = combineX86ShufflesConstants(
                   ShufVT, {MaskLHS, MaskRHS}, ByteMask,
-                  /*HasVariableMask=*/true, DAG, DL, Subtarget)) {
+                  {LHS.getNode(), RHS.getNode()}, DAG, DL, Subtarget)) {
             SDValue NewLHS = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT,
                                          LHS.getOperand(0), NewMask);
             SDValue NewRHS = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT,
@@ -43871,7 +43873,6 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
 
     SDValue NewShuffle = combineX86ShufflesRecursively(
         {Op}, 0, Op, DemandedMask, {}, 0, X86::MaxShuffleCombineDepth - Depth,
-        /*HasVarMask*/ false,
         /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, TLO.DAG,
         Subtarget);
     if (NewShuffle)
@@ -51430,7 +51431,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
       if (SDValue Shuffle = combineX86ShufflesRecursively(
               {SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 1,
               X86::MaxShuffleCombineDepth,
-              /*HasVarMask*/ false, /*AllowVarCrossLaneMask*/ true,
+              /*AllowVarCrossLaneMask*/ true,
               /*AllowVarPerLaneMask*/ true, DAG, Subtarget))
         return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Shuffle,
                            N0.getOperand(1));

Copy link

github-actions bot commented Feb 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@RKSimon RKSimon merged commit d1889cf into llvm:main Feb 19, 2025
8 checks passed
@RKSimon RKSimon deleted the x86-shuffle-chain-varmask branch February 19, 2025 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants