Improvements to Op trait implementation:

- Generalize TwoOperands and TwoResults to NOperands and NResults, which can
   be used for any fixed N.
 - Rename OpImpl namespace to OpTrait, OpImpl::Base to OpBase, and TraitImpl to
   TraitBase to better reflect what these are.

PiperOrigin-RevId: 206588634
This commit is contained in:
Chris Lattner 2018-07-30 08:48:18 -07:00 committed by jpienaar
parent 775130b6b9
commit 467c5cb3ba
3 changed files with 88 additions and 71 deletions

View File

@ -15,9 +15,10 @@
// limitations under the License.
// =============================================================================
//
// This file implements helper classes for implementing the "Op" types. Most of
// this goes into the mlir::OpImpl namespace since they are only used by code
// that is defining the op implementations, not by clients.
// This file implements helper classes for implementing the "Op" types. This
// includes the OpBase type, which is the base class for Op class definitions,
// as well as number of traits in the OpTrait namespace that provide a
// declarative way to specify properties of Ops.
//
// The purpose of these types are to allow light-weight implementation of
// concrete ops (like DimOp) with very little boilerplate.
@ -85,11 +86,6 @@ struct OpAsmParserResult {
attributes(attributes.begin(), attributes.end()) {}
};
//===----------------------------------------------------------------------===//
// OpImpl Types
//===----------------------------------------------------------------------===//
namespace OpImpl {
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
@ -97,7 +93,7 @@ namespace OpImpl {
///
/// This also has the fallback implementations of customization hooks for when
/// they aren't customized.
class BaseState {
class OpBaseState {
public:
/// Return the operation that this refers to.
const Operation *getOperation() const { return state; }
@ -133,7 +129,7 @@ protected:
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here.
explicit BaseState(const Operation *state)
explicit OpBaseState(const Operation *state)
: state(const_cast<Operation *>(state)) {}
private:
@ -144,11 +140,11 @@ private:
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern.
template <typename ConcreteType, template <typename T> class... Traits>
class Base : public BaseState, public Traits<ConcreteType>... {
class OpBase : public OpBaseState, public Traits<ConcreteType>... {
public:
/// Return the operation that this refers to.
const Operation *getOperation() const { return BaseState::getOperation(); }
Operation *getOperation() { return BaseState::getOperation(); }
const Operation *getOperation() const { return OpBaseState::getOperation(); }
Operation *getOperation() { return OpBaseState::getOperation(); }
/// Return true if this "op class" can match against the specified operation.
/// This hook can be overridden with a more specific implementation in
@ -182,7 +178,7 @@ public:
// TODO: Provide a dump() method.
protected:
explicit Base(const Operation *state) : BaseState(state) {}
explicit OpBase(const Operation *state) : OpBaseState(state) {}
private:
template <typename... Types>
@ -210,24 +206,30 @@ private:
};
};
//===----------------------------------------------------------------------===//
// Operation Trait Types
//===----------------------------------------------------------------------===//
namespace OpTrait {
/// Helper class for implementing traits. Clients are not expected to interact
/// with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class TraitImpl {
class TraitBase {
protected:
/// Return the ultimate Operation being worked on.
Operation *getOperation() {
// We have to cast up to the trait type, then to the concrete type, then to
// the BaseState class in explicit hops because the concrete type will
// multiply derive from the (content free) TraitImpl class, and we need to
// multiply derive from the (content free) TraitBase class, and we need to
// be able to disambiguate the path for the C++ compiler.
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
auto *concrete = static_cast<ConcreteType *>(trait);
auto *base = static_cast<BaseState *>(concrete);
auto *base = static_cast<OpBaseState *>(concrete);
return base->getOperation();
}
const Operation *getOperation() const {
return const_cast<TraitImpl *>(this)->getOperation();
return const_cast<TraitBase *>(this)->getOperation();
}
/// Provide default implementations of trait hooks. This allows traits to
@ -238,7 +240,7 @@ protected:
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType>
class ZeroOperands : public TraitImpl<ConcreteType, ZeroOperands> {
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
public:
static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 0)
@ -255,7 +257,7 @@ private:
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType>
class OneOperand : public TraitImpl<ConcreteType, OneOperand> {
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public:
const SSAValue *getOperand() const {
return this->getOperation()->getOperand(0);
@ -274,10 +276,15 @@ public:
}
};
/// This class provides the API for ops that are known to have exactly two
/// SSA operands.
/// This class provides the API for ops that are known to have a specified
/// number of operands. This is used as a trait like this:
///
/// class FooOp : public OpBase<FooOp, OpTrait::NOperands<2>::Impl> {
///
template <unsigned N> class NOperands {
public:
template <typename ConcreteType>
class TwoOperands : public TraitImpl<ConcreteType, TwoOperands> {
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
public:
const SSAValue *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i);
@ -292,16 +299,18 @@ public:
}
static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 2)
return "requires two operands";
// TODO(clattner): Allow verifier to return non-constant string.
if (op->getNumOperands() != N)
return "incorrect number of operands";
return nullptr;
}
};
};
/// This class provides the API for ops which have an unknown number of
/// SSA operands.
template <typename ConcreteType>
class VariadicOperands : public TraitImpl<ConcreteType, VariadicOperands> {
class VariadicOperands : public TraitBase<ConcreteType, VariadicOperands> {
public:
unsigned getNumOperands() const {
return this->getOperation()->getNumOperands();
@ -345,7 +354,7 @@ public:
/// This class provides return value APIs for ops that are known to have a
/// single result.
template <typename ConcreteType>
class OneResult : public TraitImpl<ConcreteType, OneResult> {
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
SSAValue *getResult() { return this->getOperation()->getResult(0); }
const SSAValue *getResult() const {
@ -361,30 +370,39 @@ public:
}
};
/// This class provides the API for ops that are known to have exactly two
/// results.
/// This class provides the API for ops that are known to have a specified
/// number of results. This is used as a trait like this:
///
/// class FooOp : public OpBase<FooOp, OpTrait::NResults<2>::Impl> {
///
template <unsigned N> class NResults {
public:
template <typename ConcreteType>
class TwoResults : public TraitImpl<ConcreteType, TwoResults> {
class Impl : public TraitBase<ConcreteType, NResults<N>::Impl> {
public:
const SSAValue *getResult(unsigned i) const {
return this->getOperation()->getResult(i);
}
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); }
SSAValue *getResult(unsigned i) {
return this->getOperation()->getResult(i);
}
Type *getType(unsigned i) const { return getResult(i)->getType(); }
static const char *verifyTrait(const Operation *op) {
if (op->getNumResults() != 2)
return "requires two results";
// TODO(clattner): Allow verifier to return non-constant string.
if (op->getNumResults() != N)
return "incorrect number of results";
return nullptr;
}
};
};
/// This class provides the API for ops which have an unknown number of
/// results.
template <typename ConcreteType>
class VariadicResults : public TraitImpl<ConcreteType, VariadicResults> {
class VariadicResults : public TraitBase<ConcreteType, VariadicResults> {
public:
unsigned getNumResults() const {
return this->getOperation()->getNumResults();
@ -401,7 +419,7 @@ public:
}
};
} // end namespace OpImpl
} // end namespace OpTrait
} // end namespace mlir

View File

@ -37,7 +37,7 @@ class OperationSet;
/// %2 = addf %0, %1 : f32
///
class AddFOp
: public OpImpl::Base<AddFOp, OpImpl::TwoOperands, OpImpl::OneResult> {
: public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult> {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "addf"; }
@ -48,7 +48,7 @@ public:
private:
friend class Operation;
explicit AddFOp(const Operation *state) : Base(state) {}
explicit AddFOp(const Operation *state) : OpBase(state) {}
};
/// The "affine_apply" operation applies an affine map to a list of operands,
@ -65,9 +65,8 @@ private:
/// #map42 = (d0)->(d0+1)
/// %y = affine_apply #map42(%x)
///
class AffineApplyOp
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
OpImpl::VariadicResults> {
class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
@ -84,7 +83,7 @@ public:
private:
friend class Operation;
explicit AffineApplyOp(const Operation *state) : Base(state) {}
explicit AffineApplyOp(const Operation *state) : OpBase(state) {}
};
/// The "constant" operation requires a single attribute named "value".
@ -94,7 +93,8 @@ private:
/// %2 = "constant"(){value: @foo} : (f32)->f32
///
class ConstantOp
: public OpImpl::Base<ConstantOp, OpImpl::ZeroOperands, OpImpl::OneResult> {
: public OpBase<ConstantOp, OpTrait::ZeroOperands, OpTrait::OneResult/*,
OpTrait::HasAttributeBase<"foo">::Impl*/> {
public:
Attribute *getValue() const { return getAttr("value"); }
@ -106,7 +106,7 @@ public:
protected:
friend class Operation;
explicit ConstantOp(const Operation *state) : Base(state) {}
explicit ConstantOp(const Operation *state) : OpBase(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
@ -133,8 +133,7 @@ private:
///
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
///
class DimOp
: public OpImpl::Base<DimOp, OpImpl::OneOperand, OpImpl::OneResult> {
class DimOp : public OpBase<DimOp, OpTrait::OneOperand, OpTrait::OneResult> {
public:
/// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const {
@ -151,7 +150,7 @@ public:
private:
friend class Operation;
explicit DimOp(const Operation *state) : Base(state) {}
explicit DimOp(const Operation *state) : OpBase(state) {}
};
/// The "load" op reads an element from a memref specified by an index list. The
@ -163,7 +162,7 @@ private:
/// %3 = load %0[%1, %1] : memref<4x4xi32>
///
class LoadOp
: public OpImpl::Base<LoadOp, OpImpl::VariadicOperands, OpImpl::OneResult> {
: public OpBase<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
public:
SSAValue *getMemRef() { return getOperand(0); }
const SSAValue *getMemRef() const { return getOperand(0); }
@ -185,7 +184,7 @@ public:
private:
friend class Operation;
explicit LoadOp(const Operation *state) : Base(state) {}
explicit LoadOp(const Operation *state) : OpBase(state) {}
};
/// Install the standard operations in the specified operation set.

View File

@ -27,13 +27,13 @@ using llvm::StringMap;
OpAsmParser::~OpAsmParser() {}
// The fallback for the printer is to reject the short form.
OpAsmParserResult OpImpl::BaseState::parse(OpAsmParser *parser) {
OpAsmParserResult OpBaseState::parse(OpAsmParser *parser) {
parser->emitError(parser->getNameLoc(), "has no concise form");
return {};
}
// The fallback for the printer is to print it the longhand form.
void OpImpl::BaseState::print(OpAsmPrinter *p) const {
void OpBaseState::print(OpAsmPrinter *p) const {
p->printDefaultOp(getOperation());
}