forked from OSchip/llvm-project
[mlir][linalg] Refactor PadTensorOpVectorizationPattern (NFC)
* Rename PadTensorOpVectorizationPattern to GenericPadTensorOpVectorizationPattern. * Make GenericPadTensorOpVectorizationPattern a private pattern, to be instantiated via populatePadTensorOpVectorizationPatterns. * Factor out parts of PadTensorOpVectorizationPattern into helper functions. This commit prepares PadTensorOpVectorizationPattern for a series of subsequent commits that add more specialized PadTensorOp vectorization patterns. Differential Revision: https://reviews.llvm.org/D103681
This commit is contained in:
parent
50c0aaed47
commit
e789efc92a
|
@ -880,14 +880,14 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
|
||||
/// it needs a specific pattern to vectorize.
|
||||
struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PadTensorOp padOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
/// Populates `patterns` with patterns that vectorize linalg.pad_tensor.
|
||||
/// These patterns are meant to apply in a complementary fashion. Benefits
|
||||
/// are used to encode a certain ordering of pattern application. To avoid
|
||||
/// scattering magic constants throughout the code base, the patterns must be
|
||||
/// added with this function. `baseBenefit` can be used to offset the benefit
|
||||
/// of all PadTensorOp vectorization patterns by a certain value.
|
||||
void populatePadTensorOpVectorizationPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit baseBenefit = 1);
|
||||
|
||||
/// Match and rewrite for the pattern:
|
||||
/// ```
|
||||
|
|
|
@ -650,66 +650,81 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
|
|||
// Misc. vectorization patterns.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
/// Given a block, return the Value that the block yields if that Value is
|
||||
/// constant. In this context, "constant" means "defined outside of the block".
|
||||
/// Should not be called on blocks that yield more than one value.
|
||||
///
|
||||
/// Values are considered constant in two cases:
|
||||
/// - A basic block argument from a different block.
|
||||
/// - A value defined outside of the block.
|
||||
///
|
||||
/// If the yielded value is not constant, an empty Value is returned.
|
||||
static Value getConstantYieldValueFromBlock(Block &block) {
|
||||
auto yieldOp = cast<YieldOp>(block.getTerminator());
|
||||
assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
|
||||
Value result = yieldOp.values().front();
|
||||
Operation *definingOp = result.getDefiningOp();
|
||||
|
||||
// Check if yield value is defined inside the block.
|
||||
if (definingOp && definingOp->getBlock() == &block)
|
||||
return Value();
|
||||
// Check if the yield value is a BB arg of the block.
|
||||
if (!definingOp && result.cast<BlockArgument>().getOwner() == &block)
|
||||
return Value();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
|
||||
/// TransferWriteOp. For now, this only applies when all low and high paddings
|
||||
/// are determined to be zero.
|
||||
LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
|
||||
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
|
||||
// Helper function to determine whether an OpFoldResult is not a zero Index.
|
||||
auto isNotZeroIndex = [](OpFoldResult ofr) {
|
||||
if (Attribute attr = ofr.dyn_cast<Attribute>())
|
||||
return attr.cast<IntegerAttr>().getInt() != 0;
|
||||
Value v = ofr.get<Value>();
|
||||
if (auto constOp = v.getDefiningOp<ConstantOp>())
|
||||
if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
|
||||
return intAttr.getValue().getSExtValue() != 0;
|
||||
return true;
|
||||
};
|
||||
struct GenericPadTensorOpVectorizationPattern
|
||||
: public OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
|
||||
auto resultShapedType = padOp.result().getType().cast<ShapedType>();
|
||||
// Bail on non-static shapes.
|
||||
if (!resultShapedType.hasStaticShape())
|
||||
return failure();
|
||||
LogicalResult matchAndRewrite(PadTensorOp padOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
/// Given an OpFoldResult, return true if its value is guaranteed to be a
|
||||
/// zero integer.
|
||||
auto isZeroInt = [&](OpFoldResult ofr) {
|
||||
return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); };
|
||||
// Low padding must be static 0.
|
||||
if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure();
|
||||
// High padding must be static 0.
|
||||
if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
|
||||
// Pad value must be a constant.
|
||||
auto padValue = getConstantYieldValueFromBlock(padOp.region().front());
|
||||
if (!padValue) return failure();
|
||||
|
||||
// If any pad_low is not a static 0, needs a mask. Bail for now.
|
||||
if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
|
||||
return failure();
|
||||
VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
|
||||
if (!vectorType)
|
||||
return failure();
|
||||
// Bail on non-static shapes.
|
||||
auto resultShapedType = padOp.result().getType().cast<ShapedType>();
|
||||
if (!resultShapedType.hasStaticShape())
|
||||
return failure();
|
||||
VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
|
||||
if (!vectorType)
|
||||
return failure();
|
||||
|
||||
// Only support padding with a constant for now, i.e. either:
|
||||
// 1. A BBarg from a different block.
|
||||
// 2. A value defined outside of the current block.
|
||||
Block &block = padOp.region().front();
|
||||
auto yieldOp = cast<YieldOp>(block.getTerminator());
|
||||
assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
|
||||
Value padValue = yieldOp.values().front();
|
||||
Operation *definingOp = padValue.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &block)
|
||||
return failure();
|
||||
if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
|
||||
return failure();
|
||||
// Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
|
||||
// TransferWriteOp@[0..0].
|
||||
SmallVector<Value> indices(
|
||||
resultShapedType.getRank(),
|
||||
rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
|
||||
Value read = rewriter.create<vector::TransferReadOp>(
|
||||
padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
|
||||
Value init = rewriter.create<InitTensorOp>(
|
||||
padOp.getLoc(), resultShapedType.getShape(),
|
||||
resultShapedType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
|
||||
indices);
|
||||
|
||||
// TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
|
||||
if (llvm::any_of(padOp.getMixedHighPad(),
|
||||
[&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
|
||||
// TransferWriteOp@[0..0].
|
||||
SmallVector<Value> indices(
|
||||
resultShapedType.getRank(),
|
||||
rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
|
||||
Value read = rewriter.create<vector::TransferReadOp>(
|
||||
padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
|
||||
Value init =
|
||||
rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
|
||||
resultShapedType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
|
||||
indices);
|
||||
|
||||
return success();
|
||||
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
||||
patterns.add<GenericPadTensorOpVectorizationPattern>(
|
||||
patterns.getContext(), baseBenefit);
|
||||
}
|
||||
|
||||
// TODO: cleanup all the convolution vectorization patterns.
|
||||
|
|
|
@ -508,7 +508,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
|
|||
funcOp.getContext(),
|
||||
LinalgTransformationFilter()
|
||||
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
|
||||
patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());
|
||||
populatePadTensorOpVectorizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue