From e10d18df69ac1b9f2c2bae44fb2400332d2478ee Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Thu, 10 Oct 2024 23:39:22 -0700 Subject: [PATCH] [mlir][Tensor] Rework `ReifyRankedShapedTypeInterface` implementation for `tensor.expand_shape` op. The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a `SmallVector`. Signed-off-by: MaheshRavishankar --- .../mlir/Dialect/Tensor/IR/TensorOps.td | 3 + .../mlir/Dialect/Utils/StaticValueUtils.h | 3 + .../IR/TensorInferTypeOpInterfaceImpl.cpp | 115 +++--------------- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 + mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 10 +- mlir/lib/Interfaces/InferTypeOpInterface.cpp | 8 -- .../resolve-shaped-type-result-dims.mlir | 7 +- mlir/test/Dialect/Tensor/fold-empty-op.mlir | 9 +- 8 files changed, 43 insertions(+), 116 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 8ad1b23cb2bfe..3ef7c74fd3af1 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1165,6 +1165,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { let extraClassDeclaration = commonExtraClassDeclaration # [{ int64_t getCorrespondingSourceDim(int64_t resultDim); + // Return output shape as mixes static/dynamic shapes. + SmallVector getMixedOutputShape(); + // Infer the output shape for a tensor.expand_shape when it is possible // to do so. static FailureOr> inferOutputShape( diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index d1f7ab1156248..2a3a2defb810d 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -144,6 +144,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, /// Return a vector of OpFoldResults with the same size a staticValues, but /// all elements for which ShapedType::isDynamic is true, will be replaced by /// dynamicValues. +SmallVector getMixedValues(ArrayRef staticValues, + ValueRange dynamicValues, + MLIRContext *context); SmallVector getMixedValues(ArrayRef staticValues, ValueRange dynamicValues, Builder &b); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 7ff435a033985..f6fea08e2e717 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -16,24 +16,6 @@ using namespace mlir; using namespace mlir::tensor; -/// Compute a map that for a given dimension of the expanded type gives the -/// dimension in the collapsed type it maps to. Essentially its the inverse of -/// the `reassocation` maps. -static llvm::DenseMap -getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim; - for (const auto &map : enumerate(reassociation)) { - unsigned startPos = - cast(map.value().getResults().front()).getPosition(); - unsigned endPos = - cast(map.value().getResults().back()).getPosition(); - for (auto dim : llvm::seq_inclusive(startPos, endPos)) { - expandedDimToCollapsedDim[dim] = map.index(); - } - } - return expandedDimToCollapsedDim; -} - /// For reshape op compute the shape at dimension `dimIndex` of the output in /// terms of shape of the `src`, when the reshape op is a collapsing /// operation. It is the product of the shape of the collapsed dimensions of the @@ -76,84 +58,15 @@ static SmallVector getCollapsedOutputShapeFromInputShape( })); } -/// For an expanding reshape op, compute the value for a dimension of the output -/// from the shape of the input. -static OpFoldResult getExpandedOutputDimFromInputShape( - OpBuilder &builder, Location loc, int64_t dimIndex, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation, - llvm::DenseMap &expandedDimToCollapsedDim) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { - // Static dimension: return Attribute. - return builder.getIndexAttr(dstStaticShape[dimIndex]); - } - unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; - unsigned startPos = - cast(reassociation[sourceDimPos].getResults().front()) - .getPosition(); - unsigned endPos = - cast(reassociation[sourceDimPos].getResults().back()) - .getPosition(); - int64_t linearizedStaticDim = 1; - for (auto d : - llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { - if (d.index() + startPos == static_cast(dimIndex)) - continue; - assert(!ShapedType::isDynamic(d.value()) && - "single dimension cannot be expanded into multiple dynamic " - "dimensions"); - linearizedStaticDim *= d.value(); - } - OpFoldResult sourceDim = - builder.create(loc, src, sourceDimPos).getResult(); - - // Dynamic dimension: return Value. - return affine::makeComposedAffineApply( - builder, loc, - AffineMap::get( - 0, 1, - builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), - sourceDim) - ->getResult(0); -} - -/// Given the `src` of an expanding reshape op, the reassociation maps and the -/// result type, compute the shape of the result of the reshape. -static SmallVector getExpandedOutputShapeFromInputShape( - OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim = - getExpandedDimToCollapsedDimMap(reassociation); - return llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { - return getExpandedOutputDimFromInputShape(builder, loc, dim, src, - dstStaticShape, reassociation, - expandedDimToCollapsedDim); - })); -} - -static SmallVector -getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, - ArrayRef reassocation) { - return dstStaticShape.size() > - static_cast( - llvm::cast(src.getType()).getRank()) - ? getExpandedOutputShapeFromInputShape( - builder, loc, src, dstStaticShape, reassocation) - : getCollapsedOutputShapeFromInputShape( - builder, loc, src, dstStaticShape, reassocation); -} - -template -struct ReifyExpandOrCollapseShapeOp +struct ReifyCollapseShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel< - ReifyExpandOrCollapseShapeOp, OpTy> { + ReifyCollapseShapeOp, CollapseShapeOp> { LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { auto loc = op->getLoc(); - auto reshapeOp = cast(op); - reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape( + auto reshapeOp = cast(op); + reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape( b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(), reshapeOp.getReassociationMaps())); return success(); @@ -162,6 +75,20 @@ struct ReifyExpandOrCollapseShapeOp namespace { +struct ReifyExpandShapeOp + : public ReifyRankedShapedTypeOpInterface::ExternalModel { + LogicalResult + reifyResultShapes(Operation *op, OpBuilder &b, + ReifiedRankedShapedTypeDims &reifyResultShapes) const { + auto expandShapeOp = cast(op); + SmallVector resultShapes = + expandShapeOp.getMixedOutputShape(); + reifyResultShapes.emplace_back(std::move(resultShapes)); + return success(); + } +}; + struct ReifyPadOp : public ReifyRankedShapedTypeOpInterface::ExternalModel { @@ -202,10 +129,8 @@ struct ReifyPadOp void mlir::tensor::registerInferTypeOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { - ExpandShapeOp::attachInterface< - ReifyExpandOrCollapseShapeOp>(*ctx); - CollapseShapeOp::attachInterface< - ReifyExpandOrCollapseShapeOp>(*ctx); + ExpandShapeOp::attachInterface(*ctx); + CollapseShapeOp::attachInterface(*ctx); PadOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 24a1d55315319..117908129561f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1732,6 +1732,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, return *outputShape; } +SmallVector ExpandShapeOp::getMixedOutputShape() { + return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext()); +} + void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value src, ArrayRef reassociation, diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 5c8f6ded39ba4..fcb736aa031f3 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -191,7 +191,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, /// elements for which ShapedType::isDynamic is true, will be replaced by /// dynamicValues. SmallVector getMixedValues(ArrayRef staticValues, - ValueRange dynamicValues, Builder &b) { + ValueRange dynamicValues, + MLIRContext *context) { SmallVector res; res.reserve(staticValues.size()); unsigned numDynamic = 0; @@ -200,10 +201,15 @@ SmallVector getMixedValues(ArrayRef staticValues, int64_t value = staticValues[idx]; res.push_back(ShapedType::isDynamic(value) ? OpFoldResult{dynamicValues[numDynamic++]} - : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); + : OpFoldResult{IntegerAttr::get( + IntegerType::get(context, 64), staticValues[idx])}); } return res; } +SmallVector getMixedValues(ArrayRef staticValues, + ValueRange dynamicValues, Builder &b) { + return getMixedValues(staticValues, dynamicValues, b.getContext()); +} /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 3eb401c449980..6b5e103cd36c2 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op, assert(shapedType.getRank() == static_cast(reifiedReturnShapes[resultIdx].size()) && "incorrect implementation of ReifyRankedShapedTypeOpInterface"); - for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { - // reifyResultShapes must return: - // * Attribute for static dimensions - // * Value for dynamic dimensions - assert(shapedType.isDynamicDim(dim) == - isa(reifiedReturnShapes[resultIdx][dim]) && - "incorrect implementation of ReifyRankedShapedTypeOpInterface"); - } ++resultIdx; } // Assert that every shaped value result was reified. diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir index 8fb84248c9613..3bc1f56d816d7 100644 --- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir @@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32> return %1, %2, %3 : index, index, index } -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: func @dim_reshape_expansion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32> -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-SAME: %[[ARG1:.+]]: index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]] -// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] -// CHECK: return %[[C3]], %[[C4]], %[[D1]] +// CHECK: return %[[C3]], %[[C4]], %[[ARG1]] // ----- diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir index 65ceb4ff3e3df..850bbcee34020 100644 --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} { } } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)> func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> { @@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4 return %1 : tensor<2x3x5x4x?x7xf32> } // CHECK-LABEL: func @empty_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32> -// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]]) // CHECK-NEXT: return %[[INIT]] func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {