From 085b687fbdf5d4be6ba3e0280e0610e612e8ab4e Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 28 Oct 2018 10:03:19 -0700 Subject: [PATCH] Add support for walking the use list of an SSAValue and converting owners to Operation*'s, simplifying some code in GreedyPatternRewriteDriver.cpp. Also add print/dump methods on Operation. PiperOrigin-RevId: 219045764 --- mlir/include/mlir/IR/Instructions.h | 28 ++++---- mlir/include/mlir/IR/Operation.h | 25 ++++++- mlir/include/mlir/IR/Statements.h | 2 + mlir/include/mlir/IR/UseDefLists.h | 3 +- mlir/lib/IR/Operation.cpp | 32 +++++++++ .../Utils/GreedyPatternRewriteDriver.cpp | 70 ++++++------------- 6 files changed, 91 insertions(+), 69 deletions(-) diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 64507a6f5066..e74c5616a451 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -206,11 +206,13 @@ public: ArrayRef attributes, MLIRContext *context); + using Instruction::dump; using Instruction::emitError; using Instruction::emitNote; using Instruction::emitWarning; using Instruction::getContext; using Instruction::getLoc; + using Instruction::print; OperationInst *clone() const; @@ -341,8 +343,8 @@ public: llvm::iplist::iterator iterator); /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Instruction *inst) { - return inst->getKind() == Kind::Operation; + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::OperationInst; } static bool classof(const Operation *op) { return op->getOperationKind() == OperationKind::Instruction; @@ -433,8 +435,8 @@ public: ArrayRef getBasicBlockOperands() const { return dest; } /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Instruction *inst) { - return inst->getKind() == Kind::Branch; + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::BranchInst; } private: @@ -479,10 +481,7 @@ public: unsigned getNumOperands() const { return operands.size(); } - // - // Accessors for operands to the 'true' destination - // - + // Accessors for operands to the 'true' destination. CFGValue *getTrueOperand(unsigned idx) { return getTrueInstOperand(idx).get(); } @@ -530,10 +529,7 @@ public: /// Add a list of values to the operand list. void addTrueOperands(ArrayRef values); - // - // Accessors for operands to the 'false' destination - // - + // Accessors for operands to the 'false' destination. CFGValue *getFalseOperand(unsigned idx) { return getFalseInstOperand(idx).get(); } @@ -592,8 +588,8 @@ public: ArrayRef getBasicBlockOperands() const { return dests; } /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Instruction *inst) { - return inst->getKind() == Kind::CondBranch; + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::CondBranchInst; } private: @@ -631,8 +627,8 @@ public: void destroy(); /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(const Instruction *inst) { - return inst->getKind() == Kind::Return; + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::ReturnInst; } private: diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index a0e820966c56..2d294e1bb1a9 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -29,6 +29,7 @@ template class OpPointer; template class OperandIterator; template class ResultIterator; class Function; +class IROperandOwner; class Instruction; class Statement; @@ -232,12 +233,16 @@ public: // Returns whether the operation is commutative. bool isCommutative() const { - return getAbstractOperation()->hasProperty(OperationProperty::Commutative); + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::Commutative); + return false; } // Returns whether the operation has side-effects. bool hasNoSideEffect() const { - return getAbstractOperation()->hasProperty(OperationProperty::NoSideEffect); + if (auto *absOp = getAbstractOperation()) + return absOp->hasProperty(OperationProperty::NoSideEffect); + return false; } /// Remove this operation from its parent block and delete it. @@ -251,9 +256,13 @@ public: bool constantFold(ArrayRef operands, SmallVectorImpl &results) const; + void print(raw_ostream &os) const; + void dump() const; + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Instruction *inst); static bool classof(const Statement *stmt); + static bool classof(const IROperandOwner *ptr); protected: Operation(bool isInstruction, OperationName name, @@ -410,4 +419,16 @@ inline auto Operation::getResults() const } } // end namespace mlir +/// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an +/// IROperandOwner* to Operation*. This can't be done with a simple pointer to +/// pointer cast because the pointer adjustment depends on whether the Owner is +/// dynamically an Instruction or Statement, because of multiple inheritance. +namespace llvm { +template <> +struct cast_convert_val { + static mlir::Operation *doit(const mlir::IROperandOwner *value); +}; +} // namespace llvm + #endif diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 2ad815f1c5fc..7e7a49ffa15f 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -51,10 +51,12 @@ public: /// Return the context this operation is associated with. MLIRContext *getContext() const; + using Statement::dump; using Statement::emitError; using Statement::emitNote; using Statement::emitWarning; using Statement::getLoc; + using Statement::print; /// Check if this statement is a return statement. bool isReturn() const; diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 3f191e0e3f21..4a2594774a38 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -29,6 +29,7 @@ namespace mlir { class IROperand; +class IROperandOwner; template class SSAValueUseIterator; class IRObjectWithUseList { @@ -43,7 +44,7 @@ public: /// Returns true if this value has exactly one use. inline bool hasOneUse() const; - using use_iterator = SSAValueUseIterator; + using use_iterator = SSAValueUseIterator; using use_range = llvm::iterator_range; inline use_iterator use_begin() const; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 99da81df9901..2ed09b83b53c 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -237,6 +237,18 @@ bool Operation::constantFold(ArrayRef operands, return true; } +void Operation::print(raw_ostream &os) const { + if (auto *inst = llvm::dyn_cast(this)) + return inst->print(os); + return llvm::cast(this)->print(os); +} + +void Operation::dump() const { + if (auto *inst = llvm::dyn_cast(this)) + return inst->dump(); + return llvm::cast(this)->dump(); +} + /// Methods for support type inquiry through isa, cast, and dyn_cast. bool Operation::classof(const Instruction *inst) { return inst->getKind() == Instruction::Kind::Operation; @@ -244,6 +256,26 @@ bool Operation::classof(const Instruction *inst) { bool Operation::classof(const Statement *stmt) { return stmt->getKind() == Statement::Kind::Operation; } +bool Operation::classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::OperationInst || + ptr->getKind() == IROperandOwner::Kind::OperationStmt; +} + +/// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an +/// IROperandOwner* to Operation*. This can't be done with a simple pointer to +/// pointer cast because the pointer adjustment depends on whether the Owner is +/// dynamically an Instruction or Statement, because of multiple inheritance. +Operation * +llvm::cast_convert_val::doit(const mlir::IROperandOwner + *value) { + const Operation *op; + if (auto *ptr = dyn_cast(value)) + op = ptr; + else + op = cast(value); + return const_cast(op); +} //===----------------------------------------------------------------------===// // OpState trait class. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 30034b6fce51..cdf5b7166a08 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -98,6 +98,22 @@ public: driver.removeFromWorklist(op); } + // When the root of a pattern is about to be replaced, it can trigger + // simplifications to its users - make sure to add them to the worklist + // before the root is changed. + void notifyRootReplaced(Operation *op) override { + for (auto *result : op->getResults()) + // TODO: Add a result->getUsers() iterator. + for (auto &user : result->getUses()) { + if (auto *op = dyn_cast(user.getOwner())) + driver.addToWorklist(op); + } + + // TODO: Walk the operand list dropping them as we go. If any of them + // drop to zero uses, then add them to the worklist to allow them to be + // deleted as dead. + } + GreedyPatternRewriteDriver &driver; }; @@ -206,22 +222,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, // Add all the users of the result to the worklist so we make sure to // revisit them. // - // TODO: This is super gross. SSAValue use iterators should have an - // "owner" that can be downcasted to operation and other things. This - // will require a rejiggering of the class hierarchies. - if (auto *stmt = dyn_cast(op)) { - // TODO: Add a result->getUsers() iterator. - for (auto &operand : stmt->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) - addToWorklist(op); - } - } else { - auto *inst = cast(op); - // TODO: Add a result->getUsers() iterator. - for (auto &operand : inst->getResult(i)->getUses()) { - if (auto *op = dyn_cast(operand.getOwner())) - addToWorklist(op); - } + // TODO: Add a result->getUsers() iterator. + for (auto &operand : op->getResult(i)->getUses()) { + if (auto *op = dyn_cast(operand.getOwner())) + addToWorklist(op); } res->replaceAllUsesWith(cstValue); @@ -268,23 +272,6 @@ static void processMLFunction(MLFunction *fn, OwningPatternList &&patterns) { return result; } - // When the root of a pattern is about to be replaced, it can trigger - // simplifications to its users - make sure to add them to the worklist - // before the root is changed. - void notifyRootReplaced(Operation *op) override { - auto *opStmt = cast(op); - for (auto *result : opStmt->getResults()) - // TODO: Add a result->getUsers() iterator. - for (auto &user : result->getUses()) { - if (auto *op = dyn_cast(user.getOwner())) - driver.addToWorklist(op); - } - - // TODO: Walk the operand list dropping them as we go. If any of them - // drop to zero uses, then add them to the worklist to allow them to be - // deleted as dead. - } - void setInsertionPoint(Operation *op) override { // Any new operations should be added before this statement. builder.setInsertionPoint(cast(op)); @@ -316,23 +303,6 @@ static void processCFGFunction(CFGFunction *fn, OwningPatternList &&patterns) { return result; } - // When the root of a pattern is about to be replaced, it can trigger - // simplifications to its users - make sure to add them to the worklist - // before the root is changed. - void notifyRootReplaced(Operation *op) override { - auto *opStmt = cast(op); - for (auto *result : opStmt->getResults()) - // TODO: Add a result->getUsers() iterator. - for (auto &user : result->getUses()) { - if (auto *op = dyn_cast(user.getOwner())) - driver.addToWorklist(op); - } - - // TODO: Walk the operand list dropping them as we go. If any of them - // drop to zero uses, then add them to the worklist to allow them to be - // deleted as dead. - } - void setInsertionPoint(Operation *op) override { // Any new operations should be added before this instruction. builder.setInsertionPoint(cast(op));