forked from OSchip/llvm-project
[mlir][linalg] Add pattern to push reshape after elementwise operation
This help expose more fusion opportunities. Differential Revision: https://reviews.llvm.org/D100685
This commit is contained in:
parent
f6d8cf7798
commit
d40a19c3a8
|
@ -106,6 +106,10 @@ void populateElementwiseOpsFusionPatterns(
|
|||
RewritePatternSet &patterns,
|
||||
LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
|
||||
|
||||
/// Patterns to push reshape op towards the end of the graph in order to expose
|
||||
/// more fusion opportunities.
|
||||
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `interchangeVector`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
|
|
|
@ -998,6 +998,161 @@ struct FoldProducerReshapeOpByLinearization
|
|||
}
|
||||
};
|
||||
|
||||
static SmallVector<ReassociationIndices>
|
||||
getReassociationIndices(ArrayRef<AffineMap> maps) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
for (AffineMap map : maps) {
|
||||
ReassociationIndices indices;
|
||||
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
|
||||
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
||||
indices.push_back(pos);
|
||||
}
|
||||
reassociation.push_back(indices);
|
||||
}
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
/// Pattern to move rank reducing reshape after an elementwise linalg generic
|
||||
/// op. This is useful to expose more fusion opportunities between named ops and
|
||||
/// generic op. This can only be done if there is no broadcast or permuation
|
||||
/// within the dimensions we need to merge.
|
||||
///
|
||||
/// For example,
|
||||
///
|
||||
/// %0 = linalg.tensor_reshape %A [
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
|
||||
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
/// %2 = linalg.generic {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
|
||||
/// ["parallel", "parallel", "parallel"]} {
|
||||
/// } -> tensor<112x112x16xf32>
|
||||
///
|
||||
/// into
|
||||
///
|
||||
/// %2 = linalg.generic {indexing_maps = [
|
||||
/// affine_map<(d0, d1) -> (d0, d1)>,
|
||||
/// affine_map<(d0, d1) -> (d1)>,
|
||||
/// affine_map<(d0, d1) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
|
||||
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
|
||||
/// } -> tensor<12544x16xf32>
|
||||
/// %3 = linalg.tensor_reshape %2 [
|
||||
/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
|
||||
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
template <typename GenericOpTy>
|
||||
struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
|
||||
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply to elementwise linalg on tensor.
|
||||
if (!op.hasTensorSemantics() ||
|
||||
op.getNumParallelLoops() != op.getNumLoops())
|
||||
return failure();
|
||||
// Only support identity output maps. It could be extended to permuations if
|
||||
// needed.
|
||||
if (llvm::any_of(op.getOutputIndexingMaps(),
|
||||
[](AffineMap map) { return !map.isIdentity(); }))
|
||||
return failure();
|
||||
int64_t destRank = op.getNumParallelLoops();
|
||||
SmallVector<Value, 4> newOperands = llvm::to_vector<4>(op.getInputs());
|
||||
TensorReshapeOp reshapeFound;
|
||||
// 1. Look for tensor_reshape operands and figure out save the dimensions
|
||||
// merged.
|
||||
for (auto operand : llvm::enumerate(op.getInputs())) {
|
||||
TensorReshapeOp reshapeOp =
|
||||
operand.value().template getDefiningOp<TensorReshapeOp>();
|
||||
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
|
||||
reshapeOp.getResultType().getRank()) {
|
||||
continue;
|
||||
}
|
||||
// TODO: We could support non-identity map as long as the merged
|
||||
// dimensions are still contiguous.
|
||||
if (!op.getIndexingMaps()[operand.index()].isIdentity())
|
||||
continue;
|
||||
if (reshapeFound) {
|
||||
// Only support a second reshape op if it has the same reassociate maps.
|
||||
if (reshapeFound.getReassociationMaps() ==
|
||||
reshapeOp.getReassociationMaps())
|
||||
newOperands[operand.index()] = reshapeOp.src();
|
||||
continue;
|
||||
}
|
||||
reshapeFound = reshapeOp;
|
||||
newOperands[operand.index()] = reshapeOp.src();
|
||||
}
|
||||
if (!reshapeFound)
|
||||
return failure();
|
||||
|
||||
// Calculate the reassociation indices and rassociated reverse map.
|
||||
SmallVector<ReassociationIndices> reassociation =
|
||||
getReassociationIndices(reshapeFound.getReassociationMaps());
|
||||
SmallVector<unsigned, 4> remap(destRank);
|
||||
for (auto &indices : llvm::enumerate(reassociation)) {
|
||||
for (int64_t index : indices.value()) {
|
||||
remap[index] = indices.index();
|
||||
}
|
||||
}
|
||||
// 2. Verify that we can merge the dimensions in the linalg and that we
|
||||
// don't need to create new reshapes operands. Inserting new reshape
|
||||
// operands would defeat the purpose of the transformation.
|
||||
for (auto operand : llvm::enumerate(op.getInputs())) {
|
||||
if (operand.value() == newOperands[operand.index()]) {
|
||||
AffineMap map = op.getIndexingMaps()[operand.index()];
|
||||
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
|
||||
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Calculate the affine map remapping and the reassociation to apply to
|
||||
// output tensors.
|
||||
SmallVector<AffineMap, 4> newMaps;
|
||||
unsigned newRank = reassociation.size();
|
||||
for (auto map : op.getIndexingMaps()) {
|
||||
SmallVector<AffineExpr> newExprs;
|
||||
for (auto expr : map.getResults()) {
|
||||
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
|
||||
// Skip dimension merged except for the last of the group.
|
||||
if (reassociation[remap[position]].back() == position) {
|
||||
newExprs.push_back(
|
||||
getAffineDimExpr(remap[position], op.getContext()));
|
||||
}
|
||||
}
|
||||
newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext()));
|
||||
}
|
||||
|
||||
// 4. Reshape the output tensors.
|
||||
SmallVector<Value> newOutputs;
|
||||
SmallVector<Type> newOutputTypes;
|
||||
for (auto output : op.outputs()) {
|
||||
Value newOutput = rewriter.create<TensorReshapeOp>(
|
||||
op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
|
||||
newOutputTypes.push_back(newOutput.getType());
|
||||
newOutputs.push_back(newOutput);
|
||||
}
|
||||
// 5. Create a new generic op with lowerer rank.
|
||||
SmallVector<StringRef, 4> iteratorTypes(newRank,
|
||||
getParallelIteratorTypeName());
|
||||
auto newOp =
|
||||
rewriter.create<GenericOpTy>(op->getLoc(), newOutputTypes, newOperands,
|
||||
newOutputs, newMaps, iteratorTypes);
|
||||
rewriter.inlineRegionBefore(op.region(), newOp.region(),
|
||||
newOp.region().begin());
|
||||
// 6. Reshape the so that the type matches the uses.
|
||||
SmallVector<Value> newResults;
|
||||
for (auto result : llvm::enumerate(newOp->getResults())) {
|
||||
newResults.push_back(rewriter.create<TensorReshapeOp>(
|
||||
op->getLoc(), op.getOutputTensorTypes()[result.index()],
|
||||
result.value(), reassociation));
|
||||
}
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to fuse a tensor_reshape op with its consumer
|
||||
/// generic/indexed_generic op, when the reshape op is collapsing
|
||||
/// dimensions. The dimensionality of the loop in the consumer is expanded.
|
||||
|
@ -1333,6 +1488,12 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
|||
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
|
||||
}
|
||||
|
||||
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
|
||||
auto *context = patterns.getContext();
|
||||
patterns.add<PushExpandingReshape<GenericOp>,
|
||||
PushExpandingReshape<IndexedGenericOp>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
|
||||
return std::make_unique<FusionOfTensorOpsPass>();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
|
||||
// CHECK-LABEL: func @reshape
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
|
||||
// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x112x16xf32> into tensor<?x16xf32>
|
||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
|
||||
// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x16xf32> into tensor<?x112x16xf32>
|
||||
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
|
||||
func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
|
||||
%0 = linalg.tensor_reshape %A [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
|
||||
: tensor<?x16xf32> into tensor<?x112x16xf32>
|
||||
%2 = linalg.generic {indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>)
|
||||
outs(%init : tensor<?x112x16xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
|
||||
%s = subf %arg1, %arg2 : f32
|
||||
linalg.yield %s : f32
|
||||
} -> tensor<?x112x16xf32>
|
||||
return %2 : tensor<?x112x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
|
||||
// CHECK-LABEL: func @reshape_multiple
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
|
||||
// CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
|
||||
// CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] [#[[$MAP0]], #[[$MAP1]]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
|
||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
|
||||
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
|
||||
// CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
// CHECK: return %[[RR]] : tensor<112x112x16xf32>
|
||||
func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
|
||||
%C: tensor<16xf32>) -> tensor<112x112x16xf32> {
|
||||
%0 = linalg.tensor_reshape %A [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
|
||||
: tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
%1 = linalg.tensor_reshape %B [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
|
||||
: tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
%2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
|
||||
%3 = linalg.generic {indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>)
|
||||
outs(%2 : tensor<112x112x16xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
|
||||
%s = subf %arg1, %arg2 : f32
|
||||
%m = mulf %s, %arg3 : f32
|
||||
linalg.yield %m : f32
|
||||
} -> tensor<112x112x16xf32>
|
||||
return %3 : tensor<112x112x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Negative test, since the second source is broadcasted from d1 we cannot merge
|
||||
// d0 and d1 dimensions
|
||||
// CHECK-LABEL: func @reshape_negative
|
||||
// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: } -> tensor<112x112x16xf32>
|
||||
func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
|
||||
%20 = linalg.tensor_reshape %A [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
|
||||
: tensor<12544x16xf32> into tensor<112x112x16xf32>
|
||||
%21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
|
||||
%22 = linalg.generic {indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>)
|
||||
outs(%21 : tensor<112x112x16xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32, %arg3: f32): // no predecessors
|
||||
%s = subf %arg1, %arg2 : f32
|
||||
linalg.yield %s : f32
|
||||
} -> tensor<112x112x16xf32>
|
||||
return %22 : tensor<112x112x16xf32>
|
||||
}
|
|
@ -66,6 +66,22 @@ struct TestLinalgElementwiseFusion
|
|||
std::move(fusionPatterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestPushExpandingReshape
|
||||
: public PassWrapper<TestPushExpandingReshape, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &this->getContext();
|
||||
FuncOp funcOp = this->getFunction();
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populatePushReshapeOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace test {
|
||||
|
@ -74,6 +90,11 @@ void registerTestLinalgElementwiseFusion() {
|
|||
"test-linalg-elementwise-fusion-patterns",
|
||||
"Test Linalg element wise operation fusion patterns");
|
||||
}
|
||||
|
||||
void registerTestPushExpandingReshape() {
|
||||
PassRegistration<TestPushExpandingReshape> testPushExpandingReshapePass(
|
||||
"test-linalg-push-reshape", "Test Linalg reshape push patterns");
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -78,6 +78,7 @@ void registerTestIRVisitorsPass();
|
|||
void registerTestInterfaces();
|
||||
void registerTestLinalgCodegenStrategy();
|
||||
void registerTestLinalgElementwiseFusion();
|
||||
void registerTestPushExpandingReshape();
|
||||
void registerTestLinalgFusionTransforms();
|
||||
void registerTestLinalgTensorFusionTransforms();
|
||||
void registerTestLinalgGreedyFusion();
|
||||
|
@ -156,6 +157,7 @@ void registerTestPasses() {
|
|||
test::registerTestInterfaces();
|
||||
test::registerTestLinalgCodegenStrategy();
|
||||
test::registerTestLinalgElementwiseFusion();
|
||||
test::registerTestPushExpandingReshape();
|
||||
test::registerTestLinalgFusionTransforms();
|
||||
test::registerTestLinalgTensorFusionTransforms();
|
||||
test::registerTestLinalgGreedyFusion();
|
||||
|
|
Loading…
Reference in New Issue