diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 5c0d1dc3a2fa..8f422d284df6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -263,12 +263,8 @@ Optional promoteSubViews(OpBuilder &b, LinalgOp op, OperationFolder *folder = nullptr); /// Emit a suitable vector form for a Linalg op with fully static shape. -struct VectorizedLinalgOp { - SmallVector tensorResults; - VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default; -}; -Optional vectorizeLinalgOp(OpBuilder &builder, - Operation *op); +LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &newResults); /// Emits a loop nest of `LoopTy` with the proper body for `op`. template diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index dd92ccd838cd..7f604807030d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -468,11 +468,11 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); - Optional res = vectorizeLinalgOp(rewriter, op); - if (!res) + SmallVector newResults; + if (failed(vectorizeLinalgOp(rewriter, op, newResults))) return failure(); - if (!res->tensorResults.empty()) - rewriter.replaceOp(op, res->tensorResults); + if (!newResults.empty()) + rewriter.replaceOp(op, newResults); else rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index f471ab0ebd75..48b6165d7b68 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -139,16 +139,16 @@ using CustomVectorizationHook = std::function; /// Helper function to vectorize the terminator of a `linalgOp`. New result -/// vector values are appended to `results`. -/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm -/// that it should not try to map produced operations: this is the purpose of -/// the `results` argument to capture such values and make them available for -/// RAUW to the vectorization algorithm. -/// This function is meant to be used as a CustomVectorizationHook. +/// vector values are appended to `newResults`. Return +/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it +/// should not try to map produced operations and instead return the results +/// using the `newResults` vector making them available to the +/// vectorization algorithm for RAUW. This function is meant to be used as a +/// CustomVectorizationHook. static VectorizationResult vectorizeLinalgYield(OpBuilder &builder, Operation *op, const BlockAndValueMapping &bvm, LinalgOp linalgOp, - SmallVectorImpl &results) { + SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); if (!yieldOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; @@ -156,10 +156,10 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op, // TODO: Scan for an opportunity for reuse. // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); - Value result = buildVectorWrite(builder, vectorValue, - linalgOp.getOutput(outputs.index())); - if (result) - results.push_back(result); + Value newResult = buildVectorWrite(builder, vectorValue, + linalgOp.getOutput(outputs.index())); + if (newResult) + newResults.push_back(newResult); } return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } @@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op, /// TODO: Reuse opportunities for RAR dependencies. /// 4. Register CustomVectorizationHook for YieldOp to capture the results. /// 5. Iteratively call vectorizeOneOp on the region operations. -static Optional vectorizeAsLinalgGeneric( - OpBuilder &builder, LinalgOp linalgOp, +LogicalResult vectorizeAsLinalgGeneric( + OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults, ArrayRef customVectorizationHooks = {}) { // 1. Certain Linalg ops do not have a region but only a region builder. // If so, build the region so we can vectorize. @@ -290,11 +290,10 @@ static Optional vectorizeAsLinalgGeneric( } // 4. Register CustomVectorizationHook for yieldOp. - SmallVector results; CustomVectorizationHook vectorizeYield = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { - return vectorizeLinalgYield(builder, op, bvm, linalgOp, results); + return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults); }; // Append the vectorizeYield hook. auto hooks = llvm::to_vector<4>(customVectorizationHooks); @@ -305,7 +304,7 @@ static Optional vectorizeAsLinalgGeneric( VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op); - return llvm::None; + return failure(); } if (result.status == VectorizationStatus::NewOp) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " @@ -314,7 +313,7 @@ static Optional vectorizeAsLinalgGeneric( } } - return VectorizedLinalgOp{{results}}; + return success(); } /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. @@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) { return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); } -static Optional vectorizeContraction(OpBuilder &builder, - LinalgOp linalgOp) { +static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp, + SmallVectorImpl &newResults) { assert(isaContractionOpInterface(linalgOp) && "expected vectorizeContraction preconditions to be met"); Location loc = linalgOp.getLoc(); @@ -383,7 +382,8 @@ static Optional vectorizeContraction(OpBuilder &builder, linalgOp.indexing_maps(), linalgOp.iterator_types()); return VectorizationResult{VectorizationStatus::NewOp, contract}; }; - return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction}); + return vectorizeAsLinalgGeneric(builder, linalgOp, newResults, + {vectorizeContraction}); } LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { @@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { return success(isaContractionOpInterface(linalgOp)); } -Optional mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, - Operation *op) { +LogicalResult +mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &newResults) { if (failed(vectorizeLinalgOpPrecondition(op))) - return llvm::None; + return failure(); edsc::ScopedContext scope(builder, op->getLoc()); if (isElementwise(op)) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " << "Vectorize linalg op as a generic: " << *op); - return vectorizeAsLinalgGeneric(builder, cast(op)); + return vectorizeAsLinalgGeneric(builder, cast(op), newResults); } - return vectorizeContraction(builder, cast(op)); + return vectorizeContraction(builder, cast(op), newResults); } //----------------------------------------------------------------------------//