forked from OSchip/llvm-project
[mlir][Linalg] Add a build method for linalg.pad_tensor
Add a build method that pads the source with a scalar value. Reviewed By: nicolasvasilache, antiagainst Differential Revision: https://reviews.llvm.org/D96343
This commit is contained in:
parent
dd719fda76
commit
e8d31754a2
|
@ -224,6 +224,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||||
// size is met).
|
// size is met).
|
||||||
static linalg::PadTensorOp createPadHighOp(
|
static linalg::PadTensorOp createPadHighOp(
|
||||||
Type type, Value source, Value pad, Location loc, OpBuilder & builder);
|
Type type, Value source, Value pad, Location loc, OpBuilder & builder);
|
||||||
|
|
||||||
|
// Return a PadTensorOp that pads `source to `type` size with `pad` value.
|
||||||
|
// I.e., a block will be created and the `pad` value will be yielded
|
||||||
|
// directly. If the type passed is nullptr, it is inferred.
|
||||||
|
static linalg::PadTensorOp createPadScalarOp(
|
||||||
|
Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
|
||||||
|
ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
|
@ -234,7 +241,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
||||||
// Build a PadTensorOp with all dynamic entries.
|
// Build a PadTensorOp with all dynamic entries.
|
||||||
OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
|
OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
|
||||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||||
// Build a PadTensorOp with with mixed static and dynamic entries and custom
|
// Build a PadTensorOp with mixed static and dynamic entries and custom
|
||||||
// result type. If the type passed is nullptr, it is inferred.
|
// result type. If the type passed is nullptr, it is inferred.
|
||||||
OpBuilderDAG<(ins "Type":$resultType, "Value":$source,
|
OpBuilderDAG<(ins "Type":$resultType, "Value":$source,
|
||||||
"ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
|
"ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
|
||||||
|
|
|
@ -780,6 +780,24 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
||||||
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
|
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
|
||||||
|
ArrayRef<OpFoldResult> low,
|
||||||
|
ArrayRef<OpFoldResult> high,
|
||||||
|
Location loc, OpBuilder &builder) {
|
||||||
|
auto padTensorOp =
|
||||||
|
builder.create<linalg::PadTensorOp>(loc, type, source, low, high);
|
||||||
|
int rank = padTensorOp.getResultType().getRank();
|
||||||
|
SmallVector<Type, 4> blockArgTypes;
|
||||||
|
blockArgTypes.assign(rank, builder.getIndexType());
|
||||||
|
auto ®ion = padTensorOp.region();
|
||||||
|
// `builder.createBlock` changes the insertion point within the block. Create
|
||||||
|
// a guard to reset the insertion point of the builder after it is destroyed.
|
||||||
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
|
builder.createBlock(®ion, region.end(), blockArgTypes);
|
||||||
|
builder.create<linalg::YieldOp>(loc, pad);
|
||||||
|
return padTensorOp;
|
||||||
|
}
|
||||||
|
|
||||||
PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
|
PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
|
||||||
Location loc, OpBuilder &builder) {
|
Location loc, OpBuilder &builder) {
|
||||||
SmallVector<OpFoldResult, 4> low, high;
|
SmallVector<OpFoldResult, 4> low, high;
|
||||||
|
@ -794,17 +812,8 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
|
||||||
high.push_back(highValue);
|
high.push_back(highValue);
|
||||||
low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
|
low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
|
||||||
}
|
}
|
||||||
auto padTensorOp =
|
return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc,
|
||||||
builder.create<linalg::PadTensorOp>(loc, type, source, low, high);
|
builder);
|
||||||
SmallVector<Type, 4> blockArgTypes;
|
|
||||||
blockArgTypes.assign(rank, builder.getIndexType());
|
|
||||||
auto ®ion = padTensorOp.region();
|
|
||||||
// `builder.createBlock` changes the insertion point within the block. Create
|
|
||||||
// a guard to reset the insertion point of the builder after it is destroyed.
|
|
||||||
OpBuilder::InsertionGuard guard(builder);
|
|
||||||
builder.createBlock(®ion, region.end(), blockArgTypes);
|
|
||||||
builder.create<linalg::YieldOp>(loc, pad);
|
|
||||||
return padTensorOp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue