forked from OSchip/llvm-project
1231 lines
47 KiB
C++
1231 lines
47 KiB
C++
//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements helper classes for implementing the "Op" types. This
|
|
// includes the Op type, which is the base class for Op class definitions,
|
|
// as well as number of traits in the OpTrait namespace that provide a
|
|
// declarative way to specify properties of Ops.
|
|
//
|
|
// The purpose of these types are to allow light-weight implementation of
|
|
// concrete ops (like DimOp) with very little boilerplate.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef MLIR_IR_OPDEFINITION_H
|
|
#define MLIR_IR_OPDEFINITION_H
|
|
|
|
#include "mlir/IR/Operation.h"
|
|
#include "llvm/Support/PointerLikeTypeTraits.h"
|
|
#include <type_traits>
|
|
|
|
namespace mlir {
|
|
class Builder;
|
|
|
|
namespace OpTrait {
|
|
template <typename ConcreteType> class OneResult;
|
|
}
|
|
|
|
/// This class represents success/failure for operation parsing. It is
|
|
/// essentially a simple wrapper class around LogicalResult that allows for
|
|
/// explicit conversion to bool. This allows for the parser to chain together
|
|
/// parse rules without the clutter of "failed/succeeded".
|
|
class ParseResult : public LogicalResult {
|
|
public:
|
|
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
|
|
|
|
// Allow diagnostics emitted during parsing to be converted to failure.
|
|
ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
|
|
ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
|
|
|
|
/// Failure is true in a boolean context.
|
|
explicit operator bool() const { return failed(*this); }
|
|
};
|
|
/// This class implements `Optional` functionality for ParseResult. We don't
|
|
/// directly use Optional here, because it provides an implicit conversion
|
|
/// to 'bool' which we want to avoid. This class is used to implement tri-state
|
|
/// 'parseOptional' functions that may have a failure mode when parsing that
|
|
/// shouldn't be attributed to "not present".
|
|
class OptionalParseResult {
|
|
public:
|
|
OptionalParseResult() = default;
|
|
OptionalParseResult(LogicalResult result) : impl(result) {}
|
|
OptionalParseResult(ParseResult result) : impl(result) {}
|
|
OptionalParseResult(const InFlightDiagnostic &)
|
|
: OptionalParseResult(failure()) {}
|
|
OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
|
|
|
|
/// Returns true if we contain a valid ParseResult value.
|
|
bool hasValue() const { return impl.hasValue(); }
|
|
|
|
/// Access the internal ParseResult value.
|
|
ParseResult getValue() const { return impl.getValue(); }
|
|
ParseResult operator*() const { return getValue(); }
|
|
|
|
private:
|
|
Optional<ParseResult> impl;
|
|
};
|
|
|
|
// These functions are out-of-line utilities, which avoids them being template
|
|
// instantiated/duplicated.
|
|
namespace impl {
|
|
/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
|
|
/// region's only block if it does not have a terminator already. If the region
|
|
/// is empty, insert a new block first. `buildTerminatorOp` should return the
|
|
/// terminator operation to insert.
|
|
void ensureRegionTerminator(Region ®ion, Location loc,
|
|
function_ref<Operation *()> buildTerminatorOp);
|
|
/// Templated version that fills the generates the provided operation type.
|
|
template <typename OpTy>
|
|
void ensureRegionTerminator(Region ®ion, Builder &builder, Location loc) {
|
|
ensureRegionTerminator(region, loc, [&] {
|
|
OperationState state(loc, OpTy::getOperationName());
|
|
OpTy::build(&builder, state);
|
|
return Operation::create(state);
|
|
});
|
|
}
|
|
} // namespace impl
|
|
|
|
/// This is the concrete base class that holds the operation pointer and has
|
|
/// non-generic methods that only depend on State (to avoid having them
|
|
/// instantiated on template types that don't affect them.
|
|
///
|
|
/// This also has the fallback implementations of customization hooks for when
|
|
/// they aren't customized.
|
|
class OpState {
|
|
public:
|
|
/// Ops are pointer-like, so we allow implicit conversion to bool.
|
|
operator bool() { return getOperation() != nullptr; }
|
|
|
|
/// This implicitly converts to Operation*.
|
|
operator Operation *() const { return state; }
|
|
|
|
/// Return the operation that this refers to.
|
|
Operation *getOperation() { return state; }
|
|
|
|
/// Returns the closest surrounding operation that contains this operation
|
|
/// or nullptr if this is a top-level operation.
|
|
Operation *getParentOp() { return getOperation()->getParentOp(); }
|
|
|
|
/// Return the closest surrounding parent operation that is of type 'OpTy'.
|
|
template <typename OpTy> OpTy getParentOfType() {
|
|
return getOperation()->getParentOfType<OpTy>();
|
|
}
|
|
|
|
/// Return the context this operation belongs to.
|
|
MLIRContext *getContext() { return getOperation()->getContext(); }
|
|
|
|
/// Print the operation to the given stream.
|
|
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
|
|
state->print(os, flags);
|
|
}
|
|
void print(raw_ostream &os, AsmState &asmState,
|
|
OpPrintingFlags flags = llvm::None) {
|
|
state->print(os, asmState, flags);
|
|
}
|
|
|
|
/// Dump this operation.
|
|
void dump() { state->dump(); }
|
|
|
|
/// The source location the operation was defined or derived from.
|
|
Location getLoc() { return state->getLoc(); }
|
|
void setLoc(Location loc) { state->setLoc(loc); }
|
|
|
|
/// Return all of the attributes on this operation.
|
|
ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
|
|
|
|
/// A utility iterator that filters out non-dialect attributes.
|
|
using dialect_attr_iterator = Operation::dialect_attr_iterator;
|
|
using dialect_attr_range = Operation::dialect_attr_range;
|
|
|
|
/// Return a range corresponding to the dialect attributes for this operation.
|
|
dialect_attr_range getDialectAttrs() { return state->getDialectAttrs(); }
|
|
dialect_attr_iterator dialect_attr_begin() {
|
|
return state->dialect_attr_begin();
|
|
}
|
|
dialect_attr_iterator dialect_attr_end() { return state->dialect_attr_end(); }
|
|
|
|
/// Return an attribute with the specified name.
|
|
Attribute getAttr(StringRef name) { return state->getAttr(name); }
|
|
|
|
/// If the operation has an attribute of the specified type, return it.
|
|
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
|
|
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
|
}
|
|
|
|
/// If the an attribute exists with the specified name, change it to the new
|
|
/// value. Otherwise, add a new attribute with the specified name/value.
|
|
void setAttr(Identifier name, Attribute value) {
|
|
state->setAttr(name, value);
|
|
}
|
|
void setAttr(StringRef name, Attribute value) {
|
|
setAttr(Identifier::get(name, getContext()), value);
|
|
}
|
|
|
|
/// Set the attributes held by this operation.
|
|
void setAttrs(ArrayRef<NamedAttribute> attributes) {
|
|
state->setAttrs(attributes);
|
|
}
|
|
void setAttrs(NamedAttributeList newAttrs) { state->setAttrs(newAttrs); }
|
|
|
|
/// Set the dialect attributes for this operation, and preserve all dependent.
|
|
template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
|
|
state->setDialectAttrs(std::move(attrs));
|
|
}
|
|
|
|
/// Remove the attribute with the specified name if it exists. The return
|
|
/// value indicates whether the attribute was present or not.
|
|
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
|
|
return state->removeAttr(name);
|
|
}
|
|
NamedAttributeList::RemoveResult removeAttr(StringRef name) {
|
|
return state->removeAttr(Identifier::get(name, getContext()));
|
|
}
|
|
|
|
/// Return true if there are no users of any results of this operation.
|
|
bool use_empty() { return state->use_empty(); }
|
|
|
|
/// Remove this operation from its parent block and delete it.
|
|
void erase() { state->erase(); }
|
|
|
|
/// Emit an error with the op name prefixed, like "'dim' op " which is
|
|
/// convenient for verifiers.
|
|
InFlightDiagnostic emitOpError(const Twine &message = {});
|
|
|
|
/// Emit an error about fatal conditions with this operation, reporting up to
|
|
/// any diagnostic handlers that may be listening.
|
|
InFlightDiagnostic emitError(const Twine &message = {});
|
|
|
|
/// Emit a warning about this operation, reporting up to any diagnostic
|
|
/// handlers that may be listening.
|
|
InFlightDiagnostic emitWarning(const Twine &message = {});
|
|
|
|
/// Emit a remark about this operation, reporting up to any diagnostic
|
|
/// handlers that may be listening.
|
|
InFlightDiagnostic emitRemark(const Twine &message = {});
|
|
|
|
/// Walk the operation in postorder, calling the callback for each nested
|
|
/// operation(including this one).
|
|
/// See Operation::walk for more details.
|
|
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
|
|
RetT walk(FnT &&callback) {
|
|
return state->walk(std::forward<FnT>(callback));
|
|
}
|
|
|
|
// These are default implementations of customization hooks.
|
|
public:
|
|
/// This hook returns any canonicalization pattern rewrites that the operation
|
|
/// supports, for use by the canonicalization pass.
|
|
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|
MLIRContext *context) {}
|
|
|
|
protected:
|
|
/// If the concrete type didn't implement a custom verifier hook, just fall
|
|
/// back to this one which accepts everything.
|
|
LogicalResult verify() { return success(); }
|
|
|
|
/// Unless overridden, the custom assembly form of an op is always rejected.
|
|
/// Op implementations should implement this to return failure.
|
|
/// On success, they should fill in result with the fields to use.
|
|
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
|
|
|
// The fallback for the printer is to print it the generic assembly form.
|
|
void print(OpAsmPrinter &p);
|
|
|
|
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
|
|
/// so we can cast it away here.
|
|
explicit OpState(Operation *state) : state(state) {}
|
|
|
|
private:
|
|
Operation *state;
|
|
};
|
|
|
|
// Allow comparing operators.
|
|
inline bool operator==(OpState lhs, OpState rhs) {
|
|
return lhs.getOperation() == rhs.getOperation();
|
|
}
|
|
inline bool operator!=(OpState lhs, OpState rhs) {
|
|
return lhs.getOperation() != rhs.getOperation();
|
|
}
|
|
|
|
/// This class represents a single result from folding an operation.
|
|
class OpFoldResult : public PointerUnion<Attribute, Value> {
|
|
using PointerUnion<Attribute, Value>::PointerUnion;
|
|
};
|
|
|
|
/// This template defines the foldHook as used by AbstractOperation.
|
|
///
|
|
/// The default implementation uses a general fold method that can be defined on
|
|
/// custom ops which can return multiple results.
|
|
template <typename ConcreteType, bool isSingleResult, typename = void>
|
|
class FoldingHook {
|
|
public:
|
|
/// This is an implementation detail of the constant folder hook for
|
|
/// AbstractOperation.
|
|
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
return cast<ConcreteType>(op).fold(operands, results);
|
|
}
|
|
|
|
/// This hook implements a generalized folder for this operation. Operations
|
|
/// can implement this to provide simplifications rules that are applied by
|
|
/// the Builder::createOrFold API and the canonicalization pass.
|
|
///
|
|
/// This is an intentionally limited interface - implementations of this hook
|
|
/// can only perform the following changes to the operation:
|
|
///
|
|
/// 1. They can leave the operation alone and without changing the IR, and
|
|
/// return failure.
|
|
/// 2. They can mutate the operation in place, without changing anything else
|
|
/// in the IR. In this case, return success.
|
|
/// 3. They can return a list of existing values that can be used instead of
|
|
/// the operation. In this case, fill in the results list and return
|
|
/// success. The caller will remove the operation and use those results
|
|
/// instead.
|
|
///
|
|
/// This allows expression of some simple in-place canonicalizations (e.g.
|
|
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
|
|
/// generalized constant folding.
|
|
///
|
|
/// If not overridden, this fallback implementation always fails to fold.
|
|
///
|
|
LogicalResult fold(ArrayRef<Attribute> operands,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// This template specialization defines the foldHook as used by
|
|
/// AbstractOperation for single-result operations. This gives the hook a nicer
|
|
/// signature that is easier to implement.
|
|
template <typename ConcreteType, bool isSingleResult>
|
|
class FoldingHook<ConcreteType, isSingleResult,
|
|
typename std::enable_if<isSingleResult>::type> {
|
|
public:
|
|
/// If the operation returns a single value, then the Op can be implicitly
|
|
/// converted to an Value. This yields the value of the only result.
|
|
operator Value() {
|
|
return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
|
|
}
|
|
|
|
/// This is an implementation detail of the constant folder hook for
|
|
/// AbstractOperation.
|
|
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
auto result = cast<ConcreteType>(op).fold(operands);
|
|
if (!result)
|
|
return failure();
|
|
|
|
// Check if the operation was folded in place. In this case, the operation
|
|
// returns itself.
|
|
if (result.template dyn_cast<Value>() != op->getResult(0))
|
|
results.push_back(result);
|
|
return success();
|
|
}
|
|
|
|
/// This hook implements a generalized folder for this operation. Operations
|
|
/// can implement this to provide simplifications rules that are applied by
|
|
/// the Builder::createOrFold API and the canonicalization pass.
|
|
///
|
|
/// This is an intentionally limited interface - implementations of this hook
|
|
/// can only perform the following changes to the operation:
|
|
///
|
|
/// 1. They can leave the operation alone and without changing the IR, and
|
|
/// return nullptr.
|
|
/// 2. They can mutate the operation in place, without changing anything else
|
|
/// in the IR. In this case, return the operation itself.
|
|
/// 3. They can return an existing SSA value that can be used instead of
|
|
/// the operation. In this case, return that value. The caller will
|
|
/// remove the operation and use that result instead.
|
|
///
|
|
/// This allows expression of some simple in-place canonicalizations (e.g.
|
|
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
|
|
/// generalized constant folding.
|
|
///
|
|
/// If not overridden, this fallback implementation always fails to fold.
|
|
///
|
|
OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation Trait Types
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace OpTrait {
|
|
|
|
// These functions are out-of-line implementations of the methods in the
|
|
// corresponding trait classes. This avoids them being template
|
|
// instantiated/duplicated.
|
|
namespace impl {
|
|
LogicalResult verifyZeroOperands(Operation *op);
|
|
LogicalResult verifyOneOperand(Operation *op);
|
|
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
|
|
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
|
|
LogicalResult verifyOperandsAreFloatLike(Operation *op);
|
|
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
|
|
LogicalResult verifySameTypeOperands(Operation *op);
|
|
LogicalResult verifyZeroResult(Operation *op);
|
|
LogicalResult verifyOneResult(Operation *op);
|
|
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
|
|
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
|
|
LogicalResult verifySameOperandsShape(Operation *op);
|
|
LogicalResult verifySameOperandsAndResultShape(Operation *op);
|
|
LogicalResult verifySameOperandsElementType(Operation *op);
|
|
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
|
|
LogicalResult verifySameOperandsAndResultType(Operation *op);
|
|
LogicalResult verifyResultsAreBoolLike(Operation *op);
|
|
LogicalResult verifyResultsAreFloatLike(Operation *op);
|
|
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
|
|
LogicalResult verifyIsTerminator(Operation *op);
|
|
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
|
|
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
|
|
} // namespace impl
|
|
|
|
/// Helper class for implementing traits. Clients are not expected to interact
|
|
/// with this directly, so its members are all protected.
|
|
template <typename ConcreteType, template <typename> class TraitType>
|
|
class TraitBase {
|
|
protected:
|
|
/// Return the ultimate Operation being worked on.
|
|
Operation *getOperation() {
|
|
// We have to cast up to the trait type, then to the concrete type, then to
|
|
// the BaseState class in explicit hops because the concrete type will
|
|
// multiply derive from the (content free) TraitBase class, and we need to
|
|
// be able to disambiguate the path for the C++ compiler.
|
|
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
|
|
auto *concrete = static_cast<ConcreteType *>(trait);
|
|
auto *base = static_cast<OpState *>(concrete);
|
|
return base->getOperation();
|
|
}
|
|
|
|
/// Provide default implementations of trait hooks. This allows traits to
|
|
/// provide exactly the overrides they care about.
|
|
static LogicalResult verifyTrait(Operation *op) { return success(); }
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
namespace detail {
|
|
/// Utility trait base that provides accessors for derived traits that have
|
|
/// multiple operands.
|
|
template <typename ConcreteType, template <typename> class TraitType>
|
|
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
|
|
using operand_iterator = Operation::operand_iterator;
|
|
using operand_range = Operation::operand_range;
|
|
using operand_type_iterator = Operation::operand_type_iterator;
|
|
using operand_type_range = Operation::operand_type_range;
|
|
|
|
/// Return the number of operands.
|
|
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
|
|
|
|
/// Return the operand at index 'i'.
|
|
Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
|
|
|
|
/// Set the operand at index 'i' to 'value'.
|
|
void setOperand(unsigned i, Value value) {
|
|
this->getOperation()->setOperand(i, value);
|
|
}
|
|
|
|
/// Operand iterator access.
|
|
operand_iterator operand_begin() {
|
|
return this->getOperation()->operand_begin();
|
|
}
|
|
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
|
|
operand_range getOperands() { return this->getOperation()->getOperands(); }
|
|
|
|
/// Operand type access.
|
|
operand_type_iterator operand_type_begin() {
|
|
return this->getOperation()->operand_type_begin();
|
|
}
|
|
operand_type_iterator operand_type_end() {
|
|
return this->getOperation()->operand_type_end();
|
|
}
|
|
operand_type_range getOperandTypes() {
|
|
return this->getOperation()->getOperandTypes();
|
|
}
|
|
};
|
|
} // end namespace detail
|
|
|
|
/// This class provides the API for ops that are known to have no
|
|
/// SSA operand.
|
|
template <typename ConcreteType>
|
|
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyZeroOperands(op);
|
|
}
|
|
|
|
private:
|
|
// Disable these.
|
|
void getOperand() {}
|
|
void setOperand() {}
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to have exactly one
|
|
/// SSA operand.
|
|
template <typename ConcreteType>
|
|
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
|
|
public:
|
|
Value getOperand() { return this->getOperation()->getOperand(0); }
|
|
|
|
void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
|
|
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyOneOperand(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to have a specified
|
|
/// number of operands. This is used as a trait like this:
|
|
///
|
|
/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
|
|
///
|
|
template <unsigned N> class NOperands {
|
|
public:
|
|
static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
|
|
|
|
template <typename ConcreteType>
|
|
class Impl
|
|
: public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyNOperands(op, N);
|
|
}
|
|
};
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to have a at least a
|
|
/// specified number of operands. This is used as a trait like this:
|
|
///
|
|
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
|
|
///
|
|
template <unsigned N> class AtLeastNOperands {
|
|
public:
|
|
template <typename ConcreteType>
|
|
class Impl : public detail::MultiOperandTraitBase<ConcreteType,
|
|
AtLeastNOperands<N>::Impl> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyAtLeastNOperands(op, N);
|
|
}
|
|
};
|
|
};
|
|
|
|
/// This class provides the API for ops which have an unknown number of
|
|
/// SSA operands.
|
|
template <typename ConcreteType>
|
|
class VariadicOperands
|
|
: public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
|
|
|
|
/// This class provides return value APIs for ops that are known to have
|
|
/// zero results.
|
|
template <typename ConcreteType>
|
|
class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyZeroResult(op);
|
|
}
|
|
};
|
|
|
|
namespace detail {
|
|
/// Utility trait base that provides accessors for derived traits that have
|
|
/// multiple results.
|
|
template <typename ConcreteType, template <typename> class TraitType>
|
|
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
|
|
using result_iterator = Operation::result_iterator;
|
|
using result_range = Operation::result_range;
|
|
using result_type_iterator = Operation::result_type_iterator;
|
|
using result_type_range = Operation::result_type_range;
|
|
|
|
/// Return the number of results.
|
|
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
|
|
|
|
/// Return the result at index 'i'.
|
|
Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
|
|
|
|
/// Replace all uses of results of this operation with the provided 'values'.
|
|
/// 'values' may correspond to an existing operation, or a range of 'Value'.
|
|
template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
|
|
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
|
|
}
|
|
|
|
/// Return the type of the `i`-th result.
|
|
Type getType(unsigned i) { return getResult(i).getType(); }
|
|
|
|
/// Result iterator access.
|
|
result_iterator result_begin() {
|
|
return this->getOperation()->result_begin();
|
|
}
|
|
result_iterator result_end() { return this->getOperation()->result_end(); }
|
|
result_range getResults() { return this->getOperation()->getResults(); }
|
|
|
|
/// Result type access.
|
|
result_type_iterator result_type_begin() {
|
|
return this->getOperation()->result_type_begin();
|
|
}
|
|
result_type_iterator result_type_end() {
|
|
return this->getOperation()->result_type_end();
|
|
}
|
|
result_type_range getResultTypes() {
|
|
return this->getOperation()->getResultTypes();
|
|
}
|
|
};
|
|
} // end namespace detail
|
|
|
|
/// This class provides return value APIs for ops that are known to have a
|
|
/// single result.
|
|
template <typename ConcreteType>
|
|
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
|
public:
|
|
Value getResult() { return this->getOperation()->getResult(0); }
|
|
Type getType() { return getResult().getType(); }
|
|
|
|
/// Replace all uses of 'this' value with the new value, updating anything in
|
|
/// the IR that uses 'this' to use the other value instead. When this returns
|
|
/// there are zero uses of 'this'.
|
|
void replaceAllUsesWith(Value newValue) {
|
|
getResult().replaceAllUsesWith(newValue);
|
|
}
|
|
|
|
/// Replace all uses of 'this' value with the result of 'op'.
|
|
void replaceAllUsesWith(Operation *op) {
|
|
this->getOperation()->replaceAllUsesWith(op);
|
|
}
|
|
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyOneResult(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to have a specified
|
|
/// number of results. This is used as a trait like this:
|
|
///
|
|
/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
|
|
///
|
|
template <unsigned N> class NResults {
|
|
public:
|
|
static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
|
|
|
|
template <typename ConcreteType>
|
|
class Impl
|
|
: public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyNResults(op, N);
|
|
}
|
|
};
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to have at least a
|
|
/// specified number of results. This is used as a trait like this:
|
|
///
|
|
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
|
|
///
|
|
template <unsigned N> class AtLeastNResults {
|
|
public:
|
|
template <typename ConcreteType>
|
|
class Impl : public detail::MultiResultTraitBase<ConcreteType,
|
|
AtLeastNResults<N>::Impl> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyAtLeastNResults(op, N);
|
|
}
|
|
};
|
|
};
|
|
|
|
/// This class provides the API for ops which have an unknown number of
|
|
/// results.
|
|
template <typename ConcreteType>
|
|
class VariadicResults
|
|
: public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
|
|
|
|
/// This class provides verification for ops that are known to have the same
|
|
/// operand shape: all operands are scalars, vectors/tensors of the same
|
|
/// shape.
|
|
template <typename ConcreteType>
|
|
class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifySameOperandsShape(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides verification for ops that are known to have the same
|
|
/// operand and result shape: both are scalars, vectors/tensors of the same
|
|
/// shape.
|
|
template <typename ConcreteType>
|
|
class SameOperandsAndResultShape
|
|
: public TraitBase<ConcreteType, SameOperandsAndResultShape> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifySameOperandsAndResultShape(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides verification for ops that are known to have the same
|
|
/// operand element type (or the type itself if it is scalar).
|
|
///
|
|
template <typename ConcreteType>
|
|
class SameOperandsElementType
|
|
: public TraitBase<ConcreteType, SameOperandsElementType> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifySameOperandsElementType(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides verification for ops that are known to have the same
|
|
/// operand and result element type (or the type itself if it is scalar).
|
|
///
|
|
template <typename ConcreteType>
|
|
class SameOperandsAndResultElementType
|
|
: public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifySameOperandsAndResultElementType(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides verification for ops that are known to have the same
|
|
/// operand and result type.
|
|
///
|
|
/// Note: this trait subsumes the SameOperandsAndResultShape and
|
|
/// SameOperandsAndResultElementType traits.
|
|
template <typename ConcreteType>
|
|
class SameOperandsAndResultType
|
|
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifySameOperandsAndResultType(op);
|
|
}
|
|
};
|
|
|
|
/// This class verifies that any results of the specified op have a boolean
|
|
/// type, a vector thereof, or a tensor thereof.
|
|
template <typename ConcreteType>
|
|
class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyResultsAreBoolLike(op);
|
|
}
|
|
};
|
|
|
|
/// This class verifies that any results of the specified op have a floating
|
|
/// point type, a vector thereof, or a tensor thereof.
|
|
template <typename ConcreteType>
|
|
class ResultsAreFloatLike
|
|
: public TraitBase<ConcreteType, ResultsAreFloatLike> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyResultsAreFloatLike(op);
|
|
}
|
|
};
|
|
|
|
/// This class verifies that any results of the specified op have a signless
|
|
/// integer or index type, a vector thereof, or a tensor thereof.
|
|
template <typename ConcreteType>
|
|
class ResultsAreSignlessIntegerLike
|
|
: public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyResultsAreSignlessIntegerLike(op);
|
|
}
|
|
};
|
|
|
|
/// This class adds property that the operation is commutative.
|
|
template <typename ConcreteType>
|
|
class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
|
|
public:
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return static_cast<AbstractOperation::OperationProperties>(
|
|
OperationProperty::Commutative);
|
|
}
|
|
};
|
|
|
|
/// This class adds property that the operation has no side effects.
|
|
template <typename ConcreteType>
|
|
class HasNoSideEffect : public TraitBase<ConcreteType, HasNoSideEffect> {
|
|
public:
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return static_cast<AbstractOperation::OperationProperties>(
|
|
OperationProperty::NoSideEffect);
|
|
}
|
|
};
|
|
|
|
/// This class verifies that all operands of the specified op have a float type,
|
|
/// a vector thereof, or a tensor thereof.
|
|
template <typename ConcreteType>
|
|
class OperandsAreFloatLike
|
|
: public TraitBase<ConcreteType, OperandsAreFloatLike> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyOperandsAreFloatLike(op);
|
|
}
|
|
};
|
|
|
|
/// This class verifies that all operands of the specified op have a signless
|
|
/// integer or index type, a vector thereof, or a tensor thereof.
|
|
template <typename ConcreteType>
|
|
class OperandsAreSignlessIntegerLike
|
|
: public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyOperandsAreSignlessIntegerLike(op);
|
|
}
|
|
};
|
|
|
|
/// This class verifies that all operands of the specified op have the same
|
|
/// type.
|
|
template <typename ConcreteType>
|
|
class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifySameTypeOperands(op);
|
|
}
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to be terminators.
|
|
template <typename ConcreteType>
|
|
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
|
|
public:
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return static_cast<AbstractOperation::OperationProperties>(
|
|
OperationProperty::Terminator);
|
|
}
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return impl::verifyIsTerminator(op);
|
|
}
|
|
|
|
unsigned getNumSuccessors() {
|
|
return this->getOperation()->getNumSuccessors();
|
|
}
|
|
unsigned getNumSuccessorOperands(unsigned index) {
|
|
return this->getOperation()->getNumSuccessorOperands(index);
|
|
}
|
|
|
|
Block *getSuccessor(unsigned index) {
|
|
return this->getOperation()->getSuccessor(index);
|
|
}
|
|
|
|
void setSuccessor(Block *block, unsigned index) {
|
|
return this->getOperation()->setSuccessor(block, index);
|
|
}
|
|
|
|
void addSuccessorOperand(unsigned index, Value value) {
|
|
return this->getOperation()->addSuccessorOperand(index, value);
|
|
}
|
|
void addSuccessorOperands(unsigned index, ArrayRef<Value> values) {
|
|
return this->getOperation()->addSuccessorOperand(index, values);
|
|
}
|
|
};
|
|
|
|
/// This class provides the API for ops that are known to be isolated from
|
|
/// above.
|
|
template <typename ConcreteType>
|
|
class IsIsolatedFromAbove
|
|
: public TraitBase<ConcreteType, IsIsolatedFromAbove> {
|
|
public:
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return static_cast<AbstractOperation::OperationProperties>(
|
|
OperationProperty::IsolatedFromAbove);
|
|
}
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
for (auto ®ion : op->getRegions())
|
|
if (!region.isIsolatedFromAbove(op->getLoc()))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This class provides APIs and verifiers for ops with regions having a single
|
|
/// block that must terminate with `TerminatorOpType`.
|
|
template <typename TerminatorOpType> struct SingleBlockImplicitTerminator {
|
|
template <typename ConcreteType>
|
|
class Impl : public TraitBase<ConcreteType, Impl> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
|
|
Region ®ion = op->getRegion(i);
|
|
|
|
// Empty regions are fine.
|
|
if (region.empty())
|
|
continue;
|
|
|
|
// Non-empty regions must contain a single basic block.
|
|
if (std::next(region.begin()) != region.end())
|
|
return op->emitOpError("expects region #")
|
|
<< i << " to have 0 or 1 blocks";
|
|
|
|
Block &block = region.front();
|
|
if (block.empty())
|
|
return op->emitOpError() << "expects a non-empty block";
|
|
Operation &terminator = block.back();
|
|
if (isa<TerminatorOpType>(terminator))
|
|
continue;
|
|
|
|
return op->emitOpError("expects regions to end with '" +
|
|
TerminatorOpType::getOperationName() +
|
|
"', found '" +
|
|
terminator.getName().getStringRef() + "'")
|
|
.attachNote()
|
|
<< "in custom textual format, the absence of terminator implies "
|
|
"'"
|
|
<< TerminatorOpType::getOperationName() << '\'';
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Ensure that the given region has the terminator required by this trait.
|
|
static void ensureTerminator(Region ®ion, Builder &builder,
|
|
Location loc) {
|
|
::mlir::impl::template ensureRegionTerminator<TerminatorOpType>(
|
|
region, builder, loc);
|
|
}
|
|
};
|
|
};
|
|
|
|
/// This class provides a verifier for ops that are expecting a specific parent.
|
|
template <typename ParentOpType> struct HasParent {
|
|
template <typename ConcreteType>
|
|
class Impl : public TraitBase<ConcreteType, Impl> {
|
|
public:
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
if (isa<ParentOpType>(op->getParentOp()))
|
|
return success();
|
|
return op->emitOpError() << "expects parent op '"
|
|
<< ParentOpType::getOperationName() << "'";
|
|
}
|
|
};
|
|
};
|
|
|
|
/// A trait for operations that have an attribute specifying operand segments.
|
|
///
|
|
/// Certain operations can have multiple variadic operands and their size
|
|
/// relationship is not always known statically. For such cases, we need
|
|
/// a per-op-instance specification to divide the operands into logical groups
|
|
/// or segments. This can be modeled by attributes. The attribute will be named
|
|
/// as `operand_segment_sizes`.
|
|
///
|
|
/// This trait verifies the attribute for specifying operand segments has
|
|
/// the correct type (1D vector) and values (non-negative), etc.
|
|
template <typename ConcreteType>
|
|
class AttrSizedOperandSegments
|
|
: public TraitBase<ConcreteType, AttrSizedOperandSegments> {
|
|
public:
|
|
static StringRef getOperandSegmentSizeAttr() {
|
|
return "operand_segment_sizes";
|
|
}
|
|
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
|
|
op, getOperandSegmentSizeAttr());
|
|
}
|
|
};
|
|
|
|
/// Similar to AttrSizedOperandSegments but used for results.
|
|
template <typename ConcreteType>
|
|
class AttrSizedResultSegments
|
|
: public TraitBase<ConcreteType, AttrSizedResultSegments> {
|
|
public:
|
|
static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
|
|
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return ::mlir::OpTrait::impl::verifyResultSizeAttr(
|
|
op, getResultSegmentSizeAttr());
|
|
}
|
|
};
|
|
|
|
} // end namespace OpTrait
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operation Definition classes
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// This provides public APIs that all operations should have. The template
|
|
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
|
|
/// are base classes by the policy pattern.
|
|
template <typename ConcreteType, template <typename T> class... Traits>
|
|
class Op : public OpState,
|
|
public Traits<ConcreteType>...,
|
|
public FoldingHook<ConcreteType,
|
|
llvm::is_one_of<OpTrait::OneResult<ConcreteType>,
|
|
Traits<ConcreteType>...>::value> {
|
|
public:
|
|
/// Return if this operation contains the provided trait.
|
|
template <template <typename T> class Trait>
|
|
static constexpr bool hasTrait() {
|
|
return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
|
|
}
|
|
|
|
/// Return the operation that this refers to.
|
|
Operation *getOperation() { return OpState::getOperation(); }
|
|
|
|
/// Create a deep copy of this operation.
|
|
ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
|
|
|
|
/// Create a partial copy of this operation without traversing into attached
|
|
/// regions. The new operation will have the same number of regions as the
|
|
/// original one, but they will be left empty.
|
|
ConcreteType cloneWithoutRegions() {
|
|
return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
|
|
}
|
|
|
|
/// Return the dialect that this refers to.
|
|
Dialect *getDialect() { return getOperation()->getDialect(); }
|
|
|
|
/// Return the parent Region of this operation.
|
|
Region *getParentRegion() { return getOperation()->getParentRegion(); }
|
|
|
|
/// Return true if this "op class" can match against the specified operation.
|
|
static bool classof(Operation *op) {
|
|
if (auto *abstractOp = op->getAbstractOperation())
|
|
return ClassID::getID<ConcreteType>() == abstractOp->classID;
|
|
return op->getName().getStringRef() == ConcreteType::getOperationName();
|
|
}
|
|
|
|
/// This is the hook used by the AsmParser to parse the custom form of this
|
|
/// op from an .mlir file. Op implementations should provide a parse method,
|
|
/// which returns failure. On success, they should return fill in result with
|
|
/// the fields to use.
|
|
static ParseResult parseAssembly(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
return ConcreteType::parse(parser, result);
|
|
}
|
|
|
|
/// This is the hook used by the AsmPrinter to emit this to the .mlir file.
|
|
/// Op implementations should provide a print method.
|
|
static void printAssembly(Operation *op, OpAsmPrinter &p) {
|
|
auto opPointer = dyn_cast<ConcreteType>(op);
|
|
assert(opPointer &&
|
|
"op's name does not match name of concrete type instantiated with");
|
|
opPointer.print(p);
|
|
}
|
|
|
|
/// This is the hook that checks whether or not this operation is well
|
|
/// formed according to the invariants of its opcode. It delegates to the
|
|
/// Traits for their policy implementations, and allows the user to specify
|
|
/// their own verify() method.
|
|
///
|
|
/// On success this returns false; on failure it emits an error to the
|
|
/// diagnostic subsystem and returns true.
|
|
static LogicalResult verifyInvariants(Operation *op) {
|
|
return failure(
|
|
failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) ||
|
|
failed(cast<ConcreteType>(op).verify()));
|
|
}
|
|
|
|
// Returns the properties of an operation by combining the properties of the
|
|
// traits of the op.
|
|
static AbstractOperation::OperationProperties getOperationProperties() {
|
|
return BaseProperties<Traits<ConcreteType>...>::getTraitProperties();
|
|
}
|
|
|
|
/// Expose the type we are instantiated on to template machinery that may want
|
|
/// to introspect traits on this operation.
|
|
using ConcreteOpType = ConcreteType;
|
|
|
|
/// This is a public constructor. Any op can be initialized to null.
|
|
explicit Op() : OpState(nullptr) {}
|
|
Op(std::nullptr_t) : OpState(nullptr) {}
|
|
|
|
/// This is a public constructor to enable access via the llvm::cast family of
|
|
/// methods. This should not be used directly.
|
|
explicit Op(Operation *state) : OpState(state) {}
|
|
|
|
/// Methods for supporting PointerLikeTypeTraits.
|
|
const void *getAsOpaquePointer() const {
|
|
return static_cast<const void *>((Operation *)*this);
|
|
}
|
|
static ConcreteOpType getFromOpaquePointer(const void *pointer) {
|
|
return ConcreteOpType(
|
|
reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
|
|
}
|
|
|
|
private:
|
|
template <typename... Types> struct BaseVerifier;
|
|
|
|
template <typename First, typename... Rest>
|
|
struct BaseVerifier<First, Rest...> {
|
|
static LogicalResult verifyTrait(Operation *op) {
|
|
return failure(failed(First::verifyTrait(op)) ||
|
|
failed(BaseVerifier<Rest...>::verifyTrait(op)));
|
|
}
|
|
};
|
|
|
|
template <typename...> struct BaseVerifier {
|
|
static LogicalResult verifyTrait(Operation *op) { return success(); }
|
|
};
|
|
|
|
template <typename... Types> struct BaseProperties;
|
|
|
|
template <typename First, typename... Rest>
|
|
struct BaseProperties<First, Rest...> {
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return First::getTraitProperties() |
|
|
BaseProperties<Rest...>::getTraitProperties();
|
|
}
|
|
};
|
|
|
|
template <typename...> struct BaseProperties {
|
|
static AbstractOperation::OperationProperties getTraitProperties() {
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
/// Returns true if this operation contains the trait for the given classID.
|
|
static bool hasTrait(ClassID *traitID) {
|
|
return llvm::is_contained(llvm::makeArrayRef({ClassID::getID<Traits>()...}),
|
|
traitID);
|
|
}
|
|
|
|
/// Returns an opaque pointer to a concept instance of the interface with the
|
|
/// given ID if one was registered to this operation.
|
|
static void *getRawInterface(ClassID *id) {
|
|
return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
|
|
}
|
|
|
|
struct InterfaceLookup {
|
|
/// Trait to check if T provides a static 'getInterfaceID' method.
|
|
template <typename T, typename... Args>
|
|
using has_get_interface_id = decltype(T::getInterfaceID());
|
|
|
|
/// If 'T' is the same interface as 'interfaceID' return the concept
|
|
/// instance.
|
|
template <typename T>
|
|
static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
|
|
void *>::type
|
|
lookup(ClassID *interfaceID) {
|
|
return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
|
|
}
|
|
|
|
/// 'T' is known to not be an interface, return nullptr.
|
|
template <typename T>
|
|
static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
|
|
void *>::type
|
|
lookup(ClassID *) {
|
|
return nullptr;
|
|
}
|
|
|
|
template <typename T, typename T2, typename... Ts>
|
|
static void *lookup(ClassID *interfaceID) {
|
|
auto *concept = lookup<T>(interfaceID);
|
|
return concept ? concept : lookup<T2, Ts...>(interfaceID);
|
|
}
|
|
};
|
|
|
|
/// Allow access to 'hasTrait' and 'getRawInterface'.
|
|
friend AbstractOperation;
|
|
};
|
|
|
|
/// This class represents the base of an operation interface. Operation
|
|
/// interfaces provide access to derived *Op properties through an opaquely
|
|
/// Operation instance. Derived interfaces must also provide a 'Traits' class
|
|
/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an
|
|
/// abstract virtual interface, where as the 'Model' class implements this
|
|
/// interface for a specific derived *Op type. Both of these classes *must* not
|
|
/// contain non-static data. A simple example is shown below:
|
|
///
|
|
/// struct ExampleOpInterfaceTraits {
|
|
/// struct Concept {
|
|
/// virtual unsigned getNumInputs(Operation *op) = 0;
|
|
/// };
|
|
/// template <typename OpT> class Model {
|
|
/// unsigned getNumInputs(Operation *op) final {
|
|
/// return cast<OpT>(op).getNumInputs();
|
|
/// }
|
|
/// };
|
|
/// };
|
|
///
|
|
template <typename ConcreteType, typename Traits>
|
|
class OpInterface : public Op<ConcreteType> {
|
|
public:
|
|
using Concept = typename Traits::Concept;
|
|
template <typename T> using Model = typename Traits::template Model<T>;
|
|
|
|
OpInterface(Operation *op = nullptr)
|
|
: Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
|
|
assert((!op || impl) &&
|
|
"instantiating an interface with an unregistered operation");
|
|
}
|
|
|
|
/// Support 'classof' by checking if the given operation defines the concrete
|
|
/// interface.
|
|
static bool classof(Operation *op) { return getInterfaceFor(op); }
|
|
|
|
/// Define an accessor for the ID of this interface.
|
|
static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
|
|
|
|
/// This is a special trait that registers a given interface with an
|
|
/// operation.
|
|
template <typename ConcreteOp>
|
|
struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
|
|
/// Define an accessor for the ID of this interface.
|
|
static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
|
|
|
|
/// Provide an accessor to a static instance of the interface model for the
|
|
/// concrete operation type.
|
|
/// The implementation is inspired from Sean Parent's concept-based
|
|
/// polymorphism. A key difference is that the set of classes erased is
|
|
/// statically known, which alleviates the need for using dynamic memory
|
|
/// allocation.
|
|
/// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
|
|
/// virtual table and generate a singleton object for each instantiation of
|
|
/// this class.
|
|
static Concept &instance() {
|
|
static Model<ConcreteOp> singleton;
|
|
return singleton;
|
|
}
|
|
};
|
|
|
|
protected:
|
|
/// Get the raw concept in the correct derived concept type.
|
|
Concept *getImpl() { return impl; }
|
|
|
|
private:
|
|
/// Returns the impl interface instance for the given operation.
|
|
static Concept *getInterfaceFor(Operation *op) {
|
|
// Access the raw interface from the abstract operation.
|
|
auto *abstractOp = op->getAbstractOperation();
|
|
return abstractOp ? abstractOp->getInterface<ConcreteType>() : nullptr;
|
|
}
|
|
|
|
/// A pointer to the impl concept object.
|
|
Concept *impl;
|
|
};
|
|
|
|
// These functions are out-of-line implementations of the methods in UnaryOp and
|
|
// BinaryOp, which avoids them being template instantiated/duplicated.
|
|
namespace impl {
|
|
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
|
|
OperationState &result);
|
|
|
|
void buildBinaryOp(Builder *builder, OperationState &result, Value lhs,
|
|
Value rhs);
|
|
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
|
OperationState &result);
|
|
|
|
// Prints the given binary `op` in custom assembly form if both the two operands
|
|
// and the result have the same time. Otherwise, prints the generic assembly
|
|
// form.
|
|
void printOneResultOp(Operation *op, OpAsmPrinter &p);
|
|
} // namespace impl
|
|
|
|
// These functions are out-of-line implementations of the methods in CastOp,
|
|
// which avoids them being template instantiated/duplicated.
|
|
namespace impl {
|
|
void buildCastOp(Builder *builder, OperationState &result, Value source,
|
|
Type destType);
|
|
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
|
|
void printCastOp(Operation *op, OpAsmPrinter &p);
|
|
Value foldCastOp(Operation *op);
|
|
} // namespace impl
|
|
} // end namespace mlir
|
|
|
|
#endif
|