[mlir][linalg] Add pattern to push reshape after elementwise operation

This help expose more fusion opportunities.

Differential Revision: https://reviews.llvm.org/D100685
This commit is contained in:
thomasraoux 2021-04-16 13:38:15 -07:00
parent f6d8cf7798
commit d40a19c3a8
5 changed files with 286 additions and 0 deletions

View File

@ -106,6 +106,10 @@ void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
/// Patterns to push reshape op towards the end of the graph in order to expose
/// more fusion opportunities.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify

View File

@ -998,6 +998,161 @@ struct FoldProducerReshapeOpByLinearization
}
};
static SmallVector<ReassociationIndices>
getReassociationIndices(ArrayRef<AffineMap> maps) {
SmallVector<ReassociationIndices> reassociation;
for (AffineMap map : maps) {
ReassociationIndices indices;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
indices.push_back(pos);
}
reassociation.push_back(indices);
}
return reassociation;
}
/// Pattern to move rank reducing reshape after an elementwise linalg generic
/// op. This is useful to expose more fusion opportunities between named ops and
/// generic op. This can only be done if there is no broadcast or permuation
/// within the dimensions we need to merge.
///
/// For example,
///
/// %0 = linalg.tensor_reshape %A [
/// affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
/// %2 = linalg.generic {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
/// ["parallel", "parallel", "parallel"]} {
/// } -> tensor<112x112x16xf32>
///
/// into
///
/// %2 = linalg.generic {indexing_maps = [
/// affine_map<(d0, d1) -> (d0, d1)>,
/// affine_map<(d0, d1) -> (d1)>,
/// affine_map<(d0, d1) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
/// } -> tensor<12544x16xf32>
/// %3 = linalg.tensor_reshape %2 [
/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
template <typename GenericOpTy>
struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
// Only apply to elementwise linalg on tensor.
if (!op.hasTensorSemantics() ||
op.getNumParallelLoops() != op.getNumLoops())
return failure();
// Only support identity output maps. It could be extended to permuations if
// needed.
if (llvm::any_of(op.getOutputIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
int64_t destRank = op.getNumParallelLoops();
SmallVector<Value, 4> newOperands = llvm::to_vector<4>(op.getInputs());
TensorReshapeOp reshapeFound;
// 1. Look for tensor_reshape operands and figure out save the dimensions
// merged.
for (auto operand : llvm::enumerate(op.getInputs())) {
TensorReshapeOp reshapeOp =
operand.value().template getDefiningOp<TensorReshapeOp>();
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
reshapeOp.getResultType().getRank()) {
continue;
}
// TODO: We could support non-identity map as long as the merged
// dimensions are still contiguous.
if (!op.getIndexingMaps()[operand.index()].isIdentity())
continue;
if (reshapeFound) {
// Only support a second reshape op if it has the same reassociate maps.
if (reshapeFound.getReassociationMaps() ==
reshapeOp.getReassociationMaps())
newOperands[operand.index()] = reshapeOp.src();
continue;
}
reshapeFound = reshapeOp;
newOperands[operand.index()] = reshapeOp.src();
}
if (!reshapeFound)
return failure();
// Calculate the reassociation indices and rassociated reverse map.
SmallVector<ReassociationIndices> reassociation =
getReassociationIndices(reshapeFound.getReassociationMaps());
SmallVector<unsigned, 4> remap(destRank);
for (auto &indices : llvm::enumerate(reassociation)) {
for (int64_t index : indices.value()) {
remap[index] = indices.index();
}
}
// 2. Verify that we can merge the dimensions in the linalg and that we
// don't need to create new reshapes operands. Inserting new reshape
// operands would defeat the purpose of the transformation.
for (auto operand : llvm::enumerate(op.getInputs())) {
if (operand.value() == newOperands[operand.index()]) {
AffineMap map = op.getIndexingMaps()[operand.index()];
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
return failure();
}
}
}
// 3. Calculate the affine map remapping and the reassociation to apply to
// output tensors.
SmallVector<AffineMap, 4> newMaps;
unsigned newRank = reassociation.size();
for (auto map : op.getIndexingMaps()) {
SmallVector<AffineExpr> newExprs;
for (auto expr : map.getResults()) {
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
// Skip dimension merged except for the last of the group.
if (reassociation[remap[position]].back() == position) {
newExprs.push_back(
getAffineDimExpr(remap[position], op.getContext()));
}
}
newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext()));
}
// 4. Reshape the output tensors.
SmallVector<Value> newOutputs;
SmallVector<Type> newOutputTypes;
for (auto output : op.outputs()) {
Value newOutput = rewriter.create<TensorReshapeOp>(
op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
newOutputTypes.push_back(newOutput.getType());
newOutputs.push_back(newOutput);
}
// 5. Create a new generic op with lowerer rank.
SmallVector<StringRef, 4> iteratorTypes(newRank,
getParallelIteratorTypeName());
auto newOp =
rewriter.create<GenericOpTy>(op->getLoc(), newOutputTypes, newOperands,
newOutputs, newMaps, iteratorTypes);
rewriter.inlineRegionBefore(op.region(), newOp.region(),
newOp.region().begin());
// 6. Reshape the so that the type matches the uses.
SmallVector<Value> newResults;
for (auto result : llvm::enumerate(newOp->getResults())) {
newResults.push_back(rewriter.create<TensorReshapeOp>(
op->getLoc(), op.getOutputTensorTypes()[result.index()],
result.value(), reassociation));
}
rewriter.replaceOp(op, newResults);
return success();
}
};
/// Pattern to fuse a tensor_reshape op with its consumer
/// generic/indexed_generic op, when the reshape op is collapsing
/// dimensions. The dimensionality of the loop in the consumer is expanded.
@ -1333,6 +1488,12 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
}
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<PushExpandingReshape<GenericOp>,
PushExpandingReshape<IndexedGenericOp>>(context);
}
std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
return std::make_unique<FusionOfTensorOpsPass>();
}

