From e8d31754a285c0f2f8d6625f108c1204420cb290 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Tue, 9 Feb 2021 10:19:06 -0800 Subject: [PATCH] [mlir][Linalg] Add a build method for linalg.pad_tensor Add a build method that pads the source with a scalar value. Reviewed By: nicolasvasilache, antiagainst Differential Revision: https://reviews.llvm.org/D96343 --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 9 +++++- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 31 ++++++++++++------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index bc98a04a19f2..a40d425f7f2e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -224,6 +224,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", // size is met). static linalg::PadTensorOp createPadHighOp( Type type, Value source, Value pad, Location loc, OpBuilder & builder); + + // Return a PadTensorOp that pads `source to `type` size with `pad` value. + // I.e., a block will be created and the `pad` value will be yielded + // directly. If the type passed is nullptr, it is inferred. + static linalg::PadTensorOp createPadScalarOp( + Type type, Value source, Value pad, ArrayRef low, + ArrayRef high, Location loc, OpBuilder & builder); }]; let builders = [ @@ -234,7 +241,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", // Build a PadTensorOp with all dynamic entries. OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high, CArg<"ArrayRef", "{}">:$attrs)>, - // Build a PadTensorOp with with mixed static and dynamic entries and custom + // Build a PadTensorOp with mixed static and dynamic entries and custom // result type. If the type passed is nullptr, it is inferred. OpBuilderDAG<(ins "Type":$resultType, "Value":$source, "ArrayRef":$low, "ArrayRef":$high, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d5d7cbb1e8a7..96acbd4c1949 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -780,6 +780,24 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType, b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh)); } +PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad, + ArrayRef low, + ArrayRef high, + Location loc, OpBuilder &builder) { + auto padTensorOp = + builder.create(loc, type, source, low, high); + int rank = padTensorOp.getResultType().getRank(); + SmallVector blockArgTypes; + blockArgTypes.assign(rank, builder.getIndexType()); + auto ®ion = padTensorOp.region(); + // `builder.createBlock` changes the insertion point within the block. Create + // a guard to reset the insertion point of the builder after it is destroyed. + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(®ion, region.end(), blockArgTypes); + builder.create(loc, pad); + return padTensorOp; +} + PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, Location loc, OpBuilder &builder) { SmallVector low, high; @@ -794,17 +812,8 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad, high.push_back(highValue); low.push_back(builder.createOrFold(loc, 0)); } - auto padTensorOp = - builder.create(loc, type, source, low, high); - SmallVector blockArgTypes; - blockArgTypes.assign(rank, builder.getIndexType()); - auto ®ion = padTensorOp.region(); - // `builder.createBlock` changes the insertion point within the block. Create - // a guard to reset the insertion point of the builder after it is destroyed. - OpBuilder::InsertionGuard guard(builder); - builder.createBlock(®ion, region.end(), blockArgTypes); - builder.create(loc, pad); - return padTensorOp; + return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc, + builder); } //===----------------------------------------------------------------------===//