[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface

Differential revision: https://reviews.llvm.org/D117323
This commit is contained in:
Nicolas Vasilache 2022-01-18 08:54:42 +00:00
parent fd1dce35bd
commit efdd4c169d
1 changed files with 21 additions and 22 deletions

View File

@ -43,8 +43,9 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
static FailureOr<Operation *>
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
LinalgOp convOp);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
SmallVector<Value> results;
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
LDBG("Vectorize as a conv: " << linalgOp);
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
if (failed(convOr))
return failure();
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
} else {
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
return failure();
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
return failure();
@ -1640,40 +1640,39 @@ private:
};
} // namespace
/// Helper function to vectorize a `linalgOp` with convolution semantics.
/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
static FailureOr<Operation *>
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
// TODO: these are legitimately part of ConvolutionOpInterface.
auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
// use them if present, otherwise use the default and let the generic conv.
// matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
Conv1DNwcGenerator e(b, op, stride, dilation);
auto res = e.generateConv();
if (succeeded(res))
return res;
return e.generateDilatedConv();
}
struct VectorizeConvolution
: public OpInterfaceRewritePattern<ConvolutionOpInterface> {
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
FailureOr<Operation *> resultOrFail =
vectorizeConvolution(rewriter, convOp);
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
if (failed(resultOrFail))
return failure();
Operation *newOp = *resultOrFail;
if (newOp->getNumResults() == 0) {
rewriter.eraseOp(convOp.getOperation());
rewriter.eraseOp(op.getOperation());
return success();
}
assert(newOp->getNumResults() == 1 && "expected single result");
rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
return success();
}
};