forked from OSchip/llvm-project
[mlir][linalg] Cleanup LinalgOp usage in vectorization (NFC).
Replace the uses of deprecated Structured Op Interface methods in Vectorization.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103410
This commit is contained in:
parent
a3b8695bf5
commit
912ebf60b1
|
@ -116,14 +116,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
|
|||
/// Linalg. This limitation is motivated by the fact that e.g.
|
||||
/// min(max(X)) != max(min(X))
|
||||
// TODO: use in LinalgOp verification, there is a circular dependency atm.
|
||||
static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
|
||||
static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
|
||||
unsigned yieldNum =
|
||||
outputOperand.getOperandNumber() - linalgOp.getNumInputs();
|
||||
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
|
||||
llvm::SetVector<Operation *> backwardSlice, forwardSlice;
|
||||
BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
|
||||
outputOperand.getOperandNumber());
|
||||
outputOperand->getOperandNumber());
|
||||
Value yieldVal = yieldOp->getOperand(yieldNum);
|
||||
getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
|
||||
return op->getParentOp() == linalgOp;
|
||||
|
@ -186,16 +186,15 @@ getKindForOp(Operation *reductionOp) {
|
|||
/// return a new vector.broadcast to `shape`.
|
||||
/// Otherwise, just return value.
|
||||
static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
|
||||
Value value, OpOperand &outputOperand) {
|
||||
assert(targetVectorType.getShape() ==
|
||||
outputOperand.get().getType().cast<ShapedType>().getShape());
|
||||
Value value, OpOperand *outputOperand) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
|
||||
auto vecType = value.getType().dyn_cast<VectorType>();
|
||||
if (!vecType || vecType.getShape() == targetVectorType.getShape())
|
||||
return value;
|
||||
// At this point, we know we need to reduce. Detect the reduction operator.
|
||||
// TODO: Use the generic reduction detection util.
|
||||
Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
|
||||
unsigned pos = 0;
|
||||
MLIRContext *ctx = b.getContext();
|
||||
SmallVector<AffineExpr> exprs;
|
||||
|
@ -235,23 +234,22 @@ static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
|
|||
/// currently being vectorized. If `dest` has null rank, build an memref.store.
|
||||
/// Return the produced value or null if no value is produced.
|
||||
static Value buildVectorWrite(OpBuilder &b, Value value,
|
||||
OpOperand &outputOperand) {
|
||||
OpOperand *outputOperand) {
|
||||
Operation *write;
|
||||
Location loc = value.getLoc();
|
||||
auto shapedType = outputOperand.get().getType().cast<ShapedType>();
|
||||
if (VectorType vectorType =
|
||||
extractVectorTypeFromShapedValue(outputOperand.get())) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
|
||||
AffineMap map = reindexIndexingMap(
|
||||
linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
|
||||
SmallVector<Value> indices(shapedType.getRank(),
|
||||
extractVectorTypeFromShapedValue(outputOperand->get())) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
AffineMap map =
|
||||
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
|
||||
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
|
||||
b.create<ConstantIndexOp>(loc, 0));
|
||||
value = broadcastIfNeeded(b, value, vectorType.getShape());
|
||||
value = reduceIfNeeded(b, vectorType, value, outputOperand);
|
||||
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand.get(),
|
||||
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
|
||||
indices, map);
|
||||
} else {
|
||||
write = b.create<memref::StoreOp>(loc, value, outputOperand.get());
|
||||
write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
|
||||
if (!write->getResults().empty())
|
||||
|
@ -284,7 +282,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
|
|||
// TODO: use a map.
|
||||
Value vectorValue = bvm.lookup(outputs.value());
|
||||
Value newResult = buildVectorWrite(
|
||||
b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
|
||||
b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
|
||||
if (newResult)
|
||||
newResults.push_back(newResult);
|
||||
}
|
||||
|
@ -422,8 +420,8 @@ static bool isElementwise(Operation *op) {
|
|||
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
|
||||
return false;
|
||||
// TODO: relax the restrictions on indexing map.
|
||||
for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
|
||||
if (!linalgOp.getOutputIndexingMap(i).isIdentity())
|
||||
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
|
||||
if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
|
||||
return false;
|
||||
}
|
||||
if (linalgOp->getNumRegions() != 1)
|
||||
|
@ -479,36 +477,37 @@ LogicalResult vectorizeAsLinalgGeneric(
|
|||
|
||||
// 3. Turn all BBArgs into vector.transfer_read / load.
|
||||
SmallVector<AffineMap> indexings;
|
||||
for (auto bbarg : block.getArguments()) {
|
||||
Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
|
||||
ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
|
||||
// TODO: 0-d vectors.
|
||||
if (shapedType.getShape().empty()) {
|
||||
Value loaded = b.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
|
||||
if (linalgOp.getShape(opOperand).empty()) {
|
||||
Value loaded =
|
||||
b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
|
||||
<< bbarg.getArgNumber() << "): " << loaded);
|
||||
bvm.map(bbarg, loaded);
|
||||
bvm.map(shapedArg, loaded);
|
||||
bvm.map(opOperand->get(), loaded);
|
||||
continue;
|
||||
}
|
||||
AffineMap map;
|
||||
VectorType vectorType;
|
||||
if (broadcastToMaximalCommonShape) {
|
||||
map = inverseAndBroadcastProjectedPermuation(
|
||||
linalgOp.getIndexingMap(bbarg.getArgNumber()));
|
||||
vectorType =
|
||||
VectorType::get(commonVectorShape, shapedType.getElementType());
|
||||
linalgOp.getTiedIndexingMap(opOperand));
|
||||
vectorType = VectorType::get(
|
||||
commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
|
||||
} else {
|
||||
map = inversePermutation(
|
||||
reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
|
||||
vectorType = VectorType::get(map.compose(shapedType.getShape()),
|
||||
shapedType.getElementType());
|
||||
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
|
||||
vectorType =
|
||||
VectorType::get(map.compose(linalgOp.getShape(opOperand)),
|
||||
getElementTypeOrSelf(opOperand->get().getType()));
|
||||
}
|
||||
Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map);
|
||||
Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
|
||||
<< bbarg.getArgNumber() << "): " << vectorRead);
|
||||
bvm.map(bbarg, vectorRead);
|
||||
bvm.map(shapedArg, vectorRead);
|
||||
bvm.map(opOperand->get(), vectorRead);
|
||||
}
|
||||
|
||||
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
|
||||
|
@ -562,7 +561,8 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
|
|||
const BlockAndValueMapping &bvm) -> VectorizationResult {
|
||||
if (!isa<MulIOp, MulFOp>(op))
|
||||
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
||||
auto outShape = linalgOp.getOutputShapedType(0).getShape();
|
||||
ArrayRef<int64_t> outShape =
|
||||
linalgOp.getShape(linalgOp.getOutputOperand(0));
|
||||
auto vType = outShape.empty()
|
||||
? op->getResult(0).getType()
|
||||
: VectorType::get(outShape, op->getResult(0).getType());
|
||||
|
@ -574,13 +574,14 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
|
|||
// TODO: consider dropping contraction special casing altogether, this will
|
||||
// require more advanced canonicalizations involving vector.multi_reduction
|
||||
// that are not yet available.
|
||||
SmallVector<AffineMap> indexingMaps{
|
||||
inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
|
||||
.compose(linalgOp.getIndexingMap(0)),
|
||||
inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
|
||||
.compose(linalgOp.getIndexingMap(1)),
|
||||
inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
|
||||
.compose(linalgOp.getIndexingMap(2))};
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
llvm::transform(linalgOp.getIndexingMaps(),
|
||||
std::back_inserter(indexingMaps),
|
||||
[](AffineMap indexingMap) {
|
||||
return inversePermutation(reindexIndexingMap(indexingMap))
|
||||
.compose(indexingMap);
|
||||
});
|
||||
Operation *contract = b.create<vector::ContractionOp>(
|
||||
loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
|
||||
b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
|
||||
|
@ -601,8 +602,8 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
|
|||
static LogicalResult reductionPreconditions(LinalgOp op) {
|
||||
if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
|
||||
return failure();
|
||||
for (auto &operand : op.getOutputOpOperands()) {
|
||||
Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
|
||||
if (!getKindForOp(reductionOp))
|
||||
return failure();
|
||||
}
|
||||
|
@ -612,12 +613,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
|
|||
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
||||
auto linalgOp = cast<linalg::LinalgOp>(op);
|
||||
// All types must be static shape to go to vector.
|
||||
for (Value operand : linalgOp.getShapedOperands())
|
||||
if (!operand.getType().cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
||||
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
||||
return failure();
|
||||
if (linalgOp.hasDynamicShape())
|
||||
return failure();
|
||||
if (isElementwise(op))
|
||||
return success();
|
||||
if (isaContractionOpInterface(linalgOp))
|
||||
|
@ -722,13 +719,14 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
|
|||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
ShapedType inShapeType = op.getInputShapedType(0);
|
||||
ShapedType kShapeType = op.getInputShapedType(1);
|
||||
OpOperand *input = op.getInputOperand(0);
|
||||
OpOperand *kernel = op.getInputOperand(1);
|
||||
OpOperand *output = op.getOutputOperand(0);
|
||||
ArrayRef<int64_t> inShape = op.getShape(input);
|
||||
ArrayRef<int64_t> kShape = op.getShape(kernel);
|
||||
|
||||
ArrayRef<int64_t> inShape = inShapeType.getShape();
|
||||
ArrayRef<int64_t> kShape = kShapeType.getShape();
|
||||
|
||||
if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
|
||||
if (llvm::any_of(inShape, ShapedType::isDynamic) ||
|
||||
llvm::any_of(kShape, ShapedType::isDynamic))
|
||||
return failure();
|
||||
|
||||
SmallVector<AffineExpr, 4> mapping;
|
||||
|
@ -747,22 +745,18 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
Value input = op.getInput(0);
|
||||
Value kernel = op.getInput(1);
|
||||
Value output = op.getOutputBuffer(0);
|
||||
|
||||
unsigned rank = inShapeType.getRank();
|
||||
unsigned numDims = mapping.size();
|
||||
Type elemType = inShapeType.getElementType();
|
||||
int64_t rank = op.getRank(input);
|
||||
int64_t numDims = mapping.size();
|
||||
Type elemType = getElementTypeOrSelf(input->get().getType());
|
||||
|
||||
auto map = AffineMap::get(rank, 0, mapping, context);
|
||||
SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
auto vecType = VectorType::get(vectorDims, elemType);
|
||||
|
||||
auto inputVec =
|
||||
rewriter.create<vector::TransferReadOp>(loc, vecType, input, zeros, map);
|
||||
auto kernelVec =
|
||||
rewriter.create<vector::TransferReadOp>(loc, vecType, kernel, zeros, map);
|
||||
auto inputVec = rewriter.create<vector::TransferReadOp>(
|
||||
loc, vecType, input->get(), zeros, map);
|
||||
auto kernelVec = rewriter.create<vector::TransferReadOp>(
|
||||
loc, vecType, kernel->get(), zeros, map);
|
||||
|
||||
auto acc = rewriter.create<ConstantOp>(loc, elemType,
|
||||
rewriter.getZeroAttr(elemType));
|
||||
|
@ -779,7 +773,8 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
|
|||
rewriter.getAffineMapArrayAttr(indexingMaps),
|
||||
rewriter.getStrArrayAttr(iteratorTypes));
|
||||
|
||||
rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
|
||||
rewriter.create<memref::StoreOp>(loc, result, output->get(),
|
||||
ValueRange(zeros));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
@ -939,7 +934,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
CopyOp copyOp;
|
||||
for (auto &u : subView.getUses()) {
|
||||
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
|
||||
if (newCopyOp.getOutputBuffer(0) != subView)
|
||||
assert(newCopyOp.output().getType().isa<MemRefType>());
|
||||
if (newCopyOp.output() != subView)
|
||||
continue;
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "copy candidate " << *newCopyOp);
|
||||
|
@ -958,7 +954,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
FillOp maybeFillOp;
|
||||
for (auto &u : viewOrAlloc.getUses()) {
|
||||
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
|
||||
if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
|
||||
assert(newFillOp.output().getType().isa<MemRefType>());
|
||||
if (newFillOp.output() != viewOrAlloc)
|
||||
continue;
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
||||
<< "fill candidate " << *newFillOp);
|
||||
|
@ -976,7 +973,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||
<< "with maybeFillOp " << *maybeFillOp);
|
||||
|
||||
// `in` is the subview that linalg.copy reads. Replace it.
|
||||
Value in = copyOp.getInput(0);
|
||||
Value in = copyOp.input();
|
||||
|
||||
// linalg.copy + linalg.fill can be used to create a padded local buffer.
|
||||
// The `masked` attribute is only valid on this padded buffer.
|
||||
|
@ -1014,7 +1011,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||
CopyOp copyOp;
|
||||
for (auto &u : subViewOp.getResult().getUses()) {
|
||||
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
|
||||
if (newCopyOp.getInput(0) != subView)
|
||||
if (newCopyOp.getInputOperand(0)->get() != subView)
|
||||
continue;
|
||||
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
|
||||
continue;
|
||||
|
@ -1026,7 +1023,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||
return failure();
|
||||
|
||||
// `out` is the subview copied into that we replace.
|
||||
Value out = copyOp.getOutputBuffer(0);
|
||||
assert(copyOp.output().getType().isa<MemRefType>());
|
||||
Value out = copyOp.output();
|
||||
|
||||
// Forward vector.transfer into copy.
|
||||
// linalg.copy + linalg.fill can be used to create a padded local buffer.
|
||||
|
|
Loading…
Reference in New Issue