forked from OSchip/llvm-project
[mlir][linalg] Cleanup LinalgOp usage in op declarations.
Replace the uses of deprecated Structured Op Interface methods in LinalgOps.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103506
This commit is contained in:
parent
2cf78d4ead
commit
8fb6c31cbb
|
@ -375,11 +375,12 @@ ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
|
|||
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
|
||||
|
||||
static LogicalResult verify(CopyOp op) {
|
||||
auto outputViewType = op.getOutputShapedType(0);
|
||||
auto inputViewType = op.getInputShapedType(0);
|
||||
if (inputViewType.getElementType() != outputViewType.getElementType())
|
||||
OpOperand *output = op.getOutputOperand(0);
|
||||
OpOperand *input = op.getInputOperand(0);
|
||||
if (getElementTypeOrSelf(input->get().getType()) !=
|
||||
getElementTypeOrSelf(output->get().getType()))
|
||||
return op.emitOpError("expects views of the same type");
|
||||
if (inputViewType.getRank() != outputViewType.getRank())
|
||||
if (op.getRank(input) != op.getRank(output))
|
||||
return op.emitOpError("expects views of the same rank");
|
||||
auto rank = op.getNumParallelLoops();
|
||||
auto inputPermutationMap = op.inputPermutation();
|
||||
|
@ -449,11 +450,11 @@ ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
|
|||
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
|
||||
|
||||
static LogicalResult verify(FillOp op) {
|
||||
auto viewType = op.getOutputShapedType(0);
|
||||
auto fillType = op.value().getType();
|
||||
if (viewType.getElementType() != fillType)
|
||||
OpOperand *output = op.getOutputOperand(0);
|
||||
Type fillType = op.value().getType();
|
||||
if (getElementTypeOrSelf(output->get().getType()) != fillType)
|
||||
return op.emitOpError("expects fill type to match view elemental type");
|
||||
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
|
||||
if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
|
||||
return op.emitOpError(
|
||||
"expected fill op with no result value to use memref type");
|
||||
}
|
||||
|
@ -739,11 +740,13 @@ struct ConvertIndexedToGenericOp : OpRewritePattern<IndexedGenericOp> {
|
|||
|
||||
// Create a generic replacement operation and clone the body.
|
||||
rewriter.setInsertionPointAfter(indexedOp);
|
||||
SmallVector<Value> inputOperands = indexedOp.getInputOperands();
|
||||
SmallVector<Value> outputOperands = indexedOp.getOutputOperands();
|
||||
SmallVector<StringRef> iterators = llvm::to_vector<4>(
|
||||
indexedOp.iterator_types().getAsValueRange<StringAttr>());
|
||||
GenericOp genericOp = rewriter.create<GenericOp>(
|
||||
indexedOp.getLoc(), indexedOp->getResultTypes(), indexedOp.getInputs(),
|
||||
indexedOp.getOutputs(), indexedOp.getIndexingMaps(), iterators);
|
||||
indexedOp.getLoc(), indexedOp->getResultTypes(), inputOperands,
|
||||
outputOperands, indexedOp.getIndexingMaps(), iterators);
|
||||
Region &genericRegion = genericOp.region();
|
||||
Region &indexedRegion = indexedOp.region();
|
||||
rewriter.cloneRegionBefore(indexedRegion, genericRegion,
|
||||
|
@ -2107,21 +2110,21 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
// Check the operand number and types must match the element types of the
|
||||
// LinalgOp interface's shaped operands.
|
||||
static LogicalResult verifyYield(linalg::YieldOp op,
|
||||
LinalgOp linalgOpInterface) {
|
||||
auto nOutputs = linalgOpInterface.getNumOutputs();
|
||||
if (op.getNumOperands() != nOutputs)
|
||||
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
|
||||
if (op.getNumOperands() != linalgOp.getNumOutputs())
|
||||
return op.emitOpError("expected number of yield values (")
|
||||
<< nOutputs << ") to match the number of operands of the enclosing "
|
||||
<< linalgOp.getNumOutputs()
|
||||
<< ") to match the number of operands of the enclosing "
|
||||
<< "LinalgOp (" << op.getNumOperands() << ")";
|
||||
|
||||
for (unsigned i = 0; i != nOutputs; ++i) {
|
||||
auto elementType =
|
||||
linalgOpInterface.getOutputShapedType(i).getElementType();
|
||||
if (op.getOperand(i).getType() != elementType)
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
OpOperand *outputOperand =
|
||||
linalgOp.getOutputOperand(opOperand.getOperandNumber());
|
||||
Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
|
||||
if (opOperand.get().getType() != elementType)
|
||||
return op.emitOpError("type of yield operand ")
|
||||
<< (i + 1) << " (" << op.getOperand(i).getType()
|
||||
<< ") doesn't match "
|
||||
<< (opOperand.getOperandNumber() + 1) << " ("
|
||||
<< opOperand.get().getType() << ") doesn't match "
|
||||
<< "the element type of the enclosing linalg.generic op ("
|
||||
<< elementType << ")";
|
||||
}
|
||||
|
@ -3096,14 +3099,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (Value v : op.getShapedOperands()) {
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
// Linalg "inputs" may be either tensor or memref type.
|
||||
// tensor<0xelt_type> is a convention that may not always mean
|
||||
// "0 iterations". Only erase in cases we see memref<...x0x...>.
|
||||
auto mt = v.getType().dyn_cast<MemRefType>();
|
||||
auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
|
||||
if (!mt)
|
||||
continue;
|
||||
if (llvm::is_contained(mt.getShape(), 0)) {
|
||||
if (llvm::is_contained(op.getShape(opOperand), 0)) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
@ -3119,10 +3122,10 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||
bool hasTensorCastOperand =
|
||||
llvm::any_of(op.getShapedOperands(), [&](Value v) {
|
||||
if (v.isa<BlockArgument>())
|
||||
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||
if (opOperand->get().isa<BlockArgument>())
|
||||
return false;
|
||||
auto castOp = v.getDefiningOp<tensor::CastOp>();
|
||||
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
return castOp && canFoldIntoConsumerOp(castOp);
|
||||
});
|
||||
if (!hasTensorCastOperand)
|
||||
|
@ -3133,16 +3136,18 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
SmallVector<Value, 4> newOperands;
|
||||
newOperands.reserve(op->getNumOperands());
|
||||
// Inputs may fold.
|
||||
for (Value v : op.getInputs()) {
|
||||
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
|
||||
newOperands.push_back(
|
||||
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
|
||||
for (OpOperand *opOperand : op.getInputOperands()) {
|
||||
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
|
||||
? tensorCastOp.source()
|
||||
: opOperand->get());
|
||||
}
|
||||
// Init tensors may fold, in which case the resultType must also change.
|
||||
for (Value v : op.getOutputs()) {
|
||||
auto tensorCastOp = v.getDefiningOp<tensor::CastOp>();
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
||||
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
|
||||
newOperands.push_back(fold ? tensorCastOp.getOperand()
|
||||
: opOperand->get());
|
||||
newResultTypes.push_back(newOperands.back().getType());
|
||||
}
|
||||
auto extraOperands = op.getAssumedNonShapedOperands();
|
||||
|
@ -3189,18 +3194,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
// in the case of duplicated inputs, the canonical input could be some other
|
||||
// input `< i`. That is, a later input will have some earlier input as its
|
||||
// canonical input.
|
||||
llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
|
||||
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
|
||||
// For later remapping tasks like deduplicating payload block arguments,
|
||||
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
|
||||
// convenient.
|
||||
SmallVector<int, 6> canonicalInputIndices;
|
||||
for (int i = 0, e = op.getNumInputs(); i != e; i++) {
|
||||
Value input = op.getInput(i);
|
||||
AffineMap indexingMap = op.getInputIndexingMap(i);
|
||||
SmallVector<unsigned> canonicalInputIndices;
|
||||
for (OpOperand *opOperand : op.getInputOperands()) {
|
||||
AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
|
||||
// STL-like maps have a convenient behavior for our use case here. In the
|
||||
// case of duplicate keys, the insertion is rejected, and the returned
|
||||
// iterator gives access to the value already in the map.
|
||||
auto pair = canonicalInput.insert({{input, indexingMap}, i});
|
||||
auto pair = canonicalInput.insert(
|
||||
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
|
||||
canonicalInputIndices.push_back(pair.first->second);
|
||||
}
|
||||
|
||||
|
@ -3209,26 +3214,29 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
return failure();
|
||||
|
||||
// The operands for the newly canonicalized op.
|
||||
SmallVector<Value, 6> newOperands;
|
||||
for (auto v : llvm::enumerate(op.getInputs()))
|
||||
if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
|
||||
newOperands.push_back(v.value());
|
||||
llvm::append_range(newOperands, op.getOutputs());
|
||||
SmallVector<Value> newOperands;
|
||||
for (OpOperand *opOperand : op.getInputOperands())
|
||||
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
|
||||
opOperand->getOperandNumber())
|
||||
newOperands.push_back(opOperand->get());
|
||||
SmallVector<Value> outputOperands = op.getOutputOperands();
|
||||
llvm::append_range(newOperands, outputOperands);
|
||||
llvm::append_range(newOperands, op.getAssumedNonShapedOperands());
|
||||
|
||||
// Repair the indexing maps by filtering out the ones that have been
|
||||
// eliminated.
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
for (OpOperand *opOperand : op.getInputOperands())
|
||||
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
|
||||
opOperand->getOperandNumber())
|
||||
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
|
||||
for (OpOperand *opOperand : op.getOutputOperands())
|
||||
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
|
||||
|
||||
// Clone the old op with new operands.
|
||||
Operation *newOp =
|
||||
op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
|
||||
auto newLinalgOp = cast<LinalgOp>(newOp);
|
||||
|
||||
// Repair the indexing maps by filtering out the ones that have been
|
||||
// eliminated.
|
||||
SmallVector<AffineMap, 6> newIndexingMaps;
|
||||
for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
|
||||
if (canonicalInputIndices[i] == i)
|
||||
newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
|
||||
for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
|
||||
newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
|
||||
newOp->setAttr("indexing_maps",
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMaps));
|
||||
|
||||
|
@ -3243,18 +3251,18 @@ struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
// Repair the payload entry block by RAUW'ing redundant arguments and
|
||||
// erasing them.
|
||||
Block &payload = newOp->getRegion(0).front();
|
||||
for (int i = 0, e = op.getNumInputs(); i < e; i++) {
|
||||
SmallVector<OpOperand *> inputOperands = op.getInputOperands();
|
||||
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
|
||||
// Iterate in reverse, so that we erase later args first, preventing the
|
||||
// argument list from shifting unexpectedly and invalidating all our
|
||||
// indices.
|
||||
int reversed = e - i - 1;
|
||||
int canonicalIndex = canonicalInputIndices[reversed];
|
||||
if (canonicalInputIndices[reversed] == reversed)
|
||||
unsigned operandNumber = opOperand->getOperandNumber();
|
||||
if (canonicalInputIndices[operandNumber] == operandNumber)
|
||||
continue;
|
||||
payload.getArgument(bbArgBaseOffset + reversed)
|
||||
.replaceAllUsesWith(
|
||||
payload.getArgument(bbArgBaseOffset + canonicalIndex));
|
||||
payload.eraseArgument(bbArgBaseOffset + reversed);
|
||||
payload.getArgument(bbArgBaseOffset + operandNumber)
|
||||
.replaceAllUsesWith(payload.getArgument(
|
||||
bbArgBaseOffset + canonicalInputIndices[operandNumber]));
|
||||
payload.eraseArgument(bbArgBaseOffset + operandNumber);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
|
|
Loading…
Reference in New Issue