Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Feb 21, 2025
1 parent 1a4fdd9 commit 454adc6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
36 changes: 22 additions & 14 deletions compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,12 @@ static bool makeConsumerFusableViaInterchange(
return false;
}

Operation *consumer = fusableOperand.getOwner();
auto genericOp = dyn_cast<linalg::GenericOp>(consumer);
if (!genericOp) {
auto consumer = dyn_cast<linalg::GenericOp>(fusableOperand.getOwner());
if (!consumer) {
return false;
}
assert(genericOp.getNumDpsInputs() > 0 &&
"expected consumer to have at least one input");

if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1) {
if (!linalg::isElementwise(consumer) || consumer.getNumResults() != 1) {
return false;
}

Expand All @@ -420,22 +417,32 @@ static bool makeConsumerFusableViaInterchange(
producerIndexingMap = getProjectedMap(
producerIndexingMap, getUnusedDimsBitVector(producerIndexingMap));
AffineMap consumerIndexingMap =
genericOp.getMatchingIndexingMap(&fusableOperand);
consumer.getMatchingIndexingMap(&fusableOperand);

// Since the iteration space of the consumer is going to be permuted
// to make it match with the indexing map in the producer, the interchange
// requires the indexing map in the consumer to be a permutation.
// If the producer indexing map and consumer indexing map are the same,
// then the permutation of iteration space becomes a no-op, in which
// case the permutation wasnt required for fusion. Return false here
// to indicate that the permutation is not going to "enhance" the
// fusion opportunities.
if (!consumerIndexingMap.isPermutation() ||
producerIndexingMap == consumerIndexingMap) {
return false;
}
OpResult result = cast<OpResult>(genericOp.getResult(0));
if (!genericOp.getIndexingMapMatchingResult(result).isPermutation()) {
OpResult result = cast<OpResult>(consumer.getResult(0));
if (!consumer.getIndexingMapMatchingResult(result).isPermutation()) {
return false;
}

// For now this is restricting that all indexing maps corresponding to the
// input are same as the indexing map of the fused operand, or are projected
// permutations.
// permutations. This avoids ping-ponging between different iteration space
// permutations without having any way to pick which is better.
if (!llvm::all_of(
genericOp.getDpsInputOperands(), [&](OpOperand *inputOperand) {
AffineMap map = genericOp.getMatchingIndexingMap(inputOperand);
consumer.getDpsInputOperands(), [&](OpOperand *inputOperand) {
AffineMap map = consumer.getMatchingIndexingMap(inputOperand);
return map == consumerIndexingMap ||
(map.isProjectedPermutation() && !map.isPermutation());
})) {
Expand All @@ -453,10 +460,10 @@ static bool makeConsumerFusableViaInterchange(
});
IRRewriter rewriter(consumer->getContext());
FailureOr<linalg::GenericOp> interchangedOp =
linalg::interchangeGenericOp(rewriter, genericOp, perm);
linalg::interchangeGenericOp(rewriter, consumer, perm);
(void)interchangedOp;
assert(succeeded(interchangedOp) && "expected interchange to succeed");
assert(interchangedOp.value() == genericOp &&
assert(interchangedOp.value() == consumer &&
"expected interchange to happen in place");
return true;
}
Expand Down Expand Up @@ -596,6 +603,7 @@ isFusableWithConsumer(OpOperand &fusedOperand,
if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) {
// Check if interchange in the consumer makes it fusable.
// Currently limit it to horizontally fused gemms.
// TODO(#20019) to remove this restriction.
if (!IREE::LinalgExt::isaHorizontallyFusedContraction(producer) ||
!makeConsumerFusableViaInterchange(fusedOperand,
rootOuterParallelLoops)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1284,14 +1284,16 @@ util.func @horizontal_fusion3(%lhs : tensor<2x4096x640xf16>,
} -> tensor<2x10x64x4096xf16>
util.return %8, %9, %10 : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>
}
// CHECK-LABEL: func public @horizontal_fusion3
// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region
// CHECK: %[[GENERIC:.+]]:3 = linalg.generic
// CHECK: %[[TRUNC0:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#0 :
// CHECK: %[[TRUNC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#1 :
// CHECK: %[[TRUNC2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#2 :
// CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]]
// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2
// CHECK: #[[INTERCHANGED_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d2)>
// CHECK: func public @horizontal_fusion3
// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region
// CHECK: %[[GENERIC:.+]]:3 = linalg.generic
// CHECK: %[[TRUNC0:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#0 :
// CHECK: %[[TRUNC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#1 :
// CHECK: %[[TRUNC2:.+]] = linalg.generic
// CHECK-SANE: indexing_maps = [#[[INTERCHANGED_MAP]], #[[INTERCHANGED_MAP]]]
// CHECK-SAME: ins(%[[GENERIC]]#2 :
// CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]]
// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2

0 comments on commit 454adc6

Please sign in to comment.