forked from OSchip/llvm-project
[mlir][Linalg] NFC - Modernize transformation APIs.
Differential Revision: https://reviews.llvm.org/D116665
This commit is contained in:
parent
9aa017342c
commit
9a7d111f4f
|
@ -44,12 +44,12 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
|
|||
//===----------------------------------------------------------------------===//
|
||||
using LinalgLoops = SmallVector<Operation *, 4>;
|
||||
|
||||
/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
|
||||
/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
|
||||
void populateConvVectorizationPatterns(
|
||||
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
|
||||
ArrayRef<int64_t> tileSizes);
|
||||
|
||||
/// Populates patterns for vectorizing low-D convolution ops. This is a step in
|
||||
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
|
||||
/// progressive lowering for convolution ops, it assume high-D convolution ops
|
||||
/// were decomposed previously.
|
||||
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
|
||||
|
@ -91,7 +91,7 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|||
/// canonicalizations of named ops into another named op.
|
||||
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Populates the given list with patterns to bufferize linalg ops.
|
||||
/// Populate the given list with patterns to bufferize linalg ops.
|
||||
void populateLinalgBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &converter,
|
||||
RewritePatternSet &patterns);
|
||||
|
@ -124,7 +124,7 @@ struct LinalgElementwiseFusionOptions {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Function that allows the caller to control when to stop fusion. Once a
|
||||
/// Function to allow the caller to control when to stop fusion. Once a
|
||||
/// producer is deemed fusable with the consumer (structurally), this callback
|
||||
/// can be used to abort the fusion based on non-structural constraints. This
|
||||
/// is the hook for cost models to control the amount of fusion done.
|
||||
|
@ -149,7 +149,7 @@ void populateElementwiseOpsFusionPatterns(
|
|||
/// more fusion opportunities.
|
||||
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `interchangeVector`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
/// the new ordering of the loop nest. The length of `interchangeVector`
|
||||
|
@ -157,7 +157,7 @@ void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
|
|||
/// An empty vector is interpreted as the identity permutation and the
|
||||
/// transformation returns early.
|
||||
///
|
||||
/// Returns a struct containing the tiled loops in the specified order
|
||||
/// Return a struct containing the tiled loops in the specified order
|
||||
/// and the cloned op if successful, llvm::None otherwise.
|
||||
///
|
||||
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
|
||||
|
@ -237,7 +237,7 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
const LinalgDependenceGraph &dependenceGraph,
|
||||
const LinalgTilingOptions &tilingOptions);
|
||||
|
||||
/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
|
||||
/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
|
||||
/// the index accesses of `op`. This is an in-place transformation controlled by
|
||||
/// `interchangeVector`. An empty vector is interpreted as the identity
|
||||
/// permutation and the transformation returns early.
|
||||
|
@ -246,12 +246,15 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
|
||||
/// integers, in the range 0..`op.rank` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation).
|
||||
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
|
||||
GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
|
||||
/// `namedOp` is not a GenericOp and has a region builder.
|
||||
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
|
||||
/// Create a GenericOp from the given named operation `namedOp` and replace
|
||||
/// namedOp.
|
||||
/// Return failure if `namedOp` is a GenericOp or misses a region builder.
|
||||
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
|
||||
LinalgOp namedOp);
|
||||
|
||||
/// Callback function type used to perform the allocation for the promoted
|
||||
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
|
||||
|
@ -346,7 +349,7 @@ struct LinalgPromotionOptions {
|
|||
}
|
||||
};
|
||||
|
||||
/// Creates a new buffer using the `allocationFn` provided. The size of this
|
||||
/// Create a new buffer using the `allocationFn` provided. The size of this
|
||||
/// buffer is the smallest constant bounding size along each dimension that can
|
||||
/// be computed for the size of the result of `subView`. Returns the allocated
|
||||
/// buffer as `fullLocalView` and the view that matches the size of the result
|
||||
|
@ -360,7 +363,7 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
|
|||
const AllocBufferCallbackFn &allocationFn,
|
||||
DataLayout &layout);
|
||||
|
||||
/// Promotes the `subViews` into a new buffer allocated at the insertion point
|
||||
/// Promote the `subViews` into a new buffer allocated at the insertion point
|
||||
/// `b`. Promotion occurs in 3 steps:
|
||||
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
|
||||
/// 2. Take a full view on the buffer.
|
||||
|
@ -368,24 +371,23 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
|
|||
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
|
||||
/// true.
|
||||
///
|
||||
/// Returns the modified linalg op (the modification happens in place) as well
|
||||
/// Return the modified linalg op (the modification happens in place) as well
|
||||
/// as all the copy ops created.
|
||||
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
|
||||
const LinalgPromotionOptions &options);
|
||||
|
||||
/// Emit a suitable vector form for a Linalg op with fully static shape.
|
||||
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
|
||||
SmallVectorImpl<Value> &newResults);
|
||||
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
|
||||
/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
|
||||
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
|
||||
/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
|
||||
FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
|
||||
/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
|
||||
FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
|
||||
LinalgOp linalgOp);
|
||||
|
||||
|
@ -393,28 +395,10 @@ FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
|
|||
// Preconditions that ensure the corresponding transformation succeeds and can
|
||||
// be applied as a rewrite pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
|
||||
/// permutated according to `permutation`.
|
||||
LogicalResult
|
||||
interchangeGenericOpPrecondition(GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
||||
/// Generalize named operations to generic operations.
|
||||
LogicalResult generalizeNamedOpPrecondition(Operation *op);
|
||||
|
||||
/// Promote std.subviews feeding linalg operations.
|
||||
/// Promote memref.subviews feeding linalg-on-buffers operations.
|
||||
LogicalResult promoteSubviewsPrecondition(Operation *op,
|
||||
LinalgPromotionOptions options);
|
||||
|
||||
/// Return success if the operation can be vectorized.
|
||||
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
|
||||
|
||||
/// Return success if `op` can be vectorized assuming it is static. This allows
|
||||
/// checking if an op will be vectorizable once all the dimensions are folded to
|
||||
/// static values.
|
||||
/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
|
||||
LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transformations exposed as rewrite patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -610,7 +594,7 @@ struct LinalgTilingOptions {
|
|||
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
|
||||
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Base pattern that applied the tiling transformation specified by `options`.
|
||||
/// Base pattern that applies the tiling transformation specified by `options`.
|
||||
/// Abort and return failure in 2 cases:
|
||||
/// 1. if the tiling specification is invalid and tiling fails to occur.
|
||||
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
|
||||
|
@ -812,9 +796,9 @@ private:
|
|||
};
|
||||
|
||||
///
|
||||
/// Linalg generic interchage pattern.
|
||||
/// Linalg generic interchange pattern.
|
||||
///
|
||||
/// Apply the `interchange` transformation as a pattern.
|
||||
/// Apply the `interchange` transformation on a RewriterBase.
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `interchange` for more details.
|
||||
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
|
||||
|
@ -909,13 +893,11 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
|
|||
///
|
||||
/// Linalg vectorization patterns.
|
||||
///
|
||||
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `vectorizeLinalgOp` for more details.
|
||||
|
||||
/// Empty for now, used for SFINAE purposes only.
|
||||
struct LinalgVectorizationOptions {};
|
||||
|
||||
/// `filter` controls LinalgTransformMarker matching and update when specified.
|
||||
/// See `vectorizeLinalgOp` for more details.
|
||||
struct LinalgBaseVectorizationPattern : public RewritePattern {
|
||||
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
|
||||
LinalgBaseVectorizationPattern(MLIRContext *context,
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
|
||||
static LogicalResult generalizeNamedOpPrecondition(Operation *op) {
|
||||
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
|
||||
// Check if the operation is a LinalgOp but not a GenericOp.
|
||||
if (!namedOp || isa<GenericOp>(op))
|
||||
|
@ -40,8 +40,11 @@ LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
|
||||
LinalgOp namedOp) {
|
||||
FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
|
||||
LinalgOp namedOp) {
|
||||
if (failed(generalizeNamedOpPrecondition(namedOp)))
|
||||
return rewriter.notifyMatchFailure(namedOp, "preconditions not met");
|
||||
|
||||
SmallVector<Value> inputOperands = namedOp.getInputOperands();
|
||||
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
|
||||
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
|
||||
|
@ -58,6 +61,7 @@ GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
|
|||
outputOperands, indexingMaps, iterators);
|
||||
rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
|
||||
genericOp.region().begin());
|
||||
rewriter.replaceOp(namedOp, genericOp->getResults());
|
||||
return genericOp;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
@ -30,8 +31,9 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
|
||||
GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
|
||||
static LogicalResult
|
||||
interchangeGenericOpPrecondition(GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector) {
|
||||
// Interchange vector must be non-empty and match the number of loops.
|
||||
if (interchangeVector.empty() ||
|
||||
genericOp.getNumLoops() != interchangeVector.size())
|
||||
|
@ -43,31 +45,38 @@ LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
|
||||
GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector) {
|
||||
// 1. Compute the inverse permutation map.
|
||||
FailureOr<GenericOp>
|
||||
mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector) {
|
||||
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
|
||||
return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
|
||||
|
||||
// 1. Compute the inverse permutation map, it must be non-null since the
|
||||
// preconditions are satisfied.
|
||||
MLIRContext *context = genericOp.getContext();
|
||||
AffineMap permutationMap = inversePermutation(
|
||||
AffineMap::getPermutationMap(interchangeVector, context));
|
||||
assert(permutationMap && "expected permutation to be invertible");
|
||||
assert(interchangeVector.size() == genericOp.getNumLoops() &&
|
||||
"expected interchange vector to have entry for every loop");
|
||||
assert(permutationMap && "unexpected null map");
|
||||
|
||||
// Start a guarded inplace update.
|
||||
rewriter.startRootUpdate(genericOp);
|
||||
auto guard =
|
||||
llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
|
||||
|
||||
// 2. Compute the interchanged indexing maps.
|
||||
SmallVector<Attribute, 4> newIndexingMaps;
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
AffineMap m = genericOp.getTiedIndexingMap(opOperand);
|
||||
if (!permutationMap.isEmpty())
|
||||
m = m.compose(permutationMap);
|
||||
newIndexingMaps.push_back(AffineMapAttr::get(m));
|
||||
newIndexingMaps.push_back(m);
|
||||
}
|
||||
genericOp->setAttr(getIndexingMapsAttrName(),
|
||||
ArrayAttr::get(context, newIndexingMaps));
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMaps));
|
||||
|
||||
// 3. Compute the interchanged iterator types.
|
||||
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
|
||||
SmallVector<Attribute, 4> itTypesVector;
|
||||
SmallVector<Attribute> itTypesVector;
|
||||
llvm::append_range(itTypesVector, itTypes);
|
||||
SmallVector<int64_t> permutation(interchangeVector.begin(),
|
||||
interchangeVector.end());
|
||||
|
@ -91,4 +100,6 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
|
|||
indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
|
||||
}
|
||||
}
|
||||
|
||||
return genericOp;
|
||||
}
|
||||
|
|
|
@ -137,7 +137,7 @@ struct SimplifyDepthwiseConvQOp
|
|||
struct LinalgNamedOpConversionPass
|
||||
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
|
||||
LinalgNamedOpConversionPass() = default;
|
||||
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
|
||||
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
|
|
|
@ -623,16 +623,14 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
|
|||
GenericOp genericOp, PatternRewriter &rewriter) const {
|
||||
if (failed(filter.checkAndNotify(rewriter, genericOp)))
|
||||
return failure();
|
||||
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
|
||||
|
||||
FailureOr<GenericOp> transformedOp =
|
||||
interchangeGenericOp(rewriter, genericOp, interchangeVector);
|
||||
if (failed(transformedOp))
|
||||
return failure();
|
||||
|
||||
// TODO: figure out how this interplays with named ops. In particular this
|
||||
// should break the named op property.
|
||||
rewriter.updateRootInPlace(genericOp, [&]() {
|
||||
interchangeGenericOp(rewriter, genericOp, interchangeVector);
|
||||
// New filter if specified.
|
||||
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
|
||||
});
|
||||
// New filter if specified.
|
||||
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -652,12 +650,10 @@ LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
|
|||
Operation *op, PatternRewriter &rewriter) const {
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
if (failed(generalizeNamedOpPrecondition(op)))
|
||||
FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, op);
|
||||
if (failed(genericOp))
|
||||
return failure();
|
||||
|
||||
GenericOp genericOp = generalizeNamedOp(rewriter, op);
|
||||
rewriter.replaceOp(op, genericOp.getResults());
|
||||
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
|
||||
filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -708,19 +704,13 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
|
|||
|
||||
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
// TODO: Interface-based rewrite.
|
||||
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
SmallVector<Value> newResults;
|
||||
if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
|
||||
return failure();
|
||||
if (!newResults.empty())
|
||||
rewriter.replaceOp(op, newResults);
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
return vectorize(rewriter, linalgOp);
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::applyStagedPatterns(
|
||||
|
@ -758,8 +748,8 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
|
|||
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
|
||||
}
|
||||
|
||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
|
||||
/// with pad_val) and GenericOp (to copy contents).
|
||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to
|
||||
/// initialize with pad_val) and GenericOp (to copy contents).
|
||||
LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
|
||||
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
|
||||
|
||||
|
|
|
@ -597,8 +597,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
|
||||
static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
|
||||
if (isElementwise(op))
|
||||
return success();
|
||||
// TODO: isaConvolutionOpInterface that can also infer from generic features.
|
||||
|
@ -620,8 +619,7 @@ mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
|
||||
// All types must be static shape to go to vector.
|
||||
if (linalgOp.hasDynamicShape()) {
|
||||
LDBG("precondition failed: dynamic shape");
|
||||
|
@ -630,31 +628,32 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
|||
return vectorizeStaticLinalgOpPrecondition(linalgOp);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
|
||||
SmallVectorImpl<Value> &newResults) {
|
||||
if (failed(vectorizeLinalgOpPrecondition(op)))
|
||||
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
|
||||
return failure();
|
||||
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
// TODO: isaConvolutionOpInterface that can also infer from generic features.
|
||||
// But we will still need stride/dilation attributes that will be annoying to
|
||||
// reverse-engineer...
|
||||
if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
|
||||
FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
|
||||
if (failed(resultOrFail))
|
||||
SmallVector<Value> results;
|
||||
// TODO: isaConvolutionOpInterface that can also infer from generic
|
||||
// features. Will require stride/dilation attributes inference.
|
||||
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
|
||||
LDBG("Vectorize as a conv: " << linalgOp);
|
||||
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
|
||||
if (failed(convOr))
|
||||
return failure();
|
||||
llvm::append_range(results, (*convOr)->getResults());
|
||||
} else {
|
||||
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
|
||||
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
|
||||
return failure();
|
||||
Operation *newOp = *resultOrFail;
|
||||
llvm::append_range(newResults, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
LDBG(""
|
||||
<< "Vectorize linalg op as a generic by broadcasting to "
|
||||
"maximal common shape: "
|
||||
<< *op);
|
||||
return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
|
||||
if (!results.empty())
|
||||
rewriter.replaceOp(linalgOp, results);
|
||||
else
|
||||
rewriter.eraseOp(linalgOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
@ -666,8 +665,9 @@ static int64_t getIntFromAttr(Attribute attr) {
|
|||
return attr.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
|
||||
/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
|
||||
/// are converted to ConstantIndexOps. Other attribute types are not supported.
|
||||
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
|
||||
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
|
||||
/// not supported.
|
||||
static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
|
||||
ArrayRef<OpFoldResult> ofrs) {
|
||||
SmallVector<Value> result;
|
||||
|
@ -691,9 +691,9 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
GenericPadTensorOpVectorizationPattern(MLIRContext *context,
|
||||
PatternBenefit benefit = 1)
|
||||
: GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
|
||||
/// Vectorize the copying of a PadTensorOp's source. This is possible if each
|
||||
/// dimension size is statically know in the source type or the result type
|
||||
/// (or both).
|
||||
/// Vectorize the copying of a PadTensorOp's source. This is possible if
|
||||
/// each dimension size is statically know in the source type or the result
|
||||
/// type (or both).
|
||||
static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
|
||||
PadTensorOp padOp, Value dest) {
|
||||
auto sourceType = padOp.getSourceType();
|
||||
|
@ -718,13 +718,14 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
|
||||
if (!sourceType.isDynamicDim(i)) {
|
||||
vecShape.push_back(sourceType.getDimSize(i));
|
||||
// Source shape is statically known: Neither read nor write are out-of-
|
||||
// bounds.
|
||||
// Source shape is statically known: Neither read nor write are
|
||||
// out-of- bounds.
|
||||
readInBounds.push_back(true);
|
||||
writeInBounds.push_back(true);
|
||||
} else if (!resultType.isDynamicDim(i)) {
|
||||
// Source shape is not statically known, but result shape is. Vectorize
|
||||
// with size of result shape. This may be larger than the source size.
|
||||
// Source shape is not statically known, but result shape is.
|
||||
// Vectorize with size of result shape. This may be larger than the
|
||||
// source size.
|
||||
vecShape.push_back(resultType.getDimSize(i));
|
||||
// Read may be out-of-bounds because the result size could be larger
|
||||
// than the source size.
|
||||
|
@ -749,8 +750,8 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
|
||||
ArrayRef<bool>{readInBounds});
|
||||
|
||||
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
|
||||
// tensor, write directly to the FillOp's operand.
|
||||
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
|
||||
// entire tensor, write directly to the FillOp's operand.
|
||||
if (llvm::equal(vecShape, resultType.getShape()) &&
|
||||
llvm::all_of(writeInBounds, [](bool b) { return b; }))
|
||||
if (auto fill = dest.getDefiningOp<FillOp>())
|
||||
|
@ -766,8 +767,8 @@ struct GenericPadTensorOpVectorizationPattern
|
|||
}
|
||||
};
|
||||
|
||||
/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
|
||||
/// operation type OpTy.
|
||||
/// Base pattern for rewriting PadTensorOps whose result is consumed by a
|
||||
/// given operation type OpTy.
|
||||
template <typename OpTy>
|
||||
struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
|
@ -837,10 +838,10 @@ struct PadTensorOpVectorizationWithTransferReadPattern
|
|||
};
|
||||
|
||||
/// Rewrite use of PadTensorOp result in TransferWriteOp.
|
||||
/// This pattern rewrites TransferWriteOps that write to a padded tensor value,
|
||||
/// where the same amount of padding is immediately removed again after the
|
||||
/// write. In such cases, the TransferWriteOp can write to the non-padded tensor
|
||||
/// value and apply out-of-bounds masking. E.g.:
|
||||
/// This pattern rewrites TransferWriteOps that write to a padded tensor
|
||||
/// value, where the same amount of padding is immediately removed again after
|
||||
/// the write. In such cases, the TransferWriteOp can write to the non-padded
|
||||
/// tensor value and apply out-of-bounds masking. E.g.:
|
||||
/// ```
|
||||
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
|
||||
/// : tensor<...> to tensor<?x?xf32>
|
||||
|
@ -854,17 +855,19 @@ struct PadTensorOpVectorizationWithTransferReadPattern
|
|||
/// ```
|
||||
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
|
||||
/// : tensor<...> to tensor<?x?xf32>
|
||||
/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
|
||||
/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
|
||||
/// tensor<?x?xf32>
|
||||
/// ```
|
||||
/// Note: It is important that the ExtractSliceOp %r resizes the result of the
|
||||
/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
|
||||
/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
|
||||
/// %r's old dimensions.
|
||||
/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
|
||||
/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
|
||||
/// from %r's old dimensions.
|
||||
///
|
||||
/// This rewrite is possible if:
|
||||
/// - Low padding is static 0.
|
||||
/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
|
||||
/// ExtractSliceOp trims the same amount of padding that was added beforehand.
|
||||
/// ExtractSliceOp trims the same amount of padding that was added
|
||||
/// beforehand.
|
||||
/// - Single, scalar padding value.
|
||||
struct PadTensorOpVectorizationWithTransferWritePattern
|
||||
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
|
||||
|
@ -922,8 +925,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
|
|||
/// sizes may turn out to be equal at runtime.
|
||||
bool hasSameTensorSize(Value beforePadding,
|
||||
tensor::ExtractSliceOp afterTrimming) const {
|
||||
// If the input to PadTensorOp is a CastOp, try with with both CastOp result
|
||||
// and CastOp operand.
|
||||
// If the input to PadTensorOp is a CastOp, try with with both CastOp
|
||||
// result and CastOp operand.
|
||||
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
|
||||
if (hasSameTensorSize(castOp.source(), afterTrimming))
|
||||
return true;
|
||||
|
@ -950,8 +953,9 @@ struct PadTensorOpVectorizationWithTransferWritePattern
|
|||
if (t1.getNumDynamicDims() == 0)
|
||||
return true;
|
||||
|
||||
// All dynamic sizes must be the same. The only supported case at the moment
|
||||
// is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
|
||||
// All dynamic sizes must be the same. The only supported case at the
|
||||
// moment is when `beforePadding` is an ExtractSliceOp (or a cast
|
||||
// thereof).
|
||||
|
||||
// Apart from CastOp, only ExtractSliceOp is supported.
|
||||
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
|
||||
|
@ -1062,7 +1066,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
|
|||
// InsertSliceOp.
|
||||
rewriter.setInsertionPoint(insertOp);
|
||||
|
||||
// Generate TransferReadOp: Read entire source tensor and add high padding.
|
||||
// Generate TransferReadOp: Read entire source tensor and add high
|
||||
// padding.
|
||||
SmallVector<Value> readIndices(
|
||||
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
|
||||
auto read = rewriter.create<vector::TransferReadOp>(
|
||||
|
@ -1224,9 +1229,9 @@ void mlir::linalg::populateConvVectorizationPatterns(
|
|||
// Forwarding patterns
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
/// Check whether there is any interleaved use of any `values` between `firstOp`
|
||||
/// and `secondOp`. Conservatively return `true` if any op or value is in a
|
||||
/// different block.
|
||||
/// Check whether there is any interleaved use of any `values` between
|
||||
/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
|
||||
/// is in a different block.
|
||||
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
||||
ValueRange values) {
|
||||
if (firstOp->getBlock() != secondOp->getBlock() ||
|
||||
|
@ -1252,7 +1257,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
|
||||
/// Return the unique subview use of `v` if it is indeed unique, null
|
||||
/// otherwise.
|
||||
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
|
||||
memref::SubViewOp subViewOp;
|
||||
for (auto &u : v.getUses()) {
|
||||
|
@ -1307,7 +1313,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
return failure();
|
||||
LDBG("with copy " << *copyOp);
|
||||
|
||||
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
|
||||
// Find the fill into `viewOrAlloc` without interleaved uses before the
|
||||
// copy.
|
||||
FillOp maybeFillOp;
|
||||
for (auto &u : viewOrAlloc.getUses()) {
|
||||
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
|
||||
|
@ -1468,7 +1475,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
||||
/// ```
|
||||
/// kw is always unrolled.
|
||||
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
|
||||
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
||||
/// > 1.
|
||||
FailureOr<Operation *> conv() {
|
||||
if (!valid)
|
||||
return failure();
|
||||
|
@ -1483,7 +1491,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
|
||||
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
|
||||
// When strideW == 1, we can batch the contiguous loads and avoid unrolling
|
||||
// When strideW == 1, we can batch the contiguous loads and avoid
|
||||
// unrolling
|
||||
int64_t wSizeStep = strideW == 1 ? wSize : 1;
|
||||
|
||||
Type lhsEltType = lhsShapedType.getElementType();
|
||||
|
@ -1500,7 +1509,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
|
||||
VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
|
||||
|
||||
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
|
||||
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
|
||||
// 0].
|
||||
Value lhs = builder.create<vector::TransferReadOp>(
|
||||
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
|
||||
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
|
||||
|
@ -1591,7 +1601,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
||||
/// ```
|
||||
/// kw is always unrolled.
|
||||
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
|
||||
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
||||
/// > 1.
|
||||
FailureOr<Operation *> dilatedConv() {
|
||||
if (!valid)
|
||||
return failure();
|
||||
|
@ -1605,7 +1616,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
|
||||
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
|
||||
// When strideW == 1, we can batch the contiguous loads and avoid unrolling
|
||||
// When strideW == 1, we can batch the contiguous loads and avoid
|
||||
// unrolling
|
||||
int64_t wSizeStep = strideW == 1 ? wSize : 1;
|
||||
|
||||
Type lhsEltType = lhsShapedType.getElementType();
|
||||
|
@ -1621,7 +1633,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
|
||||
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
|
||||
|
||||
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
|
||||
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
|
||||
// 0].
|
||||
Value lhs = builder.create<vector::TransferReadOp>(
|
||||
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
|
||||
// Read rhs slice of size {kw, c} @ [0, 0].
|
||||
|
|
Loading…
Reference in New Issue