Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Tensor] Rework ReifyRankedShapedTypeInterface implementation for tensor.expand_shape op. #113501

Merged

Conversation

MaheshRavishankar
Copy link
Contributor

The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a SmallVector<OpFoldResult>.

@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: None (MaheshRavishankar)

Changes

The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a SmallVector&lt;OpFoldResult&gt;.


Full diff: https://github.com/llvm/llvm-project/pull/113501.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+19-8)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+8-2)
  • (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (-8)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-6)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8203b9c0fab437 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
 
+    // Return output shape as mixes static/dynamic shapes.
+    SmallVector<OpFoldResult> getMixedOutputShape();
+
     // Infer the output shape for a tensor.expand_shape when it is possible
     // to do so.
     static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..9f0e01f1d8ca00 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
 /// Return a vector of OpFoldResults with the same size a staticValues, but
 /// all elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+                                         ValueRange dynamicValues,
+                                         MLIRContext *context);
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
                                          ValueRange dynamicValues, Builder &b);
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 792e7229183064..416aac7d64aad5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebc458170337d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
                    builder, loc, src, dstStaticShape, reassocation);
 }
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+          ReifyCollapseShapeOp, CollapseShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
+    auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
     reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
         reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
 
 namespace {
 
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+    auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> resultShapes =
+        expandShapeOp.getMixedOutputShape();
+    reifyResultShapes.emplace_back(std::move(resultShapes));
+    return success();
+  }
+};
+
 struct ReifyPadOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
                                                              PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..f1f33bd940f7d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
   return *outputShape;
 }
 
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
                           Type resultType, Value src,
                           ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..f1166269f0a400 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
 /// elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
-                                         ValueRange dynamicValues, Builder &b) {
+                                         ValueRange dynamicValues,
+                                         MLIRContext *context) {
   SmallVector<OpFoldResult> res;
   res.reserve(staticValues.size());
   unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
     int64_t value = staticValues[idx];
     res.push_back(ShapedType::isDynamic(value)
                       ? OpFoldResult{dynamicValues[numDynamic++]}
-                      : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+                      : OpFoldResult{IntegerAttr::get(
+                            IntegerType::get(context, 64), staticValues[idx])});
   }
   return res;
 }
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+                                         ValueRange dynamicValues, Builder &b) {
+  return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..c7f5fcb1d21fc8 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
     assert(shapedType.getRank() ==
                static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
            "incorrect implementation of ReifyRankedShapedTypeOpInterface");
-    for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
-      // reifyResultShapes must return:
-      // * Attribute for static dimensions
-      // * Value for dynamic dimensions
-      assert(shapedType.isDynamicDim(dim) ==
-                 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
-             "incorrect implementation of ReifyRankedShapedTypeOpInterface");
-    }
     ++resultIdx;
   }
   // Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
   %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
   return %1, %2, %3 : index, index, index
 }
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @dim_reshape_expansion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME:   %[[ARG1:.+]]: index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[ARG1]]
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
 func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME:     %[[ARG0:.+]]: index
-// CHECK:        %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT:   %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[ARG1]])
 // CHECK-NEXT:   return %[[INIT]]
 
 func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2024

@llvm/pr-subscribers-mlir-tensor

Author: None (MaheshRavishankar)

Changes

The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a SmallVector&lt;OpFoldResult&gt;.


Full diff: https://github.com/llvm/llvm-project/pull/113501.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+19-8)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+8-2)
  • (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (-8)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-6)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8203b9c0fab437 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
 
+    // Return output shape as mixes static/dynamic shapes.
+    SmallVector<OpFoldResult> getMixedOutputShape();
+
     // Infer the output shape for a tensor.expand_shape when it is possible
     // to do so.
     static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..9f0e01f1d8ca00 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
 /// Return a vector of OpFoldResults with the same size a staticValues, but
 /// all elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+                                         ValueRange dynamicValues,
+                                         MLIRContext *context);
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
                                          ValueRange dynamicValues, Builder &b);
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 792e7229183064..416aac7d64aad5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebc458170337d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
                    builder, loc, src, dstStaticShape, reassocation);
 }
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+          ReifyCollapseShapeOp, CollapseShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
+    auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
     reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
         reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
 
 namespace {
 
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+    auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> resultShapes =
+        expandShapeOp.getMixedOutputShape();
+    reifyResultShapes.emplace_back(std::move(resultShapes));
+    return success();
+  }
+};
+
 struct ReifyPadOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
                                                              PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..f1f33bd940f7d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
   return *outputShape;
 }
 
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
                           Type resultType, Value src,
                           ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..f1166269f0a400 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
 /// elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
-                                         ValueRange dynamicValues, Builder &b) {
+                                         ValueRange dynamicValues,
+                                         MLIRContext *context) {
   SmallVector<OpFoldResult> res;
   res.reserve(staticValues.size());
   unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
     int64_t value = staticValues[idx];
     res.push_back(ShapedType::isDynamic(value)
                       ? OpFoldResult{dynamicValues[numDynamic++]}
-                      : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+                      : OpFoldResult{IntegerAttr::get(
+                            IntegerType::get(context, 64), staticValues[idx])});
   }
   return res;
 }
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+                                         ValueRange dynamicValues, Builder &b) {
+  return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..c7f5fcb1d21fc8 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
     assert(shapedType.getRank() ==
                static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
            "incorrect implementation of ReifyRankedShapedTypeOpInterface");
-    for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
-      // reifyResultShapes must return:
-      // * Attribute for static dimensions
-      // * Value for dynamic dimensions
-      assert(shapedType.isDynamicDim(dim) ==
-                 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
-             "incorrect implementation of ReifyRankedShapedTypeOpInterface");
-    }
     ++resultIdx;
   }
   // Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
   %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
   return %1, %2, %3 : index, index, index
 }
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @dim_reshape_expansion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME:   %[[ARG1:.+]]: index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[ARG1]]
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
 func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME:     %[[ARG0:.+]]: index
-// CHECK:        %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT:   %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[ARG1]])
 // CHECK-NEXT:   return %[[INIT]]
 
 func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 3, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 6, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 7, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 7, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 7, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 7, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 13, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 13, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 20, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 20, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 25, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 25, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 25, 2025
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jan 25, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
… for `tensor.expand_shape` op.

The op carries the output-shape directly. This can be used
directly. Also adds a method to get the shape as a
`SmallVector<OpFoldResult>`.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar merged commit 092372d into llvm:main Jan 27, 2025
8 checks passed
pashu123 pushed a commit to pashu123/iree that referenced this pull request Jan 27, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants