Improvements to Op trait implementation:

- Generalize TwoOperands and TwoResults to NOperands and NResults, which can
   be used for any fixed N.
 - Rename OpImpl namespace to OpTrait, OpImpl::Base to OpBase, and TraitImpl to
   TraitBase to better reflect what these are.

PiperOrigin-RevId: 206588634
This commit is contained in:
Chris Lattner 2018-07-30 08:48:18 -07:00 committed by jpienaar
parent 775130b6b9
commit 467c5cb3ba
3 changed files with 88 additions and 71 deletions

View File

@ -15,9 +15,10 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
// //
// This file implements helper classes for implementing the "Op" types. Most of // This file implements helper classes for implementing the "Op" types. This
// this goes into the mlir::OpImpl namespace since they are only used by code // includes the OpBase type, which is the base class for Op class definitions,
// that is defining the op implementations, not by clients. // as well as number of traits in the OpTrait namespace that provide a
// declarative way to specify properties of Ops.
// //
// The purpose of these types are to allow light-weight implementation of // The purpose of these types are to allow light-weight implementation of
// concrete ops (like DimOp) with very little boilerplate. // concrete ops (like DimOp) with very little boilerplate.
@ -85,11 +86,6 @@ struct OpAsmParserResult {
attributes(attributes.begin(), attributes.end()) {} attributes(attributes.begin(), attributes.end()) {}
}; };
//===----------------------------------------------------------------------===//
// OpImpl Types
//===----------------------------------------------------------------------===//
namespace OpImpl {
/// This is the concrete base class that holds the operation pointer and has /// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them /// non-generic methods that only depend on State (to avoid having them
@ -97,7 +93,7 @@ namespace OpImpl {
/// ///
/// This also has the fallback implementations of customization hooks for when /// This also has the fallback implementations of customization hooks for when
/// they aren't customized. /// they aren't customized.
class BaseState { class OpBaseState {
public: public:
/// Return the operation that this refers to. /// Return the operation that this refers to.
const Operation *getOperation() const { return state; } const Operation *getOperation() const { return state; }
@ -133,7 +129,7 @@ protected:
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here. /// so we can cast it away here.
explicit BaseState(const Operation *state) explicit OpBaseState(const Operation *state)
: state(const_cast<Operation *>(state)) {} : state(const_cast<Operation *>(state)) {}
private: private:
@ -144,11 +140,11 @@ private:
/// argument 'ConcreteType' should be the concrete type by CRTP and the others /// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern. /// are base classes by the policy pattern.
template <typename ConcreteType, template <typename T> class... Traits> template <typename ConcreteType, template <typename T> class... Traits>
class Base : public BaseState, public Traits<ConcreteType>... { class OpBase : public OpBaseState, public Traits<ConcreteType>... {
public: public:
/// Return the operation that this refers to. /// Return the operation that this refers to.
const Operation *getOperation() const { return BaseState::getOperation(); } const Operation *getOperation() const { return OpBaseState::getOperation(); }
Operation *getOperation() { return BaseState::getOperation(); } Operation *getOperation() { return OpBaseState::getOperation(); }
/// Return true if this "op class" can match against the specified operation. /// Return true if this "op class" can match against the specified operation.
/// This hook can be overridden with a more specific implementation in /// This hook can be overridden with a more specific implementation in
@ -182,7 +178,7 @@ public:
// TODO: Provide a dump() method. // TODO: Provide a dump() method.
protected: protected:
explicit Base(const Operation *state) : BaseState(state) {} explicit OpBase(const Operation *state) : OpBaseState(state) {}
private: private:
template <typename... Types> template <typename... Types>
@ -210,24 +206,30 @@ private:
}; };
}; };
//===----------------------------------------------------------------------===//
// Operation Trait Types
//===----------------------------------------------------------------------===//
namespace OpTrait {
/// Helper class for implementing traits. Clients are not expected to interact /// Helper class for implementing traits. Clients are not expected to interact
/// with this directly, so its members are all protected. /// with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType> template <typename ConcreteType, template <typename> class TraitType>
class TraitImpl { class TraitBase {
protected: protected:
/// Return the ultimate Operation being worked on. /// Return the ultimate Operation being worked on.
Operation *getOperation() { Operation *getOperation() {
// We have to cast up to the trait type, then to the concrete type, then to // We have to cast up to the trait type, then to the concrete type, then to
// the BaseState class in explicit hops because the concrete type will // the BaseState class in explicit hops because the concrete type will
// multiply derive from the (content free) TraitImpl class, and we need to // multiply derive from the (content free) TraitBase class, and we need to
// be able to disambiguate the path for the C++ compiler. // be able to disambiguate the path for the C++ compiler.
auto *trait = static_cast<TraitType<ConcreteType> *>(this); auto *trait = static_cast<TraitType<ConcreteType> *>(this);
auto *concrete = static_cast<ConcreteType *>(trait); auto *concrete = static_cast<ConcreteType *>(trait);
auto *base = static_cast<BaseState *>(concrete); auto *base = static_cast<OpBaseState *>(concrete);
return base->getOperation(); return base->getOperation();
} }
const Operation *getOperation() const { const Operation *getOperation() const {
return const_cast<TraitImpl *>(this)->getOperation(); return const_cast<TraitBase *>(this)->getOperation();
} }
/// Provide default implementations of trait hooks. This allows traits to /// Provide default implementations of trait hooks. This allows traits to
@ -238,7 +240,7 @@ protected:
/// This class provides the API for ops that are known to have exactly one /// This class provides the API for ops that are known to have exactly one
/// SSA operand. /// SSA operand.
template <typename ConcreteType> template <typename ConcreteType>
class ZeroOperands : public TraitImpl<ConcreteType, ZeroOperands> { class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
public: public:
static const char *verifyTrait(const Operation *op) { static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 0) if (op->getNumOperands() != 0)
@ -255,7 +257,7 @@ private:
/// This class provides the API for ops that are known to have exactly one /// This class provides the API for ops that are known to have exactly one
/// SSA operand. /// SSA operand.
template <typename ConcreteType> template <typename ConcreteType>
class OneOperand : public TraitImpl<ConcreteType, OneOperand> { class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public: public:
const SSAValue *getOperand() const { const SSAValue *getOperand() const {
return this->getOperation()->getOperand(0); return this->getOperation()->getOperand(0);
@ -274,10 +276,15 @@ public:
} }
}; };
/// This class provides the API for ops that are known to have exactly two /// This class provides the API for ops that are known to have a specified
/// SSA operands. /// number of operands. This is used as a trait like this:
///
/// class FooOp : public OpBase<FooOp, OpTrait::NOperands<2>::Impl> {
///
template <unsigned N> class NOperands {
public:
template <typename ConcreteType> template <typename ConcreteType>
class TwoOperands : public TraitImpl<ConcreteType, TwoOperands> { class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
public: public:
const SSAValue *getOperand(unsigned i) const { const SSAValue *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i); return this->getOperation()->getOperand(i);
@ -292,16 +299,18 @@ public:
} }
static const char *verifyTrait(const Operation *op) { static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 2) // TODO(clattner): Allow verifier to return non-constant string.
return "requires two operands"; if (op->getNumOperands() != N)
return "incorrect number of operands";
return nullptr; return nullptr;
} }
}; };
};
/// This class provides the API for ops which have an unknown number of /// This class provides the API for ops which have an unknown number of
/// SSA operands. /// SSA operands.
template <typename ConcreteType> template <typename ConcreteType>
class VariadicOperands : public TraitImpl<ConcreteType, VariadicOperands> { class VariadicOperands : public TraitBase<ConcreteType, VariadicOperands> {
public: public:
unsigned getNumOperands() const { unsigned getNumOperands() const {
return this->getOperation()->getNumOperands(); return this->getOperation()->getNumOperands();
@ -345,7 +354,7 @@ public:
/// This class provides return value APIs for ops that are known to have a /// This class provides return value APIs for ops that are known to have a
/// single result. /// single result.
template <typename ConcreteType> template <typename ConcreteType>
class OneResult : public TraitImpl<ConcreteType, OneResult> { class OneResult : public TraitBase<ConcreteType, OneResult> {
public: public:
SSAValue *getResult() { return this->getOperation()->getResult(0); } SSAValue *getResult() { return this->getOperation()->getResult(0); }
const SSAValue *getResult() const { const SSAValue *getResult() const {
@ -361,30 +370,39 @@ public:
} }
}; };
/// This class provides the API for ops that are known to have exactly two /// This class provides the API for ops that are known to have a specified
/// results. /// number of results. This is used as a trait like this:
///
/// class FooOp : public OpBase<FooOp, OpTrait::NResults<2>::Impl> {
///
template <unsigned N> class NResults {
public:
template <typename ConcreteType> template <typename ConcreteType>
class TwoResults : public TraitImpl<ConcreteType, TwoResults> { class Impl : public TraitBase<ConcreteType, NResults<N>::Impl> {
public: public:
const SSAValue *getResult(unsigned i) const { const SSAValue *getResult(unsigned i) const {
return this->getOperation()->getResult(i); return this->getOperation()->getResult(i);
} }
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); } SSAValue *getResult(unsigned i) {
return this->getOperation()->getResult(i);
}
Type *getType(unsigned i) const { return getResult(i)->getType(); } Type *getType(unsigned i) const { return getResult(i)->getType(); }
static const char *verifyTrait(const Operation *op) { static const char *verifyTrait(const Operation *op) {
if (op->getNumResults() != 2) // TODO(clattner): Allow verifier to return non-constant string.
return "requires two results"; if (op->getNumResults() != N)
return "incorrect number of results";
return nullptr; return nullptr;
} }
}; };
};
/// This class provides the API for ops which have an unknown number of /// This class provides the API for ops which have an unknown number of
/// results. /// results.
template <typename ConcreteType> template <typename ConcreteType>
class VariadicResults : public TraitImpl<ConcreteType, VariadicResults> { class VariadicResults : public TraitBase<ConcreteType, VariadicResults> {
public: public:
unsigned getNumResults() const { unsigned getNumResults() const {
return this->getOperation()->getNumResults(); return this->getOperation()->getNumResults();
@ -401,7 +419,7 @@ public:
} }
}; };
} // end namespace OpImpl } // end namespace OpTrait
} // end namespace mlir } // end namespace mlir

