-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tensor] Rework ReifyRankedShapedTypeInterface
implementation for tensor.expand_shape
op.
#113501
[mlir][Tensor] Rework ReifyRankedShapedTypeInterface
implementation for tensor.expand_shape
op.
#113501
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: None (MaheshRavishankar) ChangesThe op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a Full diff: https://github.com/llvm/llvm-project/pull/113501.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8203b9c0fab437 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+ // Return output shape as mixes static/dynamic shapes.
+ SmallVector<OpFoldResult> getMixedOutputShape();
+
// Infer the output shape for a tensor.expand_shape when it is possible
// to do so.
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..9f0e01f1d8ca00 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// Return a vector of OpFoldResults with the same size a staticValues, but
/// all elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues,
+ MLIRContext *context);
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 792e7229183064..416aac7d64aad5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebc458170337d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
builder, loc, src, dstStaticShape, reassocation);
}
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+ ReifyCollapseShapeOp, CollapseShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc();
- auto reshapeOp = cast<OpTy>(op);
+ auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
namespace {
+struct ReifyExpandShapeOp
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp> {
+ LogicalResult
+ reifyResultShapes(Operation *op, OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ SmallVector<OpFoldResult> resultShapes =
+ expandShapeOp.getMixedOutputShape();
+ reifyResultShapes.emplace_back(std::move(resultShapes));
+ return success();
+ }
+};
+
struct ReifyPadOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- ExpandShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
- CollapseShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
PadOp::attachInterface<ReifyPadOp>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..f1f33bd940f7d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
return *outputShape;
}
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+ return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..f1166269f0a400 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
- ValueRange dynamicValues, Builder &b) {
+ ValueRange dynamicValues,
+ MLIRContext *context) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
- : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+ : OpFoldResult{IntegerAttr::get(
+ IntegerType::get(context, 64), staticValues[idx])});
}
return res;
}
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues, Builder &b) {
+ return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..c7f5fcb1d21fc8 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
assert(shapedType.getRank() ==
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
- for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
- // reifyResultShapes must return:
- // * Attribute for static dimensions
- // * Value for dynamic dimensions
- assert(shapedType.isDynamicDim(dim) ==
- reifiedReturnShapes[resultIdx][dim].is<Value>() &&
- "incorrect implementation of ReifyRankedShapedTypeOpInterface");
- }
++resultIdx;
}
// Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME: %[[ARG1:.+]]: index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
// -----
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
// CHECK-NEXT: return %[[INIT]]
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
|
@llvm/pr-subscribers-mlir-tensor Author: None (MaheshRavishankar) ChangesThe op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a Full diff: https://github.com/llvm/llvm-project/pull/113501.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8203b9c0fab437 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+ // Return output shape as mixes static/dynamic shapes.
+ SmallVector<OpFoldResult> getMixedOutputShape();
+
// Infer the output shape for a tensor.expand_shape when it is possible
// to do so.
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..9f0e01f1d8ca00 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// Return a vector of OpFoldResults with the same size a staticValues, but
/// all elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues,
+ MLIRContext *context);
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 792e7229183064..416aac7d64aad5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebc458170337d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
builder, loc, src, dstStaticShape, reassocation);
}
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+ ReifyCollapseShapeOp, CollapseShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc();
- auto reshapeOp = cast<OpTy>(op);
+ auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
namespace {
+struct ReifyExpandShapeOp
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp> {
+ LogicalResult
+ reifyResultShapes(Operation *op, OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ SmallVector<OpFoldResult> resultShapes =
+ expandShapeOp.getMixedOutputShape();
+ reifyResultShapes.emplace_back(std::move(resultShapes));
+ return success();
+ }
+};
+
struct ReifyPadOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- ExpandShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
- CollapseShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
PadOp::attachInterface<ReifyPadOp>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..f1f33bd940f7d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
return *outputShape;
}
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+ return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..f1166269f0a400 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
- ValueRange dynamicValues, Builder &b) {
+ ValueRange dynamicValues,
+ MLIRContext *context) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
- : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+ : OpFoldResult{IntegerAttr::get(
+ IntegerType::get(context, 64), staticValues[idx])});
}
return res;
}
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+ ValueRange dynamicValues, Builder &b) {
+ return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..c7f5fcb1d21fc8 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
assert(shapedType.getRank() ==
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
- for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
- // reifyResultShapes must return:
- // * Attribute for static dimensions
- // * Value for dynamic dimensions
- assert(shapedType.isDynamicDim(dim) ==
- reifiedReturnShapes[resultIdx][dim].is<Value>() &&
- "incorrect implementation of ReifyRankedShapedTypeOpInterface");
- }
++resultIdx;
}
// Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME: %[[ARG1:.+]]: index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
// -----
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
// CHECK-NEXT: return %[[INIT]]
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
… for `tensor.expand_shape` op. The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a `SmallVector<OpFoldResult>`. Signed-off-by: MaheshRavishankar <[email protected]>
b5a2101
to
e10d18d
Compare
Signed-off-by: MaheshRavishankar <[email protected]>
The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a
SmallVector<OpFoldResult>
.