Skip to content

Commit

Permalink
Changes to dispatch formation to allow for fusion with trunc in case …
Browse files Browse the repository at this point in the history
…of gemm + transpose gemm fusion.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Feb 21, 2025
1 parent d9a3a7c commit 1a4fdd9
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 3 deletions.
3 changes: 1 addition & 2 deletions compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,10 @@ isContractionOpSequence(Value yielded) {
/// TODO: The logic below is quite convoluted. Might be better
/// off having a dedicated operation for this.
bool isaHorizontallyFusedContraction(Operation *op) {
auto linalgOp = dyn_cast_or_null<linalg::GenericOp>(op);
auto linalgOp = dyn_cast_or_null<linalg::LinalgOp>(op);
if (!linalgOp) {
return false;
}

if (linalgOp->getNumResults() == 1) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand Down Expand Up @@ -384,6 +385,82 @@ static bool areOpsFusable(Operation *producer, Operation *consumer,
return true;
}

/// The logic to decide fusability (using the `hasCompatibleOuterParallelLoops`)
/// currently works when the indexing map corresponding to result of the
/// producer and indexing map corresponding to operand in the result are not
/// transposed with respect to each other. To find more fusion opportunities for
/// consumer elementwise operation, the indexing maps in the consumer can be
/// made to "align" with the indexing map of the producer to enhance fusion.
static bool makeConsumerFusableViaInterchange(
OpOperand &fusableOperand,
const llvm::SmallBitVector &rootOuterParallelLoops) {
auto producer =
fusableOperand.get()
.getDefiningOp<IREE::LinalgExt::LinalgFusionOpInterface>();
if (!producer) {
return false;
}

Operation *consumer = fusableOperand.getOwner();
auto genericOp = dyn_cast<linalg::GenericOp>(consumer);
if (!genericOp) {
return false;
}
assert(genericOp.getNumDpsInputs() > 0 &&
"expected consumer to have at least one input");

if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1) {
return false;
}

// If the indexing map in the consumer is already "compatible" with the
// indexing map in the producer, do nothing.
AffineMap producerIndexingMap = producer.getIndexingMapMatchingResult(
cast<OpResult>(fusableOperand.get()));
producerIndexingMap = getProjectedMap(
producerIndexingMap, getUnusedDimsBitVector(producerIndexingMap));
AffineMap consumerIndexingMap =
genericOp.getMatchingIndexingMap(&fusableOperand);
if (!consumerIndexingMap.isPermutation() ||
producerIndexingMap == consumerIndexingMap) {
return false;
}
OpResult result = cast<OpResult>(genericOp.getResult(0));
if (!genericOp.getIndexingMapMatchingResult(result).isPermutation()) {
return false;
}

// For now this is restricting that all indexing maps corresponding to the
// input are same as the indexing map of the fused operand, or are projected
// permutations.
if (!llvm::all_of(
genericOp.getDpsInputOperands(), [&](OpOperand *inputOperand) {
AffineMap map = genericOp.getMatchingIndexingMap(inputOperand);
return map == consumerIndexingMap ||
(map.isProjectedPermutation() && !map.isPermutation());
})) {
return false;
}

// Make the input map match the producer map by applying a permutation map
// computed with consumerIndexingMap.compose(inv(producerIndexingMap))
AffineMap invProducerIndexingMap = inversePermutation(producerIndexingMap);
AffineMap permutationMap =
consumerIndexingMap.compose(invProducerIndexingMap);
auto perm = llvm::map_to_vector(permutationMap.getResults(),
[](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});
IRRewriter rewriter(consumer->getContext());
FailureOr<linalg::GenericOp> interchangedOp =
linalg::interchangeGenericOp(rewriter, genericOp, perm);
(void)interchangedOp;
assert(succeeded(interchangedOp) && "expected interchange to succeed");
assert(interchangedOp.value() == genericOp &&
"expected interchange to happen in place");
return true;
}

/// For the fusion of root op -> elementwise operation to be bufferized
/// in-place without use of extra memory, the result of the root operation
/// must be able to reuse the buffer for the result of the elementwise
Expand Down Expand Up @@ -517,7 +594,13 @@ isFusableWithConsumer(OpOperand &fusedOperand,
}

if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) {
return false;
// Check if interchange in the consumer makes it fusable.
// Currently limit it to horizontally fused gemms.
if (!IREE::LinalgExt::isaHorizontallyFusedContraction(producer) ||
!makeConsumerFusableViaInterchange(fusedOperand,
rootOuterParallelLoops)) {
return false;
}
}

