forked from OSchip/llvm-project
[mlir][Linalg] NFC - Drop vectorization reliance on ConvolutionOpInterface
Differential revision: https://reviews.llvm.org/D117323
This commit is contained in:
parent
fd1dce35bd
commit
efdd4c169d
|
@ -43,8 +43,9 @@ using namespace mlir::linalg;
|
|||
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
|
||||
|
||||
static FailureOr<Operation *>
|
||||
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp);
|
||||
/// Try to vectorize `convOp` as a convolution.
|
||||
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
|
||||
LinalgOp convOp);
|
||||
|
||||
/// Return the unique instance of OpType in `block` if it is indeed unique.
|
||||
/// Return null if none or more than 1 instances exist.
|
||||
|
@ -636,13 +637,12 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
|
|||
SmallVector<Value> results;
|
||||
// TODO: isaConvolutionOpInterface that can also infer from generic
|
||||
// features. Will require stride/dilation attributes inference.
|
||||
if (auto convOp = dyn_cast<ConvolutionOpInterface>(linalgOp.getOperation())) {
|
||||
LDBG("Vectorize as a conv: " << linalgOp);
|
||||
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, convOp);
|
||||
if (failed(convOr))
|
||||
return failure();
|
||||
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
|
||||
if (succeeded(convOr)) {
|
||||
llvm::append_range(results, (*convOr)->getResults());
|
||||
} else {
|
||||
if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
|
||||
return failure();
|
||||
LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
|
||||
if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
|
||||
return failure();
|
||||
|
@ -1640,40 +1640,39 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
/// Helper function to vectorize a `linalgOp` with convolution semantics.
|
||||
/// Helper function to vectorize a LinalgOp with convolution semantics.
|
||||
// TODO: extend the generic vectorization to support windows and drop this.
|
||||
static FailureOr<Operation *>
|
||||
vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
|
||||
// TODO: these are legitimately part of ConvolutionOpInterface.
|
||||
auto strides = convOp->getAttrOfType<DenseIntElementsAttr>("strides");
|
||||
auto dilations = convOp->getAttrOfType<DenseIntElementsAttr>("dilations");
|
||||
static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
|
||||
// The ConvolutionOpInterface gives us guarantees of existence for
|
||||
// strides/dilations. However, we do not need to rely on those, we can simply
|
||||
// use them if present, otherwise use the default and let the generic conv.
|
||||
// matcher in the ConvGenerator succeed or fail.
|
||||
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
|
||||
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
|
||||
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
|
||||
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
|
||||
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
|
||||
Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
|
||||
Conv1DNwcGenerator e(b, op, stride, dilation);
|
||||
auto res = e.generateConv();
|
||||
if (succeeded(res))
|
||||
return res;
|
||||
return e.generateDilatedConv();
|
||||
}
|
||||
|
||||
struct VectorizeConvolution
|
||||
: public OpInterfaceRewritePattern<ConvolutionOpInterface> {
|
||||
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ConvolutionOpInterface convOp,
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
FailureOr<Operation *> resultOrFail =
|
||||
vectorizeConvolution(rewriter, convOp);
|
||||
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
|
||||
if (failed(resultOrFail))
|
||||
return failure();
|
||||
Operation *newOp = *resultOrFail;
|
||||
if (newOp->getNumResults() == 0) {
|
||||
rewriter.eraseOp(convOp.getOperation());
|
||||
rewriter.eraseOp(op.getOperation());
|
||||
return success();
|
||||
}
|
||||
assert(newOp->getNumResults() == 1 && "expected single result");
|
||||
rewriter.replaceOp(convOp.getOperation(), newOp->getResult(0));
|
||||
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue