[mlir][Linalg] NFC - Modernize transformation APIs.

Differential Revision: https://reviews.llvm.org/D116665
This commit is contained in:
Nicolas Vasilache 2022-01-05 10:51:42 -05:00
parent 9aa017342c
commit 9a7d111f4f
6 changed files with 149 additions and 149 deletions

View File

@ -44,12 +44,12 @@ bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
/// Populates patterns for vectorizing low-D convolution ops. This is a step in
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
@ -91,7 +91,7 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
/// Populates the given list with patterns to bufferize linalg ops.
/// Populate the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(
bufferization::BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
@ -124,7 +124,7 @@ struct LinalgElementwiseFusionOptions {
return *this;
}
/// Function that allows the caller to control when to stop fusion. Once a
/// Function to allow the caller to control when to stop fusion. Once a
/// producer is deemed fusable with the consumer (structurally), this callback
/// can be used to abort the fusion based on non-structural constraints. This
/// is the hook for cost models to control the amount of fusion done.
@ -149,7 +149,7 @@ void populateElementwiseOpsFusionPatterns(
/// more fusion opportunities.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `interchangeVector`
@ -157,7 +157,7 @@ void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
/// An empty vector is interpreted as the identity permutation and the
/// transformation returns early.
///
/// Returns a struct containing the tiled loops in the specified order
/// Return a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
@ -237,7 +237,7 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions);
/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
/// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
/// the index accesses of `op`. This is an in-place transformation controlled by
/// `interchangeVector`. An empty vector is interpreted as the identity
/// permutation and the transformation returns early.
@ -246,12 +246,15 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
/// `namedOp` is not a GenericOp and has a region builder.
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
/// Create a GenericOp from the given named operation `namedOp` and replace
/// namedOp.
/// Return failure if `namedOp` is a GenericOp or misses a region builder.
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
LinalgOp namedOp);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@ -346,7 +349,7 @@ struct LinalgPromotionOptions {
}
};
/// Creates a new buffer using the `allocationFn` provided. The size of this
/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that can
/// be computed for the size of the result of `subView`. Returns the allocated
/// buffer as `fullLocalView` and the view that matches the size of the result
@ -360,7 +363,7 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
const AllocBufferCallbackFn &allocationFn,
DataLayout &layout);
/// Promotes the `subViews` into a new buffer allocated at the insertion point
/// Promote the `subViews` into a new buffer allocated at the insertion point
/// `b`. Promotion occurs in 3 steps:
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
/// 2. Take a full view on the buffer.
@ -368,24 +371,23 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
/// true.
///
/// Returns the modified linalg op (the modification happens in place) as well
/// Return the modified linalg op (the modification happens in place) as well
/// as all the copy ops created.
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
const LinalgPromotionOptions &options);
/// Emit a suitable vector form for a Linalg op with fully static shape.
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults);
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
/// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
/// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
@ -393,28 +395,10 @@ FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
// Preconditions that ensure the corresponding transformation succeeds and can
// be applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
/// permutated according to `permutation`.
LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Generalize named operations to generic operations.
LogicalResult generalizeNamedOpPrecondition(Operation *op);
/// Promote std.subviews feeding linalg operations.
/// Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
/// Return success if the operation can be vectorized.
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
/// Return success if `op` can be vectorized assuming it is static. This allows
/// checking if an op will be vectorizable once all the dimensions are folded to
/// static values.
/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
@ -610,7 +594,7 @@ struct LinalgTilingOptions {
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// Base pattern that applied the tiling transformation specified by `options`.
/// Base pattern that applies the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
/// 1. if the tiling specification is invalid and tiling fails to occur.
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
@ -812,9 +796,9 @@ private:
};
///
/// Linalg generic interchage pattern.
/// Linalg generic interchange pattern.
///
/// Apply the `interchange` transformation as a pattern.
/// Apply the `interchange` transformation on a RewriterBase.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
@ -909,13 +893,11 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
///
/// Linalg vectorization patterns.
///
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
struct LinalgBaseVectorizationPattern : public RewritePattern {
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBaseVectorizationPattern(MLIRContext *context,

View File

@ -29,7 +29,7 @@
using namespace mlir;
using namespace mlir::linalg;
LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
static LogicalResult generalizeNamedOpPrecondition(Operation *op) {
LinalgOp namedOp = dyn_cast<LinalgOp>(op);
// Check if the operation is a LinalgOp but not a GenericOp.
if (!namedOp || isa<GenericOp>(op))
@ -40,8 +40,11 @@ LogicalResult mlir::linalg::generalizeNamedOpPrecondition(Operation *op) {
return success();
}
GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
LinalgOp namedOp) {
FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
LinalgOp namedOp) {
if (failed(generalizeNamedOpPrecondition(namedOp)))
return rewriter.notifyMatchFailure(namedOp, "preconditions not met");
SmallVector<Value> inputOperands = namedOp.getInputOperands();
SmallVector<Value> outputOperands = namedOp.getOutputOperands();
SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
@ -58,6 +61,7 @@ GenericOp mlir::linalg::generalizeNamedOp(PatternRewriter &rewriter,
outputOperands, indexingMaps, iterators);
rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
genericOp.region().begin());
rewriter.replaceOp(namedOp, genericOp->getResults());
return genericOp;
}

View File

@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
@ -30,8 +31,9 @@
using namespace mlir;
using namespace mlir::linalg;
LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
static LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector) {
// Interchange vector must be non-empty and match the number of loops.
if (interchangeVector.empty() ||
genericOp.getNumLoops() != interchangeVector.size())
@ -43,31 +45,38 @@ LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
return success();
}
void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
GenericOp genericOp,
ArrayRef<unsigned> interchangeVector) {
// 1. Compute the inverse permutation map.
FailureOr<GenericOp>
mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector) {
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
return rewriter.notifyMatchFailure(genericOp, "preconditions not met");
// 1. Compute the inverse permutation map, it must be non-null since the
// preconditions are satisfied.
MLIRContext *context = genericOp.getContext();
AffineMap permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
assert(permutationMap && "expected permutation to be invertible");
assert(interchangeVector.size() == genericOp.getNumLoops() &&
"expected interchange vector to have entry for every loop");
assert(permutationMap && "unexpected null map");
// Start a guarded inplace update.
rewriter.startRootUpdate(genericOp);
auto guard =
llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
// 2. Compute the interchanged indexing maps.
SmallVector<Attribute, 4> newIndexingMaps;
SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
AffineMap m = genericOp.getTiedIndexingMap(opOperand);
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(AffineMapAttr::get(m));
newIndexingMaps.push_back(m);
}
genericOp->setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(context, newIndexingMaps));
rewriter.getAffineMapArrayAttr(newIndexingMaps));
// 3. Compute the interchanged iterator types.
ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
SmallVector<Attribute> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
@ -91,4 +100,6 @@ void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
}
}
return genericOp;
}