View File

@ -37,7 +37,7 @@ class OperationSet;
/// %2 = addf %0, %1 : f32 /// %2 = addf %0, %1 : f32
/// ///
class AddFOp class AddFOp
: public OpImpl::Base<AddFOp, OpImpl::TwoOperands, OpImpl::OneResult> { : public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult> {
public: public:
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "addf"; } static StringRef getOperationName() { return "addf"; }
@ -48,7 +48,7 @@ public:
private: private:
friend class Operation; friend class Operation;
explicit AddFOp(const Operation *state) : Base(state) {} explicit AddFOp(const Operation *state) : OpBase(state) {}
}; };
/// The "affine_apply" operation applies an affine map to a list of operands, /// The "affine_apply" operation applies an affine map to a list of operands,
@ -65,9 +65,8 @@ private:
/// #map42 = (d0)->(d0+1) /// #map42 = (d0)->(d0+1)
/// %y = affine_apply #map42(%x) /// %y = affine_apply #map42(%x)
/// ///
class AffineApplyOp class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands, OpTrait::VariadicResults> {
OpImpl::VariadicResults> {
public: public:
// Returns the affine map to be applied by this operation. // Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const { AffineMap *getAffineMap() const {
@ -84,7 +83,7 @@ public:
private: private:
friend class Operation; friend class Operation;
explicit AffineApplyOp(const Operation *state) : Base(state) {} explicit AffineApplyOp(const Operation *state) : OpBase(state) {}
}; };
/// The "constant" operation requires a single attribute named "value". /// The "constant" operation requires a single attribute named "value".
@ -94,7 +93,8 @@ private:
/// %2 = "constant"(){value: @foo} : (f32)->f32 /// %2 = "constant"(){value: @foo} : (f32)->f32
/// ///
class ConstantOp class ConstantOp
: public OpImpl::Base<ConstantOp, OpImpl::ZeroOperands, OpImpl::OneResult> { : public OpBase<ConstantOp, OpTrait::ZeroOperands, OpTrait::OneResult/*,
OpTrait::HasAttributeBase<"foo">::Impl*/> {
public: public:
Attribute *getValue() const { return getAttr("value"); } Attribute *getValue() const { return getAttr("value"); }
@ -106,7 +106,7 @@ public:
protected: protected:
friend class Operation; friend class Operation;
explicit ConstantOp(const Operation *state) : Base(state) {} explicit ConstantOp(const Operation *state) : OpBase(state) {}
}; };
/// This is a refinement of the "constant" op for the case where it is /// This is a refinement of the "constant" op for the case where it is
@ -133,8 +133,7 @@ private:
/// ///
/// %1 = dim %0, 2 : tensor<?x?x?xf32> /// %1 = dim %0, 2 : tensor<?x?x?xf32>
/// ///
class DimOp class DimOp : public OpBase<DimOp, OpTrait::OneOperand, OpTrait::OneResult> {
: public OpImpl::Base<DimOp, OpImpl::OneOperand, OpImpl::OneResult> {
public: public:
/// This returns the dimension number that the 'dim' is inspecting. /// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const { unsigned getIndex() const {
@ -151,7 +150,7 @@ public:
private: private:
friend class Operation; friend class Operation;
explicit DimOp(const Operation *state) : Base(state) {} explicit DimOp(const Operation *state) : OpBase(state) {}
}; };
/// The "load" op reads an element from a memref specified by an index list. The /// The "load" op reads an element from a memref specified by an index list. The
@ -163,7 +162,7 @@ private:
/// %3 = load %0[%1, %1] : memref<4x4xi32> /// %3 = load %0[%1, %1] : memref<4x4xi32>
/// ///
class LoadOp class LoadOp
: public OpImpl::Base<LoadOp, OpImpl::VariadicOperands, OpImpl::OneResult> { : public OpBase<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public: public:
SSAValue *getMemRef() { return getOperand(0); } SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); } const SSAValue *getMemRef() const { return getOperand(0); }
@ -185,7 +184,7 @@ public:
private: private:
friend class Operation; friend class Operation;
explicit LoadOp(const Operation *state) : Base(state) {} explicit LoadOp(const Operation *state) : OpBase(state) {}
}; };
/// Install the standard operations in the specified operation set. /// Install the standard operations in the specified operation set.

View File

@ -27,13 +27,13 @@ using llvm::StringMap;
OpAsmParser::~OpAsmParser() {} OpAsmParser::~OpAsmParser() {}
// The fallback for the printer is to reject the short form. // The fallback for the printer is to reject the short form.
OpAsmParserResult OpImpl::BaseState::parse(OpAsmParser *parser) { OpAsmParserResult OpBaseState::parse(OpAsmParser *parser) {
parser->emitError(parser->getNameLoc(), "has no concise form"); parser->emitError(parser->getNameLoc(), "has no concise form");
return {}; return {};
} }
// The fallback for the printer is to print it the longhand form. // The fallback for the printer is to print it the longhand form.
void OpImpl::BaseState::print(OpAsmPrinter *p) const { void OpBaseState::print(OpAsmPrinter *p) const {
p->printDefaultOp(getOperation()); p->printDefaultOp(getOperation());
} }