Enhance the customizable "Op" implementations in a bunch of ways:

- Op classes can now provide customized matchers, allowing specializations
   beyond just a name match.
 - We now provide default implementations of verify/print hooks, so Op classes
   only need to implement them if they're doing custom stuff, and only have to
   implement the ones they're interested in.
 - "Base" now takes a variadic list of template template arguments, allowing
   concrete Op types to avoid passing the Concrete type multiple times.
 - Add new ZeroOperands trait.
 - Add verification hooks to Zero/One/Two operands and OneResult to check that
   ops using them are correctly formed.
 - Implement getOperand hooks to zero/one/two operand traits, and
   getResult/getType hook to OneResult trait.
 - Add a new "constant" op to show some of this off, with a specialization for
   the constant case.

This patch also splits op validity checks out to a new test/IR/invalid-ops.mlir
file.

This stubs out support for default asmprinter support.  My next planned patch
building on top of this will make asmprinter hooks real and will revise this.

PiperOrigin-RevId: 205833214
This commit is contained in:
Chris Lattner 2018-07-24 08:34:58 -07:00 committed by jpienaar
parent aaeb8daa50
commit 0ab2e2536a
12 changed files with 292 additions and 71 deletions

View File

@ -148,7 +148,7 @@ public:
/// a null OpPointer on failure.
template <typename OpClass>
OpPointer<OpClass> getAs() {
bool isMatch = getName().is(OpClass::getOperationName());
bool isMatch = OpClass::isClassFor(this);
return OpPointer<OpClass>(OpClass(isMatch ? this : nullptr));
}
@ -157,7 +157,7 @@ public:
/// a null ConstOpPointer on failure.
template <typename OpClass>
ConstOpPointer<OpClass> getAs() const {
bool isMatch = getName().is(OpClass::getOperationName());
bool isMatch = OpClass::isClassFor(this);
return ConstOpPointer<OpClass>(OpClass(isMatch ? this : nullptr));
}

View File

