[mlir][linalg] Cleanup LinalgOp usage in op declarations.

Replace the uses of deprecated Structured Op Interface methods in LinalgOps.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103506
This commit is contained in:
Tobias Gysi 2021-06-03 13:33:39 +00:00
parent 2cf78d4ead
commit 8fb6c31cbb
1 changed files with 70 additions and 62 deletions

View File

@ -375,11 +375,12 @@ ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
OpOperand *output = op.getOutputOperand(0);
OpOperand *input = op.getInputOperand(0);
if (getElementTypeOrSelf(input->get().getType()) !=
getElementTypeOrSelf(output->get().getType()))
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
if (op.getRank(input) != op.getRank(output))
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
@ -449,11 +450,11 @@ ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
OpOperand *output = op.getOutputOperand(0);
Type fillType = op.value().getType();
if (getElementTypeOrSelf(output->get().getType()) != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
@ -739,11 +740,13 @@ struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
// Create a generic replacement operation and clone the body.
rewriter.setInsertionPointAfter(indexedOp);
SmallVector<Value> inputOperands = indexedOp.getInputOperands();
SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
SmallVector<StringRef> iterators = llvm::to_vector<4>(
indexedOp.iterator_types().getAsValueRange<StringAttr>());
GenericOp genericOp = rewriter.create<GenericOp>(
indexedOp.getLoc(), indexedOp->getResultTypes(), indexedOp.getInputs(),
indexedOp.getOutputs(), indexedOp.getIndexingMaps(), iterators);
indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
outputOperands, indexedOp.getIndexingMaps(), iterators);
Region &genericRegion = genericOp.region();
Region &indexedRegion = indexedOp.region();
rewriter.cloneRegionBefore(indexedRegion, genericRegion,
@ -2107,21 +2110,21 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
// Check the operand number and types must match the element types of the
// LinalgOp interface's shaped operands.
static LogicalResult verifyYield(linalg::YieldOp op,
LinalgOp linalgOpInterface) {
auto nOutputs = linalgOpInterface.getNumOutputs();
if (op.getNumOperands() != nOutputs)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
if (op.getNumOperands() != linalgOp.getNumOutputs())
return op.emitOpError("expected number of yield values (")
<< nOutputs << ") to match the number of operands of the enclosing "
<< linalgOp.getNumOutputs()
<< ") to match the number of operands of the enclosing "
<< "LinalgOp (" << op.getNumOperands() << ")";
for (unsigned i = 0; i != nOutputs; ++i) {
auto elementType =
linalgOpInterface.getOutputShapedType(i).getElementType();
if (op.getOperand(i).getType() != elementType)
for (OpOperand &opOperand : op->getOpOperands()) {
OpOperand *outputOperand =
linalgOp.getOutputOperand(opOperand.getOperandNumber());
Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
if (opOperand.get().getType() != elementType)
return op.emitOpError("type of yield operand ")
<< (i + 1) << " (" << op.getOperand(i).getType()
<< ") doesn't match "
<< (opOperand.getOperandNumber() + 1) << " ("
<< opOperand.get().getType() << ") doesn't match "
<< "the element type of the enclosing linalg.generic op ("
<< elementType << ")";
}
@ -3096,14 +3099,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
for (Value v : op.getShapedOperands()) {
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
// Linalg "inputs" may be either tensor or memref type.
// tensor<0xelt_type> is a convention that may not always mean
// "0 iterations". Only erase in cases we see memref<...x0x...>.
auto mt = v.getType().dyn_cast<MemRefType>();
auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
if (!mt)
continue;
if (llvm::is_contained(mt.getShape(), 0)) {
if (llvm::is_contained(op.getShape(opOperand), 0)) {
rewriter.eraseOp(op);
return success();
}
@ -3119,10 +3122,10 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
PatternRewriter &rewriter) const override {
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op.getShapedOperands(), [&](Value v) {
if (v.isa<BlockArgument>())
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
if (opOperand->get().isa<BlockArgument>())
return false;
auto castOp = v.getDefiningOp<tensor::CastOp>();
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
if (!hasTensorCastOperand)
@ -3133,16 +3136,18 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
// Inputs may fold.
for (Value v : op.getInputs()) {
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
newOperands.push_back(
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
for (OpOperand *opOperand : op.getInputOperands()) {
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
? tensorCastOp.source()
: opOperand->get());
}
// Init tensors may fold, in which case the resultType must also change.
for (Value v : op.getOutputs()) {
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
for (OpOperand *opOperand : op.getOutputOperands()) {
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
newOperands.push_back(fold ? tensorCastOp.getOperand()
: opOperand->get());
newResultTypes.push_back(newOperands.back().getType());
}
auto extraOperands = op.getAssumedNonShapedOperands();
@ -3189,18 +3194,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
// in the case of duplicated inputs, the canonical input could be some other
// input `< i`. That is, a later input will have some earlier input as its
// canonical input.
llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
// For later remapping tasks like deduplicating payload block arguments,
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
// convenient.
SmallVector<int, 6> canonicalInputIndices;
for (int i = 0, e = op.getNumInputs(); i != e; i++) {
Value input = op.getInput(i);
AffineMap indexingMap = op.getInputIndexingMap(i);
SmallVector<unsigned> canonicalInputIndices;
for (OpOperand *opOperand : op.getInputOperands()) {
AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
// STL-like maps have a convenient behavior for our use case here. In the
// case of duplicate keys, the insertion is rejected, and the returned
// iterator gives access to the value already in the map.
auto pair = canonicalInput.insert({{input, indexingMap}, i});
auto pair = canonicalInput.insert(
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
canonicalInputIndices.push_back(pair.first->second);
}
@ -3209,26 +3214,29 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
return failure();
// The operands for the newly canonicalized op.
SmallVector<Value, 6> newOperands;
for (auto v : llvm::enumerate(op.getInputs()))
if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
newOperands.push_back(v.value());
llvm::append_range(newOperands, op.getOutputs());
SmallVector<Value> newOperands;
for (OpOperand *opOperand : op.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newOperands.push_back(opOperand->get());
SmallVector<Value> outputOperands = op.getOutputOperands();
llvm::append_range(newOperands, outputOperands);
llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
// Repair the indexing maps by filtering out the ones that have been
// eliminated.
SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : op.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
for (OpOperand *opOperand : op.getOutputOperands())
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
// Clone the old op with new operands.
Operation *newOp =
op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
auto newLinalgOp = cast<LinalgOp>(newOp);
// Repair the indexing maps by filtering out the ones that have been
// eliminated.
SmallVector<AffineMap, 6> newIndexingMaps;
for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
if (canonicalInputIndices[i] == i)
newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
newOp->setAttr("indexing_maps",
rewriter.getAffineMapArrayAttr(newIndexingMaps));
@ -3243,18 +3251,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp->getRegion(0).front();
for (int i = 0, e = op.getNumInputs(); i < e; i++) {
SmallVector<OpOperand *> inputOperands = op.getInputOperands();
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
// Iterate in reverse, so that we erase later args first, preventing the
// argument list from shifting unexpectedly and invalidating all our
// indices.
int reversed = e - i - 1;
int canonicalIndex = canonicalInputIndices[reversed];
if (canonicalInputIndices[reversed] == reversed)
unsigned operandNumber = opOperand->getOperandNumber();
if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
payload.getArgument(bbArgBaseOffset + reversed)
.replaceAllUsesWith(
payload.getArgument(bbArgBaseOffset + canonicalIndex));
payload.eraseArgument(bbArgBaseOffset + reversed);
payload.getArgument(bbArgBaseOffset + operandNumber)
.replaceAllUsesWith(payload.getArgument(
bbArgBaseOffset + canonicalInputIndices[operandNumber]));
payload.eraseArgument(bbArgBaseOffset + operandNumber);
}
rewriter.replaceOp(op, newOp->getResults());