diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h new file mode 100644 index 000000000000..3dd22ac79b82 --- /dev/null +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -0,0 +1,81 @@ +//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Attribute wrapper to simplify using TableGen Record defining a MLIR +// Attribute. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_ +#define MLIR_TABLEGEN_ATTRIBUTE_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Type.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class DefInit; +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class providing helper methods for accessing MLIR Attribute defined +// in TableGen. This class should closely reflect what is defined as class +// `Attr` in TableGen. +class Attribute { +public: + explicit Attribute(const llvm::Record &def); + explicit Attribute(const llvm::Record *def) : Attribute(*def) {} + explicit Attribute(const llvm::DefInit *init); + + // Returns true if this attribute is a derived attribute (i.e., a subclass + // of `DrivedAttr`). + bool isDerivedAttr() const; + + // Returns the type of this attribute. + Type getType() const; + + // Returns true if this attribute has storage type set. + bool hasStorageType() const; + + // Returns the storage type if set. Returns the default storage type + // ("Attribute") otherwise. + StringRef getStorageType() const; + + // Returns the return type for this attribute. + StringRef getReturnType() const; + + // Returns the template getter method call which reads this attribute's + // storage and returns the value as of the desired return type. + // The call will contain a `{0}` which will be expanded to this attribute. + StringRef getConvertFromStorageCall() const; + + // Returns the code body for derived attribute. Aborts if this is not a + // derived attribute. + StringRef getDerivedCodeBody() const; + +private: + // The TableGen definition of this attribute. + const llvm::Record &def; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_ATTRIBUTE_H_ diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 54f404a060d1..c940ae56b1df 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -23,6 +23,7 @@ #define MLIR_TABLEGEN_OPERATOR_H_ #include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Attribute.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -57,26 +58,25 @@ public: // Returns the C++ class name of the op with namespace added. std::string qualifiedCppClassName() const; - struct Attribute { + struct NamedAttribute { std::string getName() const; - StringRef getReturnType() const; - StringRef getStorageType() const; llvm::StringInit *name; - llvm::Record *record; - bool isDerived; + Attribute attr; }; // Op attribute interators. - using attribute_iterator = Attribute *; + using attribute_iterator = NamedAttribute *; attribute_iterator attribute_begin(); attribute_iterator attribute_end(); llvm::iterator_range getAttributes(); // Op attribute accessors. int getNumAttributes() const { return attributes.size(); } - Attribute &getAttribute(int index) { return attributes[index]; } - const Attribute &getAttribute(int index) const { return attributes[index]; } + NamedAttribute &getAttribute(int index) { return attributes[index]; } + const NamedAttribute &getAttribute(int index) const { + return attributes[index]; + } struct Operand { bool hasMatcher() const; @@ -99,7 +99,7 @@ public: const Operand &getOperand(int index) const { return operands[index]; } // Op argument (attribute or operand) accessors. - using Argument = llvm::PointerUnion; + using Argument = llvm::PointerUnion; Argument getArg(int index); StringRef getArgName(int index) const; int getNumArgs() const { return operands.size() + attributes.size(); } @@ -115,7 +115,7 @@ private: SmallVector operands; // The attributes of the op. - SmallVector attributes; + SmallVector attributes; // The start of native attributes, which are specified when creating the op // as a part of the op's definition. diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp new file mode 100644 index 000000000000..ff49dc0fe0f8 --- /dev/null +++ b/mlir/lib/TableGen/Attribute.cpp @@ -0,0 +1,80 @@ +//===- Attribute.cpp - Attribute wrapper class ------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Attribute wrapper to simplify using TableGen Record defining a MLIR +// Attribute. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Operator.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; + +// Returns the initializer's value as string if the given TableGen initializer +// is a code or string initializer. Returns the empty StringRef otherwise. +static StringRef getValueAsString(const llvm::Init *init) { + if (const auto *code = dyn_cast(init)) + return code->getValue().trim(); + else if (const auto *str = dyn_cast(init)) + return str->getValue().trim(); + return {}; +} + +tblgen::Attribute::Attribute(const llvm::Record &def) : def(def) { + assert(def.isSubClassOf("Attr") && + "must be subclass of TableGen 'Attr' class"); +} + +tblgen::Attribute::Attribute(const llvm::DefInit *init) + : Attribute(*init->getDef()) {} + +bool tblgen::Attribute::isDerivedAttr() const { + return def.isSubClassOf("DerivedAttr"); +} + +tblgen::Type tblgen::Attribute::getType() const { + return Type(def.getValueAsDef("type")); +} + +bool tblgen::Attribute::hasStorageType() const { + const auto *init = def.getValueInit("storageType"); + return !getValueAsString(init).empty(); +} + +StringRef tblgen::Attribute::getStorageType() const { + const auto *init = def.getValueInit("storageType"); + auto type = getValueAsString(init); + if (type.empty()) + return "Attribute"; + return type; +} + +StringRef tblgen::Attribute::getReturnType() const { + const auto *init = def.getValueInit("returnType"); + return getValueAsString(init); +} + +StringRef tblgen::Attribute::getConvertFromStorageCall() const { + const auto *init = def.getValueInit("convertFromStorage"); + return getValueAsString(init); +} + +StringRef tblgen::Attribute::getDerivedCodeBody() const { + assert(isDerivedAttr() && "only derived attribute has 'body' field"); + return def.getValueAsString("body"); +} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 04f1457c6a46..595cf8a59b75 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -125,7 +125,7 @@ void tblgen::Operator::populateOperandsAndAttributes() { if (isDerived) PrintFatalError(def.getLoc(), "derived attributes not allowed in argument list"); - attributes.push_back({givenName, argDef, isDerived}); + attributes.push_back({givenName, Attribute(argDef)}); } // Handle derived attributes. @@ -144,13 +144,12 @@ void tblgen::Operator::populateOperandsAndAttributes() { "unsupported attribute modelling, only single class expected"); } attributes.push_back({cast(val.getNameInit()), - cast(val.getValue())->getDef(), - /*isDerived=*/true}); + Attribute(cast(val.getValue()))}); } } } -std::string tblgen::Operator::Attribute::getName() const { +std::string tblgen::Operator::NamedAttribute::getName() const { std::string ret = name->getAsUnquotedString(); // TODO(jpienaar): Revise this post dialect prefixing attribute discussion. auto split = StringRef(ret).split("__"); @@ -159,14 +158,6 @@ std::string tblgen::Operator::Attribute::getName() const { return llvm::join_items("$", split.first, split.second); } -StringRef tblgen::Operator::Attribute::getReturnType() const { - return record->getValueAsString("returnType").trim(); -} - -StringRef tblgen::Operator::Attribute::getStorageType() const { - return record->getValueAsString("storageType").trim(); -} - bool tblgen::Operator::Operand::hasMatcher() const { return !tblgen::Type(defInit).getPredicate().isEmpty(); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index ca5602935ad5..ae3cb555e19d 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -61,17 +61,7 @@ static inline bool hasStringAttribute(const Record &record, return isa(valueInit) || isa(valueInit); } -// Returns `fieldName`'s value queried from `record` if `fieldName` is set as -// an string in record; otherwise, returns `defaultVal`. -static inline StringRef getAsStringOrDefault(const Record &record, - StringRef fieldName, - StringRef defaultVal) { - return hasStringAttribute(record, fieldName) - ? record.getValueAsString(fieldName) - : defaultVal; -} - -static std::string getAttributeName(const Operator::Attribute &attr) { +static std::string getAttributeName(const Operator::NamedAttribute &attr) { return attr.name->getAsUnquotedString(); } @@ -189,14 +179,14 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { } void OpEmitter::emitAttrGetters() { - for (auto &attr : op.getAttributes()) { - auto name = getAttributeName(attr); - auto *def = attr.record; + for (auto &namedAttr : op.getAttributes()) { + auto name = getAttributeName(namedAttr); + const auto &attr = namedAttr.attr; // Emit the derived attribute body. - if (attr.isDerived) { + if (attr.isDerivedAttr()) { OUT(2) << attr.getReturnType() << ' ' << name << "() const {" - << def->getValueAsString("body") << " }\n"; + << attr.getDerivedCodeBody() << " }\n"; continue; } @@ -206,8 +196,7 @@ void OpEmitter::emitAttrGetters() { // Return the queried attribute with the correct return type. std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")", attr.getStorageType(), name); - OUT(4) << "return " - << formatv(def->getValueAsString("convertFromStorage"), attrVal) + OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal) << ";\n }\n"; } } @@ -258,12 +247,11 @@ void OpEmitter::emitBuilder() { // Emit parameters for all attributes // TODO(antiagainst): Support default initializer for attributes - for (const auto &attr : op.getAttributes()) { - if (attr.isDerived) + for (const auto &namedAttr : op.getAttributes()) { + const auto &attr = namedAttr.attr; + if (attr.isDerivedAttr()) break; - const Record &def = *attr.record; - os << ", " << getAsStringOrDefault(def, "storageType", "Attribute").trim() - << ' ' << getAttributeName(attr); + os << ", " << attr.getStorageType() << ' ' << getAttributeName(namedAttr); } os << ") {\n"; @@ -285,10 +273,10 @@ void OpEmitter::emitBuilder() { } // Push all attributes to the result - for (const auto &attr : op.getAttributes()) - if (!attr.isDerived) + for (const auto &namedAttr : op.getAttributes()) + if (!namedAttr.attr.isDerivedAttr()) OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n", - getAttributeName(attr)); + getAttributeName(namedAttr)); OUT(2) << "}\n"; // 2. Aggregated parameters @@ -368,12 +356,14 @@ void OpEmitter::emitVerifier() { OUT(2) << "bool verify() const {\n"; // Verify the attributes have the correct type. - for (const auto &attr : op.getAttributes()) { - if (attr.isDerived) + for (const auto &namedAttr : op.getAttributes()) { + const auto &attr = namedAttr.attr; + + if (attr.isDerivedAttr()) continue; - auto name = getAttributeName(attr); - if (!hasStringAttribute(*attr.record, "storageType")) { + auto name = getAttributeName(namedAttr); + if (!attr.hasStorageType()) { OUT(4) << "if (!this->getAttr(\"" << name << "\")) return emitOpError(\"requires attribute '" << name << "'\");\n"; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 4e8c2fdf15b8..b4e38954ec79 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -37,6 +37,7 @@ using namespace llvm; using namespace mlir; +using mlir::tblgen::Attribute; using mlir::tblgen::Operator; using mlir::tblgen::Type; @@ -90,10 +91,10 @@ private: } // end namespace void Pattern::emitAttributeValue(Record *constAttr) { - Record *attr = constAttr->getValueAsDef("attr"); + Attribute attr(constAttr->getValueAsDef("attr")); auto value = constAttr->getValue("value"); - Type type(attr->getValueAsDef("type")); - auto storageType = attr->getValueAsString("storageType").trim(); + Type type = attr.getType(); + auto storageType = attr.getStorageType(); // For attributes stored as strings we do not need to query builder etc. if (storageType == "StringAttr") { @@ -183,7 +184,7 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, } // TODO(jpienaar): Verify attributes. - if (auto *attr = opArg.dyn_cast()) { + if (auto *attr = opArg.dyn_cast()) { } } @@ -194,10 +195,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth, if (opArg.is()) os.indent(indent) << "state->" << name << " = op" << depth << "->getOperand(" << i << ");\n"; - if (auto attr = opArg.dyn_cast()) { + if (auto namedAttr = opArg.dyn_cast()) { os.indent(indent) << "state->" << name << " = op" << depth - << "->getAttrOfType<" << attr->getStorageType() - << ">(\"" << attr->getName() << "\");\n"; + << "->getAttrOfType<" + << namedAttr->attr.getStorageType() << ">(\"" + << namedAttr->getName() << "\");\n"; } } } @@ -234,8 +236,8 @@ void Pattern::emit(StringRef rewriteName) { for (auto &arg : boundArguments) { if (arg.second.isAttr()) { DefInit *defInit = cast(arg.second.init); - os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim() - << " " << arg.first() << ";\n"; + os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first() + << ";\n"; } else { os.indent(4) << "Value* " << arg.first() << ";\n"; } @@ -311,7 +313,7 @@ void Pattern::emit(StringRef rewriteName) { // TODO(jpienaar): Refactor out into map to avoid recomputing these. auto argument = resultOp.getArg(i); - if (!argument.is()) + if (!argument.is()) PrintFatalError(pattern->getLoc(), Twine("expected attribute ") + Twine(i));