forked from OSchip/llvm-project
Match attributes in input pattern.
Bind attributes similar to operands. Use to rewrite leakyreulo and const rewrite pattern. The attribute type/attributes are not currently checked so should only be used where the attributes match due to the construction of the op. To support current attribute namespacing, convert __ in attribute name to "$" for matching purposes ('$' is not valid character in variable in TableGen). Some simplification to make it simpler to specify indented ostream and avoid so many spaces. The goal is not to have perfectly formatted code generated but good enough so that its still easy to read for a user. PiperOrigin-RevId: 228183639
This commit is contained in:
parent
8d849eb4b9
commit
aae85ddce1
|
@ -55,6 +55,10 @@ public:
|
|||
|
||||
// Operations attribute accessors.
|
||||
struct Attribute {
|
||||
std::string getName() const;
|
||||
StringRef getReturnType() const;
|
||||
StringRef getStorageType() const;
|
||||
|
||||
llvm::StringInit *name;
|
||||
llvm::Record *record;
|
||||
bool isDerived;
|
||||
|
|
|
@ -141,6 +141,23 @@ void Operator::populateOperandsAndAttributes() {
|
|||
}
|
||||
}
|
||||
|
||||
std::string mlir::Operator::Attribute::getName() const {
|
||||
std::string ret = name->getAsUnquotedString();
|
||||
// TODO(jpienaar): Revise this post dialect prefixing attribute discussion.
|
||||
auto split = StringRef(ret).split("__");
|
||||
if (split.second.empty())
|
||||
return ret;
|
||||
return llvm::join_items("$", split.first, split.second);
|
||||
}
|
||||
|
||||
StringRef mlir::Operator::Attribute::getReturnType() const {
|
||||
return record->getValueAsString("returnType").trim();
|
||||
}
|
||||
|
||||
StringRef mlir::Operator::Attribute::getStorageType() const {
|
||||
return record->getValueAsString("storageType").trim();
|
||||
}
|
||||
|
||||
bool mlir::Operator::Operand::hasMatcher() const {
|
||||
llvm::Init *matcher = defInit->getDef()->getValue("predicate")->getValue();
|
||||
return !isa<llvm::UnsetInit>(matcher);
|
||||
|
|
|
@ -34,6 +34,9 @@ using namespace mlir;
|
|||
|
||||
static const char *const generatedArgName = "_arg";
|
||||
|
||||
// Helper macro that returns indented os.
|
||||
#define OUT(X) os.indent((X))
|
||||
|
||||
// TODO(jpienaar): The builder body should probably be separate from the header.
|
||||
|
||||
// Variation of method in FormatVariadic.h which takes a StringRef as input
|
||||
|
@ -164,7 +167,7 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
|
|||
os << "> {\npublic:\n";
|
||||
|
||||
// Build operation name.
|
||||
os << " static StringRef getOperationName() { return \""
|
||||
OUT(2) << "static StringRef getOperationName() { return \""
|
||||
<< emitter.op.getOperationName() << "\"; };\n";
|
||||
|
||||
emitter.emitNamedOperands();
|
||||
|
@ -176,8 +179,8 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) {
|
|||
emitter.emitCanonicalizationPatterns();
|
||||
emitter.emitConstantFolder();
|
||||
|
||||
os << "private:\n friend class ::mlir::OperationInst;\n";
|
||||
os << " explicit " << emitter.op.cppClassName()
|
||||
os << "private:\n friend class ::mlir::OperationInst;\n"
|
||||
<< " explicit " << emitter.op.cppClassName()
|
||||
<< "(const OperationInst* state) : Op(state) {}\n};\n";
|
||||
emitter.mapOverClassNamespaces(
|
||||
[&os](StringRef ns) { os << "} // end namespace " << ns << "\n"; });
|
||||
|
@ -190,20 +193,18 @@ void OpEmitter::emitAttrGetters() {
|
|||
|
||||
// Emit the derived attribute body.
|
||||
if (attr.isDerived) {
|
||||
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
|
||||
<< "() const {" << def->getValueAsString("body") << " }\n";
|
||||
OUT(2) << attr.getReturnType() << ' ' << name << "() const {"
|
||||
<< def->getValueAsString("body") << " }\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Emit normal emitter.
|
||||
os << " " << def->getValueAsString("returnType").trim() << ' ' << name
|
||||
<< "() const {\n";
|
||||
OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
|
||||
|
||||
// Return the queried attribute with the correct return type.
|
||||
std::string attrVal =
|
||||
formatv("this->getAttrOfType<{0}>(\"{1}\")",
|
||||
def->getValueAsString("storageType").trim(), name);
|
||||
os << " return "
|
||||
std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
|
||||
attr.getStorageType(), name);
|
||||
OUT(4) << "return "
|
||||
<< formatv(def->getValueAsString("convertFromStorage"), attrVal)
|
||||
<< ";\n }\n";
|
||||
}
|
||||
|
@ -243,7 +244,7 @@ void OpEmitter::emitBuilder() {
|
|||
// 1. Stand-alone parameters
|
||||
|
||||
std::vector<Record *> returnTypes = def.getValueAsListOfDefs("returnTypes");
|
||||
os << " static void build(Builder* builder, OperationState* result";
|
||||
OUT(2) << "static void build(Builder* builder, OperationState* result";
|
||||
|
||||
// Emit parameters for all return types
|
||||
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
|
||||
|
@ -267,7 +268,7 @@ void OpEmitter::emitBuilder() {
|
|||
|
||||
// Push all result types to the result
|
||||
if (!returnTypes.empty()) {
|
||||
os << " result->addTypes({returnType0";
|
||||
OUT(4) << "result->addTypes({returnType0";
|
||||
for (unsigned i = 1, e = returnTypes.size(); i != e; ++i)
|
||||
os << ", returnType" << i;
|
||||
os << "});\n\n";
|
||||
|
@ -275,7 +276,7 @@ void OpEmitter::emitBuilder() {
|
|||
|
||||
// Push all operands to the result
|
||||
if (op.getNumOperands() > 0) {
|
||||
os << " result->addOperands({" << getArgumentName(op, 0);
|
||||
OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
|
||||
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
|
||||
os << ", " << getArgumentName(op, i);
|
||||
os << "});\n";
|
||||
|
@ -284,33 +285,33 @@ void OpEmitter::emitBuilder() {
|
|||
// Push all attributes to the result
|
||||
for (const auto &attr : op.getAttributes())
|
||||
if (!attr.isDerived)
|
||||
os.indent(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
|
||||
OUT(4) << formatv("result->addAttribute(\"{0}\", {0});\n",
|
||||
getAttributeName(attr));
|
||||
os << " }\n";
|
||||
OUT(2) << "}\n";
|
||||
|
||||
// 2. Aggregated parameters
|
||||
|
||||
// Signature
|
||||
os << " static void build(Builder* builder, OperationState* result, "
|
||||
OUT(2) << "static void build(Builder* builder, OperationState* result, "
|
||||
<< "ArrayRef<Type> resultTypes, ArrayRef<Value*> args, "
|
||||
"ArrayRef<NamedAttribute> attributes) {\n";
|
||||
|
||||
// Result types
|
||||
os << " assert(resultTypes.size() == " << returnTypes.size()
|
||||
OUT(4) << "assert(resultTypes.size() == " << returnTypes.size()
|
||||
<< "u && \"mismatched number of return types\");\n"
|
||||
<< " result->addTypes(resultTypes);\n";
|
||||
|
||||
// Operands
|
||||
os << " assert(args.size() == " << op.getNumOperands()
|
||||
OUT(4) << "assert(args.size() == " << op.getNumOperands()
|
||||
<< "u && \"mismatched number of parameters\");\n"
|
||||
<< " result->addOperands(args);\n\n";
|
||||
|
||||
// Attributes
|
||||
if (op.getNumAttributes() > 0) {
|
||||
os << " assert(!attributes.size() && \"no attributes expected\");\n"
|
||||
OUT(4) << "assert(!attributes.size() && \"no attributes expected\");\n"
|
||||
<< " }\n";
|
||||
} else {
|
||||
os << " assert(attributes.size() >= " << op.getNumAttributes()
|
||||
OUT(4) << "assert(attributes.size() >= " << op.getNumAttributes()
|
||||
<< "u && \"not enough attributes\");\n"
|
||||
<< " for (const auto& pair : attributes)\n"
|
||||
<< " result->addAttribute(pair.first, pair.second);\n"
|
||||
|
@ -321,7 +322,7 @@ void OpEmitter::emitBuilder() {
|
|||
void OpEmitter::emitCanonicalizationPatterns() {
|
||||
if (!def.getValueAsBit("hasCanonicalizationPatterns"))
|
||||
return;
|
||||
os << " static void getCanonicalizationPatterns("
|
||||
OUT(2) << "static void getCanonicalizationPatterns("
|
||||
<< "OwningRewritePatternList &results, MLIRContext* context);\n";
|
||||
}
|
||||
|
||||
|
@ -363,7 +364,7 @@ void OpEmitter::emitVerifier() {
|
|||
if (!hasCustomVerify && op.getNumArgs() == 0)
|
||||
return;
|
||||
|
||||
os << " bool verify() const {\n";
|
||||
OUT(2) << "bool verify() const {\n";
|
||||
// Verify the attributes have the correct type.
|
||||
for (const auto &attr : op.getAttributes()) {
|
||||
if (attr.isDerived)
|
||||
|
@ -371,17 +372,15 @@ void OpEmitter::emitVerifier() {
|
|||
|
||||
auto name = getAttributeName(attr);
|
||||
if (!hasStringAttribute(*attr.record, "storageType")) {
|
||||
os << " if (!this->getAttr(\"" << name
|
||||
OUT(4) << "if (!this->getAttr(\"" << name
|
||||
<< "\")) return emitOpError(\"requires attribute '" << name
|
||||
<< "'\");\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
||||
<< attr.record->getValueAsString("storageType").trim()
|
||||
<< ">()) return emitOpError(\"requires "
|
||||
<< attr.record->getValueAsString("returnType").trim() << " attribute '"
|
||||
<< name << "'\");\n";
|
||||
OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
||||
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
|
||||
<< attr.getReturnType() << " attribute '" << name << "'\");\n";
|
||||
}
|
||||
|
||||
// TODO: Handle variadic.
|
||||
|
@ -392,17 +391,17 @@ void OpEmitter::emitVerifier() {
|
|||
if (operand.hasMatcher()) {
|
||||
auto pred =
|
||||
"if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n";
|
||||
os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" +
|
||||
OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" +
|
||||
Twine(opIndex) + ")->getType()");
|
||||
}
|
||||
++opIndex;
|
||||
}
|
||||
|
||||
if (hasCustomVerify)
|
||||
os << " " << codeInit->getValue() << "\n";
|
||||
OUT(4) << codeInit->getValue() << "\n";
|
||||
else
|
||||
os << " return false;\n";
|
||||
os << " }\n";
|
||||
OUT(4) << "return false;\n";
|
||||
OUT(2) << "}\n";
|
||||
}
|
||||
|
||||
void OpEmitter::emitTraits() {
|
||||
|
|
|
@ -150,6 +150,8 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
"' in pattern and op's definition");
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
auto opArg = op.getArg(i);
|
||||
|
||||
if (auto argTree = dyn_cast<DagInit>(arg)) {
|
||||
os.indent(indent) << "{\n";
|
||||
os.indent(indent + 2) << formatv(
|
||||
|
@ -162,12 +164,11 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
|
||||
// Verify arguments.
|
||||
if (auto defInit = dyn_cast<DefInit>(arg)) {
|
||||
auto opArg = op.getArg(i);
|
||||
// Verify operands.
|
||||
if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
|
||||
// Skip verification where not needed due to definition of op.
|
||||
if (operand->defInit == defInit)
|
||||
goto SkipOperandVerification;
|
||||
goto StateCapture;
|
||||
|
||||
if (!defInit->getDef()->isSubClassOf("Type"))
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
|
@ -185,15 +186,24 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
formatv("op{0}->getOperand({1})->getType()", depth, i))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
}
|
||||
SkipOperandVerification:
|
||||
// TODO(jpienaar): Verify attributes.
|
||||
|
||||
// TODO(jpienaar): Verify attributes.
|
||||
if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) {
|
||||
}
|
||||
}
|
||||
|
||||
StateCapture:
|
||||
auto name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
if (opArg.is<Operator::Operand *>())
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getOperand(" << i << ");\n";
|
||||
if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) {
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getAttrOfType<" << attr->getStorageType()
|
||||
<< ">(\"" << attr->getName() << "\");\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -291,13 +301,18 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
(os << ",\n").indent(6);
|
||||
|
||||
// The argument in the result DAG pattern.
|
||||
auto name = resultOp.getArgName(i);
|
||||
auto name = resultTree->getArgNameStr(i);
|
||||
auto opName = resultOp.getArgName(i);
|
||||
auto defInit = dyn_cast<DefInit>(resultTree->getArg(i));
|
||||
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
||||
if (!value)
|
||||
if (!value) {
|
||||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("attribute '") + name +
|
||||
"' needs to be constant initialized");
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
os << "/*" << opName << "=*/"
|
||||
<< "s." << name;
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
|
|
Loading…
Reference in New Issue