// Check if the iteration spaces of the producer and consumer are same.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1211,3 +1211,87 @@ util.func @avoid_use_def_violation_on_consumer_fusion(%arg0 : tensor<?xf32>,
// CHECK-SAME: ins(%[[DISPATCH1]], %[[BARRIER]] :
// CHECK: flow.return %[[GENERIC2]]
// CHECK: util.return %[[DISPATCH2]]

// -----

util.func @horizontal_fusion3(%lhs : tensor<2x4096x640xf16>,
%rhs0 : tensor<10x64x640xf16>, %rhs1 : tensor<10x64x640xf16>,
%rhs2 : tensor<10x64x640xf16>) ->
(tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>,
tensor<2x10x64x4096xf16>) {
%0 = tensor.empty() : tensor<2x10x64x4096xf32>
%4 = tensor.empty() : tensor<2x10x4096x64xf32>
%cst = arith.constant 0.0 : f32
%1 = linalg.fill ins(%cst : f32)
outs(%0 : tensor<2x10x64x4096xf32>) -> tensor<2x10x64x4096xf32>
%5 = linalg.fill ins(%cst : f32)
outs(%4 : tensor<2x10x4096x64xf32>) -> tensor<2x10x4096x64xf32>
%6:3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]}
ins(%lhs, %rhs0, %rhs1, %rhs2
: tensor<2x4096x640xf16>, tensor<10x64x640xf16>, tensor<10x64x640xf16>,
tensor<10x64x640xf16>)
outs(%5, %5, %1
: tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>) {
^bb0(%in: f16, %in_0: f16, %in_1: f16, %in_2: f16, %out: f32, %out_3: f32, %out_4: f32):
%14 = arith.extf %in : f16 to f32
%15 = arith.extf %in_0 : f16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.addf %out, %16 : f32
%18 = arith.extf %in_1 : f16 to f32
%19 = arith.mulf %14, %18 : f32
%20 = arith.addf %out_3, %19 : f32
%21 = arith.extf %in_2 : f16 to f32
%22 = arith.mulf %14, %21 : f32
%23 = arith.addf %out_4, %22 : f32
linalg.yield %17, %20, %23 : f32, f32, f32
} -> (tensor<2x10x4096x64xf32>, tensor<2x10x4096x64xf32>, tensor<2x10x64x4096xf32>)
%7 = tensor.empty() : tensor<2x10x4096x64xf16>
%8 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%6#0 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%14 = arith.truncf %in : f32 to f16
linalg.yield %14 : f16
} -> tensor<2x10x4096x64xf16>
%9 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%6#1 : tensor<2x10x4096x64xf32>) outs(%7 : tensor<2x10x4096x64xf16>) {
^bb0(%in: f32, %out: f16):
%14 = arith.truncf %in : f32 to f16
linalg.yield %14 : f16
} -> tensor<2x10x4096x64xf16>
%2 = tensor.empty() : tensor<2x10x64x4096xf16>
%10 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%6#2 : tensor<2x10x64x4096xf32>) outs(%2 : tensor<2x10x64x4096xf16>) {
^bb0(%in: f32, %out: f16):
%14 = arith.truncf %in : f32 to f16
linalg.yield %14 : f16
} -> tensor<2x10x64x4096xf16>
util.return %8, %9, %10 : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x64x4096xf16>
}
// CHECK-LABEL: func public @horizontal_fusion3
// CHECK: %[[DISPATCH:.+]]:3 = flow.dispatch.region
// CHECK: %[[GENERIC:.+]]:3 = linalg.generic
// CHECK: %[[TRUNC0:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#0 :
// CHECK: %[[TRUNC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#1 :
// CHECK: %[[TRUNC2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC]]#2 :
// CHECK: flow.return %[[TRUNC0]], %[[TRUNC1]], %[[TRUNC2]]
// CHECK: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1, %[[DISPATCH]]#2

0 comments on commit 1a4fdd9

Please sign in to comment.