forked from OSchip/llvm-project
[mlir][linalg] Expose flag to control nofold attribute when padding.
Setting the nofold attribute enables packing an operand. At the moment, the attribute is set by default. The pack introduces a callback to control the flag. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111718
This commit is contained in:
parent
0b48b015b5
commit
a8f69be61f
|
@ -452,6 +452,10 @@ using TileSizeComputationFunction =
|
|||
using PaddingValueComputationFunction =
|
||||
std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>;
|
||||
|
||||
/// Callback returning true if the pad tensor operation defining the given
|
||||
/// OpOperand shall be marked as nofold to enable packing.
|
||||
using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
|
||||
|
||||
struct LinalgTilingOptions {
|
||||
/// Computation function that returns the tile sizes for each operation.
|
||||
/// Delayed construction of constant tile sizes should occur to interoperate
|
||||
|
@ -526,6 +530,18 @@ struct LinalgTilingOptions {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Callback returning true if the pad tensor operation defining the given
|
||||
/// OpOperand shall be marked as nofold to enable packing. A padding operation
|
||||
/// is only marked nofold if `paddingNoFoldComputationFunction` is set and
|
||||
/// returns true. Otherwise, the nofold attribute is set to false.
|
||||
PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
|
||||
|
||||
LinalgTilingOptions &
|
||||
setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
|
||||
paddingNoFoldComputationFunction = std::move(fun);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Peel the specified loops.
|
||||
SmallVector<int64_t> peeledLoops;
|
||||
|
||||
|
@ -999,6 +1015,7 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
|
|||
LogicalResult
|
||||
rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
||||
const PaddingValueComputationFunction &paddingFunc,
|
||||
const PaddingNoFoldComputationFunction &nofoldFunc,
|
||||
LinalgOp &paddedOp);
|
||||
|
||||
using OptimizeCopyFn =
|
||||
|
|
|
@ -153,7 +153,8 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
|
|||
/// padded to a static shape.
|
||||
static LogicalResult padOperandToSmallestStaticBoundingBox(
|
||||
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
|
||||
const PaddingValueComputationFunction &paddingFunc, Value &result) {
|
||||
const PaddingValueComputationFunction &paddingFunc,
|
||||
const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
|
||||
// Can't pad scalars.
|
||||
if (opToPad.getShape(opOperand).empty())
|
||||
return success();
|
||||
|
@ -181,15 +182,17 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
|
|||
}
|
||||
auto staticTensorType = RankedTensorType::get(
|
||||
staticSizes, getElementTypeOrSelf(opOperand->get()));
|
||||
bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
|
||||
result = linalg::PadTensorOp::createPadHighOp(
|
||||
staticTensorType, opOperand->get(), paddingValue.getValue(),
|
||||
/*nofold=*/true, opToPad->getLoc(), rewriter);
|
||||
/*nofold=*/nofold, opToPad->getLoc(), rewriter);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
||||
const PaddingValueComputationFunction &paddingFunc,
|
||||
const PaddingNoFoldComputationFunction &nofoldFunc,
|
||||
LinalgOp &paddedOp) {
|
||||
Location loc = opToPad->getLoc();
|
||||
|
||||
|
@ -208,7 +211,8 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
|
|||
// If padding was requested but the shape cannot be bounded statically then
|
||||
// the pattern fails to apply.
|
||||
if (failed(padOperandToSmallestStaticBoundingBox(
|
||||
rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
|
||||
rewriter, opToPad, opOperand, paddingFunc, nofoldFunc,
|
||||
paddedOperand)))
|
||||
return failure();
|
||||
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
|
||||
}
|
||||
|
@ -341,9 +345,9 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
|
|||
// Try to pad on the fly by rewriting res->op as a padded op. If successful,
|
||||
// `res.op` is rewritten in static form with padded operands.
|
||||
LinalgOp paddedOp;
|
||||
if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
|
||||
options.paddingValueComputationFunction,
|
||||
paddedOp))) {
|
||||
if (succeeded(rewriteAsPaddedOp(
|
||||
rewriter, res->op, options.paddingValueComputationFunction,
|
||||
options.paddingNoFoldComputationFunction, paddedOp))) {
|
||||
filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
|
||||
res->op = paddedOp;
|
||||
result = *res;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 nofold-operands=0,1 tile-sizes=2,3,4" -canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 nofold-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
|
||||
|
||||
// CHECK-LABEL: func @matmul_tensors(
|
||||
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
|
||||
|
@ -24,7 +24,7 @@ func @matmul_tensors(
|
|||
// CHECK: : tensor<?x?xi8> to tensor<2x4xi8>
|
||||
// CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
|
||||
// CHECK: : tensor<?x?xi8> to tensor<4x3xi8>
|
||||
// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
|
||||
// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
|
||||
// CHECK: : tensor<?x?xi32> to tensor<2x3xi32>
|
||||
// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
|
||||
// CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
|
|
|
@ -113,6 +113,10 @@ struct TestLinalgTransforms
|
|||
*this, "padded-operands",
|
||||
llvm::cl::desc("Operands to pad when test-tile-pattern"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
||||
ListOption<int64_t> nofoldOperands{
|
||||
*this, "nofold-operands",
|
||||
llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
|
||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
||||
ListOption<int64_t> peeledLoops{
|
||||
*this, "peeled-loops",
|
||||
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
|
||||
|
@ -581,6 +585,7 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
|
|||
static void applyTilePattern(FuncOp funcOp, std::string loopType,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
ArrayRef<int64_t> paddedOperands,
|
||||
ArrayRef<int64_t> nofoldOperands,
|
||||
ArrayRef<int64_t> peeledLoops,
|
||||
bool scalarizeDynamicDims) {
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
|
@ -608,7 +613,13 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
|
|||
return failure();
|
||||
return getNeutralOfLinalgOp(b, opOperand);
|
||||
};
|
||||
auto nofoldFunc = [&](OpOperand &opOperand) {
|
||||
if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0)
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
|
||||
linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc);
|
||||
}
|
||||
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
|
||||
linalg::LinalgTilingPattern<linalg::GenericOp>>(
|
||||
|
@ -743,9 +754,11 @@ void TestLinalgTransforms::runOnFunction() {
|
|||
skipPartial);
|
||||
if (testTilePattern)
|
||||
return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
|
||||
peeledLoops, /*scalarizeDynamicDims=*/false);
|
||||
nofoldOperands, peeledLoops,
|
||||
/*scalarizeDynamicDims=*/false);
|
||||
if (testTileScalarizeDynamicDims)
|
||||
return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
|
||||
nofoldOperands,
|
||||
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
|
||||
if (testHoistPadding) {
|
||||
getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
|
||||
|
|
Loading…
Reference in New Issue