@ -30,6 +30,7 @@
#include "mlir/IR/Operation.h"
namespace mlir {
class Type;
/// This pointer represents a notional "Operation*" but where the actual
/// storage of the pointer is maintained in the templated "OpType" class.
@ -72,20 +73,25 @@ public:
namespace OpImpl {
/// This provides public APIs that all operations should have. The template
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern.
template <typename ConcreteType, typename... Traits>
class Base : public Traits... {
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
///
/// This also has the fallback implementations of customization hooks for when
/// they aren't customized.
class BaseState {
public:
/// Return the operation that this refers to.
const Operation *getOperation() const { return state; }
Operation *getOperation() { return state; }
/// Return an attribute with the specified name.
Attribute *getAttr(StringRef name) const { return state->getAttr(name); }
/// If the operation has an attribute of the specified type, return it.
template <typename AttrClass>
AttrClass *getAttrOfType(StringRef name) const {
return dyn_cast_or_null<AttrClass>(state->getAttr(name));
return dyn_cast_or_null<AttrClass>(getAttr(name));
}
/// If the an attribute exists with the specified name, change it to the new
@ -94,6 +100,43 @@ public:
state->setAttr(name, value, context);
}
protected:
// These are default implementations of customization hooks.
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
const char *verify() const { return nullptr; }
// The fallback for the printer is to print it the longhand form.
void print(raw_ostream &os) const;
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here.
explicit BaseState(const Operation *state)
: state(const_cast<Operation *>(state)) {}
private:
Operation *state;
};
/// This provides public APIs that all operations should have. The template
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern.
template <typename ConcreteType, template <typename T> class... Traits>
class Base : public BaseState, public Traits<ConcreteType>... {
public:
/// Return the operation that this refers to.
const Operation *getOperation() const { return BaseState::getOperation(); }
Operation *getOperation() { return BaseState::getOperation(); }
/// Return true if this "op class" can match against the specified operation.
/// This hook can be overridden with a more specific implementation in
/// the subclass of Base.
///
static bool isClassFor(const Operation *op) {
return op->getName().is(ConcreteType::getOperationName());
}
/// This is the hook used by the AsmPrinter to emit this to the .mlir file.
/// Op implementations should provide a print method.
static void printAssembly(const Operation *op, raw_ostream &os) {
@ -104,16 +147,15 @@ public:
/// delegates to the Traits for their policy implementations, and allows the
/// user to specify their own verify() method.
static const char *verifyInvariants(const Operation *op) {
if (auto error = BaseVerifier<Traits...>::verifyBase(op))
if (auto error = BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op))
return error;
return op->getAs<ConcreteType>()->verify();
}
// TODO: Provide a dump() method.
protected:
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here.
explicit Base(const Operation *state)
: state(const_cast<Operation *>(state)) {}
explicit Base(const Operation *state) : BaseState(state) {}
private:
template <typename... Types>
@ -121,72 +163,129 @@ private:
template <typename First, typename... Rest>
struct BaseVerifier<First, Rest...> {
static const char *verifyBase(const Operation *op) {
if (auto error = First::verifyBase(op))
static const char *verifyTrait(const Operation *op) {
if (auto error = First::verifyTrait(op))
return error;
return BaseVerifier<Rest...>::verifyBase(op);
return BaseVerifier<Rest...>::verifyTrait(op);
}
};
template <typename First>
struct BaseVerifier<First> {
static const char *verifyBase(const Operation *op) {
return First::verifyBase(op);
static const char *verifyTrait(const Operation *op) {
return First::verifyTrait(op);
}
};
template <>
struct BaseVerifier<> {
static const char *verifyBase(const Operation *op) {
return nullptr;
}
static const char *verifyTrait(const Operation *op) { return nullptr; }
};
};
Operation *state;
/// Helper class for implementing traits. Clients are not expected to interact
/// with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class TraitImpl {
protected:
/// Return the ultimate Operation being worked on.
Operation *getOperation() {
// We have to cast up to the trait type, then to the concrete type, then to
// the BaseState class in explicit hops because the concrete type will
// multiply derive from the (content free) TraitImpl class, and we need to
// be able to disambiguate the path for the C++ compiler.
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
auto *concrete = static_cast<ConcreteType *>(trait);
auto *base = static_cast<BaseState *>(concrete);
return base->getOperation();
}
const Operation *getOperation() const {
return const_cast<TraitImpl *>(this)->getOperation();
}
/// Provide default implementations of trait hooks. This allows traits to
/// provide exactly the overrides they care about.
static const char *verifyTrait(const Operation *op) { return nullptr; }
};
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType> class OneOperand {
template <typename ConcreteType>
class ZeroOperands : public TraitImpl<ConcreteType, ZeroOperands> {
public:
SSAValue *getOperand() const {
return static_cast<ConcreteType *>(this)->getOperand(0);
}
void setOperand(SSAValue *value) {
static_cast<ConcreteType *>(this)->setOperand(0, value);
static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 0)
return "requires zero operands";
return nullptr;
}
static const char *verifyBase(const Operation *op) {
// TODO: Check that op has one operand.
private:
// Disable these.
void getOperand() const {}
void setOperand() const {}
};
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType>
class OneOperand : public TraitImpl<ConcreteType, OneOperand> {
public:
const SSAValue *getOperand() const {
return this->getOperation()->getOperand(0);
}
SSAValue *getOperand() { return this->getOperation()->getOperand(0); }
void setOperand(SSAValue *value) {
this->getOperation()->setOperand(0, value);
}
static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 1)
return "requires a single operand";
return nullptr;
}
};
/// This class provides the API for ops that are known to have exactly two
/// SSA operands.
class TwoOperands {
template <typename ConcreteType>
class TwoOperands : public TraitImpl<ConcreteType, TwoOperands> {
public:
void getOperand() const {
/// TODO.
}
void setOperand() {
/// TODO.
const SSAValue *getOperand(unsigned i) const {
return this->getOperation()->getOperand(i);
}
static const char *verifyBase(const Operation *op) {
// TODO: Check that op has two operands.
SSAValue *getOperand(unsigned i) {
return this->getOperation()->getOperand(i);
}
void setOperand(unsigned i, SSAValue *value) {
this->getOperation()->setOperand(i, value);
}
static const char *verifyTrait(const Operation *op) {
if (op->getNumOperands() != 2)
return "requires two operands";
return nullptr;
}
};
/// This class provides return value APIs for ops that are known to have a
/// single result.
class OneResult {
template <typename ConcreteType>
class OneResult : public TraitImpl<ConcreteType, OneResult> {
public:
// TODO: Implement results!
SSAValue *getResult() { return this->getOperation()->getResult(0); }
const SSAValue *getResult() const {
return this->getOperation()->getResult(0);
}
static const char *verifyBase(const Operation *op) {
// TODO: Check that op has one result.
Type *getType() const { return getResult()->getType(); }
static const char *verifyTrait(const Operation *op) {
if (op->getNumResults() != 1)
return "requires one result";
return nullptr;
}
};

View File

@ -38,13 +38,16 @@ class AbstractOperation {
public:
template <typename T>
static AbstractOperation get() {
return AbstractOperation(T::getOperationName(), T::printAssembly,
T::verifyInvariants);
return AbstractOperation(T::getOperationName(), T::isClassFor,
T::printAssembly, T::verifyInvariants);
}
/// This is the name of the operation.
const StringRef name;
/// Return true if this "op class" can match against the specified operation.
bool (&isClassFor)(const Operation *op);
/// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(const Operation *op, raw_ostream &os);
@ -55,10 +58,10 @@ public:
// TODO: Parsing hook.
private:
AbstractOperation(StringRef name,
AbstractOperation(StringRef name, bool (&isClassFor)(const Operation *op),
void (&printAssembly)(const Operation *op, raw_ostream &os),
const char *(&verifyInvariants)(const Operation *op))
: name(name), printAssembly(printAssembly),
: name(name), isClassFor(isClassFor), printAssembly(printAssembly),
verifyInvariants(verifyInvariants) {}
};

View File

@ -50,14 +50,54 @@ private:
explicit AddFOp(const Operation *state) : Base(state) {}
};
/// The "dim" builtin takes a memref or tensor operand and returns an
/// The "constant" operation requires a single attribute named "value".
/// It returns its value as an SSA value. For example:
///
/// %1 = "constant"(){value: 42} : i32
/// %2 = "constant"(){value: @foo} : (f32)->f32
///
class ConstantOp
: public OpImpl::Base<ConstantOp, OpImpl::ZeroOperands, OpImpl::OneResult> {
public:
Attribute *getValue() const { return getAttr("value"); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "constant"; }
// Hooks to customize behavior of this op.
const char *verify() const;
protected:
friend class Operation;
explicit ConstantOp(const Operation *state) : Base(state) {}
};
/// This is a refinement of the "constant" op for the case where it is
/// returning an integer value.
///
/// %1 = "constant"(){value: 42}
///
class ConstantIntOp : public ConstantOp {
public:
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
}
static bool isClassFor(const Operation *op);
private:
friend class Operation;
explicit ConstantIntOp(const Operation *state) : ConstantOp(state) {}
};
/// The "dim" operation takes a memref or tensor operand and returns an
/// "affineint". It requires a single integer attribute named "index". It
/// returns the size of the specified dimension. For example:
///
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
///
class DimOp
: public OpImpl::Base<DimOp, OpImpl::OneOperand<DimOp>, OpImpl::OneResult> {
: public OpImpl::Base<DimOp, OpImpl::OneOperand, OpImpl::OneResult> {
public:
/// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const {

View File

@ -294,6 +294,8 @@ public:
return ArrayRef<int>(shapeElements, getSubclassData());
}
unsigned getRank() const { return getShape().size(); }
/// Returns the elemental type for this memref shape.
Type *getElementType() const { return elementType; }

View File

@ -16,11 +16,17 @@
// =============================================================================
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/OperationImpl.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using llvm::StringMap;
// The fallback for the printer is to print it the longhand form.
void OpImpl::BaseState::print(raw_ostream &os) const {
os << "FIXME: IMPLEMENT DEFAULT PRINTER";
}
static StringMap<AbstractOperation> &getImpl(void *pImpl) {
return *static_cast<StringMap<AbstractOperation> *>(pImpl);
}

View File

@ -17,6 +17,8 @@
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@ -32,24 +34,62 @@ const char *AddFOp::verify() const {
return nullptr;
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {
auto *value = getValue();
if (!value)
return "requires a 'value' attribute";
auto *type = this->getType();
if (isa<IntegerType>(type)) {
if (!isa<IntegerAttr>(value))
return "requires 'value' to be an integer for an integer result type";
return nullptr;
}
if (isa<FunctionType>(type)) {
// TODO: Verify a function attr.
}
return "requires a result type that aligns with the 'value' attribute";
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
isa<IntegerType>(op->getResult(0)->getType());
}
void DimOp::print(raw_ostream &os) const {
os << "dim xxx, " << getIndex() << " : sometype";
}
const char *DimOp::verify() const {
// TODO: Check that the operand has tensor or memref type.
// Check that we have an integer index operand.
auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return "'dim' op requires an integer attribute named 'index'";
return "requires an integer attribute named 'index'";
uint64_t index = (uint64_t)indexAttr->getValue();
// TODO: Check that the index is in range.
auto *type = getOperand()->getType();
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
if (index >= tensorType->getRank())
return "index is out of range";
} else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
if (index >= memrefType->getRank())
return "index is out of range";
} else if (isa<UnrankedTensorType>(type)) {
// ok, assumed to be in-range.
} else {
return "requires an operand with tensor or memref type";
}
return nullptr;
}
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, DimOp>(/*prefix=*/ "");
opSet.addOperations<AddFOp, ConstantOp, DimOp>(/*prefix=*/"");
}

View File

@ -216,7 +216,8 @@ bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
// See if we can get operation info for this.
if (auto *opInfo = inst.getAbstractOperation(fn.getContext())) {
if (auto errorMessage = opInfo->verifyInvariants(&inst))
return failure(errorMessage, inst);
return failure(Twine("'") + inst.getName().str() + "' op " + errorMessage,
inst);
}
return false;

View File

@ -1576,7 +1576,7 @@ FunctionParser::parseOperation(const CreateOperationFunction &createOpFunc) {
// source location.
if (auto *opInfo = op->getAbstractOperation(builder.getContext())) {
if (auto error = opInfo->verifyInvariants(op))
return emitError(loc, error);
return emitError(loc, Twine("'") + op->getName().str() + "' op " + error);
}
// If the instruction had a name, register it.

View File

@ -0,0 +1,35 @@
// TODO(andydavis) Resolve relative path issue w.r.t invoking mlir-opt in RUN
// statements (perhaps through using lit config substitutions).
//
// RUN: %S/../../mlir-opt %s -o - -check-parser-errors
cfgfunc @dim(tensor<1xf32>) {
bb(%0: tensor<1xf32>):
"dim"(%0){index: "xyz"} : (tensor<1xf32>)->i32 // expected-error {{'dim' op requires an integer attribute named 'index'}}
return
}
// -----
cfgfunc @dim2(tensor<1xf32>) {
bb(%0: tensor<1xf32>):
"dim"(){index: "xyz"} : ()->i32 // expected-error {{'dim' op requires a single operand}}
return
}
// -----
cfgfunc @dim3(tensor<1xf32>) {
bb(%0: tensor<1xf32>):
"dim"(%0){index: 1} : (tensor<1xf32>)->i32 // expected-error {{'dim' op index is out of range}}
return
}
// -----
cfgfunc @constant() {
bb:
%x = "constant"(){value: "xyz"} : () -> i32 // expected-error {{'constant' op requires 'value' to be an integer for an integer result type}}
return
}

View File

@ -175,14 +175,6 @@ mlfunc @non_statement() {
// -----
cfgfunc @malformed_dim() {
bb42:
"dim"(){index: "xyz"} : ()->i32 // expected-error {{'dim' op requires an integer attribute named 'index'}}
return
}
// -----
#map = (d0) -> (% // expected-error {{invalid SSA name}}
// -----
@ -197,8 +189,8 @@ bb40:
cfgfunc @redef() {
bb42:
%x = "dim"(){index: 0} : ()->i32 // expected-error {{previously defined here}}
%x = "dim"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value '%x'}}
%x = "xxx"(){index: 0} : ()->i32 // expected-error {{previously defined here}}
%x = "xxx"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value '%x'}}
return
}

View File

@ -112,17 +112,17 @@ mlfunc @mlfunc_with_args(%a : f16) {
return %a // CHECK: return
}
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops() {
cfgfunc @cfgfunc_with_ops() {
bb0:
// CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
cfgfunc @cfgfunc_with_ops(f32) {
bb0(%a : f32):
// CHECK: %1 = "getTensor"() : () -> tensor<4x4x?xf32>
%t = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: dim xxx, 2 : sometype
%a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
%t2 = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
// CHECK: addf xx, yy : sometype
"addf"() : () -> ()
%x = "addf"(%a, %a) : (f32,f32) -> (f32)
// CHECK: return
return
@ -187,6 +187,9 @@ bb42: // CHECK: bb0:
%f = "Const"(){value: 1} : () -> f32
// CHECK: addf xx, yy : sometype
"addf"(%f, %f) : (f32,f32) -> f32
// TODO: CHECK: FIXME: IMPLEMENT DEFAULT PRINTER
%x = "constant"(){value: 42} : () -> i32
return
}