forked from OSchip/llvm-project
[mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612). Differential Revision: https://reviews.llvm.org/D102245
This commit is contained in:
parent
a4db7025a9
commit
06bb9cf30d
|
@ -213,8 +213,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
|
||||
/// integers, in the range 0..`op.rank` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation).
|
||||
void interchange(PatternRewriter &rewriter, LinalgOp op,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Callback function type used to perform the allocation for the promoted
|
||||
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
|
||||
|
@ -363,11 +363,11 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
|
|||
// Preconditions that ensure the corresponding transformation succeeds and can
|
||||
// be applied as a rewrite pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
|
||||
/// and `iterator_types` permutated according to `permutation`.
|
||||
/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
|
||||
/// permutated according to `permutation`.
|
||||
LogicalResult
|
||||
interchangeGenericLinalgOpPrecondition(Operation *op,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
interchangeGenericOpPrecondition(GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Promote std.subviews feeding linalg operations.
|
||||
LogicalResult promoteSubviewsPrecondition(Operation *op,
|
||||
|
@ -630,18 +630,18 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
|
|||
};
|
||||
|
||||
///
|
||||
/// Linalg interchange patterns.
|
||||
/// Linalg generic interchage pattern.
|
||||
///
|
||||
/// Apply the `interchange` transformation as a pattern.
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `interchange` for more details.
|
||||
struct LinalgBaseInterchangePattern : public RewritePattern {
|
||||
LinalgBaseInterchangePattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector,
|
||||
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
GenericOpInterchangePattern(
|
||||
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1);
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
|
@ -651,16 +651,6 @@ private:
|
|||
SmallVector<unsigned, 8> interchangeVector;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
|
||||
LinalgInterchangePattern(
|
||||
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
|
||||
LinalgTransformationFilter filter = LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
|
||||
interchangeVector, filter, benefit) {}
|
||||
};
|
||||
|
||||
///
|
||||
/// Linalg promotion patterns.
|
||||
///
|
||||
|
|
|
@ -32,68 +32,65 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
|
||||
Operation *op, ArrayRef<unsigned> interchangeVector) {
|
||||
// Transformation applies to generic ops only.
|
||||
if (!isa<GenericOp, IndexedGenericOp>(op))
|
||||
return failure();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
|
||||
GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
|
||||
// Interchange vector must be non-empty and match the number of loops.
|
||||
if (interchangeVector.empty() ||
|
||||
linalgOp.getNumLoops() != interchangeVector.size())
|
||||
genericOp.getNumLoops() != interchangeVector.size())
|
||||
return failure();
|
||||
// Permutation map must be invertible.
|
||||
if (!inversePermutation(
|
||||
AffineMap::getPermutationMap(interchangeVector, op->getContext())))
|
||||
if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
|
||||
genericOp.getContext())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
|
||||
ArrayRef<unsigned> interchangeVector) {
|
||||
void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
|
||||
GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector) {
|
||||
// 1. Compute the inverse permutation map.
|
||||
MLIRContext *context = op.getContext();
|
||||
MLIRContext *context = genericOp.getContext();
|
||||
AffineMap permutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(interchangeVector, context));
|
||||
assert(permutationMap && "expected permutation to be invertible");
|
||||
assert(interchangeVector.size() == op.getNumLoops() &&
|
||||
assert(interchangeVector.size() == genericOp.getNumLoops() &&
|
||||
"expected interchange vector to have entry for every loop");
|
||||
|
||||
// 2. Compute the interchanged indexing maps.
|
||||
SmallVector<Attribute, 4> newIndexingMaps;
|
||||
ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
|
||||
for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
|
||||
ArrayRef<Attribute> indexingMaps = genericOp.indexing_maps().getValue();
|
||||
for (unsigned i = 0, e = genericOp.getNumShapedOperands(); i != e; ++i) {
|
||||
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
|
||||
if (!permutationMap.isEmpty())
|
||||
m = m.compose(permutationMap);
|
||||
newIndexingMaps.push_back(AffineMapAttr::get(m));
|
||||
}
|
||||
op->setAttr(getIndexingMapsAttrName(),
|
||||
ArrayAttr::get(context, newIndexingMaps));
|
||||
genericOp->setAttr(getIndexingMapsAttrName(),
|
||||
ArrayAttr::get(context, newIndexingMaps));
|
||||
|
||||
// 3. Compute the interchanged iterator types.
|
||||
ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
|
||||
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
|
||||
SmallVector<Attribute, 4> itTypesVector;
|
||||
llvm::append_range(itTypesVector, itTypes);
|
||||
applyPermutationToVector(itTypesVector, interchangeVector);
|
||||
op->setAttr(getIteratorTypesAttrName(),
|
||||
ArrayAttr::get(context, itTypesVector));
|
||||
genericOp->setAttr(getIteratorTypesAttrName(),
|
||||
ArrayAttr::get(context, itTypesVector));
|
||||
|
||||
// 4. Transform the index operations by applying the permutation map.
|
||||
if (op.hasIndexSemantics()) {
|
||||
if (genericOp.hasIndexSemantics()) {
|
||||
// TODO: Remove the assertion and add a getBody() method to LinalgOp
|
||||
// interface once every LinalgOp has a body.
|
||||
assert(op->getNumRegions() == 1 &&
|
||||
op->getRegion(0).getBlocks().size() == 1 &&
|
||||
assert(genericOp->getNumRegions() == 1 &&
|
||||
genericOp->getRegion(0).getBlocks().size() == 1 &&
|
||||
"expected generic operation to have one block.");
|
||||
Block &block = op->getRegion(0).front();
|
||||
Block &block = genericOp->getRegion(0).front();
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
for (IndexOp indexOp :
|
||||
llvm::make_early_inc_range(block.getOps<IndexOp>())) {
|
||||
rewriter.setInsertionPoint(indexOp);
|
||||
SmallVector<Value> allIndices;
|
||||
allIndices.reserve(op.getNumLoops());
|
||||
llvm::transform(llvm::seq<uint64_t>(0, op.getNumLoops()),
|
||||
allIndices.reserve(genericOp.getNumLoops());
|
||||
llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
|
||||
std::back_inserter(allIndices), [&](uint64_t dim) {
|
||||
return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
|
||||
});
|
||||
|
|
|
@ -393,30 +393,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Linalg base interchange pattern.
|
||||
mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
|
||||
StringRef opName, MLIRContext *context,
|
||||
ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
|
||||
PatternBenefit benefit)
|
||||
: RewritePattern(opName, benefit, context, {}), filter(filter),
|
||||
/// Linalg generic interchange pattern.
|
||||
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
|
||||
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
|
||||
LinalgTransformationFilter filter, PatternBenefit benefit)
|
||||
: OpRewritePattern(context, benefit), filter(filter),
|
||||
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
|
||||
|
||||
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
|
||||
GenericOp genericOp, PatternRewriter &rewriter) const {
|
||||
if (failed(filter.checkAndNotify(rewriter, genericOp)))
|
||||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
|
||||
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
|
||||
return failure();
|
||||
|
||||
// TODO: figure out how this interplays with named ops. In particular this
|
||||
// should break the named op property.
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
interchange(rewriter, linalgOp, interchangeVector);
|
||||
rewriter.updateRootInPlace(genericOp, [&]() {
|
||||
interchangeGenericOp(rewriter, genericOp, interchangeVector);
|
||||
// New filter if specified.
|
||||
filter.replaceLinalgTransformationFilter(rewriter, op);
|
||||
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -125,37 +125,6 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
|
||||
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
|
||||
|
||||
#indexed_matmul_trait = {
|
||||
args_in = 2,
|
||||
args_out = 1,
|
||||
indexing_maps = #matmul_accesses,
|
||||
library_call = "linalg_matmul_indexed",
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
func @permute_generic_indexed(
|
||||
%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
linalg.indexed_generic #indexed_matmul_trait
|
||||
ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?x?xf32, offset: ?, strides: [?, 1]>)
|
||||
outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
|
||||
^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
|
||||
%d = mulf %a, %b: f32
|
||||
%e = addf %c, %d: f32
|
||||
linalg.yield %e: f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @permute_generic_indexed
|
||||
// CHECK: linalg.indexed_generic {
|
||||
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
|
||||
// CHECK-SAME: library_call = "linalg_matmul_indexed"}
|
||||
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
|
||||
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
|
||||
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
|
||||
|
||||
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>) {
|
||||
|
|
|
@ -194,14 +194,9 @@ static void applyPatterns(FuncOp funcOp) {
|
|||
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Linalg generic permutation patterns.
|
||||
// Linalg generic interchange pattern.
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.add<LinalgInterchangePattern<GenericOp>>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get("PERMUTED", ctx)));
|
||||
patterns.add<LinalgInterchangePattern<IndexedGenericOp>>(
|
||||
patterns.add<GenericOpInterchangePattern>(
|
||||
ctx,
|
||||
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
|
@ -551,7 +546,7 @@ static void applyInterchangePattern(FuncOp funcOp,
|
|||
ArrayRef<unsigned> interchangeVector) {
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
RewritePatternSet interchangePattern(context);
|
||||
interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
|
||||
interchangePattern.add<GenericOpInterchangePattern>(
|
||||
context, interchangeVector,
|
||||
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
||||
Identifier::get("interchange", context)));
|
||||
|
|
Loading…
Reference in New Issue