forked from OSchip/llvm-project
Add tblgen::Attribute to wrap around TableGen Attr defs
This CL added a tblgen::Attribute class to wrap around raw TableGen Record getValue*() calls on Attr defs, which will provide a nicer API for handling TableGen Record. PiperOrigin-RevId: 228581107
This commit is contained in:
parent
6ce30becd7
commit
9b034f0bfd
|
@ -0,0 +1,81 @@
|
|||
//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Attribute wrapper to simplify using TableGen Record defining a MLIR
|
||||
// Attribute.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Type.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace llvm {
|
||||
class DefInit;
|
||||
class Record;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR Attribute defined
|
||||
// in TableGen. This class should closely reflect what is defined as class
|
||||
// `Attr` in TableGen.
|
||||
class Attribute {
|
||||
public:
|
||||
explicit Attribute(const llvm::Record &def);
|
||||
explicit Attribute(const llvm::Record *def) : Attribute(*def) {}
|
||||
explicit Attribute(const llvm::DefInit *init);
|
||||
|
||||
// Returns true if this attribute is a derived attribute (i.e., a subclass
|
||||
// of `DrivedAttr`).
|
||||
bool isDerivedAttr() const;
|
||||
|
||||
// Returns the type of this attribute.
|
||||
Type getType() const;
|
||||
|
||||
// Returns true if this attribute has storage type set.
|
||||
bool hasStorageType() const;
|
||||
|
||||
// Returns the storage type if set. Returns the default storage type
|
||||
// ("Attribute") otherwise.
|
||||
StringRef getStorageType() const;
|
||||
|
||||
// Returns the return type for this attribute.
|
||||
StringRef getReturnType() const;
|
||||
|
||||
// Returns the template getter method call which reads this attribute's
|
||||
// storage and returns the value as of the desired return type.
|
||||
// The call will contain a `{0}` which will be expanded to this attribute.
|
||||
StringRef getConvertFromStorageCall() const;
|
||||
|
||||
// Returns the code body for derived attribute. Aborts if this is not a
|
||||
// derived attribute.
|
||||
StringRef getDerivedCodeBody() const;
|
||||
|
||||
private:
|
||||
// The TableGen definition of this attribute.
|
||||
const llvm::Record &def;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_ATTRIBUTE_H_
|
|
@ -23,6 +23,7 @@
|
|||
#define MLIR_TABLEGEN_OPERATOR_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
@ -57,26 +58,25 @@ public:
|
|||
// Returns the C++ class name of the op with namespace added.
|
||||
std::string qualifiedCppClassName() const;
|
||||
|
||||
struct Attribute {
|
||||
struct NamedAttribute {
|
||||
std::string getName() const;
|
||||
StringRef getReturnType() const;
|
||||
StringRef getStorageType() const;
|
||||
|
||||
llvm::StringInit *name;
|
||||
llvm::Record *record;
|
||||
bool isDerived;
|
||||
Attribute attr;
|
||||
};
|
||||
|
||||
// Op attribute interators.
|
||||
using attribute_iterator = Attribute *;
|
||||
using attribute_iterator = NamedAttribute *;
|
||||
attribute_iterator attribute_begin();
|
||||
attribute_iterator attribute_end();
|
||||
llvm::iterator_range<attribute_iterator> getAttributes();
|
||||
|
||||
// Op attribute accessors.
|
||||
int getNumAttributes() const { return attributes.size(); }
|
||||
Attribute &getAttribute(int index) { return attributes[index]; }
|
||||
const Attribute &getAttribute(int index) const { return attributes[index]; }
|
||||
NamedAttribute &getAttribute(int index) { return attributes[index]; }
|
||||
const NamedAttribute &getAttribute(int index) const {
|
||||
return attributes[index];
|
||||
}
|
||||
|
||||
struct Operand {
|
||||
bool hasMatcher() const;
|
||||
|
@ -99,7 +99,7 @@ public:
|
|||
const Operand &getOperand(int index) const { return operands[index]; }
|
||||
|
||||
// Op argument (attribute or operand) accessors.
|
||||
using Argument = llvm::PointerUnion<Attribute *, Operand *>;
|
||||
using Argument = llvm::PointerUnion<NamedAttribute *, Operand *>;
|
||||
Argument getArg(int index);
|
||||
StringRef getArgName(int index) const;
|
||||
int getNumArgs() const { return operands.size() + attributes.size(); }
|
||||
|
@ -115,7 +115,7 @@ private:
|
|||
SmallVector<Operand, 4> operands;
|
||||
|
||||
// The attributes of the op.
|
||||
SmallVector<Attribute, 4> attributes;
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
|
||||
// The start of native attributes, which are specified when creating the op
|
||||
// as a part of the op's definition.
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
//===- Attribute.cpp - Attribute wrapper class ------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Attribute wrapper to simplify using TableGen Record defining a MLIR
|
||||
// Attribute.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Returns the initializer's value as string if the given TableGen initializer
|
||||
// is a code or string initializer. Returns the empty StringRef otherwise.
|
||||
static StringRef getValueAsString(const llvm::Init *init) {
|
||||
if (const auto *code = dyn_cast<llvm::CodeInit>(init))
|
||||
return code->getValue().trim();
|
||||
else if (const auto *str = dyn_cast<llvm::StringInit>(init))
|
||||
return str->getValue().trim();
|
||||
return {};
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::Record &def) : def(def) {
|
||||
assert(def.isSubClassOf("Attr") &&
|
||||
"must be subclass of TableGen 'Attr' class");
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::DefInit *init)
|
||||
: Attribute(*init->getDef()) {}
|
||||
|
||||
bool tblgen::Attribute::isDerivedAttr() const {
|
||||
return def.isSubClassOf("DerivedAttr");
|
||||
}
|
||||
|
||||
tblgen::Type tblgen::Attribute::getType() const {
|
||||
return Type(def.getValueAsDef("type"));
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::hasStorageType() const {
|
||||
const auto *init = def.getValueInit("storageType");
|
||||
return !getValueAsString(init).empty();
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getStorageType() const {
|
||||
const auto *init = def.getValueInit("storageType");
|
||||
auto type = getValueAsString(init);
|
||||
if (type.empty())
|
||||
return "Attribute";
|
||||
return type;
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getReturnType() const {
|
||||
const auto *init = def.getValueInit("returnType");
|
||||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getConvertFromStorageCall() const {
|
||||
const auto *init = def.getValueInit("convertFromStorage");
|
||||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
||||
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
||||
return def.getValueAsString("body");
|
||||
}
|
|
@ -125,7 +125,7 @@ void tblgen::Operator::populateOperandsAndAttributes() {
|
|||
if (isDerived)
|
||||
PrintFatalError(def.getLoc(),
|
||||
"derived attributes not allowed in argument list");
|
||||
attributes.push_back({givenName, argDef, isDerived});
|
||||
attributes.push_back({givenName, Attribute(argDef)});
|
||||
}
|
||||
|
||||
// Handle derived attributes.
|
||||
|
@ -144,13 +144,12 @@ void tblgen::Operator::populateOperandsAndAttributes() {
|
|||
"unsupported attribute modelling, only single class expected");
|
||||
}
|
||||
attributes.push_back({cast<llvm::StringInit>(val.getNameInit()),
|
||||
cast<DefInit>(val.getValue())->getDef(),
|
||||
/*isDerived=*/true});
|
||||
Attribute(cast<DefInit>(val.getValue()))});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string tblgen::Operator::Attribute::getName() const {
|
||||
std::string tblgen::Operator::NamedAttribute::getName() const {
|
||||
std::string ret = name->getAsUnquotedString();
|
||||
// TODO(jpienaar): Revise this post dialect prefixing attribute discussion.
|
||||
auto split = StringRef(ret).split("__");
|
||||
|
@ -159,14 +158,6 @@ std::string tblgen::Operator::Attribute::getName() const {
|
|||
return llvm::join_items("$", split.first, split.second);
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::Attribute::getReturnType() const {
|
||||
return record->getValueAsString("returnType").trim();
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::Attribute::getStorageType() const {
|
||||
return record->getValueAsString("storageType").trim();
|
||||
}
|
||||
|
||||
bool tblgen::Operator::Operand::hasMatcher() const {
|
||||
return !tblgen::Type(defInit).getPredicate().isEmpty();
|
||||
}
|
||||
|
|
|
@ -61,17 +61,7 @@ static inline bool hasStringAttribute(const Record &record,
|
|||
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
|
||||
}
|
||||
|
||||
// Returns `fieldName`'s value queried from `record` if `fieldName` is set as
|
||||
// an string in record; otherwise, returns `defaultVal`.
|
||||
static inline StringRef getAsStringOrDefault(const Record &record,
|
||||
StringRef fieldName,
|
||||
StringRef defaultVal) {
|
||||
return hasStringAttribute(record, fieldName)
|
||||
? record.getValueAsString(fieldName)
|
||||
: defaultVal;
|
||||
}
|
||||
|
||||
static std::string getAttributeName(const Operator::Attribute &attr) {
|
||||
static std::string getAttributeName(const Operator::NamedAttribute &attr) {
|
||||
return attr.name->getAsUnquotedString();
|
||||
}
|
||||
|
||||
|
@ -189,14 +179,14 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
|
|||
}
|
||||
|
||||
void OpEmitter::emitAttrGetters() {
|
||||
for (auto &attr : op.getAttributes()) {
|
||||
auto name = getAttributeName(attr);
|
||||
auto *def = attr.record;
|
||||
for (auto &namedAttr : op.getAttributes()) {
|
||||
auto name = getAttributeName(namedAttr);
|
||||
const auto &attr = namedAttr.attr;
|
||||
|
||||
// Emit the derived attribute body.
|
||||
if (attr.isDerived) {
|
||||
if (attr.isDerivedAttr()) {
|
||||
OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
|
||||
<< def->getValueAsString("body") << " }\n";
|
||||
<< attr.getDerivedCodeBody() << " }\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -206,8 +196,7 @@ void OpEmitter::emitAttrGetters() {
|
|||
// Return the queried attribute with the correct return type.
|
||||
std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
|
||||
attr.getStorageType(), name);
|
||||
OUT(4) << "return "
|
||||
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
|
||||
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal)
|
||||
<< ";\n }\n";
|
||||
}
|
||||
}
|
||||
|
@ -258,12 +247,11 @@ void OpEmitter::emitBuilder() {
|
|||
|
||||
// Emit parameters for all attributes
|
||||
// TODO(antiagainst): Support default initializer for attributes
|
||||
for (const auto &attr : op.getAttributes()) {
|
||||
if (attr.isDerived)
|
||||
for (const auto &namedAttr : op.getAttributes()) {
|
||||
const auto &attr = namedAttr.attr;
|
||||
if (attr.isDerivedAttr())
|
||||
break;
|
||||
const Record &def = *attr.record;
|
||||
os << ", " << getAsStringOrDefault(def, "storageType", "Attribute").trim()
|
||||
<< ' ' << getAttributeName(attr);
|
||||
os << ", " << attr.getStorageType() << ' ' << getAttributeName(namedAttr);
|
||||
}
|
||||
|
||||
os << ") {\n";
|
||||
|
@ -285,10 +273,10 @@ void OpEmitter::emitBuilder() {
|
|||
}
|
||||
|
||||
// Push all attributes to the result
|
||||
for (const auto &attr : op.getAttributes())
|
||||
if (!attr.isDerived)
|
||||
for (const auto &namedAttr : op.getAttributes())
|
||||
if (!namedAttr.attr.isDerivedAttr())
|
||||
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
|
||||
getAttributeName(attr));
|
||||
getAttributeName(namedAttr));
|
||||
OUT(2) << "}\n";
|
||||
|
||||
// 2. Aggregated parameters
|
||||
|
@ -368,12 +356,14 @@ void OpEmitter::emitVerifier() {
|
|||
|
||||
OUT(2) << "bool verify() const {\n";
|
||||
// Verify the attributes have the correct type.
|
||||
for (const auto &attr : op.getAttributes()) {
|
||||
if (attr.isDerived)
|
||||
for (const auto &namedAttr : op.getAttributes()) {
|
||||
const auto &attr = namedAttr.attr;
|
||||
|
||||
if (attr.isDerivedAttr())
|
||||
continue;
|
||||
|
||||
auto name = getAttributeName(attr);
|
||||
if (!hasStringAttribute(*attr.record, "storageType")) {
|
||||
auto name = getAttributeName(namedAttr);
|
||||
if (!attr.hasStorageType()) {
|
||||
OUT(4) << "if (!this->getAttr(\"" << name
|
||||
<< "\")) return emitOpError(\"requires attribute '" << name
|
||||
<< "'\");\n";
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::Operator;
|
||||
using mlir::tblgen::Type;
|
||||
|
||||
|
@ -90,10 +91,10 @@ private:
|
|||
} // end namespace
|
||||
|
||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||
Record *attr = constAttr->getValueAsDef("attr");
|
||||
Attribute attr(constAttr->getValueAsDef("attr"));
|
||||
auto value = constAttr->getValue("value");
|
||||
Type type(attr->getValueAsDef("type"));
|
||||
auto storageType = attr->getValueAsString("storageType").trim();
|
||||
Type type = attr.getType();
|
||||
auto storageType = attr.getStorageType();
|
||||
|
||||
// For attributes stored as strings we do not need to query builder etc.
|
||||
if (storageType == "StringAttr") {
|
||||
|
@ -183,7 +184,7 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
}
|
||||
|
||||
// TODO(jpienaar): Verify attributes.
|
||||
if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
|
||||
if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,10 +195,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
if (opArg.is<Operator::Operand *>())
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getOperand(" << i << ");\n";
|
||||
if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
|
||||
if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getAttrOfType<" << attr->getStorageType()
|
||||
<< ">(\"" << attr->getName() << "\");\n";
|
||||
<< "->getAttrOfType<"
|
||||
<< namedAttr->attr.getStorageType() << ">(\""
|
||||
<< namedAttr->getName() << "\");\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -234,8 +236,8 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
for (auto &arg : boundArguments) {
|
||||
if (arg.second.isAttr()) {
|
||||
DefInit *defInit = cast<DefInit>(arg.second.init);
|
||||
os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
|
||||
<< " " << arg.first() << ";\n";
|
||||
os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first()
|
||||
<< ";\n";
|
||||
} else {
|
||||
os.indent(4) << "Value* " << arg.first() << ";\n";
|
||||
}
|
||||
|
@ -311,7 +313,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<Operator::Attribute *>())
|
||||
if (!argument.is<Operator::NamedAttribute *>())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("expected attribute ") + Twine(i));
|
||||
|
||||
|
|
Loading…
Reference in New Issue