forked from OSchip/llvm-project
[mlir][sparse] merge ifs in new sparse rewriting rules
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D120500
This commit is contained in:
parent
180c9f9efe
commit
8e4f8d3532
|
@ -45,15 +45,8 @@ static bool isSparseTensor(OpOperand *op) {
|
|||
// Helper method to find zero or empty initialization.
|
||||
static bool isEmptyInit(OpOperand *op) {
|
||||
Value val = op->get();
|
||||
if (matchPattern(val, m_Zero()))
|
||||
return true;
|
||||
if (matchPattern(val, m_AnyZeroFloat()))
|
||||
return true;
|
||||
if (val.getDefiningOp<InitTensorOp>())
|
||||
return true;
|
||||
if (val.getDefiningOp<InitOp>())
|
||||
return true;
|
||||
return false;
|
||||
return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
|
||||
val.getDefiningOp<InitTensorOp>() || val.getDefiningOp<InitOp>();
|
||||
}
|
||||
|
||||
// Helper to detect sampling operation.
|
||||
|
@ -123,11 +116,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Check consumer.
|
||||
if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
|
||||
op.getNumResults() != 1)
|
||||
return failure();
|
||||
if (op.getNumParallelLoops() != op.getNumLoops())
|
||||
return failure();
|
||||
if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
|
||||
op.getNumResults() != 1 ||
|
||||
op.getNumParallelLoops() != op.getNumLoops() ||
|
||||
!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
|
||||
!op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
|
||||
!op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
|
||||
return failure();
|
||||
|
@ -143,15 +134,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
|
|||
// Check producer.
|
||||
auto prod = dyn_cast_or_null<GenericOp>(
|
||||
op.getInputOperand(other)->get().getDefiningOp());
|
||||
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1)
|
||||
return failure();
|
||||
if (!prod.getResult(0).hasOneUse())
|
||||
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
|
||||
!prod.getResult(0).hasOneUse())
|
||||
return failure();
|
||||
// Sampling consumer and sum of multiplication chain producer.
|
||||
if (!isEmptyInit(op.getOutputOperand(0)) ||
|
||||
!isEmptyInit(prod.getOutputOperand(0)))
|
||||
return failure();
|
||||
if (!isSampling(op) || !isSumOfMul(prod))
|
||||
!isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
|
||||
!isSumOfMul(prod))
|
||||
return failure();
|
||||
// Modify operand structure of producer and consumer.
|
||||
Location loc = prod.getLoc();
|
||||
|
|
Loading…
Reference in New Issue