forked from OSchip/llvm-project
[mlir][linalg] Add optional output operand to PadTensorOp
This optional operand will be used for tiling in a subsequent commit. Differential Revision: https://reviews.llvm.org/D105459
This commit is contained in:
parent
3469a8e03b
commit
5da010af9a
|
@ -146,6 +146,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
|||
dimension, i.e `low`.
|
||||
* high: A list contains the padding along the end of each
|
||||
dimension, i.e. `high`.
|
||||
* output: An optional output operand.
|
||||
|
||||
The result tensor dimensions are `low` + `dim` + `high` along that
|
||||
dimension. The number of elements of `low` and `high` must match
|
||||
|
@ -194,16 +195,21 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
|||
Variadic<Index>:$low,
|
||||
Variadic<Index>:$high,
|
||||
I64ArrayAttr:$static_low,
|
||||
I64ArrayAttr:$static_high);
|
||||
I64ArrayAttr:$static_high,
|
||||
Optional<AnyTensor>:$output);
|
||||
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let results = (outs AnyTensor:$result);
|
||||
|
||||
// TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
|
||||
let assemblyFormat = [{
|
||||
$source `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
|
||||
$source
|
||||
`low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
|
||||
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
|
||||
(`into` $output^ )?
|
||||
$region attr-dict `:` type($source) `to` type($result)
|
||||
custom<InferType>(ref($output), type($output), ref(type($result)))
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -292,7 +298,12 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
|
|||
// result type. If the type passed is nullptr, it is inferred.
|
||||
OpBuilder<(ins "Type":$resultType, "Value":$source,
|
||||
"ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
// Build a PadTensorOp with mixed static and dynamic entries and custom
|
||||
// result type.
|
||||
OpBuilder<(ins "Type":$resultType, "Value":$source,
|
||||
"ArrayRef<Value>":$low, "ArrayRef<Value>":$high, "ArrayAttr":$staticLow,
|
||||
"ArrayAttr":$staticHigh)>
|
||||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
|
|
@ -855,6 +855,19 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
|
|||
// PadTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
|
||||
// supports optional types.
|
||||
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
|
||||
Type typeToInfer, Type typeToInferFrom) {}
|
||||
|
||||
ParseResult parseInferType(OpAsmParser &parser,
|
||||
Optional<OpAsmParser::OperandType> optOperand,
|
||||
Type &typeToInfer, Type typeToInferFrom) {
|
||||
if (optOperand)
|
||||
typeToInfer = typeToInferFrom;
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(PadTensorOp op) {
|
||||
auto sourceType = op.source().getType().cast<RankedTensorType>();
|
||||
auto resultType = op.result().getType().cast<RankedTensorType>();
|
||||
|
@ -870,6 +883,9 @@ static LogicalResult verify(PadTensorOp op) {
|
|||
<< resultType << " does not match the inferred type "
|
||||
<< expectedType;
|
||||
}
|
||||
if (op.output() && op.output().getType() != op.getResultType()) {
|
||||
op.emitError("expected that output operand type equals result type");
|
||||
}
|
||||
|
||||
auto ®ion = op.region();
|
||||
unsigned rank = resultType.getRank();
|
||||
|
@ -916,7 +932,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
|
|||
auto sourceType = source.getType().cast<RankedTensorType>();
|
||||
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
|
||||
build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
|
||||
b.getI64ArrayAttr(staticHigh));
|
||||
b.getI64ArrayAttr(staticHigh), /*output=*/Value());
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
|
@ -953,7 +969,15 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
|||
PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
|
||||
}
|
||||
build(b, result, resultType, source, dynamicLow, dynamicHigh,
|
||||
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
|
||||
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
|
||||
/*output=*/Value());
|
||||
}
|
||||
|
||||
void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
|
||||
Value source, ArrayRef<Value> low, ArrayRef<Value> high,
|
||||
ArrayAttr staticLow, ArrayAttr staticHigh) {
|
||||
build(b, result, resultType, source, low, high, staticLow, staticHigh,
|
||||
/*output=*/{});
|
||||
}
|
||||
|
||||
PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
|
||||
|
@ -1038,11 +1062,25 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Fold tensor.dim(pad_tensor(%input, %output)) to tensor.dim(%output).
|
||||
struct FoldToDimOfOutputOperand : public OpRewritePattern<tensor::DimOp> {
|
||||
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto padTensorOp = dimOp.source().getDefiningOp<PadTensorOp>();
|
||||
if (!padTensorOp || !padTensorOp.output())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, padTensorOp.output(),
|
||||
dimOp.index());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<FoldStaticZeroPadding>(context);
|
||||
results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand>(context);
|
||||
}
|
||||
|
||||
/// Return the padding value of the PadTensorOp if it constant. In this context,
|
||||
|
|
|
@ -902,3 +902,21 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
|
|||
%r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
|
||||
return %r: tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dim_of_pad_tensor(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
|
||||
%pad_value: f32) -> index {
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 {
|
||||
^bb0(%arg2: index, %arg3: index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%r = tensor.dim %0, %c0 : tensor<?x?xf32>
|
||||
return %r : index
|
||||
}
|
||||
|
|
|
@ -584,6 +584,18 @@ func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> t
|
|||
|
||||
// -----
|
||||
|
||||
// expected-note@+1 {{prior use here}}
|
||||
func @pad_output_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32, %output: tensor<?x6x6x7xf32>) -> tensor<?x?x?x8xf32> {
|
||||
// expected-error @+1 {{use of value '%output' expects different type than prior uses: 'tensor<?x5x6x7xf32>' vs 'tensor<?x6x6x7xf32>'}}
|
||||
%0 = linalg.pad_tensor %arg0 low[1, 1, 1, 1] high[2, 2, 2, 2] into %output {
|
||||
^bb0(%arg3: index, %arg4: index): // no predecessors
|
||||
linalg.yield %arg2 : i32
|
||||
} : tensor<?x2x3x4xi32> to tensor<?x5x6x7xf32>
|
||||
return %0 : tensor<?x5x6x7xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
|
||||
// expected-error @+1 {{expected the block to have 2 arguments}}
|
||||
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
|
||||
|
|
|
@ -51,6 +51,24 @@ func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @pad_static_with_output(%arg0: tensor<3x4xf32>,
|
||||
%out_tensor : tensor<6x9xf32>,
|
||||
%pad_value: f32)
|
||||
-> tensor<6x9xf32> {
|
||||
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] into %out_tensor {
|
||||
^bb0(%arg1 : index, %arg2 : index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<3x4xf32> to tensor<6x9xf32>
|
||||
return %0 : tensor<6x9xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @pad_static
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<3x4xf32>,
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<6x9xf32>,
|
||||
// CHECK: linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] into %[[ARG1]]
|
||||
// CHECK: : tensor<3x4xf32> to tensor<6x9xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
|
||||
%pad_value: f32) -> tensor<?x?xf32> {
|
||||
%0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
|
||||
|
|
Loading…
Reference in New Issue