diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp index 619df9a79ece..1b83906dc493 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp @@ -2275,10 +2275,89 @@ struct ElideNoOpAsyncExecuteOp : public OpRewritePattern { } }; +struct DeduplicateYieldCmdExecuteOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AsyncExecuteOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector keepYield; + llvm::SmallVector yieldOperands; + llvm::SmallVector remapping; + + auto yield = + cast(op.getBody().front().getTerminator()); + int64_t oldYieldCount = yield.getResourceOperands().size(); + for (int i = 0, s = oldYieldCount; i < s; ++i) { + auto operand = yield.getResourceOperands()[i]; + + auto find = + std::find(yieldOperands.begin(), yieldOperands.end(), operand); + if (find != yieldOperands.end()) { + keepYield.push_back(false); + remapping.push_back(find - yieldOperands.begin()); + continue; + } + + remapping.push_back(yieldOperands.size()); + keepYield.push_back(true); + yieldOperands.push_back(operand); + } + + if (oldYieldCount == yieldOperands.size()) { + return failure(); + } + + llvm::SmallVector newTypes; + llvm::SmallVector newResultSizes; + + for (int i = 0; i < oldYieldCount; ++i) { + if (!keepYield[i]) + continue; + newTypes.push_back(op.getResults()[i].getType()); + newResultSizes.push_back(op.getResultSizes()[i]); + } + + auto newExecuteOp = rewriter.create( + op.getLoc(), newTypes, newResultSizes, op.getAwaitTimepoint(), + op.getResourceOperands(), op.getResourceOperandSizes(), + llvm::SmallVector()); + + newExecuteOp.setAffinityAttr(op.getAffinityAttr()); + + rewriter.inlineRegionBefore(op.getRegion(), newExecuteOp.getRegion(), + newExecuteOp.getRegion().end()); + + llvm::SmallVector newYieldVals; + llvm::SmallVector newYieldSizes; + yield = cast( + newExecuteOp.getBody().front().getTerminator()); + for (int i = 0; i < oldYieldCount; ++i) { + if (!keepYield[i]) + continue; + newYieldVals.push_back(yield.getResourceOperands()[i]); + newYieldSizes.push_back(yield.getResourceOperandSizes()[i]); + } + + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, newYieldVals, + newYieldSizes); + + llvm::SmallVector replace; + for (auto i : remapping) { + replace.push_back(newExecuteOp.getResult(i)); + } + + replace.push_back(newExecuteOp.getResultTimepoint()); + rewriter.replaceOp(op, replace); + + return success(); + } +}; + } // namespace void AsyncExecuteOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { + results.insert(context); results.insert>(context); results.insert>(context); results.insert(context); diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir index cc7ce41ffc1e..43515da8cfc1 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/IR/test/async_folding.mlir @@ -483,6 +483,23 @@ util.func private @ElideImmediateAsyncExecuteWaits(%arg0: !stream.resource<*>, % util.return %0#0, %0#1 : !stream.resource<*>, !stream.timepoint } +// ----- + +// CHECK-LABEL: @DedeuplicateASyncExecuteReturns +util.func private @DedeuplicateASyncExecuteReturns(%arg0: !stream.resource<*>, %arg1: index) -> (!stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: %[[VAL:.+]], %[[TP:.+]] = stream.async.execute + %0:4 = stream.async.execute with(%arg0 as %arg2: !stream.resource<*>{%arg1}) -> %arg0 as !stream.resource<*>{%arg1}, %arg0 as !stream.resource<*>{%arg1}, %arg0 as !stream.resource<*>{%arg1} { + // CHECK: %[[VAL:.+]]:2 = stream.async.dispatch + %1, %2 = stream.async.dispatch @executable::@dispatch0[%c1, %c1, %c1](%arg2[%c0 to %arg1 for %arg1]) : (!stream.resource<*>{%arg1}) -> !stream.resource<*>{%arg1}, !stream.resource<*>{%arg1} + // CHECK: stream.yield %[[VAL]]#0, %[[VAL]]#1 + stream.yield %1, %2, %1 : !stream.resource<*>{%arg1}, !stream.resource<*>{%arg1}, !stream.resource<*>{%arg1} + } => !stream.timepoint + util.return %0#0, %0#1, %0#2, %0#3 : !stream.resource<*>, !stream.resource<*>, !stream.resource<*>, !stream.timepoint +} + + // ----- // CHECK-LABEL: @ChainAsyncExecuteWaits