forked from OSchip/llvm-project
[mlir][linalg] Specialize LinalgOp canonicalization patterns (NFC).
Specialize the DeduplicateInputs and RemoveIdentityLinalgOps patterns for GenericOp instead of implementing them for the LinalgOp interface. This revsion is based on https://reviews.llvm.org/D105622 that moves the logic to erase identity CopyOps in a separate pattern. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105291
This commit is contained in:
parent
b0ef3d8f66
commit
09635dc7bf
|
@ -658,6 +658,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
|
|||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -671,6 +671,138 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
|
|||
|
||||
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
|
||||
|
||||
namespace {
|
||||
// Deduplicate redundant args of a linalg generic op.
|
||||
// An arg is redundant if it has the same Value and indexing map as another.
|
||||
struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Associate each input to an equivalent "canonical" input that has the same
|
||||
// Value and indexing map.
|
||||
//
|
||||
// In the non-duplicate case, input `i` will have canonical input `i`. But
|
||||
// in the case of duplicated inputs, the canonical input could be some other
|
||||
// input `< i`. That is, a later input will have some earlier input as its
|
||||
// canonical input.
|
||||
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
|
||||
// For later remapping tasks like deduplicating payload block arguments,
|
||||
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
|
||||
// convenient.
|
||||
SmallVector<unsigned> canonicalInputIndices;
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
||||
// STL-like maps have a convenient behavior for our use case here. In the
|
||||
// case of duplicate keys, the insertion is rejected, and the returned
|
||||
// iterator gives access to the value already in the map.
|
||||
auto pair = canonicalInput.insert(
|
||||
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
|
||||
canonicalInputIndices.push_back(pair.first->second);
|
||||
}
|
||||
|
||||
// If there are no duplicate args, then bail out.
|
||||
if (canonicalInput.size() == genericOp.getNumInputs())
|
||||
return failure();
|
||||
|
||||
// The operands for the newly canonicalized op.
|
||||
SmallVector<Value> newInputOperands;
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands())
|
||||
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
|
||||
opOperand->getOperandNumber())
|
||||
newInputOperands.push_back(opOperand->get());
|
||||
|
||||
// Repair the indexing maps by filtering out the ones that have been
|
||||
// eliminated.
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands())
|
||||
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
|
||||
opOperand->getOperandNumber())
|
||||
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
|
||||
for (OpOperand *opOperand : genericOp.getOutputOperands())
|
||||
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
|
||||
|
||||
// Clone the old op with new operands.
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
auto newOp = rewriter.create<GenericOp>(
|
||||
genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
|
||||
outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
||||
genericOp.iterator_types(), genericOp.docAttr(),
|
||||
genericOp.library_callAttr());
|
||||
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
|
||||
newOp.region().begin());
|
||||
|
||||
// Repair the payload entry block by RAUW'ing redundant arguments and
|
||||
// erasing them.
|
||||
Block &payload = newOp.region().front();
|
||||
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
||||
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
|
||||
// Iterate in reverse, so that we erase later args first, preventing the
|
||||
// argument list from shifting unexpectedly and invalidating all our
|
||||
// indices.
|
||||
unsigned operandNumber = opOperand->getOperandNumber();
|
||||
if (canonicalInputIndices[operandNumber] == operandNumber)
|
||||
continue;
|
||||
payload.getArgument(operandNumber)
|
||||
.replaceAllUsesWith(
|
||||
payload.getArgument(canonicalInputIndices[operandNumber]));
|
||||
payload.eraseArgument(operandNumber);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(genericOp, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Remove generic operations (on tensors) that are just copying
|
||||
/// the values from inputs to the results. Requirements are
|
||||
/// 1) All iterator types are parallel
|
||||
/// 2) The body contains just a yield operation with the yielded values being
|
||||
/// the arguments corresponding to the operands.
|
||||
struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
// Check all indexing maps are identity.
|
||||
if (llvm::any_of(genericOp.getIndexingMaps(),
|
||||
[](AffineMap map) { return !map.isIdentity(); }))
|
||||
return failure();
|
||||
|
||||
// Check that the body of the linalg operation is just a linalg.yield
|
||||
// operation.
|
||||
Block &body = genericOp.region().front();
|
||||
if (!llvm::hasSingleElement(body))
|
||||
return failure();
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
||||
if (!yieldOp)
|
||||
return failure();
|
||||
|
||||
// Get the argument number of the returned values. That is the operand
|
||||
// number to use for replacing uses of this operation.
|
||||
SmallVector<Value> returnedArgs;
|
||||
for (Value yieldVal : yieldOp.values()) {
|
||||
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
|
||||
if (!yieldArg || yieldArg.getOwner() != &body)
|
||||
return failure();
|
||||
unsigned argumentNumber = yieldArg.getArgNumber();
|
||||
returnedArgs.push_back(genericOp->getOperand(argumentNumber));
|
||||
}
|
||||
if (returnedArgs.size() != genericOp->getNumResults())
|
||||
return failure();
|
||||
rewriter.replaceOp(genericOp, returnedArgs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InitTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2539,143 +2671,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Deduplicate redundant args of a linalg op.
|
||||
// An arg is redundant if it has the same Value and indexing map as another.
|
||||
struct DeduplicateInputs : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// This pattern reduces the number of arguments of an op, which breaks
|
||||
// the invariants of semantically charged named ops.
|
||||
if (!isa<GenericOp>(op))
|
||||
return failure();
|
||||
|
||||
// Associate each input to an equivalent "canonical" input that has the same
|
||||
// Value and indexing map.
|
||||
//
|
||||
// In the non-duplicate case, input `i` will have canonical input `i`. But
|
||||
// in the case of duplicated inputs, the canonical input could be some other
|
||||
// input `< i`. That is, a later input will have some earlier input as its
|
||||
// canonical input.
|
||||
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
|
||||
// For later remapping tasks like deduplicating payload block arguments,
|
||||
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
|
||||
// convenient.
|
||||
SmallVector<unsigned> canonicalInputIndices;
|
||||
for (OpOperand *opOperand : op.getInputOperands()) {
|
||||
AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
|
||||
// STL-like maps have a convenient behavior for our use case here. In the
|
||||
// case of duplicate keys, the insertion is rejected, and the returned
|
||||
// iterator gives access to the value already in the map.
|
||||
auto pair = canonicalInput.insert(
|
||||
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
|
||||
canonicalInputIndices.push_back(pair.first->second);
|
||||
}
|
||||
|
||||
// If there are no duplicate args, then bail out.
|
||||
if (canonicalInput.size() == op.getNumInputs())
|
||||
return failure();
|
||||
|
||||
// The operands for the newly canonicalized op.
|
||||
SmallVector<Value> newOperands;
|
||||
for (OpOperand *opOperand : op.getInputOperands())
|
||||
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
|
||||
opOperand->getOperandNumber())
|
||||
newOperands.push_back(opOperand->get());
|
||||
SmallVector<Value> outputOperands = op.getOutputOperands();
|
||||
llvm::append_range(newOperands, outputOperands);
|
||||
|
||||
// Repair the indexing maps by filtering out the ones that have been
|
||||
// eliminated.
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
for (OpOperand *opOperand : op.getInputOperands())
|
||||
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
|
||||
opOperand->getOperandNumber())
|
||||
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
|
||||
for (OpOperand *opOperand : op.getOutputOperands())
|
||||
newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand));
|
||||
|
||||
// Clone the old op with new operands.
|
||||
Operation *newOp =
|
||||
op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands);
|
||||
auto newLinalgOp = cast<LinalgOp>(newOp);
|
||||
newOp->setAttr("indexing_maps",
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMaps));
|
||||
|
||||
// Set the number of inputs to the new value. The `clone` call above kept
|
||||
// the value from the original op.
|
||||
newLinalgOp.setNumInputs(canonicalInput.size());
|
||||
|
||||
// Repair the payload entry block by RAUW'ing redundant arguments and
|
||||
// erasing them.
|
||||
Block &payload = newOp->getRegion(0).front();
|
||||
SmallVector<OpOperand *> inputOperands = op.getInputOperands();
|
||||
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
|
||||
// Iterate in reverse, so that we erase later args first, preventing the
|
||||
// argument list from shifting unexpectedly and invalidating all our
|
||||
// indices.
|
||||
unsigned operandNumber = opOperand->getOperandNumber();
|
||||
if (canonicalInputIndices[operandNumber] == operandNumber)
|
||||
continue;
|
||||
payload.getArgument(operandNumber)
|
||||
.replaceAllUsesWith(
|
||||
payload.getArgument(canonicalInputIndices[operandNumber]));
|
||||
payload.eraseArgument(operandNumber);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Remove generic operations (on tensors) that are just copying
|
||||
/// the values from inputs to the results. Requirements are
|
||||
/// 1) All iterator types are parallel
|
||||
/// 2) The body contains just a yield operation with the yielded values being
|
||||
/// the arguments corresponding to the operands.
|
||||
struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isa<GenericOp>(op))
|
||||
return failure();
|
||||
if (!op.hasTensorSemantics())
|
||||
return failure();
|
||||
// Check all indexing maps are identity.
|
||||
if (llvm::any_of(op.getIndexingMaps(),
|
||||
[](AffineMap map) { return !map.isIdentity(); }))
|
||||
return failure();
|
||||
|
||||
// Check that the body of the linalg operation is just a linalg.yield
|
||||
// operation.
|
||||
Block &body = op->getRegion(0).front();
|
||||
if (!llvm::hasSingleElement(body))
|
||||
return failure();
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
||||
if (!yieldOp)
|
||||
return failure();
|
||||
|
||||
// Get the argument number of the returned values. That is the operand
|
||||
// number to use for replacing uses of this operation.
|
||||
SmallVector<Value, 4> returnedArgs;
|
||||
for (Value yieldVal : yieldOp.values()) {
|
||||
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
|
||||
if (!yieldArg || yieldArg.getOwner() != &body)
|
||||
return failure();
|
||||
unsigned argumentNumber = yieldArg.getArgNumber();
|
||||
returnedArgs.push_back(op->getOperand(argumentNumber));
|
||||
}
|
||||
if (returnedArgs.size() != op.getOperation()->getNumResults())
|
||||
return failure();
|
||||
rewriter.replaceOp(op, returnedArgs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
#define LINALGOP_FOLDERS(XXX) \
|
||||
LogicalResult XXX::fold(ArrayRef<Attribute>, \
|
||||
SmallVectorImpl<OpFoldResult> &) { \
|
||||
|
@ -2699,6 +2694,5 @@ LINALGOP_FOLDERS(GenericOp)
|
|||
|
||||
void LinalgDialect::getCanonicalizationPatterns(
|
||||
RewritePatternSet &results) const {
|
||||
results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
|
||||
RemoveIdentityLinalgOps>(getContext());
|
||||
results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue