Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Collapse dynamic dims #19654

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -888,12 +888,6 @@ collapseOpIterationDims(AttentionOp op,
}))
return failure();

FailureOr<SmallVector<int64_t>> staticLoops = op.getStaticLoopRanges();
if (failed(staticLoops) ||
llvm::any_of(staticLoops.value(), ShapedType::isDynamic)) {
return failure();
}

CollapsingInfo collapsingInfo;
if (failed(
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -177,13 +178,6 @@ static bool isEligibleForCollapse(Operation *op) {
return false;
}

// TODO(guray) There is no mechanism to tell the collapsed indexes to
// `tensor.expand_shape`. Once we have this support in MLIR, we can enable
// dynamic tensor shapes.
if (genericOp.hasDynamicShape()) {
return false;
}

// TODO(guray) Currently we can only collapse when result of all the
// AffineMaps are dimensions. Possible to collapse cases like
// affine_map<d0, d1+d2> with affine_map<d0, d1+d2>, however, this is not
Expand Down Expand Up @@ -684,12 +678,15 @@ hoistTensorReshapesOutOfDispatchRegion(
SmallVector<SmallVector<ReassociationIndices>> allReassociationIndices;
ValueRange dynamicDimsList = dispatchOp.getResultDims();
Location loc = dispatchOp.getLoc();
for (Value yieldedValue : returnOp->getOperands()) {
for (auto [resultIndex, yieldedValue] :
llvm::enumerate(returnOp->getOperands())) {
auto expandShapeOp = yieldedValue.getDefiningOp<tensor::ExpandShapeOp>();
if (!expandShapeOp) {
// 4a. Keep the same yield value if the producer is not a
// `tensor.expand_shape` op.
newReturnTypes.push_back(yieldedValue.getType());
ValueRange resultDims = dispatchOp.getResultDynamicDims(resultIndex);
newDynamicDims.append(resultDims.begin(), resultDims.end());
newYieldVals.push_back(yieldedValue);
continue;
}
Expand Down Expand Up @@ -774,9 +771,17 @@ hoistTensorReshapesOutOfDispatchRegion(
rewriter.replaceAllUsesWith(origResult, returnValue);
continue;
}

auto shapedType = dyn_cast<ShapedType>(origResult.getType());
assert(shapedType && "result should be shaped type");

ValueRange dynamicDims = dispatchOp.getResultDynamicDims(index);
SmallVector<OpFoldResult> outputShape =
mlir::getMixedValues(shapedType.getShape(), dynamicDims, rewriter);

auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, origResult.getType(), returnValue,
allReassociationIndicesRef.front());
allReassociationIndicesRef.front(), outputShape);
allReassociationIndicesRef = allReassociationIndicesRef.drop_front();
rewriter.replaceAllUsesWith(origResult, newExpandShapeOp.getResult());
}
Expand Down Expand Up @@ -1048,17 +1053,25 @@ void CollapseDimensionsPass::runOnOperation() {
memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps);
tensor::populateFoldTensorEmptyPatterns(moveReshapeOps);
SmallVector<Operation *> candidateOps;
block.walk([&](Operation *op) {
if (isa<tensor::CollapseShapeOp>(op)) {
candidateOps.push_back(op);
}
});
block.walk([&](Operation *op) { candidateOps.push_back(op); });
if (failed(
applyOpPatternsGreedily(candidateOps, std::move(moveReshapeOps)))) {
funcOp.emitOpError(
"failed to propagate reshape ops introduced during collapse");
return signalPassFailure();
}

// Expand affine.apply ops from dynamic dims
newDispatchOp->walk([&](affine::AffineApplyOp op) {
rewriter.setInsertionPoint(op);
auto maybeExpanded = mlir::affine::expandAffineMap(
rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<4>(op.getOperands()));
if (!maybeExpanded) {
return;
}
rewriter.replaceOp(op, *maybeExpanded);
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,37 @@ util.func public @uncollapsable_consumer_partial(%arg0: tensor<10x20x30x2304xf32
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[RES]]
// CHECK: flow.return %[[RES]]

// -----

util.func @elementwise_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32>{
%cst_0 = arith.constant 0 : index
%cst_1 = arith.constant 1 : index
%0 = tensor.dim %arg0, %cst_0 : tensor<?x?xf32>
%1 = tensor.dim %arg0, %cst_1 : tensor<?x?xf32>
%3 = flow.dispatch.region -> (tensor<?x?xf32>{%0, %1}) {
%5 = tensor.empty(%0, %1) : tensor<?x?xf32>
%cst = arith.constant 1.000000e+02 : f32
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%7 = arith.addf %in, %cst : f32
linalg.yield %7 : f32
} -> tensor<?x?xf32>
flow.return %6 : tensor<?x?xf32>
}
util.return %3 : tensor<?x?xf32>
}
// CHECK-LABEL: util.func public @elementwise_dynamic
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]]
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]]
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[VAL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK: flow.return %[[VAL]] : tensor<?xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]]
// CHECK-SAME: {{.+}} output_shape [%[[DIM0]], %[[DIM1]]]
// CHECK: util.return %[[EXPAND]] : tensor<?x?xf32>
Loading