Add tblgen::Attribute to wrap around TableGen Attr defs

This CL added a tblgen::Attribute class to wrap around raw TableGen
Record getValue*() calls on Attr defs, which will provide a nicer
API for handling TableGen Record.

PiperOrigin-RevId: 228581107
This commit is contained in:
Lei Zhang 2019-01-09 13:50:20 -08:00 committed by jpienaar
parent 6ce30becd7
commit 9b034f0bfd
6 changed files with 206 additions and 62 deletions

View File

@ -0,0 +1,81 @@
//===- Attribute.h - Attribute wrapper class --------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Attribute wrapper to simplify using TableGen Record defining a MLIR
// Attribute.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_ATTRIBUTE_H_
#define MLIR_TABLEGEN_ATTRIBUTE_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class DefInit;
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class providing helper methods for accessing MLIR Attribute defined
// in TableGen. This class should closely reflect what is defined as class
// `Attr` in TableGen.
class Attribute {
public:
explicit Attribute(const llvm::Record &def);
explicit Attribute(const llvm::Record *def) : Attribute(*def) {}
explicit Attribute(const llvm::DefInit *init);
// Returns true if this attribute is a derived attribute (i.e., a subclass
// of `DrivedAttr`).
bool isDerivedAttr() const;
// Returns the type of this attribute.
Type getType() const;
// Returns true if this attribute has storage type set.
bool hasStorageType() const;
// Returns the storage type if set. Returns the default storage type
// ("Attribute") otherwise.
StringRef getStorageType() const;
// Returns the return type for this attribute.
StringRef getReturnType() const;
// Returns the template getter method call which reads this attribute's
// storage and returns the value as of the desired return type.
// The call will contain a `{0}` which will be expanded to this attribute.
StringRef getConvertFromStorageCall() const;
// Returns the code body for derived attribute. Aborts if this is not a
// derived attribute.
StringRef getDerivedCodeBody() const;
private:
// The TableGen definition of this attribute.
const llvm::Record &def;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_ATTRIBUTE_H_

View File

@ -23,6 +23,7 @@
#define MLIR_TABLEGEN_OPERATOR_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Attribute.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@ -57,26 +58,25 @@ public:
// Returns the C++ class name of the op with namespace added.
std::string qualifiedCppClassName() const;
struct Attribute {
struct NamedAttribute {
std::string getName() const;
StringRef getReturnType() const;
StringRef getStorageType() const;
llvm::StringInit *name;
llvm::Record *record;
bool isDerived;
Attribute attr;
};
// Op attribute interators.
using attribute_iterator = Attribute *;
using attribute_iterator = NamedAttribute *;
attribute_iterator attribute_begin();
attribute_iterator attribute_end();
llvm::iterator_range<attribute_iterator> getAttributes();
// Op attribute accessors.
int getNumAttributes() const { return attributes.size(); }
Attribute &getAttribute(int index) { return attributes[index]; }
const Attribute &getAttribute(int index) const { return attributes[index]; }
NamedAttribute &getAttribute(int index) { return attributes[index]; }
const NamedAttribute &getAttribute(int index) const {
return attributes[index];
}
struct Operand {
bool hasMatcher() const;
@ -99,7 +99,7 @@ public:
const Operand &getOperand(int index) const { return operands[index]; }
// Op argument (attribute or operand) accessors.
using Argument = llvm::PointerUnion<Attribute *, Operand *>;
using Argument = llvm::PointerUnion<NamedAttribute *, Operand *>;
Argument getArg(int index);
StringRef getArgName(int index) const;
int getNumArgs() const { return operands.size() + attributes.size(); }
@ -115,7 +115,7 @@ private:
SmallVector<Operand, 4> operands;
// The attributes of the op.
SmallVector<Attribute, 4> attributes;
SmallVector<NamedAttribute, 4> attributes;
// The start of native attributes, which are specified when creating the op
// as a part of the op's definition.

View File

@ -0,0 +1,80 @@
//===- Attribute.cpp - Attribute wrapper class ------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Attribute wrapper to simplify using TableGen Record defining a MLIR
// Attribute.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
// Returns the initializer's value as string if the given TableGen initializer
// is a code or string initializer. Returns the empty StringRef otherwise.
static StringRef getValueAsString(const llvm::Init *init) {
if (const auto *code = dyn_cast<llvm::CodeInit>(init))
return code->getValue().trim();
else if (const auto *str = dyn_cast<llvm::StringInit>(init))
return str->getValue().trim();
return {};
}
tblgen::Attribute::Attribute(const llvm::Record &def) : def(def) {
assert(def.isSubClassOf("Attr") &&
"must be subclass of TableGen 'Attr' class");
}
tblgen::Attribute::Attribute(const llvm::DefInit *init)
: Attribute(*init->getDef()) {}
bool tblgen::Attribute::isDerivedAttr() const {
return def.isSubClassOf("DerivedAttr");
}
tblgen::Type tblgen::Attribute::getType() const {
return Type(def.getValueAsDef("type"));
}
bool tblgen::Attribute::hasStorageType() const {
const auto *init = def.getValueInit("storageType");
return !getValueAsString(init).empty();
}
StringRef tblgen::Attribute::getStorageType() const {
const auto *init = def.getValueInit("storageType");
auto type = getValueAsString(init);
if (type.empty())
return "Attribute";
return type;
}
StringRef tblgen::Attribute::getReturnType() const {
const auto *init = def.getValueInit("returnType");
return getValueAsString(init);
}
StringRef tblgen::Attribute::getConvertFromStorageCall() const {
const auto *init = def.getValueInit("convertFromStorage");
return getValueAsString(init);
}
StringRef tblgen::Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def.getValueAsString("body");
}

View File

@ -125,7 +125,7 @@ void tblgen::Operator::populateOperandsAndAttributes() {
if (isDerived)
PrintFatalError(def.getLoc(),
"derived attributes not allowed in argument list");
attributes.push_back({givenName, argDef, isDerived});
attributes.push_back({givenName, Attribute(argDef)});
}
// Handle derived attributes.
@ -144,13 +144,12 @@ void tblgen::Operator::populateOperandsAndAttributes() {
"unsupported attribute modelling, only single class expected");
}
attributes.push_back({cast<llvm::StringInit>(val.getNameInit()),
cast<DefInit>(val.getValue())->getDef(),
/*isDerived=*/true});
Attribute(cast<DefInit>(val.getValue()))});
}
}
}
std::string tblgen::Operator::Attribute::getName() const {
std::string tblgen::Operator::NamedAttribute::getName() const {
std::string ret = name->getAsUnquotedString();
// TODO(jpienaar): Revise this post dialect prefixing attribute discussion.
auto split = StringRef(ret).split("__");
@ -159,14 +158,6 @@ std::string tblgen::Operator::Attribute::getName() const {
return llvm::join_items("$", split.first, split.second);
}
StringRef tblgen::Operator::Attribute::getReturnType() const {
return record->getValueAsString("returnType").trim();
}
StringRef tblgen::Operator::Attribute::getStorageType() const {
return record->getValueAsString("storageType").trim();
}
bool tblgen::Operator::Operand::hasMatcher() const {
return !tblgen::Type(defInit).getPredicate().isEmpty();
}

View File

@ -61,17 +61,7 @@ static inline bool hasStringAttribute(const Record &record,
return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
}
// Returns `fieldName`'s value queried from `record` if `fieldName` is set as
// an string in record; otherwise, returns `defaultVal`.
static inline StringRef getAsStringOrDefault(const Record &record,
StringRef fieldName,
StringRef defaultVal) {
return hasStringAttribute(record, fieldName)
? record.getValueAsString(fieldName)
: defaultVal;
}
static std::string getAttributeName(const Operator::Attribute &attr) {
static std::string getAttributeName(const Operator::NamedAttribute &attr) {
return attr.name->getAsUnquotedString();
}
@ -189,14 +179,14 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
}
void OpEmitter::emitAttrGetters() {
for (auto &attr : op.getAttributes()) {
auto name = getAttributeName(attr);
auto *def = attr.record;
for (auto &namedAttr : op.getAttributes()) {
auto name = getAttributeName(namedAttr);
const auto &attr = namedAttr.attr;
// Emit the derived attribute body.
if (attr.isDerived) {
if (attr.isDerivedAttr()) {
OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
<< def->getValueAsString("body") << " }\n";
<< attr.getDerivedCodeBody() << " }\n";
continue;
}
@ -206,8 +196,7 @@ void OpEmitter::emitAttrGetters() {
// Return the queried attribute with the correct return type.
std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
attr.getStorageType(), name);
OUT(4) << "return "
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal)
<< ";\n }\n";
}
}
@ -258,12 +247,11 @@ void OpEmitter::emitBuilder() {
// Emit parameters for all attributes
// TODO(antiagainst): Support default initializer for attributes
for (const auto &attr : op.getAttributes()) {
if (attr.isDerived)
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
break;
const Record &def = *attr.record;
os << ", " << getAsStringOrDefault(def, "storageType", "Attribute").trim()
<< ' ' << getAttributeName(attr);
os << ", " << attr.getStorageType() << ' ' << getAttributeName(namedAttr);
}
os << ") {\n";
@ -285,10 +273,10 @@ void OpEmitter::emitBuilder() {
}
// Push all attributes to the result
for (const auto &attr : op.getAttributes())
if (!attr.isDerived)
for (const auto &namedAttr : op.getAttributes())
if (!namedAttr.attr.isDerivedAttr())
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
getAttributeName(attr));
getAttributeName(namedAttr));
OUT(2) << "}\n";
// 2. Aggregated parameters
@ -368,12 +356,14 @@ void OpEmitter::emitVerifier() {
OUT(2) << "bool verify() const {\n";
// Verify the attributes have the correct type.
for (const auto &attr : op.getAttributes()) {
if (attr.isDerived)
for (const auto &namedAttr : op.getAttributes()) {
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
continue;
auto name = getAttributeName(attr);
if (!hasStringAttribute(*attr.record, "storageType")) {
auto name = getAttributeName(namedAttr);
if (!attr.hasStorageType()) {
OUT(4) << "if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";

View File

@ -37,6 +37,7 @@
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Attribute;
using mlir::tblgen::Operator;
using mlir::tblgen::Type;
@ -90,10 +91,10 @@ private:
} // end namespace
void Pattern::emitAttributeValue(Record *constAttr) {
Record *attr = constAttr->getValueAsDef("attr");
Attribute attr(constAttr->getValueAsDef("attr"));
auto value = constAttr->getValue("value");
Type type(attr->getValueAsDef("type"));
auto storageType = attr->getValueAsString("storageType").trim();
Type type = attr.getType();
auto storageType = attr.getStorageType();
// For attributes stored as strings we do not need to query builder etc.
if (storageType == "StringAttr") {
@ -183,7 +184,7 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
}
// TODO(jpienaar): Verify attributes.
if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
}
}
@ -194,10 +195,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
if (opArg.is<Operator::Operand *>())
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << i << ");\n";
if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getAttrOfType<" << attr->getStorageType()
<< ">(\"" << attr->getName() << "\");\n";
<< "->getAttrOfType<"
<< namedAttr->attr.getStorageType() << ">(\""
<< namedAttr->getName() << "\");\n";
}
}
}
@ -234,8 +236,8 @@ void Pattern::emit(StringRef rewriteName) {
for (auto &arg : boundArguments) {
if (arg.second.isAttr()) {
DefInit *defInit = cast<DefInit>(arg.second.init);
os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim()
<< " " << arg.first() << ";\n";
os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first()
<< ";\n";
} else {
os.indent(4) << "Value* " << arg.first() << ";\n";
}
@ -311,7 +313,7 @@ void Pattern::emit(StringRef rewriteName) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<Operator::Attribute *>())
if (!argument.is<Operator::NamedAttribute *>())
PrintFatalError(pattern->getLoc(),
Twine("expected attribute ") + Twine(i));