Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims #118208

Merged
merged 3 commits into from
Feb 19, 2025

Conversation

Groverkss
Copy link
Member

This pr fixes how iteration domain of linalg.generic is collapsed when fusing with tensor.expand_shape. Previously, the output_shape for tensor.expand shape was infered, which doesn't always work except some special cases.

This patch makes the logic explicitly set the bounds of the new collapsed iteration domain, because we already know them.

@Groverkss
Copy link
Member Author

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Kunwar Grover (Groverkss)

Changes

This pr fixes how iteration domain of linalg.generic is collapsed when fusing with tensor.expand_shape. Previously, the output_shape for tensor.expand shape was infered, which doesn't always work except some special cases.

This patch makes the logic explicitly set the bounds of the new collapsed iteration domain, because we already know them.


Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118208.diff

8 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+14-9)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+23-96)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+6-26)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+45-18)
  • (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+3-4)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+2-6)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-6)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c44194a1231588..fa730241203039 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1549,7 +1549,7 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
 /// value in the collapsed operation.
 void generateCollapsedIndexingRegion(Location loc, Block *block,
                                      const CollapsingInfo &collapsingInfo,
-                                     ValueRange loopRange,
+                                     ArrayRef<OpFoldResult> loopRange,
                                      RewriterBase &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointToStart(block);
@@ -1571,10 +1571,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
     Value newIndexVal =
         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
+      Value loopDim =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
       indexReplacementVals[dim] =
-          rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::RemUIOp>(loc, newIndexVal, loopDim);
       newIndexVal =
-          rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::DivUIOp>(loc, newIndexVal, loopDim);
     }
     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
   }
@@ -1721,14 +1723,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
   LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
   Location loc = op->getLoc();
+  SmallVector<OpFoldResult> loopBound =
+      llvm::map_to_vector(loopRanges, [&](Range range) { return range.size; });
+
   if (collapsedOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(collapsedOp);
-    SmallVector<Value> loopBound =
-        llvm::map_to_vector(loopRanges, [&](Range range) {
-          return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
-        });
     generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
@@ -1746,15 +1747,19 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      SmallVector<OpFoldResult> resultShape =
+          applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
         MemRefType expandShapeResultType = MemRefType::get(
             originalResultType.getShape(), originalResultType.getElementType());
         result = rewriter.create<memref::ExpandShapeOp>(
-            loc, expandShapeResultType, collapsedOpResult, reassociation);
+            loc, expandShapeResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = rewriter.create<tensor::ExpandShapeOp>(
-            loc, originalResultType, collapsedOpResult, reassociation);
+            loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       }
       results.push_back(result);
     } else {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebb88bf695d4c2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -16,24 +16,6 @@
 using namespace mlir;
 using namespace mlir::tensor;
 
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static llvm::DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
-  for (const auto &map : enumerate(reassociation)) {
-    unsigned startPos =
-        cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
-    unsigned endPos =
-        cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
-    for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
-      expandedDimToCollapsedDim[dim] = map.index();
-    }
-  }
-  return expandedDimToCollapsedDim;
-}
-
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
       }));
 }
 
-/// For an expanding reshape op, compute the value for a dimension of the output
-/// from the shape of the input.
-static OpFoldResult getExpandedOutputDimFromInputShape(
-    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
-    llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
-  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
-    // Static dimension: return Attribute.
-    return builder.getIndexAttr(dstStaticShape[dimIndex]);
-  }
-  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
-  unsigned startPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
-          .getPosition();
-  unsigned endPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
-          .getPosition();
-  int64_t linearizedStaticDim = 1;
-  for (auto d :
-       llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
-    if (d.index() + startPos == static_cast<unsigned>(dimIndex))
-      continue;
-    assert(!ShapedType::isDynamic(d.value()) &&
-           "single dimension cannot be expanded into multiple dynamic "
-           "dimensions");
-    linearizedStaticDim *= d.value();
+struct ReifyCollapseShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
+          ReifyCollapseShapeOp, CollapseShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
+    auto loc = op->getLoc();
+    auto collapseShape = cast<CollapseShapeOp>(op);
+    reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
+        b, loc, collapseShape.getSrc(),
+        collapseShape.getResultType().getShape(),
+        collapseShape.getReassociationMaps()));
+    return success();
   }
