Skip to content

Commit

Permalink
[mlir][Tensor] Rework ReifyRankedShapedTypeInterface implementation…
Browse files Browse the repository at this point in the history
… for `tensor.expand_shape` op.

The op carries the output-shape directly. This can be used
directly. Also adds a method to get the shape as a
`SmallVector<OpFoldResult>`.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Jan 25, 2025
1 parent 3b35b4c commit e10d18d
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 116 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);

// Return output shape as mixes static/dynamic shapes.
SmallVector<OpFoldResult> getMixedOutputShape();

// Infer the output shape for a tensor.expand_shape when it is possible
// to do so.
static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// Return a vector of OpFoldResults with the same size a staticValues, but
/// all elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues,
MLIRContext *context);
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b);

Expand Down
115 changes: 20 additions & 95 deletions mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@
using namespace mlir;
using namespace mlir::tensor;

/// Compute a map that for a given dimension of the expanded type gives the
/// dimension in the collapsed type it maps to. Essentially its the inverse of
/// the `reassocation` maps.
static llvm::DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
for (const auto &map : enumerate(reassociation)) {
unsigned startPos =
cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
unsigned endPos =
cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
expandedDimToCollapsedDim[dim] = map.index();
}
}
return expandedDimToCollapsedDim;
}

/// For reshape op compute the shape at dimension `dimIndex` of the output in
/// terms of shape of the `src`, when the reshape op is a collapsing
/// operation. It is the product of the shape of the collapsed dimensions of the
Expand Down Expand Up @@ -76,84 +58,15 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
}));
}

/// For an expanding reshape op, compute the value for a dimension of the output
/// from the shape of the input.
static OpFoldResult getExpandedOutputDimFromInputShape(
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
// Static dimension: return Attribute.
return builder.getIndexAttr(dstStaticShape[dimIndex]);
}
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
unsigned startPos =
cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
.getPosition();
unsigned endPos =
cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
.getPosition();
int64_t linearizedStaticDim = 1;
for (auto d :
llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
if (d.index() + startPos == static_cast<unsigned>(dimIndex))
continue;
assert(!ShapedType::isDynamic(d.value()) &&
"single dimension cannot be expanded into multiple dynamic "
"dimensions");
linearizedStaticDim *= d.value();
}
OpFoldResult sourceDim =
builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();

// Dynamic dimension: return Value.
return affine::makeComposedAffineApply(
builder, loc,
AffineMap::get(
0, 1,
builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
sourceDim)
->getResult(0);
}

/// Given the `src` of an expanding reshape op, the reassociation maps and the
/// result type, compute the shape of the result of the reshape.
static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
getExpandedDimToCollapsedDimMap(reassociation);
return llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
dstStaticShape, reassociation,
expandedDimToCollapsedDim);
}));
}

static SmallVector<OpFoldResult, 4>
getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape,
ArrayRef<AffineMap> reassocation) {
return dstStaticShape.size() >
static_cast<size_t>(
llvm::cast<ShapedType>(src.getType()).getRank())
? getExpandedOutputShapeFromInputShape(
builder, loc, src, dstStaticShape, reassocation)
: getCollapsedOutputShapeFromInputShape(
builder, loc, src, dstStaticShape, reassocation);
}

template <typename OpTy>
struct ReifyExpandOrCollapseShapeOp
struct ReifyCollapseShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
ReifyCollapseShapeOp, CollapseShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc();
auto reshapeOp = cast<OpTy>(op);
reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()));
return success();
Expand All @@ -162,6 +75,20 @@ struct ReifyExpandOrCollapseShapeOp

namespace {

struct ReifyExpandShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
ExpandShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifyResultShapes) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
SmallVector<OpFoldResult> resultShapes =
expandShapeOp.getMixedOutputShape();
reifyResultShapes.emplace_back(std::move(resultShapes));
return success();
}
};

struct ReifyPadOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
PadOp> {
Expand Down Expand Up @@ -202,10 +129,8 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
ExpandShapeOp::attachInterface<
ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
CollapseShapeOp::attachInterface<
ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
PadOp::attachInterface<ReifyPadOp>(*ctx);
});
}
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
return *outputShape;
}

SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
}

void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
Expand Down
10 changes: 8 additions & 2 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
/// elements for which ShapedType::isDynamic is true, will be replaced by
/// dynamicValues.
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b) {
ValueRange dynamicValues,
MLIRContext *context) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
Expand All @@ -200,10 +201,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
: OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
: OpFoldResult{IntegerAttr::get(
IntegerType::get(context, 64), staticValues[idx])});
}
return res;
}
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b) {
return getMixedValues(staticValues, dynamicValues, b.getContext());
}

/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
Expand Down
8 changes: 0 additions & 8 deletions mlir/lib/Interfaces/InferTypeOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
assert(shapedType.getRank() ==
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
// reifyResultShapes must return:
// * Attribute for static dimensions
// * Value for dynamic dimensions
assert(shapedType.isDynamicDim(dim) ==
isa<Value>(reifiedReturnShapes[resultIdx][dim]) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
}
++resultIdx;
}
// Assert that every shaped value result was reified.
Expand Down
7 changes: 2 additions & 5 deletions mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-SAME: %[[ARG1:.+]]: index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]

// -----

Expand Down
9 changes: 3 additions & 6 deletions mlir/test/Dialect/Tensor/fold-empty-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
}
}

// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>

func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
Expand All @@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
// CHECK-NEXT: return %[[INIT]]

func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
Expand Down

0 comments on commit e10d18d

Please sign in to comment.