forked from OSchip/llvm-project
[mlir][Linalg] Fix padding related bugs.
This revision fixes the fact that the padding transformation did not have enough information to set the proper type for the padding value. Additionally, the verifier for Yield in the presence of PadTensorOp is fixed to properly report incorrect number of results or operands. Previously, the error would be silently ignored which made the core issue difficult to debug. Differential Revision: https://reviews.llvm.org/D96264
This commit is contained in:
parent
5112035751
commit
d57a305fdf
|
@ -376,8 +376,10 @@ enum class LinalgTilingLoopType {
|
|||
using TileSizeComputationFunction =
|
||||
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
|
||||
|
||||
/// Specify the padding value for an OpOperand. This should be a function of
|
||||
/// both the operation and the operand type.
|
||||
using PaddingValueComputationFunction =
|
||||
std::function<Value(OpBuilder &, Operation *)>;
|
||||
std::function<Value(OpBuilder &, OpOperand &)>;
|
||||
|
||||
struct LinalgTilingOptions {
|
||||
/// Computation function that returns the tile sizes for each operation.
|
||||
|
|
|
@ -1373,10 +1373,13 @@ static LogicalResult verify(linalg::YieldOp op) {
|
|||
return verifyYield(op, cast<LinalgOp>(parentOp));
|
||||
|
||||
if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) {
|
||||
return success(
|
||||
op.getNumOperands() == 1 &&
|
||||
op.getOperand(0).getType() ==
|
||||
padTensorOp.getType().cast<ShapedType>().getElementType());
|
||||
if (op.getNumOperands() != 1)
|
||||
return op.emitOpError("expected single yield operand (got ")
|
||||
<< op->getNumOperands() << ")";
|
||||
if (op.getOperand(0).getType() !=
|
||||
padTensorOp.getType().cast<ShapedType>().getElementType())
|
||||
return op.emitOpError("expected yield type to match shape element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
return op.emitOpError("expected parent op with LinalgOp interface");
|
||||
|
|
|
@ -127,13 +127,13 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
|
|||
/// created PadTensorOp.
|
||||
/// Return failure if the operand cannot be padded to a static shape.
|
||||
static LogicalResult padOperandToSmallestStaticBoundingBox(
|
||||
PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand,
|
||||
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand,
|
||||
const LinalgTilingOptions &options, Value &result) {
|
||||
auto tensorType = operand.getType().cast<RankedTensorType>();
|
||||
auto tensorType = operand.get().getType().cast<RankedTensorType>();
|
||||
// Already static shape, no need to pad.
|
||||
if (tensorType.hasStaticShape())
|
||||
return success();
|
||||
auto subtensor = operand.getDefiningOp<SubTensorOp>();
|
||||
auto subtensor = operand.get().getDefiningOp<SubTensorOp>();
|
||||
// Not a subtensor, cannot construct a static bounding box.
|
||||
if (!subtensor)
|
||||
return failure();
|
||||
|
@ -152,11 +152,11 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
|
|||
opToPad, "No constant bounding box can be found for padding");
|
||||
staticSizes.push_back(indexAttr.getInt());
|
||||
}
|
||||
Value pad = options.paddingValueComputationFunction(rewriter, opToPad);
|
||||
Value pad = options.paddingValueComputationFunction(rewriter, operand);
|
||||
auto staticTensorType =
|
||||
RankedTensorType::get(staticSizes, tensorType.getElementType());
|
||||
result = linalg::PadTensorOp::createPadHighOp(staticTensorType, operand, pad,
|
||||
opToPad->getLoc(), rewriter);
|
||||
result = linalg::PadTensorOp::createPadHighOp(
|
||||
staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -180,26 +180,26 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
|
|||
// Set IP after op because we also take the dims of the original output.
|
||||
rewriter.setInsertionPointAfter(opToPad);
|
||||
// Make a copy of the shaped operands and update it.
|
||||
SmallVector<Value> operands = opToPad.getShapedOperands();
|
||||
for (Value &v : operands) {
|
||||
SmallVector<Value> newOperands;
|
||||
newOperands.reserve(opToPad.getNumShapedOperands());
|
||||
for (OpOperand &operand : opToPad.getShapedOpOperands()) {
|
||||
Value paddedOperand;
|
||||
// If padding was requested but the shape cannot be bounded statically then
|
||||
// the pattern fails to apply.
|
||||
if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v,
|
||||
if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, operand,
|
||||
options, paddedOperand))) {
|
||||
return failure();
|
||||
}
|
||||
// Update v if we indeed got a padded operand.
|
||||
v = paddedOperand ? paddedOperand : v;
|
||||
newOperands.push_back(paddedOperand ? paddedOperand : operand.get());
|
||||
}
|
||||
|
||||
// Clone `opToPad` to operate on the statically padded shapes.
|
||||
auto resultTensorTypes =
|
||||
ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes();
|
||||
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
|
||||
ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
|
||||
operands.append(otherOperands.begin(), otherOperands.end());
|
||||
newOperands.append(otherOperands.begin(), otherOperands.end());
|
||||
linalg::LinalgOp paddedOp =
|
||||
opToPad.clone(rewriter, loc, resultTensorTypes, operands);
|
||||
opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
|
||||
|
||||
// Recover the subtensor out of the new static results. This keeps the
|
||||
// original linalg op around because it uses the dims of the original results.
|
||||
|
|
|
@ -646,6 +646,28 @@ func @pad_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @pad_num_yields(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
|
||||
// expected-error @+3 {{op expected single yield operand (got 2)}}
|
||||
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
|
||||
^bb0(%arg2: index, %arg3: index): // no predecessors
|
||||
linalg.yield %arg1, %arg1 : i32, i32
|
||||
} : tensor<?x4xi32> to tensor<?x9xi32>
|
||||
return %0 : tensor<?x9xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pad_yield_type(%arg0: tensor<?x4xi32>, %arg1: i8) -> tensor<?x9xi32> {
|
||||
// expected-error @+3 {{op expected yield type to match shape element type}}
|
||||
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
|
||||
^bb0(%arg2: index, %arg3: index): // no predecessors
|
||||
linalg.yield %arg1 : i8
|
||||
} : tensor<?x4xi32> to tensor<?x9xi32>
|
||||
return %0 : tensor<?x9xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
|
||||
{
|
||||
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
|
||||
|
|
|
@ -1,41 +1,41 @@
|
|||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @matmul_tensors(
|
||||
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
|
||||
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xi8>
|
||||
// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
func @matmul_tensors(
|
||||
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32> {
|
||||
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
|
||||
// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
%arg0: tensor<?x?xi8>, %arg1: tensor<?x?xi8>, %arg2: tensor<?x?xi32>)
|
||||
-> tensor<?x?xi32> {
|
||||
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
|
||||
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
|
||||
// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xi32>) {
|
||||
// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xi8> to tensor<?x?xi8>
|
||||
// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xi8> to tensor<?x?xi8>
|
||||
// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
|
||||
|
||||
// Dynamic op has been canonicalized away.
|
||||
// CHECK-NOT: linalg.matmul {{.*}} tensor<?x?xf32>
|
||||
// CHECK-NOT: linalg.matmul {{.*}} tensor<?x?xi8>
|
||||
|
||||
// Padding injects static information.
|
||||
// CHECK: %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
|
||||
// CHECK: : tensor<?x?xf32> to tensor<2x4xf32>
|
||||
// CHECK: : tensor<?x?xi8> to tensor<2x4xi8>
|
||||
// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
|
||||
// CHECK: : tensor<?x?xf32> to tensor<4x3xf32>
|
||||
// CHECK: : tensor<?x?xi8> to tensor<4x3xi8>
|
||||
// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%c0, %c0] high[%{{.*}}, %{{.*}}]
|
||||
// CHECK: : tensor<?x?xf32> to tensor<2x3xf32>
|
||||
// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xf32>, tensor<4x3xf32>)
|
||||
// CHECK-SAME: outs(%[[pC]] : tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
|
||||
// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
|
||||
// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
|
||||
%0 = linalg.matmul {__internal_linalg_transform__ = "tile-and-pad"}
|
||||
ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32>
|
||||
// CHECK: : tensor<?x?xi32> to tensor<2x3xi32>
|
||||
// CHECK: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
|
||||
// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xi32> to tensor<?x?xi32>
|
||||
// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xi32> into tensor<?x?xi32>
|
||||
// CHECK: scf.yield %[[TD]] : tensor<?x?xi32>
|
||||
// CHECK: scf.yield %[[TD2]] : tensor<?x?xi32>
|
||||
// CHECK: scf.yield %[[TD1]] : tensor<?x?xi32>
|
||||
%0 = linalg.matmul_i8_i8_i32 {__internal_linalg_transform__ = "tile-and-pad"}
|
||||
ins(%arg0, %arg1: tensor<?x?xi8>, tensor<?x?xi8>)
|
||||
outs(%arg2: tensor<?x?xi32>)
|
||||
-> tensor<?x?xi32>
|
||||
|
||||
// CHECK: return %[[TD0]] : tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
// CHECK: return %[[TD0]] : tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
|
|
@ -508,9 +508,9 @@ static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
|
|||
|
||||
// For now, just assume it is the zero of type.
|
||||
// In the future, it should be the zero of type + op.
|
||||
static Value getNeutralOfLinalgOp(OpBuilder &b, Operation *op) {
|
||||
auto t = op->getResult(0).getType().cast<ShapedType>().getElementType();
|
||||
return b.create<ConstantOp>(op->getLoc(), t, b.getZeroAttr(t));
|
||||
static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
|
||||
auto t = getElementTypeOrSelf(op.get().getType());
|
||||
return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
|
||||
}
|
||||
|
||||
static void applyTileAndPadPattern(FuncOp funcOp) {
|
||||
|
@ -520,7 +520,7 @@ static void applyTileAndPadPattern(FuncOp funcOp) {
|
|||
linalg::LinalgTilingOptions()
|
||||
.setTileSizes({2, 3, 4})
|
||||
.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
|
||||
tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
|
||||
tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
|
||||
context, linalgTilingOptions,
|
||||
linalg::LinalgTransformationFilter(
|
||||
Identifier::get("tile-and-pad", context)));
|
||||
|
|
Loading…
Reference in New Issue