Match attributes in input pattern.

Bind attributes similar to operands. Use to rewrite leakyreulo and const rewrite pattern. The attribute type/attributes are not currently checked so should only be used where the attributes match due to the construction of the op.

To support current attribute namespacing, convert __ in attribute name to "$" for matching purposes ('$' is not valid character in variable in TableGen).

Some simplification to make it simpler to specify indented ostream and avoid so many spaces. The goal is not to have perfectly formatted code generated but good enough so that its still easy to read for a user.

PiperOrigin-RevId: 228183639
This commit is contained in:
Jacques Pienaar 2019-01-07 09:52:26 -08:00 committed by jpienaar
parent 8d849eb4b9
commit aae85ddce1
4 changed files with 99 additions and 64 deletions

View File

@ -55,6 +55,10 @@ public:
// Operations attribute accessors.
struct Attribute {
std::string getName() const;
StringRef getReturnType() const;
StringRef getStorageType() const;
llvm::StringInit *name;
llvm::Record *record;
bool isDerived;

View File

@ -141,6 +141,23 @@ void Operator::populateOperandsAndAttributes() {
}
}
std::string mlir::Operator::Attribute::getName() const {
std::string ret = name->getAsUnquotedString();
// TODO(jpienaar): Revise this post dialect prefixing attribute discussion.
auto split = StringRef(ret).split("__");
if (split.second.empty())
return ret;
return llvm::join_items("$", split.first, split.second);
}
StringRef mlir::Operator::Attribute::getReturnType() const {
return record->getValueAsString("returnType").trim();
}
StringRef mlir::Operator::Attribute::getStorageType() const {
return record->getValueAsString("storageType").trim();
}
bool mlir::Operator::Operand::hasMatcher() const {
llvm::Init *matcher = defInit->getDef()->getValue("predicate")->getValue();
return !isa<llvm::UnsetInit>(matcher);

View File

@ -34,6 +34,9 @@ using namespace mlir;
static const char *const generatedArgName = "_arg";
// Helper macro that returns indented os.
#define OUT(X) os.indent((X))
// TODO(jpienaar): The builder body should probably be separate from the header.
// Variation of method in FormatVariadic.h which takes a StringRef as input
@ -164,8 +167,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
os << "> {\npublic:\n";
// Build operation name.
os << " static StringRef getOperationName() { return \""
<< emitter.op.getOperationName() << "\"; };\n";
OUT(2) << "static StringRef getOperationName() { return \""
<< emitter.op.getOperationName() << "\"; };\n";
emitter.emitNamedOperands();
emitter.emitBuilder();
@ -176,8 +179,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
emitter.emitCanonicalizationPatterns();
emitter.emitConstantFolder();
os << "private:\n friend class ::mlir::OperationInst;\n";
os << " explicit " << emitter.op.cppClassName()
os << "private:\n friend class ::mlir::OperationInst;\n"
<< " explicit " << emitter.op.cppClassName()
<< "(const OperationInst* state) : Op(state) {}\n};\n";
emitter.mapOverClassNamespaces(
[&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
@ -190,22 +193,20 @@ void OpEmitter::emitAttrGetters() {
// Emit the derived attribute body.
if (attr.isDerived) {
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
<< "() const {" << def->getValueAsString("body") << " }\n";
OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
<< def->getValueAsString("body") << " }\n";
continue;
}
// Emit normal emitter.
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
<< "() const {\n";
OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
// Return the queried attribute with the correct return type.
std::string attrVal =
formatv("this->getAttrOfType<{0}>(\"{1}\")",
def->getValueAsString("storageType").trim(), name);
os << " return "
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
<< ";\n }\n";
std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
attr.getStorageType(), name);
OUT(4) << "return "
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
<< ";\n }\n";
}
}
@ -243,7 +244,7 @@ void OpEmitter::emitBuilder() {
// 1. Stand-alone parameters
std::vector<Record *> returnTypes = def.getValueAsListOfDefs("returnTypes");
os << " static void build(Builder* builder, OperationState* result";
OUT(2) << "static void build(Builder* builder, OperationState* result";
// Emit parameters for all return types
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
@ -267,7 +268,7 @@ void OpEmitter::emitBuilder() {
// Push all result types to the result
if (!returnTypes.empty()) {
os << " result->addTypes({returnType0";
OUT(4) << "result->addTypes({returnType0";
for (unsigned i = 1, e = returnTypes.size(); i != e; ++i)
os << ", returnType" << i;
os << "});\n\n";
@ -275,7 +276,7 @@ void OpEmitter::emitBuilder() {
// Push all operands to the result
if (op.getNumOperands() > 0) {
os << " result->addOperands({" << getArgumentName(op, 0);
OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
os << ", " << getArgumentName(op, i);
os << "});\n";
@ -284,45 +285,45 @@ void OpEmitter::emitBuilder() {
// Push all attributes to the result
for (const auto &attr : op.getAttributes())
if (!attr.isDerived)
os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
getAttributeName(attr));
os << " }\n";
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
getAttributeName(attr));
OUT(2) << "}\n";
// 2. Aggregated parameters
// Signature
os << " static void build(Builder* builder, OperationState* result, "
<< "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
"ArrayRef<NamedAttribute> attributes) {\n";
OUT(2) << "static void build(Builder* builder, OperationState* result, "
<< "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
"ArrayRef<NamedAttribute> attributes) {\n";
// Result types
os << " assert(resultTypes.size() == " << returnTypes.size()
<< "u && \"mismatched number of return types\");\n"
<< " result->addTypes(resultTypes);\n";
OUT(4) << "assert(resultTypes.size() == " << returnTypes.size()
<< "u && \"mismatched number of return types\");\n"
<< " result->addTypes(resultTypes);\n";
// Operands
os << " assert(args.size() == " << op.getNumOperands()
<< "u && \"mismatched number of parameters\");\n"
<< " result->addOperands(args);\n\n";
OUT(4) << "assert(args.size() == " << op.getNumOperands()
<< "u && \"mismatched number of parameters\");\n"
<< " result->addOperands(args);\n\n";
// Attributes
if (op.getNumAttributes() > 0) {
os << " assert(!attributes.size() && \"no attributes expected\");\n"
<< " }\n";
OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n"
<< " }\n";
} else {
os << " assert(attributes.size() >= " << op.getNumAttributes()
<< "u && \"not enough attributes\");\n"
<< " for (const auto& pair : attributes)\n"
<< " result->addAttribute(pair.first, pair.second);\n"
<< " }\n";
OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes()
<< "u && \"not enough attributes\");\n"
<< " for (const auto& pair : attributes)\n"
<< " result->addAttribute(pair.first, pair.second);\n"
<< " }\n";
}
}
void OpEmitter::emitCanonicalizationPatterns() {
if (!def.getValueAsBit("hasCanonicalizationPatterns"))
return;
os << " static void getCanonicalizationPatterns("
<< "OwningRewritePatternList &results, MLIRContext* context);\n";
OUT(2) << "static void getCanonicalizationPatterns("
<< "OwningRewritePatternList &results, MLIRContext* context);\n";
}
void OpEmitter::emitConstantFolder() {
@ -363,7 +364,7 @@ void OpEmitter::emitVerifier() {
if (!hasCustomVerify && op.getNumArgs() == 0)
return;
os << " bool verify() const {\n";
OUT(2) << "bool verify() const {\n";
// Verify the attributes have the correct type.
for (const auto &attr : op.getAttributes()) {
if (attr.isDerived)
@ -371,17 +372,15 @@ void OpEmitter::emitVerifier() {
auto name = getAttributeName(attr);
if (!hasStringAttribute(*attr.record, "storageType")) {
os << " if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";
OUT(4) << "if (!this->getAttr(\"" << name
<< "\")) return emitOpError(\"requires attribute '" << name
<< "'\");\n";
continue;
}
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.record->getValueAsString("storageType").trim()
<< ">()) return emitOpError(\"requires "
<< attr.record->getValueAsString("returnType").trim() << " attribute '"
<< name << "'\");\n";
OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
<< attr.getReturnType() << " attribute '" << name << "'\");\n";
}
// TODO: Handle variadic.
@ -392,17 +391,17 @@ void OpEmitter::emitVerifier() {
if (operand.hasMatcher()) {
auto pred =
"if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n";
os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" +
Twine(opIndex) + ")->getType()");
OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" +
Twine(opIndex) + ")->getType()");
}
++opIndex;
}
if (hasCustomVerify)
os << " " << codeInit->getValue() << "\n";
OUT(4) << codeInit->getValue() << "\n";
else
os << " return false;\n";
os << " }\n";
OUT(4) << "return false;\n";
OUT(2) << "}\n";
}
void OpEmitter::emitTraits() {

View File

@ -150,6 +150,8 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
"' in pattern and op's definition");
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
auto opArg = op.getArg(i);
if (auto argTree = dyn_cast<DagInit>(arg)) {
os.indent(indent) << "{\n";
os.indent(indent + 2) << formatv(
@ -162,12 +164,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
// Verify arguments.
if (auto defInit = dyn_cast<DefInit>(arg)) {
auto opArg = op.getArg(i);
// Verify operands.
if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
// Skip verification where not needed due to definition of op.
if (operand->defInit == defInit)
goto SkipOperandVerification;
goto StateCapture;
if (!defInit->getDef()->isSubClassOf("Type"))
PrintFatalError(pattern->getLoc(),
@ -185,15 +186,24 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
formatv("op{0}->getOperand({1})->getType()", depth, i))
<< ")) return matchFailure();\n";
}
}
SkipOperandVerification:
// TODO(jpienaar): Verify attributes.
// TODO(jpienaar): Verify attributes.
if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
}
}
StateCapture:
auto name = tree->getArgNameStr(i);
if (name.empty())
continue;
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << i << ");\n";
if (opArg.is<Operator::Operand *>())
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getOperand(" << i << ");\n";
if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
os.indent(indent) << "state->" << name << " = op" << depth
<< "->getAttrOfType<" << attr->getStorageType()
<< ">(\"" << attr->getName() << "\");\n";
}
}
}
@ -291,13 +301,18 @@ void Pattern::emit(StringRef rewriteName) {
(os << ",\n").indent(6);
// The argument in the result DAG pattern.
auto name = resultOp.getArgName(i);
auto name = resultTree->getArgNameStr(i);
auto opName = resultOp.getArgName(i);
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
if (!value)
PrintFatalError(pattern->getLoc(),
Twine("attribute '") + name +
"' needs to be constant initialized");
if (!value) {
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(pattern->getLoc(),
Twine("referencing unbound variable '") + name + "'");
os << "/*" << opName << "=*/"
<< "s." << name;
continue;
}
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);