From 34c6f8c6e4df968aa0feb63a685789b3456fcbb8 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 22 Jan 2019 10:26:09 -0800 Subject: [PATCH] Add default attr value & define tf.AvgPool op and use pattern for rewrite. Add default values to attributes, to allow attribute being left unspecified. The attr getter will always return an attribute so callers need not check for it, if the attribute is not set then the default will be returned (at present the default will be constructed upon query but this will be changed). Add op definition for tf.AvgPool in ops.td, rewrite matcher using pattern using attribute matching & transforms. Adding some helper functions to make it simpler. Handle attributes with dialect prefix and map them to getter without dialect prefix. Note: VerifyAvgPoolOp could probably be autogenerated by know given the predicate specification on attributes, but deferring that to a follow up. PiperOrigin-RevId: 230364857 --- mlir/include/mlir/IR/op_base.td | 5 ++- mlir/include/mlir/TableGen/Attribute.h | 8 ++++ mlir/lib/TableGen/Attribute.cpp | 13 ++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 44 ++++++++++++++++++--- 4 files changed, 62 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 9982fd49b27e..a79eec8a235a 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -248,8 +248,9 @@ class Attr> : AttrConstraint { // 'builder.getStringAttr("foo")'. code constBuilderCall = ?; - // TODO(jpienaar): Add predicate to verify the validity of Attr so - // that verification can be generated. + // Default value for attribute. + // Requires a constBuilderCall defined. + string defaultValue = ?; } // A generic attribute that must be constructed around a specific type. diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index a720b1be6fb4..64c2db02cb9e 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -72,6 +72,14 @@ public: // the constant value. StringRef getConstBuilderTemplate() const; + // Returns whether this attribute has a default value. + bool hasDefaultValue() const; + + // Returns the template that can be used to produce the default value of + // the attribute. + // Syntax: {0} should be replaced with a builder. + std::string getDefaultValueTemplate() const; + StringRef getTableGenDefName() const; // Returns the code body for derived attribute. Aborts if this is not a diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index cbcc4c05eb9f..42dd333b3f74 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Operator.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" using namespace mlir; @@ -82,6 +83,18 @@ StringRef tblgen::Attribute::getConstBuilderTemplate() const { return getValueAsString(init); } +bool tblgen::Attribute::hasDefaultValue() const { + const auto *init = def->getValueInit("defaultValue"); + return !getValueAsString(init).empty(); +} + +std::string tblgen::Attribute::getDefaultValueTemplate() const { + assert(isConstBuildable() && "requiers constBuilderCall"); + const auto *init = def->getValueInit("defaultValue"); + return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}", + getValueAsString(init)); +} + StringRef tblgen::Attribute::getTableGenDefName() const { return def->getName(); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index ead70a35deef..236fdb37de61 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -179,20 +179,41 @@ void OpEmitter::emitAttrGetters() { auto name = namedAttr.getName(); const auto &attr = namedAttr.attr; + // Determine the name of the attribute getter. The name matches the + // attribute name excluding dialect prefix. + StringRef getter = name; + auto it = getter.rfind('$'); + if (it != StringRef::npos) + getter = getter.substr(it + 1); + // Emit the derived attribute body. if (attr.isDerivedAttr()) { - OUT(2) << attr.getReturnType() << ' ' << name << "() const {" + OUT(2) << attr.getReturnType() << ' ' << getter << "() const {" << attr.getDerivedCodeBody() << " }\n"; continue; } // Emit normal emitter. - OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n"; + OUT(2) << attr.getReturnType() << ' ' << getter << "() const {\n"; // Return the queried attribute with the correct return type. std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()", attr.getStorageType(), name); - OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal) + OUT(4) << "auto attr = " << attrVal << ";\n"; + if (attr.hasDefaultValue()) { + // Returns the default value if not set. + // TODO: this is inefficient, we are recreating the attribute for every + // call. This should be set instead. + OUT(4) << "if (!attr)\n"; + OUT(6) << "return " + << formatv( + attr.getConvertFromStorageCall(), + formatv( + attr.getDefaultValueTemplate(), + "mlir::Builder(this->getInstruction()->getContext())")) + << ";\n"; + } + OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), "attr") << ";\n }\n"; } } @@ -359,25 +380,36 @@ void OpEmitter::emitVerifier() { continue; auto name = namedAttr.getName(); - if (!attr.hasStorageType()) { + if (!attr.hasStorageType() && !attr.hasDefaultValue()) { + // TODO: Some verification can be done even without storage type. OUT(4) << "if (!this->getAttr(\"" << name << "\")) return emitOpError(\"requires attribute '" << name << "'\");\n"; continue; } - OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" + if (attr.hasDefaultValue()) { + // If the attribute has a default value, then only verify the predicate if + // set. This does effectively assume that the default value is valid. + // TODO: verify the debug value is valid (perhaps in debug mode only). + OUT(4) << "if (this->getAttr(\"" << name << "\")) {\n"; + } + + OUT(6) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" << attr.getStorageType() << ">()) return emitOpError(\"requires " << attr.getReturnType() << " attribute '" << name << "'\");\n"; auto attrPred = attr.getPredicate(); if (!attrPred.isNull()) { - OUT(4) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' " + OUT(6) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' " "failed to satisfy constraint of {2}\");\n", formatv(attrPred.getCondition(), formatv("this->getAttr(\"{0}\")", name)), name, attr.getTableGenDefName()); } + + if (attr.hasDefaultValue()) + OUT(4) << "}\n"; } // TODO: Handle variadic.