[mlir][sparse] merge ifs in new sparse rewriting rules

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D120500
This commit is contained in:
Aart Bik 2022-02-24 09:40:01 -08:00
parent 180c9f9efe
commit 8e4f8d3532
1 changed files with 9 additions and 20 deletions

View File

@ -45,15 +45,8 @@ static bool isSparseTensor(OpOperand *op) {
// Helper method to find zero or empty initialization.
static bool isEmptyInit(OpOperand *op) {
Value val = op->get();
if (matchPattern(val, m_Zero()))
return true;
if (matchPattern(val, m_AnyZeroFloat()))
return true;
if (val.getDefiningOp<InitTensorOp>())
return true;
if (val.getDefiningOp<InitOp>())
return true;
return false;
return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
val.getDefiningOp<InitTensorOp>() || val.getDefiningOp<InitOp>();
}
// Helper to detect sampling operation.
@ -123,11 +116,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {
// Check consumer.
if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
op.getNumResults() != 1)
return failure();
if (op.getNumParallelLoops() != op.getNumLoops())
return failure();
if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
op.getNumResults() != 1 ||
op.getNumParallelLoops() != op.getNumLoops() ||
!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
!op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
!op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
return failure();
@ -143,15 +134,13 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
// Check producer.
auto prod = dyn_cast_or_null<GenericOp>(
op.getInputOperand(other)->get().getDefiningOp());
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1)
return failure();
if (!prod.getResult(0).hasOneUse())
if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
!prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
if (!isEmptyInit(op.getOutputOperand(0)) ||
!isEmptyInit(prod.getOutputOperand(0)))
return failure();
if (!isSampling(op) || !isSumOfMul(prod))
!isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
!isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
Location loc = prod.getLoc();