Add default attr value & define tf.AvgPool op and use pattern for rewrite.

Add default values to attributes, to allow attribute being left unspecified.  The attr getter will always return an attribute so callers need not check for it, if the attribute is not set then the default will be returned (at present the default will be constructed upon query but this will be changed).

Add op definition for tf.AvgPool in ops.td, rewrite matcher using pattern using attribute matching & transforms. Adding some helper functions to make it simpler.

Handle attributes with dialect prefix and map them to getter without dialect prefix.

Note: VerifyAvgPoolOp could probably be autogenerated by know given the predicate specification on attributes, but deferring that to a follow up.
PiperOrigin-RevId: 230364857
This commit is contained in:
Jacques Pienaar 2019-01-22 10:26:09 -08:00 committed by jpienaar
parent d2aaa175ca
commit 34c6f8c6e4
4 changed files with 62 additions and 8 deletions

View File

@ -248,8 +248,9 @@ class Attr<Pred condition = CPred<"true">> : AttrConstraint<condition> {
// 'builder.getStringAttr("foo")'.
code constBuilderCall = ?;
// TODO(jpienaar): Add predicate to verify the validity of Attr so
// that verification can be generated.
// Default value for attribute.
// Requires a constBuilderCall defined.
string defaultValue = ?;
}
// A generic attribute that must be constructed around a specific type.

View File

@ -72,6 +72,14 @@ public:
// the constant value.
StringRef getConstBuilderTemplate() const;
// Returns whether this attribute has a default value.
bool hasDefaultValue() const;
// Returns the template that can be used to produce the default value of
// the attribute.
// Syntax: {0} should be replaced with a builder.
std::string getDefaultValueTemplate() const;
StringRef getTableGenDefName() const;
// Returns the code body for derived attribute. Aborts if this is not a

View File

@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
@ -82,6 +83,18 @@ StringRef tblgen::Attribute::getConstBuilderTemplate() const {
return getValueAsString(init);
}
bool tblgen::Attribute::hasDefaultValue() const {
const auto *init = def->getValueInit("defaultValue");
return !getValueAsString(init).empty();
}
std::string tblgen::Attribute::getDefaultValueTemplate() const {
assert(isConstBuildable() && "requiers constBuilderCall");
const auto *init = def->getValueInit("defaultValue");
return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}",
getValueAsString(init));
}
StringRef tblgen::Attribute::getTableGenDefName() const {
return def->getName();
}

View File

@ -179,20 +179,41 @@ void OpEmitter::emitAttrGetters() {
auto name = namedAttr.getName();
const auto &attr = namedAttr.attr;
// Determine the name of the attribute getter. The name matches the
// attribute name excluding dialect prefix.
StringRef getter = name;
auto it = getter.rfind('$');
if (it != StringRef::npos)
getter = getter.substr(it + 1);
// Emit the derived attribute body.
if (attr.isDerivedAttr()) {
OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
OUT(2) << attr.getReturnType() << ' ' << getter << "() const {"
<< attr.getDerivedCodeBody() << " }\n";
continue;
}
// Emit normal emitter.
OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
OUT(2) << attr.getReturnType() << ' ' << getter << "() const {\n";
// Return the queried attribute with the correct return type.
std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()",
attr.getStorageType(), name);
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal)
OUT(4) << "auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValue()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
OUT(4) << "if (!attr)\n";
OUT(6) << "return "
<< formatv(
attr.getConvertFromStorageCall(),
formatv(
attr.getDefaultValueTemplate(),
"mlir::Builder(this->getInstruction()->getContext())"))
<< ";\n";
}
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), "attr")
<< ";\n }\n";
}
}
@ -359,25 +380,36 @@ void OpEmitter::emitVerifier() {
continue;
auto name = namedAttr.getName();
if (!attr.hasStorageType()) {
if (!attr.hasStorageType() && !attr.hasDefaultValue()) {
// TODO: Some verification can be done even without storage type.
OUT(4) << "if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";
continue;
}
OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
if (attr.hasDefaultValue()) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
// TODO: verify the debug value is valid (perhaps in debug mode only).
OUT(4) << "if (this->getAttr(\"" << name << "\")) {\n";
}
OUT(6) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
<< attr.getReturnType() << " attribute '" << name << "'\");\n";
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
OUT(4) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' "
OUT(6) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' "
"failed to satisfy constraint of {2}\");\n",
formatv(attrPred.getCondition(),
formatv("this->getAttr(\"{0}\")", name)),
name, attr.getTableGenDefName());
}
if (attr.hasDefaultValue())
OUT(4) << "}\n";
}
// TODO: Handle variadic.