[mlir][linalg] Add a new pattern to handle folding unit reduction dims.

The output operands will be added to input operands if the generic op (on tensors)
becomes an elementwise operation. The outputs of the generic op is still the same.
They will be cleaned up by ReplaceWithEmptyTensorIfUnused pattern.

This is https://reviews.llvm.org/D138251, plus a cmake dep fix.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D138843
This commit is contained in:
Hanhan Wang 2022-11-23 10:46:46 -08:00
parent eac90d1236
commit 9b16d9d271
4 changed files with 133 additions and 5 deletions

View File

@ -52,6 +52,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRLinalgDialect
MLIRLinalgAnalysis
MLIRLinalgUtils

View File

@ -19,12 +19,15 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@ -225,6 +228,125 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
}
};
/// Pattern to add init operands to ins when all the loops are parallel and
/// blockArgument corresponding to init is used in the region. This is a fix-up
/// when unit reduction dimensions are all folded away. In this context, it
/// becomes a elementwise generic op. E.g., it converts
///
/// %0 = tensor.empty() : tensor<1x1xf32>
/// %1 = linalg.fill
/// ins(%cst : f32)
/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
/// affine_map<(d0) -> (0, d0)>],
/// iterator_types = ["parallel"]}
/// ins(%arg0 : tensor<1x?x1x1xf32>)
/// outs(%1 : tensor<1x1xf32>) {
/// ^bb0(%in: f32, %out: f32):
/// %3 = arith.addf %in, %out : f32
/// linalg.yield %3 : f32
/// } -> tensor<1x1xf32>
///
/// into
///
/// %0 = tensor.empty() : tensor<1x1xf32>
/// %1 = linalg.fill
/// ins(%cst : f32)
/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
/// %2 = tensor.empty() : tensor<1x1xf32>
/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
/// affine_map<(d0) -> (0, d0)>,
/// affine_map<(d0) -> (0, d0)>],
/// iterator_types = ["parallel"]}
/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
/// outs(%2 : tensor<1x1xf32>) {
/// ^bb0(%in: f32, %in_0: f32, %out: f32):
/// %4 = arith.addf %in, %in_0 : f32
/// linalg.yield %4 : f32
/// } -> tensor<1x1xf32>
struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
auto outputOperands = genericOp.getDpsInitOperands();
SetVector<OpOperand *> candidates;
for (OpOperand *op : outputOperands) {
if (genericOp.getMatchingBlockArgument(op).use_empty())
continue;
candidates.insert(op);
}
if (candidates.empty())
return failure();
// Compute the modified indexing maps.
int64_t origNumInput = genericOp.getNumDpsInputs();
SmallVector<Value> newInputOperands = genericOp.getDpsInputOperands();
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
SmallVector<AffineMap> newIndexingMaps;
newIndexingMaps.append(indexingMaps.begin(),
std::next(indexingMaps.begin(), origNumInput));
for (OpOperand *op : candidates) {
newInputOperands.push_back(op->get());
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
}
newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
indexingMaps.end());
Location loc = genericOp.getLoc();
SmallVector<Value> newOutputOperands = outputOperands;
for (OpOperand *op : candidates) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(op->get());
auto elemType = op->get().getType().cast<ShapedType>().getElementType();
auto empty = rewriter.create<tensor::EmptyOp>(
loc, tensor::createDimValues(rewriter, loc, op->get()), elemType);
auto [start, end] = genericOp.getDpsInitsPositionRange();
newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
}
auto newOp = rewriter.create<GenericOp>(
loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
newIndexingMaps, genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
Region &region = newOp.getRegion();
Block *block = new Block();
region.push_back(block);
BlockAndValueMapping mapper;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(block);
for (auto bbarg : genericOp.getRegionInputArgs())
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
for (OpOperand *op : candidates) {
BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
}
for (OpOperand *op : outputOperands) {
BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
if (candidates.count(op))
block->addArgument(bbarg.getType(), loc);
else
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
}
for (auto &op : genericOp.getBody()->getOperations()) {
rewriter.clone(op, mapper);
}
rewriter.replaceOp(genericOp, newOp.getResults());
return success();
}
};
struct UnitExtentReplacementInfo {
Type type;
AffineMap indexMap;
@ -536,7 +658,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
patterns.add<FoldUnitDimLoops, AddInitOperandsToInput, ReplaceUnitExtents,
RankReducedExtractSliceOp,
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
@ -544,6 +667,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
namespace {
@ -555,7 +680,7 @@ struct LinalgFoldUnitExtentDimsPass
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly)
patterns.add<FoldUnitDimLoops>(context);
patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
else
populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));

View File

@ -384,11 +384,12 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1
// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32>
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>)
// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor<?xf32>, tensor<1xf32>)
// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]

View File

@ -8331,6 +8331,7 @@ cc_library(
":LinalgUtils",
":MathDialect",
":MemRefDialect",
":MemRefTransforms",
":Pass",
":SCFDialect",
":SCFTransforms",