[mlir][linalg] refactor the result handling during vectorization.

Return the vectorization results using a vector passed by reference instead of returning them embedded in a structure.

Differential Revision: https://reviews.llvm.org/D98182
This commit is contained in:
Tobias Gysi 2021-03-08 16:08:40 +00:00
parent e31c77b182
commit c1a4cd551f
3 changed files with 32 additions and 35 deletions

View File

@ -263,12 +263,8 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
OperationFolder *folder = nullptr);
/// Emit a suitable vector form for a Linalg op with fully static shape.
struct VectorizedLinalgOp {
SmallVector<Value> tensorResults;
VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
};
Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
Operation *op);
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults);
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>

View File

@ -468,11 +468,11 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
if (!res)
SmallVector<Value> newResults;
if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
return failure();
if (!res->tensorResults.empty())
rewriter.replaceOp(op, res->tensorResults);
if (!newResults.empty())
rewriter.replaceOp(op, newResults);
else
rewriter.eraseOp(op);
return success();

View File

@ -139,16 +139,16 @@ using CustomVectorizationHook = std::function<VectorizationResult(
Operation *, const BlockAndValueMapping &)>;
/// Helper function to vectorize the terminator of a `linalgOp`. New result
/// vector values are appended to `results`.
/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
/// that it should not try to map produced operations: this is the purpose of
/// the `results` argument to capture such values and make them available for
/// RAUW to the vectorization algorithm.
/// This function is meant to be used as a CustomVectorizationHook.
/// vector values are appended to `newResults`. Return
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
/// should not try to map produced operations and instead return the results
/// using the `newResults` vector making them available to the
/// vectorization algorithm for RAUW. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult
vectorizeLinalgYield(OpBuilder &builder, Operation *op,
const BlockAndValueMapping &bvm, LinalgOp linalgOp,
SmallVectorImpl<Value> &results) {
SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
@ -156,10 +156,10 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
// TODO: Scan for an opportunity for reuse.
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value result = buildVectorWrite(builder, vectorValue,
linalgOp.getOutput(outputs.index()));
if (result)
results.push_back(result);
Value newResult = buildVectorWrite(builder, vectorValue,
linalgOp.getOutput(outputs.index()));
if (newResult)
newResults.push_back(newResult);
}
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
}
@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
/// TODO: Reuse opportunities for RAR dependencies.
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
/// 5. Iteratively call vectorizeOneOp on the region operations.
static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp,
LogicalResult vectorizeAsLinalgGeneric(
OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
// 1. Certain Linalg ops do not have a region but only a region builder.
// If so, build the region so we can vectorize.
@ -290,11 +290,10 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
}
// 4. Register CustomVectorizationHook for yieldOp.
SmallVector<Value> results;
CustomVectorizationHook vectorizeYield =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
};
// Append the vectorizeYield hook.
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@ -305,7 +304,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
return llvm::None;
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@ -314,7 +313,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
}
}
return VectorizedLinalgOp{{results}};
return success();
}
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) {
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
}
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
LinalgOp linalgOp) {
static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
assert(isaContractionOpInterface(linalgOp) &&
"expected vectorizeContraction preconditions to be met");
Location loc = linalgOp.getLoc();
@ -383,7 +382,8 @@ static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
linalgOp.indexing_maps(), linalgOp.iterator_types());
return VectorizationResult{VectorizationStatus::NewOp, contract};
};
return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
{vectorizeContraction});
}
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
return success(isaContractionOpInterface(linalgOp));
}
Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
Operation *op) {
LogicalResult
mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults) {
if (failed(vectorizeLinalgOpPrecondition(op)))
return llvm::None;
return failure();
edsc::ScopedContext scope(builder, op->getLoc());
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Vectorize linalg op as a generic: " << *op);
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
}
return vectorizeContraction(builder, cast<LinalgOp>(op));
return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
}
//----------------------------------------------------------------------------//