forked from OSchip/llvm-project
Improvements to Op trait implementation:
- Generalize TwoOperands and TwoResults to NOperands and NResults, which can be used for any fixed N. - Rename OpImpl namespace to OpTrait, OpImpl::Base to OpBase, and TraitImpl to TraitBase to better reflect what these are. PiperOrigin-RevId: 206588634
This commit is contained in:
parent
775130b6b9
commit
467c5cb3ba
|
@ -15,9 +15,10 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
//
|
//
|
||||||
// This file implements helper classes for implementing the "Op" types. Most of
|
// This file implements helper classes for implementing the "Op" types. This
|
||||||
// this goes into the mlir::OpImpl namespace since they are only used by code
|
// includes the OpBase type, which is the base class for Op class definitions,
|
||||||
// that is defining the op implementations, not by clients.
|
// as well as number of traits in the OpTrait namespace that provide a
|
||||||
|
// declarative way to specify properties of Ops.
|
||||||
//
|
//
|
||||||
// The purpose of these types are to allow light-weight implementation of
|
// The purpose of these types are to allow light-weight implementation of
|
||||||
// concrete ops (like DimOp) with very little boilerplate.
|
// concrete ops (like DimOp) with very little boilerplate.
|
||||||
|
@ -85,11 +86,6 @@ struct OpAsmParserResult {
|
||||||
attributes(attributes.begin(), attributes.end()) {}
|
attributes(attributes.begin(), attributes.end()) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// OpImpl Types
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace OpImpl {
|
|
||||||
|
|
||||||
/// This is the concrete base class that holds the operation pointer and has
|
/// This is the concrete base class that holds the operation pointer and has
|
||||||
/// non-generic methods that only depend on State (to avoid having them
|
/// non-generic methods that only depend on State (to avoid having them
|
||||||
|
@ -97,7 +93,7 @@ namespace OpImpl {
|
||||||
///
|
///
|
||||||
/// This also has the fallback implementations of customization hooks for when
|
/// This also has the fallback implementations of customization hooks for when
|
||||||
/// they aren't customized.
|
/// they aren't customized.
|
||||||
class BaseState {
|
class OpBaseState {
|
||||||
public:
|
public:
|
||||||
/// Return the operation that this refers to.
|
/// Return the operation that this refers to.
|
||||||
const Operation *getOperation() const { return state; }
|
const Operation *getOperation() const { return state; }
|
||||||
|
@ -133,7 +129,7 @@ protected:
|
||||||
|
|
||||||
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
|
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
|
||||||
/// so we can cast it away here.
|
/// so we can cast it away here.
|
||||||
explicit BaseState(const Operation *state)
|
explicit OpBaseState(const Operation *state)
|
||||||
: state(const_cast<Operation *>(state)) {}
|
: state(const_cast<Operation *>(state)) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -144,11 +140,11 @@ private:
|
||||||
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
|
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
|
||||||
/// are base classes by the policy pattern.
|
/// are base classes by the policy pattern.
|
||||||
template <typename ConcreteType, template <typename T> class... Traits>
|
template <typename ConcreteType, template <typename T> class... Traits>
|
||||||
class Base : public BaseState, public Traits<ConcreteType>... {
|
class OpBase : public OpBaseState, public Traits<ConcreteType>... {
|
||||||
public:
|
public:
|
||||||
/// Return the operation that this refers to.
|
/// Return the operation that this refers to.
|
||||||
const Operation *getOperation() const { return BaseState::getOperation(); }
|
const Operation *getOperation() const { return OpBaseState::getOperation(); }
|
||||||
Operation *getOperation() { return BaseState::getOperation(); }
|
Operation *getOperation() { return OpBaseState::getOperation(); }
|
||||||
|
|
||||||
/// Return true if this "op class" can match against the specified operation.
|
/// Return true if this "op class" can match against the specified operation.
|
||||||
/// This hook can be overridden with a more specific implementation in
|
/// This hook can be overridden with a more specific implementation in
|
||||||
|
@ -182,7 +178,7 @@ public:
|
||||||
// TODO: Provide a dump() method.
|
// TODO: Provide a dump() method.
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit Base(const Operation *state) : BaseState(state) {}
|
explicit OpBase(const Operation *state) : OpBaseState(state) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename... Types>
|
template <typename... Types>
|
||||||
|
@ -210,24 +206,30 @@ private:
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Operation Trait Types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace OpTrait {
|
||||||
|
|
||||||
/// Helper class for implementing traits. Clients are not expected to interact
|
/// Helper class for implementing traits. Clients are not expected to interact
|
||||||
/// with this directly, so its members are all protected.
|
/// with this directly, so its members are all protected.
|
||||||
template <typename ConcreteType, template <typename> class TraitType>
|
template <typename ConcreteType, template <typename> class TraitType>
|
||||||
class TraitImpl {
|
class TraitBase {
|
||||||
protected:
|
protected:
|
||||||
/// Return the ultimate Operation being worked on.
|
/// Return the ultimate Operation being worked on.
|
||||||
Operation *getOperation() {
|
Operation *getOperation() {
|
||||||
// We have to cast up to the trait type, then to the concrete type, then to
|
// We have to cast up to the trait type, then to the concrete type, then to
|
||||||
// the BaseState class in explicit hops because the concrete type will
|
// the BaseState class in explicit hops because the concrete type will
|
||||||
// multiply derive from the (content free) TraitImpl class, and we need to
|
// multiply derive from the (content free) TraitBase class, and we need to
|
||||||
// be able to disambiguate the path for the C++ compiler.
|
// be able to disambiguate the path for the C++ compiler.
|
||||||
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
|
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
|
||||||
auto *concrete = static_cast<ConcreteType *>(trait);
|
auto *concrete = static_cast<ConcreteType *>(trait);
|
||||||
auto *base = static_cast<BaseState *>(concrete);
|
auto *base = static_cast<OpBaseState *>(concrete);
|
||||||
return base->getOperation();
|
return base->getOperation();
|
||||||
}
|
}
|
||||||
const Operation *getOperation() const {
|
const Operation *getOperation() const {
|
||||||
return const_cast<TraitImpl *>(this)->getOperation();
|
return const_cast<TraitBase *>(this)->getOperation();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provide default implementations of trait hooks. This allows traits to
|
/// Provide default implementations of trait hooks. This allows traits to
|
||||||
|
@ -238,7 +240,7 @@ protected:
|
||||||
/// This class provides the API for ops that are known to have exactly one
|
/// This class provides the API for ops that are known to have exactly one
|
||||||
/// SSA operand.
|
/// SSA operand.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class ZeroOperands : public TraitImpl<ConcreteType, ZeroOperands> {
|
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
|
||||||
public:
|
public:
|
||||||
static const char *verifyTrait(const Operation *op) {
|
static const char *verifyTrait(const Operation *op) {
|
||||||
if (op->getNumOperands() != 0)
|
if (op->getNumOperands() != 0)
|
||||||
|
@ -255,7 +257,7 @@ private:
|
||||||
/// This class provides the API for ops that are known to have exactly one
|
/// This class provides the API for ops that are known to have exactly one
|
||||||
/// SSA operand.
|
/// SSA operand.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class OneOperand : public TraitImpl<ConcreteType, OneOperand> {
|
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
|
||||||
public:
|
public:
|
||||||
const SSAValue *getOperand() const {
|
const SSAValue *getOperand() const {
|
||||||
return this->getOperation()->getOperand(0);
|
return this->getOperation()->getOperand(0);
|
||||||
|
@ -274,10 +276,15 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This class provides the API for ops that are known to have exactly two
|
/// This class provides the API for ops that are known to have a specified
|
||||||
/// SSA operands.
|
/// number of operands. This is used as a trait like this:
|
||||||
|
///
|
||||||
|
/// class FooOp : public OpBase<FooOp, OpTrait::NOperands<2>::Impl> {
|
||||||
|
///
|
||||||
|
template <unsigned N> class NOperands {
|
||||||
|
public:
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class TwoOperands : public TraitImpl<ConcreteType, TwoOperands> {
|
class Impl : public TraitBase<ConcreteType, NOperands<N>::Impl> {
|
||||||
public:
|
public:
|
||||||
const SSAValue *getOperand(unsigned i) const {
|
const SSAValue *getOperand(unsigned i) const {
|
||||||
return this->getOperation()->getOperand(i);
|
return this->getOperation()->getOperand(i);
|
||||||
|
@ -292,16 +299,18 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char *verifyTrait(const Operation *op) {
|
static const char *verifyTrait(const Operation *op) {
|
||||||
if (op->getNumOperands() != 2)
|
// TODO(clattner): Allow verifier to return non-constant string.
|
||||||
return "requires two operands";
|
if (op->getNumOperands() != N)
|
||||||
|
return "incorrect number of operands";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
};
|
||||||
|
|
||||||
/// This class provides the API for ops which have an unknown number of
|
/// This class provides the API for ops which have an unknown number of
|
||||||
/// SSA operands.
|
/// SSA operands.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class VariadicOperands : public TraitImpl<ConcreteType, VariadicOperands> {
|
class VariadicOperands : public TraitBase<ConcreteType, VariadicOperands> {
|
||||||
public:
|
public:
|
||||||
unsigned getNumOperands() const {
|
unsigned getNumOperands() const {
|
||||||
return this->getOperation()->getNumOperands();
|
return this->getOperation()->getNumOperands();
|
||||||
|
@ -345,7 +354,7 @@ public:
|
||||||
/// This class provides return value APIs for ops that are known to have a
|
/// This class provides return value APIs for ops that are known to have a
|
||||||
/// single result.
|
/// single result.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class OneResult : public TraitImpl<ConcreteType, OneResult> {
|
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
||||||
public:
|
public:
|
||||||
SSAValue *getResult() { return this->getOperation()->getResult(0); }
|
SSAValue *getResult() { return this->getOperation()->getResult(0); }
|
||||||
const SSAValue *getResult() const {
|
const SSAValue *getResult() const {
|
||||||
|
@ -361,30 +370,39 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This class provides the API for ops that are known to have exactly two
|
/// This class provides the API for ops that are known to have a specified
|
||||||
/// results.
|
/// number of results. This is used as a trait like this:
|
||||||
|
///
|
||||||
|
/// class FooOp : public OpBase<FooOp, OpTrait::NResults<2>::Impl> {
|
||||||
|
///
|
||||||
|
template <unsigned N> class NResults {
|
||||||
|
public:
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class TwoResults : public TraitImpl<ConcreteType, TwoResults> {
|
class Impl : public TraitBase<ConcreteType, NResults<N>::Impl> {
|
||||||
public:
|
public:
|
||||||
const SSAValue *getResult(unsigned i) const {
|
const SSAValue *getResult(unsigned i) const {
|
||||||
return this->getOperation()->getResult(i);
|
return this->getOperation()->getResult(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
SSAValue *getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
SSAValue *getResult(unsigned i) {
|
||||||
|
return this->getOperation()->getResult(i);
|
||||||
|
}
|
||||||
|
|
||||||
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
Type *getType(unsigned i) const { return getResult(i)->getType(); }
|
||||||
|
|
||||||
static const char *verifyTrait(const Operation *op) {
|
static const char *verifyTrait(const Operation *op) {
|
||||||
if (op->getNumResults() != 2)
|
// TODO(clattner): Allow verifier to return non-constant string.
|
||||||
return "requires two results";
|
if (op->getNumResults() != N)
|
||||||
|
return "incorrect number of results";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
};
|
||||||
|
|
||||||
/// This class provides the API for ops which have an unknown number of
|
/// This class provides the API for ops which have an unknown number of
|
||||||
/// results.
|
/// results.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class VariadicResults : public TraitImpl<ConcreteType, VariadicResults> {
|
class VariadicResults : public TraitBase<ConcreteType, VariadicResults> {
|
||||||
public:
|
public:
|
||||||
unsigned getNumResults() const {
|
unsigned getNumResults() const {
|
||||||
return this->getOperation()->getNumResults();
|
return this->getOperation()->getNumResults();
|
||||||
|
@ -401,7 +419,7 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace OpImpl
|
} // end namespace OpTrait
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ class OperationSet;
|
||||||
/// %2 = addf %0, %1 : f32
|
/// %2 = addf %0, %1 : f32
|
||||||
///
|
///
|
||||||
class AddFOp
|
class AddFOp
|
||||||
: public OpImpl::Base<AddFOp, OpImpl::TwoOperands, OpImpl::OneResult> {
|
: public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult> {
|
||||||
public:
|
public:
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static StringRef getOperationName() { return "addf"; }
|
static StringRef getOperationName() { return "addf"; }
|
||||||
|
@ -48,7 +48,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
explicit AddFOp(const Operation *state) : Base(state) {}
|
explicit AddFOp(const Operation *state) : OpBase(state) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The "affine_apply" operation applies an affine map to a list of operands,
|
/// The "affine_apply" operation applies an affine map to a list of operands,
|
||||||
|
@ -65,9 +65,8 @@ private:
|
||||||
/// #map42 = (d0)->(d0+1)
|
/// #map42 = (d0)->(d0+1)
|
||||||
/// %y = affine_apply #map42(%x)
|
/// %y = affine_apply #map42(%x)
|
||||||
///
|
///
|
||||||
class AffineApplyOp
|
class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
|
||||||
: public OpImpl::Base<AffineApplyOp, OpImpl::VariadicOperands,
|
OpTrait::VariadicResults> {
|
||||||
OpImpl::VariadicResults> {
|
|
||||||
public:
|
public:
|
||||||
// Returns the affine map to be applied by this operation.
|
// Returns the affine map to be applied by this operation.
|
||||||
AffineMap *getAffineMap() const {
|
AffineMap *getAffineMap() const {
|
||||||
|
@ -84,7 +83,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
explicit AffineApplyOp(const Operation *state) : Base(state) {}
|
explicit AffineApplyOp(const Operation *state) : OpBase(state) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The "constant" operation requires a single attribute named "value".
|
/// The "constant" operation requires a single attribute named "value".
|
||||||
|
@ -94,7 +93,8 @@ private:
|
||||||
/// %2 = "constant"(){value: @foo} : (f32)->f32
|
/// %2 = "constant"(){value: @foo} : (f32)->f32
|
||||||
///
|
///
|
||||||
class ConstantOp
|
class ConstantOp
|
||||||
: public OpImpl::Base<ConstantOp, OpImpl::ZeroOperands, OpImpl::OneResult> {
|
: public OpBase<ConstantOp, OpTrait::ZeroOperands, OpTrait::OneResult/*,
|
||||||
|
OpTrait::HasAttributeBase<"foo">::Impl*/> {
|
||||||
public:
|
public:
|
||||||
Attribute *getValue() const { return getAttr("value"); }
|
Attribute *getValue() const { return getAttr("value"); }
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ public:
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
explicit ConstantOp(const Operation *state) : Base(state) {}
|
explicit ConstantOp(const Operation *state) : OpBase(state) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This is a refinement of the "constant" op for the case where it is
|
/// This is a refinement of the "constant" op for the case where it is
|
||||||
|
@ -133,8 +133,7 @@ private:
|
||||||
///
|
///
|
||||||
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
|
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
|
||||||
///
|
///
|
||||||
class DimOp
|
class DimOp : public OpBase<DimOp, OpTrait::OneOperand, OpTrait::OneResult> {
|
||||||
: public OpImpl::Base<DimOp, OpImpl::OneOperand, OpImpl::OneResult> {
|
|
||||||
public:
|
public:
|
||||||
/// This returns the dimension number that the 'dim' is inspecting.
|
/// This returns the dimension number that the 'dim' is inspecting.
|
||||||
unsigned getIndex() const {
|
unsigned getIndex() const {
|
||||||
|
@ -151,7 +150,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
explicit DimOp(const Operation *state) : Base(state) {}
|
explicit DimOp(const Operation *state) : OpBase(state) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The "load" op reads an element from a memref specified by an index list. The
|
/// The "load" op reads an element from a memref specified by an index list. The
|
||||||
|
@ -163,7 +162,7 @@ private:
|
||||||
/// %3 = load %0[%1, %1] : memref<4x4xi32>
|
/// %3 = load %0[%1, %1] : memref<4x4xi32>
|
||||||
///
|
///
|
||||||
class LoadOp
|
class LoadOp
|
||||||
: public OpImpl::Base<LoadOp, OpImpl::VariadicOperands, OpImpl::OneResult> {
|
: public OpBase<LoadOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
|
||||||
public:
|
public:
|
||||||
SSAValue *getMemRef() { return getOperand(0); }
|
SSAValue *getMemRef() { return getOperand(0); }
|
||||||
const SSAValue *getMemRef() const { return getOperand(0); }
|
const SSAValue *getMemRef() const { return getOperand(0); }
|
||||||
|
@ -185,7 +184,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
explicit LoadOp(const Operation *state) : Base(state) {}
|
explicit LoadOp(const Operation *state) : OpBase(state) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Install the standard operations in the specified operation set.
|
/// Install the standard operations in the specified operation set.
|
||||||
|
|
|
@ -27,13 +27,13 @@ using llvm::StringMap;
|
||||||
OpAsmParser::~OpAsmParser() {}
|
OpAsmParser::~OpAsmParser() {}
|
||||||
|
|
||||||
// The fallback for the printer is to reject the short form.
|
// The fallback for the printer is to reject the short form.
|
||||||
OpAsmParserResult OpImpl::BaseState::parse(OpAsmParser *parser) {
|
OpAsmParserResult OpBaseState::parse(OpAsmParser *parser) {
|
||||||
parser->emitError(parser->getNameLoc(), "has no concise form");
|
parser->emitError(parser->getNameLoc(), "has no concise form");
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// The fallback for the printer is to print it the longhand form.
|
// The fallback for the printer is to print it the longhand form.
|
||||||
void OpImpl::BaseState::print(OpAsmPrinter *p) const {
|
void OpBaseState::print(OpAsmPrinter *p) const {
|
||||||
p->printDefaultOp(getOperation());
|
p->printDefaultOp(getOperation());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue