forked from OSchip/llvm-project
Improvements to Op trait implementation:
- Generalize TwoOperands and TwoResults to NOperands and NResults, which can be used for any fixed N. - Rename OpImpl namespace to OpTrait, OpImpl::Base to OpBase, and TraitImpl to TraitBase to better reflect what these are. PiperOrigin-RevId: 206588634
This commit is contained in:
parent
775130b6b9
commit
467c5cb3ba
|
@ -15,9 +15,10 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements helper classes for implementing the "Op" types. Most of
|
||||
// this goes into the mlir::OpImpl namespace since they are only used by code
|
||||
// that is defining the op implementations, not by clients.
|
||||
// This file implements helper classes for implementing the "Op" types. This
|
||||
// includes the OpBase type, which is the base class for Op class definitions,
|
||||
// as well as number of traits in the OpTrait namespace that provide a
|
||||
// declarative way to specify properties of Ops.
|
||||
//
|
||||
// The purpose of these types are to allow light-weight implementation of
|
||||
// concrete ops (like DimOp) with very little boilerplate.
|
||||
|
@ -85,11 +86,6 @@ struct OpAsmParserResult {
|
|||
attributes(attributes.begin(), attributes.end()) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpImpl Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace OpImpl {
|
||||
|
||||
/// This is the concrete base class that holds the operation pointer and has
|
||||
/// non-generic methods that only depend on State (to avoid having them
|
||||
|
@ -97,7 +93,7 @@ namespace OpImpl {
|
|||
///
|
||||
/// This also has the fallback implementations of customization hooks for when
|
||||
/// they aren't customized.
|
||||
class BaseState {
|
||||
class OpBaseState {
|
||||
public:
|
||||
/// Return the operation that this refers to.
|
||||
const Operation *getOperation() const { return state; }
|
||||
|
@ -133,7 +129,7 @@ protected:
|
|||
|
||||
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
|
||||
/// so we can cast it away here.
|
||||
explicit BaseState(const Operation *state)
|
||||
explicit OpBaseState(const Operation *state)
|
||||
: state(const_cast<Operation *>(state)) {}
|
||||
|
||||
private:
|
||||
|
@ -144,11 +140,11 @@ private:
|
|||
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
|
||||
/// are base classes by the policy pattern.
|
||||
template <typename ConcreteType, template <typename T> class... Traits>
|
||||
class Base : public BaseState, public Traits<ConcreteType>... {
|
||||
class OpBase : public OpBaseState, public Traits<ConcreteType>... {
|
||||
public:
|
||||
/// Return the operation that this refers to.
|
||||
const Operation *getOperation() const { return BaseState::getOperation(); }
|
||||
Operation *getOperation() { return BaseState::getOperation(); }
|
||||
const Operation *getOperation() const { return OpBaseState::getOperation(); }
|
||||
Operation *getOperation() { return OpBaseState::getOperation(); }
|
||||
|
||||
/// Return true if this "op class" can match against the specified operation.
|
||||
/// This hook can be overridden with a more specific implementation in
|
||||
|
@ -182,7 +178,7 @@ public:
|
|||
// TODO: Provide a dump() method.
|
||||
|
||||
protected:
|
||||
explicit Base(const Operation *state) : BaseState(state) {}
|
||||
explicit OpBase(const Operation *state) : OpBaseState(state) {}
|
||||
|
||||
private:
|
||||
template <typename... Types>
|
||||
|
@ -210,24 +206,30 @@ private:
|
|||
};
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation Trait Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace OpTrait {
|
||||
|
||||
/// Helper class for implementing traits. Clients are not expected to interact
|
||||
/// with this directly, so its members are all protected.
|
||||
template <typename ConcreteType, template <typename> class TraitType>
|
||||
class TraitImpl {
|
||||
class TraitBase {
|
||||
protected:
|
||||
/// Return the ultimate Operation being worked on.
|
||||
Operation *getOperation() {
|
||||
// We have to cast up to the trait type, then to the concrete type, then to
|
||||
// the BaseState class in explicit hops because the concrete type will
|
||||
// multiply derive from the (content free) TraitImpl class, and we need to
|
||||
// multiply derive from the (content free) TraitBase class, and we need to
|
||||
// be able to disambiguate the path for the C++ compiler.
|
||||
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
|
||||
auto *concrete = static_cast<ConcreteType *>(trait);
|
||||
auto *base = static_cast<BaseState *>(concrete);
|
||||
auto *base = static_cast<OpBaseState *>(concrete);
|
||||
return base->getOperation();
|
||||
}
|
||||
const Operation *getOperation() const {
|
||||
return const_cast<TraitImpl *>(this)->getOperation();
|
||||
return const_cast<TraitBase *>(this)->getOperation();
|
||||
}
|
||||
|
||||
/// Provide default implementations of trait hooks. This allows traits to
|
||||
|
@ -238,7 +240,7 @@ protected:
|
|||
/// This class provides the API for ops that are known to have exactly one
|
||||
/// SSA operand.
|
||||
template <typename ConcreteType>
|
||||
class ZeroOperands : public TraitImpl<ConcreteType, ZeroOperands> {
|
||||
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
|
||||
public:
|
||||
static const char *verifyTrait(const Operation *op) {
|
||||
if (op->getNumOperands() != 0)
|
||||
|
@ -255,7 +257,7 @@ private:
|
|||
/// This class provides the API for ops that are known to have exactly one
|
||||
/// SSA operand.
|
||||
template <typename ConcreteType>
|
||||
class OneOperand : public TraitImpl<ConcreteType, OneOperand> {
|
||||
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
|
||||
public:
|
||||
const SSAValue *getOperand() const {
|
||||
return this->getOperation()->getOperand(0);
|
||||
|
@ -274,10 +276,15 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to have exactly two
|
||||
/// SSA operands.
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of operands. This is used as a trait like this:
|
||||
///
|
||||
/// class FooOp : public OpBase<FooOp, OpTrait::NOperands<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class NOperands {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class TwoOperands : public TraitImpl<ConcreteType, TwoOperands> {
|
||||
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
|
||||
public:
|
||||
const SSAValue *getOperand(unsigned i) const {
|
||||
return this->getOperation()->getOperand(i);
|
||||
|
@ -292,16 +299,18 @@ public:
|
|||
}
|
||||
|
||||
static const char *verifyTrait(const Operation *op) {
|
||||
if (op->getNumOperands() != 2)
|
||||
return "requires two operands";
|
||||
// TODO(clattner): Allow verifier to return non-constant string.
|
||||
if (op->getNumOperands() != N)
|
||||
return "incorrect number of operands";
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// This class provides the API for ops which have an unknown number of
|
||||
/// SSA operands.
|
||||
template <typename ConcreteType>
|
||||
class VariadicOperands : public TraitImpl<ConcreteType, VariadicOperands> {
|
||||
class VariadicOperands : public TraitBase<ConcreteType, VariadicOperands> {
|
||||
public:
|
||||
unsigned getNumOperands() const {
|
||||
return this->getOperation()->getNumOperands();
|
||||
|
@ -345,7 +354,7 @@ public:
|
|||
/// This class provides return value APIs for ops that are known to have a
|
||||
/// single result.
|
||||
template <typename ConcreteType>
|
||||
class OneResult : public TraitImpl<ConcreteType, OneResult> {
|
||||
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
||||
public:
|
||||
SSAValue *getResult() { return this->getOperation()->getResult(0); }
|
||||
const SSAValue *getResult() const {
|
||||
|
@ -361,30 +370,39 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to have exactly two
|
||||
/// results.
|
||||
/// This class provides the API for ops that are known to have a specified
|
||||
/// number of results. This is used as a trait like this:
|
||||
///
|
||||
/// class FooOp : public OpBase<FooOp, OpTrait::NResults<2>::Impl> {
|
||||
///
|
||||
template <unsigned N> class NResults {
|
||||
public:
|
||||
template <typename ConcreteType>
|
||||
class TwoResults : public TraitImpl<ConcreteType, TwoResults> {
|
||||
class Impl : public TraitBase<ConcreteType, NResults<N>::Impl> {
|
||||
public:
|
||||
const SSAValue *getResult(unsigned i) const {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
||||
SSAValue *getResult(unsigned i) {
|
||||
return this->getOperation()->getResult(i);
|
||||
}
|
||||
|
||||
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
||||
|
||||
static const char *verifyTrait(const Operation *op) {
|
||||
if (op->getNumResults() != 2)
|
||||
return "requires two results";
|
||||
// TODO(clattner): Allow verifier to return non-constant string.
|
||||
if (op->getNumResults() != N)
|
||||
return "incorrect number of results";
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/// This class provides the API for ops which have an unknown number of
|
||||
/// results.
|
||||
template <typename ConcreteType>
|
||||
class VariadicResults : public TraitImpl<ConcreteType, VariadicResults> {
|
||||
class VariadicResults : public TraitBase<ConcreteType, VariadicResults> {
|
||||
public:
|
||||
unsigned getNumResults() const {
|
||||
return this->getOperation()->getNumResults();
|
||||
|
@ -401,7 +419,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
} // end namespace OpImpl
|
||||
} // end namespace OpTrait
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class OperationSet;
|
|||
/// %2 = addf %0, %1 : f32
|
||||
///
|
||||
class AddFOp
|
||||
: public OpImpl::Base<AddFOp, OpImpl::TwoOperands, OpImpl::OneResult> {
|
||||
: public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult> {
|
||||
public:
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "addf"; }
|
||||
|
@ -48,7 +48,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit AddFOp(const Operation *state) : Base(state) {}
|
||||
explicit AddFOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// The "affine_apply" operation applies an affine map to a list of operands,
|
||||
|
@ -65,9 +65,8 @@ private:
|
|||
/// #map42 = (d0)->(d0+1)
|
||||
/// %y = affine_apply #map42(%x)
|
||||
///
|
||||
class AffineApplyOp
|
||||
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
|
||||
OpImpl::VariadicResults> {
|
||||
class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
// Returns the affine map to be applied by this operation.
|
||||
AffineMap *getAffineMap() const {
|
||||
|
@ -84,7 +83,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit AffineApplyOp(const Operation *state) : Base(state) {}
|
||||
explicit AffineApplyOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// The "constant" operation requires a single attribute named "value".
|
||||
|
@ -94,7 +93,8 @@ private:
|
|||
/// %2 = "constant"(){value: @foo} : (f32)->f32
|
||||
///
|
||||
class ConstantOp
|
||||
: public OpImpl::Base<ConstantOp, OpImpl::ZeroOperands, OpImpl::OneResult> {
|
||||
: public OpBase<ConstantOp, OpTrait::ZeroOperands, OpTrait::OneResult/*,
|
||||
OpTrait::HasAttributeBase<"foo">::Impl*/> {
|
||||
public:
|
||||
Attribute *getValue() const { return getAttr("value"); }
|
||||
|
||||
|
@ -106,7 +106,7 @@ public:
|
|||
|
||||
protected:
|
||||
friend class Operation;
|
||||
explicit ConstantOp(const Operation *state) : Base(state) {}
|
||||
explicit ConstantOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// This is a refinement of the "constant" op for the case where it is
|
||||
|
@ -133,8 +133,7 @@ private:
|
|||
///
|
||||
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
|
||||
///
|
||||
class DimOp
|
||||
: public OpImpl::Base<DimOp, OpImpl::OneOperand, OpImpl::OneResult> {
|
||||
class DimOp : public OpBase<DimOp, OpTrait::OneOperand, OpTrait::OneResult> {
|
||||
public:
|
||||
/// This returns the dimension number that the 'dim' is inspecting.
|
||||
unsigned getIndex() const {
|
||||
|
@ -151,7 +150,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit DimOp(const Operation *state) : Base(state) {}
|
||||
explicit DimOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// The "load" op reads an element from a memref specified by an index list. The
|
||||
|
@ -163,7 +162,7 @@ private:
|
|||
/// %3 = load %0[%1, %1] : memref<4x4xi32>
|
||||
///
|
||||
class LoadOp
|
||||
: public OpImpl::Base<LoadOp, OpImpl::VariadicOperands, OpImpl::OneResult> {
|
||||
: public OpBase<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
|
||||
public:
|
||||
SSAValue *getMemRef() { return getOperand(0); }
|
||||
const SSAValue *getMemRef() const { return getOperand(0); }
|
||||
|
@ -185,7 +184,7 @@ public:
|
|||
|
||||
private:
|
||||
friend class Operation;
|
||||
explicit LoadOp(const Operation *state) : Base(state) {}
|
||||
explicit LoadOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// Install the standard operations in the specified operation set.
|
||||
|
|
|
@ -27,13 +27,13 @@ using llvm::StringMap;
|
|||
OpAsmParser::~OpAsmParser() {}
|
||||
|
||||
// The fallback for the printer is to reject the short form.
|
||||
OpAsmParserResult OpImpl::BaseState::parse(OpAsmParser *parser) {
|
||||
OpAsmParserResult OpBaseState::parse(OpAsmParser *parser) {
|
||||
parser->emitError(parser->getNameLoc(), "has no concise form");
|
||||
return {};
|
||||
}
|
||||
|
||||
// The fallback for the printer is to print it the longhand form.
|
||||
void OpImpl::BaseState::print(OpAsmPrinter *p) const {
|
||||
void OpBaseState::print(OpAsmPrinter *p) const {
|
||||
p->printDefaultOp(getOperation());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue