[mlir][linalg] Cleanup LinalgOp usage in vectorization (NFC).

Replace the uses of deprecated Structured Op Interface methods in Vectorization.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103410
This commit is contained in:
Tobias Gysi 2021-06-01 07:48:03 +00:00
parent a3b8695bf5
commit 912ebf60b1
1 changed files with 71 additions and 73 deletions

View File

@ -116,14 +116,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
/// Linalg. This limitation is motivated by the fact that e.g.
/// min(max(X)) != max(min(X))
// TODO: use in LinalgOp verification, there is a circular dependency atm.
static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
unsigned yieldNum =
outputOperand.getOperandNumber() - linalgOp.getNumInputs();
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
llvm::SetVector<Operation *> backwardSlice, forwardSlice;
BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
outputOperand.getOperandNumber());
outputOperand->getOperandNumber());
Value yieldVal = yieldOp->getOperand(yieldNum);
getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
return op->getParentOp() == linalgOp;
@ -186,16 +186,15 @@ getKindForOp(Operation *reductionOp) {
/// return a new vector.broadcast to `shape`.
/// Otherwise, just return value.
static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
Value value, OpOperand &outputOperand) {
assert(targetVectorType.getShape() ==
outputOperand.get().getType().cast<ShapedType>().getShape());
Value value, OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
auto vecType = value.getType().dyn_cast<VectorType>();
if (!vecType || vecType.getShape() == targetVectorType.getShape())
return value;
// At this point, we know we need to reduce. Detect the reduction operator.
// TODO: Use the generic reduction detection util.
Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
@ -235,23 +234,22 @@ static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
/// currently being vectorized. If `dest` has null rank, build an memref.store.
/// Return the produced value or null if no value is produced.
static Value buildVectorWrite(OpBuilder &b, Value value,
OpOperand &outputOperand) {
OpOperand *outputOperand) {
Operation *write;
Location loc = value.getLoc();
auto shapedType = outputOperand.get().getType().cast<ShapedType>();
if (VectorType vectorType =
extractVectorTypeFromShapedValue(outputOperand.get())) {
auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
AffineMap map = reindexIndexingMap(
linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
SmallVector<Value> indices(shapedType.getRank(),
extractVectorTypeFromShapedValue(outputOperand->get())) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap map =
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
b.create<ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(b, value, vectorType.getShape());
value = reduceIfNeeded(b, vectorType, value, outputOperand);
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand.get(),
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
write = b.create<memref::StoreOp>(loc, value, outputOperand.get());
write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
}
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
if (!write->getResults().empty())
@ -284,7 +282,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
// TODO: use a map.
Value vectorValue = bvm.lookup(outputs.value());
Value newResult = buildVectorWrite(
b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
if (newResult)
newResults.push_back(newResult);
}
@ -422,8 +420,8 @@ static bool isElementwise(Operation *op) {
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
// TODO: relax the restrictions on indexing map.
for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
if (!linalgOp.getOutputIndexingMap(i).isIdentity())
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
return false;
}
if (linalgOp->getNumRegions() != 1)
@ -479,36 +477,37 @@ LogicalResult vectorizeAsLinalgGeneric(
// 3. Turn all BBArgs into vector.transfer_read / load.
SmallVector<AffineMap> indexings;
for (auto bbarg : block.getArguments()) {
Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
// TODO: 0-d vectors.
if (shapedType.getShape().empty()) {
Value loaded = b.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
if (linalgOp.getShape(opOperand).empty()) {
Value loaded =
b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << loaded);
bvm.map(bbarg, loaded);
bvm.map(shapedArg, loaded);
bvm.map(opOperand->get(), loaded);
continue;
}
AffineMap map;
VectorType vectorType;
if (broadcastToMaximalCommonShape) {
map = inverseAndBroadcastProjectedPermuation(
linalgOp.getIndexingMap(bbarg.getArgNumber()));
vectorType =
VectorType::get(commonVectorShape, shapedType.getElementType());
linalgOp.getTiedIndexingMap(opOperand));
vectorType = VectorType::get(
commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
} else {
map = inversePermutation(
reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
vectorType = VectorType::get(map.compose(shapedType.getShape()),
shapedType.getElementType());
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
vectorType =
VectorType::get(map.compose(linalgOp.getShape(opOperand)),
getElementTypeOrSelf(opOperand->get().getType()));
}
Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map);
Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
<< bbarg.getArgNumber() << "): " << vectorRead);
bvm.map(bbarg, vectorRead);
bvm.map(shapedArg, vectorRead);
bvm.map(opOperand->get(), vectorRead);
}
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@ -562,7 +561,8 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
const BlockAndValueMapping &bvm) -> VectorizationResult {
if (!isa<MulIOp, MulFOp>(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto outShape = linalgOp.getOutputShapedType(0).getShape();
ArrayRef<int64_t> outShape =
linalgOp.getShape(linalgOp.getOutputOperand(0));
auto vType = outShape.empty()
? op->getResult(0).getType()
: VectorType::get(outShape, op->getResult(0).getType());
@ -574,13 +574,14 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
// TODO: consider dropping contraction special casing altogether, this will
// require more advanced canonicalizations involving vector.multi_reduction
// that are not yet available.
SmallVector<AffineMap> indexingMaps{
inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
.compose(linalgOp.getIndexingMap(0)),
inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
.compose(linalgOp.getIndexingMap(1)),
inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
.compose(linalgOp.getIndexingMap(2))};
SmallVector<AffineMap> indexingMaps;
indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
llvm::transform(linalgOp.getIndexingMaps(),
std::back_inserter(indexingMaps),
[](AffineMap indexingMap) {
return inversePermutation(reindexIndexingMap(indexingMap))
.compose(indexingMap);
});
Operation *contract = b.create<vector::ContractionOp>(
loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
@ -601,8 +602,8 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
return failure();
for (auto &operand : op.getOutputOpOperands()) {
Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
for (OpOperand *opOperand : op.getOutputOperands()) {
Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
if (!getKindForOp(reductionOp))
return failure();
}
@ -612,12 +613,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
for (Value operand : linalgOp.getShapedOperands())
if (!operand.getType().cast<ShapedType>().hasStaticShape())
return failure();
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (linalgOp.hasDynamicShape())
return failure();
if (isElementwise(op))
return success();
if (isaContractionOpInterface(linalgOp))
@ -722,13 +719,14 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
ShapedType inShapeType = op.getInputShapedType(0);
ShapedType kShapeType = op.getInputShapedType(1);
OpOperand *input = op.getInputOperand(0);
OpOperand *kernel = op.getInputOperand(1);
OpOperand *output = op.getOutputOperand(0);
ArrayRef<int64_t> inShape = op.getShape(input);
ArrayRef<int64_t> kShape = op.getShape(kernel);
ArrayRef<int64_t> inShape = inShapeType.getShape();
ArrayRef<int64_t> kShape = kShapeType.getShape();
if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
if (llvm::any_of(inShape, ShapedType::isDynamic) ||
llvm::any_of(kShape, ShapedType::isDynamic))
return failure();
SmallVector<AffineExpr, 4> mapping;
@ -747,22 +745,18 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
}
}
Value input = op.getInput(0);
Value kernel = op.getInput(1);
Value output = op.getOutputBuffer(0);
unsigned rank = inShapeType.getRank();
unsigned numDims = mapping.size();
Type elemType = inShapeType.getElementType();
int64_t rank = op.getRank(input);
int64_t numDims = mapping.size();
Type elemType = getElementTypeOrSelf(input->get().getType());
auto map = AffineMap::get(rank, 0, mapping, context);
SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
auto vecType = VectorType::get(vectorDims, elemType);
auto inputVec =
rewriter.create<vector::TransferReadOp>(loc, vecType, input, zeros, map);
auto kernelVec =
rewriter.create<vector::TransferReadOp>(loc, vecType, kernel, zeros, map);
auto inputVec = rewriter.create<vector::TransferReadOp>(
loc, vecType, input->get(), zeros, map);
auto kernelVec = rewriter.create<vector::TransferReadOp>(
loc, vecType, kernel->get(), zeros, map);
auto acc = rewriter.create<ConstantOp>(loc, elemType,
rewriter.getZeroAttr(elemType));
@ -779,7 +773,8 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
rewriter.getAffineMapArrayAttr(indexingMaps),
rewriter.getStrArrayAttr(iteratorTypes));
rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
rewriter.create<memref::StoreOp>(loc, result, output->get(),
ValueRange(zeros));
rewriter.eraseOp(op);
return success();
}
@ -939,7 +934,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
CopyOp copyOp;
for (auto &u : subView.getUses()) {
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
if (newCopyOp.getOutputBuffer(0) != subView)
assert(newCopyOp.output().getType().isa<MemRefType>());
if (newCopyOp.output() != subView)
continue;
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "copy candidate " << *newCopyOp);
@ -958,7 +954,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
assert(newFillOp.output().getType().isa<MemRefType>());
if (newFillOp.output() != viewOrAlloc)
continue;
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
<< "fill candidate " << *newFillOp);
@ -976,7 +973,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
<< "with maybeFillOp " << *maybeFillOp);
// `in` is the subview that linalg.copy reads. Replace it.
Value in = copyOp.getInput(0);
Value in = copyOp.input();
// linalg.copy + linalg.fill can be used to create a padded local buffer.
// The `masked` attribute is only valid on this padded buffer.
@ -1014,7 +1011,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
CopyOp copyOp;
for (auto &u : subViewOp.getResult().getUses()) {
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
if (newCopyOp.getInput(0) != subView)
if (newCopyOp.getInputOperand(0)->get() != subView)
continue;
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
continue;
@ -1026,7 +1023,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
return failure();
// `out` is the subview copied into that we replace.
Value out = copyOp.getOutputBuffer(0);
assert(copyOp.output().getType().isa<MemRefType>());
Value out = copyOp.output();
// Forward vector.transfer into copy.
// linalg.copy + linalg.fill can be used to create a padded local buffer.