View File

@ -137,7 +137,7 @@ struct SimplifyDepthwiseConvQOp
struct LinalgNamedOpConversionPass
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
LinalgNamedOpConversionPass() = default;
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
void runOnOperation() override {
Operation *op = getOperation();

View File

@ -623,16 +623,14 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
GenericOp genericOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, genericOp)))
return failure();
if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
FailureOr<GenericOp> transformedOp =
interchangeGenericOp(rewriter, genericOp, interchangeVector);
if (failed(transformedOp))
return failure();
// TODO: figure out how this interplays with named ops. In particular this
// should break the named op property.
rewriter.updateRootInPlace(genericOp, [&]() {
interchangeGenericOp(rewriter, genericOp, interchangeVector);
// New filter if specified.
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
});
// New filter if specified.
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
return success();
}
@ -652,12 +650,10 @@ LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
if (failed(generalizeNamedOpPrecondition(op)))
FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, op);
if (failed(genericOp))
return failure();
GenericOp genericOp = generalizeNamedOp(rewriter, op);
rewriter.replaceOp(op, genericOp.getResults());
filter.replaceLinalgTransformationFilter(rewriter, genericOp);
filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
return success();
}
@ -708,19 +704,13 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
// TODO: Interface-based rewrite.
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
SmallVector<Value> newResults;
if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
return failure();
if (!newResults.empty())
rewriter.replaceOp(op, newResults);
else
rewriter.eraseOp(op);
return success();
return vectorize(rewriter, linalgOp);
}
LogicalResult mlir::linalg::applyStagedPatterns(
@ -758,8 +748,8 @@ static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
/// with pad_val) and GenericOp (to copy contents).
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to
/// initialize with pad_val) and GenericOp (to copy contents).
LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {

View File

@ -597,8 +597,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
return success();
}
LogicalResult
mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
if (isElementwise(op))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
@ -620,8 +619,7 @@ mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
return success();
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
// All types must be static shape to go to vector.
if (linalgOp.hasDynamicShape()) {
LDBG("precondition failed: dynamic shape");
@ -630,31 +628,32 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
return vectorizeStaticLinalgOpPrecondition(linalgOp);
}
LogicalResult
mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
SmallVectorImpl<Value> &newResults) {
if (failed(vectorizeLinalgOpPrecondition(op)))
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
LinalgOp linalgOp) {
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
return failure();
auto linalgOp = cast<LinalgOp>(op);
// TODO: isaConvolutionOpInterface that can also infer from generic features.
// But we will still need stride/dilation attributes that will be annoying to
// reverse-engineer...
if (auto convOp = dyn_cast<ConvolutionOpInterface>(op)) {
FailureOr<Operation *> resultOrFail = vectorizeConvolution(b, convOp);
if (failed(resultOrFail))
SmallVector<Value> results;
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
LDBG("Vectorize as a conv: " << linalgOp);
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
if (failed(convOr))
return failure();
llvm::append_range(results, (*convOr)->getResults());
} else {
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
Operation *newOp = *resultOrFail;
llvm::append_range(newResults, newOp->getResults());
return success();
}
LDBG(""
<< "Vectorize linalg op as a generic by broadcasting to "
"maximal common shape: "
<< *op);
return vectorizeAsLinalgGeneric(b, linalgOp, newResults);
if (!results.empty())
rewriter.replaceOp(linalgOp, results);
else
rewriter.eraseOp(linalgOp);
return success();
}
//----------------------------------------------------------------------------//
@ -666,8 +665,9 @@ static int64_t getIntFromAttr(Attribute attr) {
return attr.cast<IntegerAttr>().getInt();
}
/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
/// are converted to ConstantIndexOps. Other attribute types are not supported.
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
/// not supported.
static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
@ -691,9 +691,9 @@ struct GenericPadTensorOpVectorizationPattern
GenericPadTensorOpVectorizationPattern(MLIRContext *context,
PatternBenefit benefit = 1)
: GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
/// Vectorize the copying of a PadTensorOp's source. This is possible if each
/// dimension size is statically know in the source type or the result type
/// (or both).
/// Vectorize the copying of a PadTensorOp's source. This is possible if
/// each dimension size is statically know in the source type or the result
/// type (or both).
static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
PadTensorOp padOp, Value dest) {
auto sourceType = padOp.getSourceType();
@ -718,13 +718,14 @@ struct GenericPadTensorOpVectorizationPattern
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
if (!sourceType.isDynamicDim(i)) {
vecShape.push_back(sourceType.getDimSize(i));
// Source shape is statically known: Neither read nor write are out-of-
// bounds.
// Source shape is statically known: Neither read nor write are
// out-of- bounds.
readInBounds.push_back(true);
writeInBounds.push_back(true);
} else if (!resultType.isDynamicDim(i)) {
// Source shape is not statically known, but result shape is. Vectorize
// with size of result shape. This may be larger than the source size.
// Source shape is not statically known, but result shape is.
// Vectorize with size of result shape. This may be larger than the
// source size.
vecShape.push_back(resultType.getDimSize(i));
// Read may be out-of-bounds because the result size could be larger
// than the source size.
@ -749,8 +750,8 @@ struct GenericPadTensorOpVectorizationPattern
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
ArrayRef<bool>{readInBounds});
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
// tensor, write directly to the FillOp's operand.
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
// entire tensor, write directly to the FillOp's operand.
if (llvm::equal(vecShape, resultType.getShape()) &&
llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
@ -766,8 +767,8 @@ struct GenericPadTensorOpVectorizationPattern
}
};
/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
/// operation type OpTy.
/// Base pattern for rewriting PadTensorOps whose result is consumed by a
/// given operation type OpTy.
template <typename OpTy>
struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
@ -837,10 +838,10 @@ struct PadTensorOpVectorizationWithTransferReadPattern
};
/// Rewrite use of PadTensorOp result in TransferWriteOp.
/// This pattern rewrites TransferWriteOps that write to a padded tensor value,
/// where the same amount of padding is immediately removed again after the
/// write. In such cases, the TransferWriteOp can write to the non-padded tensor
/// value and apply out-of-bounds masking. E.g.:
/// This pattern rewrites TransferWriteOps that write to a padded tensor
/// value, where the same amount of padding is immediately removed again after
/// the write. In such cases, the TransferWriteOp can write to the non-padded
/// tensor value and apply out-of-bounds masking. E.g.:
/// ```
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
/// : tensor<...> to tensor<?x?xf32>
@ -854,17 +855,19 @@ struct PadTensorOpVectorizationWithTransferReadPattern
/// ```
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
/// : tensor<...> to tensor<?x?xf32>
/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
/// tensor<?x?xf32>
/// ```
/// Note: It is important that the ExtractSliceOp %r resizes the result of the
/// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
/// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
/// %r's old dimensions.
/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
/// from %r's old dimensions.
///
/// This rewrite is possible if:
/// - Low padding is static 0.
/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
/// ExtractSliceOp trims the same amount of padding that was added beforehand.
/// ExtractSliceOp trims the same amount of padding that was added
/// beforehand.
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferWritePattern
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
@ -922,8 +925,8 @@ struct PadTensorOpVectorizationWithTransferWritePattern
/// sizes may turn out to be equal at runtime.
bool hasSameTensorSize(Value beforePadding,
tensor::ExtractSliceOp afterTrimming) const {
// If the input to PadTensorOp is a CastOp, try with with both CastOp result
// and CastOp operand.
// If the input to PadTensorOp is a CastOp, try with with both CastOp
// result and CastOp operand.
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
if (hasSameTensorSize(castOp.source(), afterTrimming))
return true;
@ -950,8 +953,9 @@ struct PadTensorOpVectorizationWithTransferWritePattern
if (t1.getNumDynamicDims() == 0)
return true;
// All dynamic sizes must be the same. The only supported case at the moment
// is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
// All dynamic sizes must be the same. The only supported case at the
// moment is when `beforePadding` is an ExtractSliceOp (or a cast
// thereof).
// Apart from CastOp, only ExtractSliceOp is supported.
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
@ -1062,7 +1066,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
// InsertSliceOp.
rewriter.setInsertionPoint(insertOp);
// Generate TransferReadOp: Read entire source tensor and add high padding.
// Generate TransferReadOp: Read entire source tensor and add high
// padding.
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
@ -1224,9 +1229,9 @@ void mlir::linalg::populateConvVectorizationPatterns(
// Forwarding patterns
//----------------------------------------------------------------------------//
/// Check whether there is any interleaved use of any `values` between `firstOp`
/// and `secondOp`. Conservatively return `true` if any op or value is in a
/// different block.
/// Check whether there is any interleaved use of any `values` between
/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
/// is in a different block.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
@ -1252,7 +1257,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
return false;
}
/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
/// Return the unique subview use of `v` if it is indeed unique, null
/// otherwise.
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
memref::SubViewOp subViewOp;
for (auto &u : v.getUses()) {
@ -1307,7 +1313,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
return failure();
LDBG("with copy " << *copyOp);
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
// Find the fill into `viewOrAlloc` without interleaved uses before the
// copy.
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
@ -1468,7 +1475,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
/// ```
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
FailureOr<Operation *> conv() {
if (!valid)
return failure();
@ -1483,7 +1491,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
// When strideW == 1, we can batch the contiguous loads and avoid unrolling
// When strideW == 1, we can batch the contiguous loads and avoid
// unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
@ -1500,7 +1509,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0].
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
// 0].
Value lhs = builder.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
@ -1591,7 +1601,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
/// ```
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
FailureOr<Operation *> dilatedConv() {
if (!valid)
return failure();
@ -1605,7 +1616,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
// When strideW == 1, we can batch the contiguous loads and avoid unrolling
// When strideW == 1, we can batch the contiguous loads and avoid
// unrolling
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
@ -1621,7 +1633,8 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = builder.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c} @ [0, 0].