-  OpFoldResult sourceDim =
-      builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
-
-  // Dynamic dimension: return Value.
-  return affine::makeComposedAffineApply(
-             builder, loc,
-             AffineMap::get(
-                 0, 1,
-                 builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
-             sourceDim)
-      ->getResult(0);
-}
-
-/// Given the `src` of an expanding reshape op, the reassociation maps and the
-/// result type, compute the shape of the result of the reshape.
-static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
-    OpBuilder &builder, Location loc, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
-      getExpandedDimToCollapsedDimMap(reassociation);
-  return llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
-        return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
-                                                  dstStaticShape, reassociation,
-                                                  expandedDimToCollapsedDim);
-      }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
-                                    ArrayRef<int64_t> dstStaticShape,
-                                    ArrayRef<AffineMap> reassocation) {
-  return dstStaticShape.size() >
-                 static_cast<size_t>(
-                     llvm::cast<ShapedType>(src.getType()).getRank())
-             ? getExpandedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation)
-             : getCollapsedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation);
-}
+};
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
-    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
-    reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
-        b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
-        reshapeOp.getReassociationMaps()));
+    auto expandShape = cast<ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> outputShape = getMixedValues(
+        expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b);
+    reifiedReturnShapes.push_back(outputShape);
     return success();
   }
 };
@@ -202,10 +131,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 616d4a7d0a0ab5..a6ae728b20fa47 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
     if (!dim.has_value())
       return failure();
 
-    // Skip static dims. These are folded to constant ops.
-    RankedTensorType resultType = expandShapeOp.getResultType();
-    if (!resultType.isDynamicDim(*dim))
-      return failure();
-
-    // Find reassociation group that contains this result dimension.
-    int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
-    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
-    // ExpandShapeOp would be ambiguous.)
-    int64_t product = 1;
-    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
-    for (int64_t d : grp) {
-      if (d != dim) {
-        assert(!resultType.isDynamicDim(d) && "expected static dim");
-        product *= resultType.getDimSize(d);
-      }
-    }
-
-    // result dim size = src dim size / (product(other dims in reassoc group))
-    Value srcDimSz =
-        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
-    AffineExpr expr;
-    bindSymbols(dimOp.getContext(), expr);
-    rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
-        dimOp, expr.floorDiv(product), srcDimSz);
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+    OpFoldResult outputDim = outputShape[dim.value()];
+    rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+                                  rewriter, dimOp.getLoc(), outputDim));
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index f17881d59a266e..f29f231cdeca87 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
+  %init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map0],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0 : tensor<?x?xf32>) 
+      outs(%init : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32):
+          %out = arith.negf %b0 : f32
+          linalg.yield %out : f32
+      } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
+// CHECK:     %[[OUT:.+]] = linalg.generic
+// CHECK-SAME:   ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME:   outs(%{{.*}} : tensor<?xf32>)
+// CHECK:     %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
+// CHECK-SAME:    output_shape [%[[DIM0]], %[[DIM1]]]
+// CHECK:      return %[[EXPANDED_1]]
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
 func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @fuse_only_one_reassociation
 // CHECK-SAME:     (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//  CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
 //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
+//  CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
 //  CHECK-DAG:   %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
 // CHECK-SAME:       outs(%[[COLLAPSE_ARG1_1]] :
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
 //      CHECK: func @fold_non_consecutive_dims(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
+//      CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
+//      CHECK-DAG:   %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG:       %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
 //  CHECK-DAG:       %[[T7:.+]] = arith.index_cast %[[T6]]
 //      CHECK:       linalg.yield %[[T7]]
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index
-//      CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C4]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 751ece37bc094f..fd3c3217225086 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -5,15 +5,14 @@
 
 // CHECK-LABEL: func @reshape
 // CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
-//      CHECK: %[[C112:.*]] = arith.constant 112 : index
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
+//      CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
 //      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x...
[truncated]

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR LGTM. Holding approval for the dependencies which seem to be blocked.

@Groverkss Groverkss force-pushed the fix-collapsing-fusion branch from c22b9d4 to 533294d Compare February 9, 2025 00:41
@Groverkss
Copy link
Member Author

The other patches landed in some form. This is ready to go now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants