[mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...

after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).

Differential Revision: https://reviews.llvm.org/D102245
This commit is contained in:
Tobias Gysi 2021-05-12 12:43:34 +00:00
parent a4db7025a9
commit 06bb9cf30d
5 changed files with 50 additions and 103 deletions

View File

@ -213,8 +213,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
void interchange(PatternRewriter &rewriter, LinalgOp op,
ArrayRef<unsigned> interchangeVector);
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@ -363,11 +363,11 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
// Preconditions that ensure the corresponding transformation succeeds and can
// be applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
/// permutated according to `permutation`.
LogicalResult
interchangeGenericLinalgOpPrecondition(Operation *op,
ArrayRef<unsigned> interchangeVector);
interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
@ -630,18 +630,18 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
};
///
/// Linalg interchange patterns.
/// Linalg generic interchage pattern.
///
/// Apply the `interchange` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct LinalgBaseInterchangePattern : public RewritePattern {
LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector,
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
private:
@ -651,16 +651,6 @@ private:
SmallVector<unsigned, 8> interchangeVector;
};
template <typename OpTy>
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
LinalgInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
interchangeVector, filter, benefit) {}
};
///
/// Linalg promotion patterns.
///

View File

@ -32,68 +32,65 @@
using namespace mlir;
using namespace mlir::linalg;
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
Operation *op, ArrayRef<unsigned> interchangeVector) {
// Transformation applies to generic ops only.
if (!isa<GenericOp, IndexedGenericOp>(op))
return failure();
LinalgOp linalgOp = cast<LinalgOp>(op);
LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
// Interchange vector must be non-empty and match the number of loops.
if (interchangeVector.empty() ||
linalgOp.getNumLoops() != interchangeVector.size())
genericOp.getNumLoops() != interchangeVector.size())
return failure();
// Permutation map must be invertible.
if (!inversePermutation(
AffineMap::getPermutationMap(interchangeVector, op->getContext())))
if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
genericOp.getContext())))
return failure();
return success();
}
void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
ArrayRef<unsigned> interchangeVector) {
void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
GenericOp genericOp,
ArrayRef<unsigned> interchangeVector) {
// 1. Compute the inverse permutation map.
MLIRContext *context = op.getContext();
MLIRContext *context = genericOp.getContext();
AffineMap permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
assert(permutationMap && "expected permutation to be invertible");
assert(interchangeVector.size() == op.getNumLoops() &&
assert(interchangeVector.size() == genericOp.getNumLoops() &&
"expected interchange vector to have entry for every loop");
// 2. Compute the interchanged indexing maps.
SmallVector<Attribute, 4> newIndexingMaps;
ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
ArrayRef<Attribute> indexingMaps = genericOp.indexing_maps().getValue();
for (unsigned i = 0, e = genericOp.getNumShapedOperands(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(AffineMapAttr::get(m));
}
op->setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(context, newIndexingMaps));
genericOp->setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(context, newIndexingMaps));
// 3. Compute the interchanged iterator types.
ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
applyPermutationToVector(itTypesVector, interchangeVector);
op->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
genericOp->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
// 4. Transform the index operations by applying the permutation map.
if (op.hasIndexSemantics()) {
if (genericOp.hasIndexSemantics()) {
// TODO: Remove the assertion and add a getBody() method to LinalgOp
// interface once every LinalgOp has a body.
assert(op->getNumRegions() == 1 &&
op->getRegion(0).getBlocks().size() == 1 &&
assert(genericOp->getNumRegions() == 1 &&
genericOp->getRegion(0).getBlocks().size() == 1 &&
"expected generic operation to have one block.");
Block &block = op->getRegion(0).front();
Block &block = genericOp->getRegion(0).front();
OpBuilder::InsertionGuard guard(rewriter);
for (IndexOp indexOp :
llvm::make_early_inc_range(block.getOps<IndexOp>())) {
rewriter.setInsertionPoint(indexOp);
SmallVector<Value> allIndices;
allIndices.reserve(op.getNumLoops());
llvm::transform(llvm::seq<uint64_t>(0, op.getNumLoops()),
allIndices.reserve(genericOp.getNumLoops());
llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
std::back_inserter(allIndices), [&](uint64_t dim) {
return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
});

View File

@ -393,30 +393,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
return success();
}
/// Linalg base interchange pattern.
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
StringRef opName, MLIRContext *context,
ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
PatternBenefit benefit)
: RewritePattern(opName, benefit, context, {}), filter(filter),
/// Linalg generic interchange pattern.
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter filter, PatternBenefit benefit)
: OpRewritePattern(context, benefit), filter(filter),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
GenericOp genericOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, genericOp)))
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
return failure();
// TODO: figure out how this interplays with named ops. In particular this
// should break the named op property.
rewriter.updateRootInPlace(op, [&]() {
interchange(rewriter, linalgOp, interchangeVector);
rewriter.updateRootInPlace(genericOp, [&]() {
interchangeGenericOp(rewriter, genericOp, interchangeVector);
// New filter if specified.
filter.replaceLinalgTransformationFilter(rewriter, op);
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
});
return success();
}

View File

@ -125,37 +125,6 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
#indexed_matmul_trait = {
args_in = 2,
args_out = 1,
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul_indexed",
iterator_types = ["parallel", "parallel", "reduction"]
}
func @permute_generic_indexed(
%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.indexed_generic #indexed_matmul_trait
ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e: f32
}
return
}
// CHECK-LABEL: func @permute_generic_indexed
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
// CHECK-SAME: library_call = "linalg_matmul_indexed"}
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {

View File

@ -194,14 +194,9 @@ static void applyPatterns(FuncOp funcOp) {
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
// Linalg generic interchange pattern.
//===--------------------------------------------------------------------===//
patterns.add<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("PERMUTED", ctx)));
patterns.add<LinalgInterchangePattern<IndexedGenericOp>>(
patterns.add<GenericOpInterchangePattern>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgTransformationFilter(ArrayRef<Identifier>{},
@ -551,7 +546,7 @@ static void applyInterchangePattern(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet interchangePattern(context);
interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
interchangePattern.add<GenericOpInterchangePattern>(
context, interchangeVector,
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("interchange", context)));