[mlir][linalg] Specialize LinalgOp canonicalization patterns (NFC).

Specialize the DeduplicateInputs and RemoveIdentityLinalgOps patterns for GenericOp instead of implementing them for the LinalgOp interface.

This revsion is based on https://reviews.llvm.org/D105622 that moves the logic to erase identity CopyOps in a separate pattern.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D105291
This commit is contained in:
Tobias Gysi 2021-07-28 11:24:27 +00:00
parent b0ef3d8f66
commit 09635dc7bf
2 changed files with 134 additions and 139 deletions

View File

@ -658,6 +658,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
let verifier = [{ return ::verify(*this); }];
let hasCanonicalizer = 1;
let hasFolder = 1;
}

View File

@ -671,6 +671,138 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
namespace {
// Deduplicate redundant args of a linalg generic op.
// An arg is redundant if it has the same Value and indexing map as another.
struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Associate each input to an equivalent "canonical" input that has the same
// Value and indexing map.
//
// In the non-duplicate case, input `i` will have canonical input `i`. But
// in the case of duplicated inputs, the canonical input could be some other
// input `< i`. That is, a later input will have some earlier input as its
// canonical input.
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
// For later remapping tasks like deduplicating payload block arguments,
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
// convenient.
SmallVector<unsigned> canonicalInputIndices;
for (OpOperand *opOperand : genericOp.getInputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
// STL-like maps have a convenient behavior for our use case here. In the
// case of duplicate keys, the insertion is rejected, and the returned
// iterator gives access to the value already in the map.
auto pair = canonicalInput.insert(
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
canonicalInputIndices.push_back(pair.first->second);
}
// If there are no duplicate args, then bail out.
if (canonicalInput.size() == genericOp.getNumInputs())
return failure();
// The operands for the newly canonicalized op.
SmallVector<Value> newInputOperands;
for (OpOperand *opOperand : genericOp.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newInputOperands.push_back(opOperand->get());
// Repair the indexing maps by filtering out the ones that have been
// eliminated.
SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : genericOp.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
for (OpOperand *opOperand : genericOp.getOutputOperands())
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
// Clone the old op with new operands.
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
auto newOp = rewriter.create<GenericOp>(
genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
genericOp.iterator_types(), genericOp.docAttr(),
genericOp.library_callAttr());
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
newOp.region().begin());
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp.region().front();
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
// Iterate in reverse, so that we erase later args first, preventing the
// argument list from shifting unexpectedly and invalidating all our
// indices.
unsigned operandNumber = opOperand->getOperandNumber();
if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
payload.getArgument(operandNumber)
.replaceAllUsesWith(
payload.getArgument(canonicalInputIndices[operandNumber]));
payload.eraseArgument(operandNumber);
}
rewriter.replaceOp(genericOp, newOp->getResults());
return success();
}
};
/// Remove generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
// Check that the body of the linalg operation is just a linalg.yield
// operation.
Block &body = genericOp.region().front();
if (!llvm::hasSingleElement(body))
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
if (!yieldOp)
return failure();
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
returnedArgs.push_back(genericOp->getOperand(argumentNumber));
}
if (returnedArgs.size() != genericOp->getNumResults())
return failure();
rewriter.replaceOp(genericOp, returnedArgs);
return success();
}
};
} // namespace
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
}
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
@ -2539,143 +2671,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
};
} // namespace
namespace {
// Deduplicate redundant args of a linalg op.
// An arg is redundant if it has the same Value and indexing map as another.
struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
// This pattern reduces the number of arguments of an op, which breaks
// the invariants of semantically charged named ops.
if (!isa<GenericOp>(op))
return failure();
// Associate each input to an equivalent "canonical" input that has the same
// Value and indexing map.
//
// In the non-duplicate case, input `i` will have canonical input `i`. But
// in the case of duplicated inputs, the canonical input could be some other
// input `< i`. That is, a later input will have some earlier input as its
// canonical input.
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
// For later remapping tasks like deduplicating payload block arguments,
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
// convenient.
SmallVector<unsigned> canonicalInputIndices;
for (OpOperand *opOperand : op.getInputOperands()) {
AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
// STL-like maps have a convenient behavior for our use case here. In the
// case of duplicate keys, the insertion is rejected, and the returned
// iterator gives access to the value already in the map.
auto pair = canonicalInput.insert(
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
canonicalInputIndices.push_back(pair.first->second);
}
// If there are no duplicate args, then bail out.
if (canonicalInput.size() == op.getNumInputs())
return failure();
// The operands for the newly canonicalized op.
SmallVector<Value> newOperands;
for (OpOperand *opOperand : op.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newOperands.push_back(opOperand->get());
SmallVector<Value> outputOperands = op.getOutputOperands();
llvm::append_range(newOperands, outputOperands);
// Repair the indexing maps by filtering out the ones that have been
// eliminated.
SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : op.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
for (OpOperand *opOperand : op.getOutputOperands())
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
// Clone the old op with new operands.
Operation *newOp =
op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
auto newLinalgOp = cast<LinalgOp>(newOp);
newOp->setAttr("indexing_maps",
rewriter.getAffineMapArrayAttr(newIndexingMaps));
// Set the number of inputs to the new value. The `clone` call above kept
// the value from the original op.
newLinalgOp.setNumInputs(canonicalInput.size());
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp->getRegion(0).front();
SmallVector<OpOperand *> inputOperands = op.getInputOperands();
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
// Iterate in reverse, so that we erase later args first, preventing the
// argument list from shifting unexpectedly and invalidating all our
// indices.
unsigned operandNumber = opOperand->getOperandNumber();
if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
payload.getArgument(operandNumber)
.replaceAllUsesWith(
payload.getArgument(canonicalInputIndices[operandNumber]));
payload.eraseArgument(operandNumber);
}
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
/// Remove generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
if (!isa<GenericOp>(op))
return failure();
if (!op.hasTensorSemantics())
return failure();
// Check all indexing maps are identity.
if (llvm::any_of(op.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
// Check that the body of the linalg operation is just a linalg.yield
// operation.
Block &body = op->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
if (!yieldOp)
return failure();
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value, 4> returnedArgs;
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
returnedArgs.push_back(op->getOperand(argumentNumber));
}
if (returnedArgs.size() != op.getOperation()->getNumResults())
return failure();
rewriter.replaceOp(op, returnedArgs);
return success();
}
};
} // namespace
#define LINALGOP_FOLDERS(XXX) \
LogicalResult XXX::fold(ArrayRef<Attribute>, \
SmallVectorImpl<OpFoldResult> &) { \
@ -2699,6 +2694,5 @@ LINALGOP_FOLDERS(GenericOp)
void LinalgDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
RemoveIdentityLinalgOps>(getContext());
results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
}