View File

@ -0,0 +1,98 @@
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-LABEL: func @reshape
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x112x16xf32> into tensor<?x16xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x16xf32> into tensor<?x112x16xf32>
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
%0 = linalg.tensor_reshape %A [
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
: tensor<?x16xf32> into tensor<?x112x16xf32>
%2 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>)
outs(%init : tensor<?x112x16xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
%s = subf %arg1, %arg2 : f32
linalg.yield %s : f32
} -> tensor<?x112x16xf32>
return %2 : tensor<?x112x16xf32>
}
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-LABEL: func @reshape_multiple
// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
// CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] [#[[$MAP0]], #[[$MAP1]]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: return %[[RR]] : tensor<112x112x16xf32>
func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
%C: tensor<16xf32>) -> tensor<112x112x16xf32> {
%0 = linalg.tensor_reshape %A [
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%1 = linalg.tensor_reshape %B [
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
%3 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>)
outs(%2 : tensor<112x112x16xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%s = subf %arg1, %arg2 : f32
%m = mulf %s, %arg3 : f32
linalg.yield %m : f32
} -> tensor<112x112x16xf32>
return %3 : tensor<112x112x16xf32>
}
// -----
// Negative test, since the second source is broadcasted from d1 we cannot merge
// d0 and d1 dimensions
// CHECK-LABEL: func @reshape_negative
// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: linalg.generic
// CHECK: } -> tensor<112x112x16xf32>
func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
%20 = linalg.tensor_reshape %A [
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
%22 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>)
outs(%21 : tensor<112x112x16xf32>) {
^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
%s = subf %arg1, %arg2 : f32
linalg.yield %s : f32
} -> tensor<112x112x16xf32>
return %22 : tensor<112x112x16xf32>
}

View File

@ -66,6 +66,22 @@ struct TestLinalgElementwiseFusion
std::move(fusionPatterns));
}
};
struct TestPushExpandingReshape
: public PassWrapper<TestPushExpandingReshape, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
} // namespace
namespace test {
@ -74,6 +90,11 @@ void registerTestLinalgElementwiseFusion() {
"test-linalg-elementwise-fusion-patterns",
"Test Linalg element wise operation fusion patterns");
}
void registerTestPushExpandingReshape() {
PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
"test-linalg-push-reshape", "Test Linalg reshape push patterns");
}
} // namespace test
} // namespace mlir

View File

@ -78,6 +78,7 @@ void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
void registerTestLinalgFusionTransforms();
void registerTestLinalgTensorFusionTransforms();
void registerTestLinalgGreedyFusion();
@ -156,6 +157,7 @@ void registerTestPasses() {
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
test::registerTestLinalgElementwiseFusion();
test::registerTestPushExpandingReshape();
test::registerTestLinalgFusionTransforms();
test::registerTestLinalgTensorFusionTransforms();
test::registerTestLinalgGreedyFusion();