forked from OSchip/llvm-project
[mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect
Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D120627
This commit is contained in:
parent
5aeaabf35e
commit
7d249dfd7d
|
@ -103,10 +103,9 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
|||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
|
||||
/// source, if the producer is a `linalg` operation with all parallel iterator
|
||||
/// types.
|
||||
void populateFusePadTensorWithProducerLinalgOpPatterns(
|
||||
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
|
||||
/// if the producer is a `linalg` operation with all parallel iterator types.
|
||||
void populateFuseTensorPadWithProducerLinalgOpPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to convert from one named op to another. These can be seen as
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
|
||||
//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,8 +6,9 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements patterns that intechanges a generic op -> pad_tensor
|
||||
// pattern into extract_slice -> generic_op.
|
||||
// This file implements patterns that intechanges a linalg.generic -> tensor.pad
|
||||
// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
|
||||
// op chain.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -17,7 +18,6 @@
|
|||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -25,7 +25,7 @@ namespace {
|
|||
///
|
||||
/// ```mlir
|
||||
/// %0 = linalg. ...
|
||||
/// %1 = linalg.pad_tensor %0 ...
|
||||
/// %1 = tensor.pad %0 ...
|
||||
/// ```
|
||||
///
|
||||
/// can be replaced with
|
||||
|
@ -40,6 +40,7 @@ namespace {
|
|||
/// if the `linalg.generic` has all parallel iterator types.
|
||||
struct FusePadOp : OpRewritePattern<tensor::PadOp> {
|
||||
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only works on padding op that sets the padded value to a constant.
|
||||
|
@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
|
|||
// This pattern could work for any Linalg op. For now restrict it to generic
|
||||
// ops.
|
||||
Value source = padOp.source();
|
||||
auto linalgOp = source.getDefiningOp<GenericOp>();
|
||||
auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
|
||||
if (!linalgOp) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp, "expected source to be linalg.generic op");
|
||||
|
@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
|
|||
// Create the tensor of same size as output of the pad op.
|
||||
RankedTensorType padResultType = padOp.getResultType();
|
||||
auto resultSizes = getAsOpFoldResult(resultShape[0]);
|
||||
auto initTensor = rewriter.create<InitTensorOp>(
|
||||
auto initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultSizes, padResultType.getElementType());
|
||||
|
||||
// Fill the tensor with the pad value.
|
||||
// TODO: There is an option to fill only the boundaries. For now just
|
||||
// filling the whole tensor.
|
||||
auto fillTensor =
|
||||
rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
|
||||
rewriter.create<linalg::FillOp>(loc, padValue, initTensor.getResult());
|
||||
|
||||
// Construct a slice of the fill result that is to be replaced with the
|
||||
// result of the generic op. The low pad values are the offsets, the size of
|
||||
|
@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
|
|||
loc, fillTensor.getResult(0), offsets, sizes, strides);
|
||||
|
||||
// Clone the generic op.
|
||||
auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
|
||||
auto clonedOp =
|
||||
cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
|
||||
clonedOp.setOutputOperand(resultNumber, slice.getResult());
|
||||
|
||||
// Insert it back into the result of the fill.
|
||||
|
@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
|
||||
void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<FusePadOp>(patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ struct TestPadFusionPass
|
|||
MLIRContext *context = &getContext();
|
||||
FuncOp funcOp = getOperation();
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
|
||||
linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
Loading…
Reference in New Issue