[mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D120627
This commit is contained in:
Lei Zhang 2022-03-03 09:44:40 -05:00
parent 5aeaabf35e
commit 7d249dfd7d
3 changed files with 16 additions and 15 deletions

View File

@ -103,10 +103,9 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
/// source, if the producer is a `linalg` operation with all parallel iterator
/// types.
void populateFusePadTensorWithProducerLinalgOpPatterns(
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
/// if the producer is a `linalg` operation with all parallel iterator types.
void populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);
/// Patterns to convert from one named op to another. These can be seen as

View File

@ -1,4 +1,4 @@
//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,8 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that intechanges a generic op -> pad_tensor
// pattern into extract_slice -> generic_op.
// This file implements patterns that intechanges a linalg.generic -> tensor.pad
// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
// op chain.
//
//===----------------------------------------------------------------------===//
@ -17,7 +18,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
@ -25,7 +25,7 @@ namespace {
///
/// ```mlir
/// %0 = linalg. ...
/// %1 = linalg.pad_tensor %0 ...
/// %1 = tensor.pad %0 ...
/// ```
///
/// can be replaced with
@ -40,6 +40,7 @@ namespace {
/// if the `linalg.generic` has all parallel iterator types.
struct FusePadOp : OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
// Only works on padding op that sets the padded value to a constant.
@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// This pattern could work for any Linalg op. For now restrict it to generic
// ops.
Value source = padOp.source();
auto linalgOp = source.getDefiningOp<GenericOp>();
auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
if (!linalgOp) {
return rewriter.notifyMatchFailure(
padOp, "expected source to be linalg.generic op");
@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// Create the tensor of same size as output of the pad op.
RankedTensorType padResultType = padOp.getResultType();
auto resultSizes = getAsOpFoldResult(resultShape[0]);
auto initTensor = rewriter.create<InitTensorOp>(
auto initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizes, padResultType.getElementType());
// Fill the tensor with the pad value.
// TODO: There is an option to fill only the boundaries. For now just
// filling the whole tensor.
auto fillTensor =
rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
rewriter.create<linalg::FillOp>(loc, padValue, initTensor.getResult());
// Construct a slice of the fill result that is to be replaced with the
// result of the generic op. The low pad values are the offsets, the size of
@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
loc, fillTensor.getResult(0), offsets, sizes, strides);
// Clone the generic op.
auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
auto clonedOp =
cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
clonedOp.setOutputOperand(resultNumber, slice.getResult());
// Insert it back into the result of the fill.
@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
};
} // namespace
void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
RewritePatternSet &patterns) {
patterns.add<FusePadOp>(patterns.getContext());
}

View File

@ -34,7 +34,7 @@ struct TestPadFusionPass
MLIRContext *context = &getContext();
FuncOp funcOp = getOperation();
RewritePatternSet patterns(context);
linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();