forked from OSchip/llvm-project
[mlir][linalg] refactor the result handling during vectorization.
Return the vectorization results using a vector passed by reference instead of returning them embedded in a structure. Differential Revision: https://reviews.llvm.org/D98182
This commit is contained in:
parent
e31c77b182
commit
c1a4cd551f
|
@ -263,12 +263,8 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
|
|||
OperationFolder *folder = nullptr);
|
||||
|
||||
/// Emit a suitable vector form for a Linalg op with fully static shape.
|
||||
struct VectorizedLinalgOp {
|
||||
SmallVector<Value> tensorResults;
|
||||
VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
|
||||
};
|
||||
Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
|
||||
Operation *op);
|
||||
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
|
||||
SmallVectorImpl<Value> &newResults);
|
||||
|
||||
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
|
||||
template <typename LoopTy>
|
||||
|
|
|
@ -468,11 +468,11 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
|
|||
return failure();
|
||||
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
|
||||
return failure();
|
||||
Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
|
||||
if (!res)
|
||||
SmallVector<Value> newResults;
|
||||
if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
|
||||
return failure();
|
||||
if (!res->tensorResults.empty())
|
||||
rewriter.replaceOp(op, res->tensorResults);
|
||||
if (!newResults.empty())
|
||||
rewriter.replaceOp(op, newResults);
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
|
|
|
@ -139,16 +139,16 @@ using CustomVectorizationHook = std::function<VectorizationResult(
|
|||
Operation *, const BlockAndValueMapping &)>;
|
||||
|
||||
/// Helper function to vectorize the terminator of a `linalgOp`. New result
|
||||
/// vector values are appended to `results`.
|
||||
/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
|
||||
/// that it should not try to map produced operations: this is the purpose of
|
||||
/// the `results` argument to capture such values and make them available for
|
||||
/// RAUW to the vectorization algorithm.
|
||||
/// This function is meant to be used as a CustomVectorizationHook.
|
||||
/// vector values are appended to `newResults`. Return
|
||||
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
|
||||
/// should not try to map produced operations and instead return the results
|
||||
/// using the `newResults` vector making them available to the
|
||||
/// vectorization algorithm for RAUW. This function is meant to be used as a
|
||||
/// CustomVectorizationHook.
|
||||
static VectorizationResult
|
||||
vectorizeLinalgYield(OpBuilder &builder, Operation *op,
|
||||
const BlockAndValueMapping &bvm, LinalgOp linalgOp,
|
||||
SmallVectorImpl<Value> &results) {
|
||||
SmallVectorImpl<Value> &newResults) {
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
|
||||
if (!yieldOp)
|
||||
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
||||
|
@ -156,10 +156,10 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
|
|||
// TODO: Scan for an opportunity for reuse.
|
||||
// TODO: use a map.
|
||||
Value vectorValue = bvm.lookup(outputs.value());
|
||||
Value result = buildVectorWrite(builder, vectorValue,
|
||||
linalgOp.getOutput(outputs.index()));
|
||||
if (result)
|
||||
results.push_back(result);
|
||||
Value newResult = buildVectorWrite(builder, vectorValue,
|
||||
linalgOp.getOutput(outputs.index()));
|
||||
if (newResult)
|
||||
newResults.push_back(newResult);
|
||||
}
|
||||
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
|
||||
}
|
||||
|
@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
|
|||
/// TODO: Reuse opportunities for RAR dependencies.
|
||||
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
|
||||
/// 5. Iteratively call vectorizeOneOp on the region operations.
|
||||
static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
|
||||
OpBuilder &builder, LinalgOp linalgOp,
|
||||
LogicalResult vectorizeAsLinalgGeneric(
|
||||
OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
|
||||
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
|
||||
// 1. Certain Linalg ops do not have a region but only a region builder.
|
||||
// If so, build the region so we can vectorize.
|
||||
|
@ -290,11 +290,10 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
|
|||
}
|
||||
|
||||
// 4. Register CustomVectorizationHook for yieldOp.
|
||||
SmallVector<Value> results;
|
||||
CustomVectorizationHook vectorizeYield =
|
||||
[&](Operation *op,
|
||||
const BlockAndValueMapping &bvm) -> VectorizationResult {
|
||||
return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
|
||||
return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
|
||||
};
|
||||
// Append the vectorizeYield hook.
|
||||
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
|
||||
|
@ -305,7 +304,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
|
|||
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
|
||||
if (result.status == VectorizationStatus::Failure) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
|
||||
return llvm::None;
|
||||
return failure();
|
||||
}
|
||||
if (result.status == VectorizationStatus::NewOp) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
|
||||
|
@ -314,7 +313,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
|
|||
}
|
||||
}
|
||||
|
||||
return VectorizedLinalgOp{{results}};
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
|
||||
|
@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) {
|
|||
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
|
||||
}
|
||||
|
||||
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
|
||||
LinalgOp linalgOp) {
|
||||
static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
|
||||
SmallVectorImpl<Value> &newResults) {
|
||||
assert(isaContractionOpInterface(linalgOp) &&
|
||||
"expected vectorizeContraction preconditions to be met");
|
||||
Location loc = linalgOp.getLoc();
|
||||
|
@ -383,7 +382,8 @@ static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
|
|||
linalgOp.indexing_maps(), linalgOp.iterator_types());
|
||||
return VectorizationResult{VectorizationStatus::NewOp, contract};
|
||||
};
|
||||
return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
|
||||
return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
|
||||
{vectorizeContraction});
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
|
@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
|||
return success(isaContractionOpInterface(linalgOp));
|
||||
}
|
||||
|
||||
Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
|
||||
Operation *op) {
|
||||
LogicalResult
|
||||
mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
|
||||
SmallVectorImpl<Value> &newResults) {
|
||||
if (failed(vectorizeLinalgOpPrecondition(op)))
|
||||
return llvm::None;
|
||||
return failure();
|
||||
|
||||
edsc::ScopedContext scope(builder, op->getLoc());
|
||||
if (isElementwise(op)) {
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "Vectorize linalg op as a generic: " << *op);
|
||||
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
||||
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
|
||||
}
|
||||
|
||||
return vectorizeContraction(builder, cast<LinalgOp>(op));
|
||||
return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
|
|
Loading…
Reference in New Issue