forked from OSchip/llvm-project
[TableGen] Use tblgen::DagLeaf to model DAG arguments
This CL added a tblgen::DagLeaf wrapper class with several helper methods for handling DAG arguments. It helps to refactor the rewriter generation logic to be more higher level. This CL also added a tblgen::ConstantAttr wrapper class for constant attributes. PiperOrigin-RevId: 232050683
This commit is contained in:
parent
70e3873e86
commit
e0774c008f
|
@ -111,6 +111,24 @@ public:
|
|||
StringRef getDerivedCodeBody() const;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR constant attribute
|
||||
// defined in TableGen. This class should closely reflect what is defined as
|
||||
// class `ConstantAttr` in TableGen.
|
||||
class ConstantAttr {
|
||||
public:
|
||||
explicit ConstantAttr(const llvm::DefInit *init);
|
||||
|
||||
// Returns the attribute kind.
|
||||
Attribute getAttribute() const;
|
||||
|
||||
// Returns the constant value.
|
||||
StringRef getConstantValue() const;
|
||||
|
||||
private:
|
||||
// The TableGen definition of this constant attribute.
|
||||
const llvm::Record *def;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#ifndef MLIR_TABLEGEN_PATTERN_H_
|
||||
#define MLIR_TABLEGEN_PATTERN_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Argument.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -30,9 +31,9 @@
|
|||
#include "llvm/TableGen/Error.h"
|
||||
|
||||
namespace llvm {
|
||||
class Record;
|
||||
class Init;
|
||||
class DagInit;
|
||||
class Init;
|
||||
class Record;
|
||||
class StringRef;
|
||||
} // end namespace llvm
|
||||
|
||||
|
@ -42,20 +43,62 @@ namespace tblgen {
|
|||
// Mapping from TableGen Record to Operator wrapper object
|
||||
using RecordOperatorMap = llvm::DenseMap<const llvm::Record *, Operator>;
|
||||
|
||||
// Wrapper around DAG argument.
|
||||
struct DagArg {
|
||||
DagArg(Argument arg, llvm::Init *constraint)
|
||||
: arg(arg), constraint(constraint) {}
|
||||
|
||||
// Returns true if this DAG argument concerns an operation attribute.
|
||||
bool isAttr() const;
|
||||
|
||||
Argument arg;
|
||||
llvm::Init *constraint;
|
||||
};
|
||||
|
||||
class Pattern;
|
||||
|
||||
// Wrapper class providing helper methods for accessing TableGen DAG leaves
|
||||
// used inside Patterns. This class is lightweight and designed to be used like
|
||||
// values.
|
||||
//
|
||||
// A TableGen DAG construct is of the syntax
|
||||
// `(operator, arg0, arg1, ...)`.
|
||||
//
|
||||
// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
|
||||
// for handy helper methods. It only works on `arg*`s that are not nested DAG
|
||||
// constructs.
|
||||
class DagLeaf {
|
||||
public:
|
||||
explicit DagLeaf(const llvm::Init *def) : def(def) {}
|
||||
|
||||
// Returns true if this DAG leaf is not specified in the pattern. That is, it
|
||||
// places no further constraints/transforms and just carries over the original
|
||||
// value.
|
||||
bool isUnspecified() const;
|
||||
|
||||
// Returns true if this DAG leaf is matching an operand. That is, it specifies
|
||||
// a type constraint.
|
||||
bool isOperandMatcher() const;
|
||||
|
||||
// Returns true if this DAG leaf is matching an attribute. That is, it
|
||||
// specifies an attribute constraint.
|
||||
bool isAttrMatcher() const;
|
||||
|
||||
// Returns true if this DAG leaf is transforming an attribute.
|
||||
bool isAttrTransformer() const;
|
||||
|
||||
// Returns true if this DAG leaf is specifying a constant attribute.
|
||||
bool isConstantAttr() const;
|
||||
|
||||
// Returns this DAG leaf as a type constraint. Asserts if fails.
|
||||
TypeConstraint getAsTypeConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as an attribute constraint. Asserts if fails.
|
||||
AttrConstraint getAsAttrConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as an constant attribute. Asserts if fails.
|
||||
ConstantAttr getAsConstantAttr() const;
|
||||
|
||||
// Returns the matching condition template inside this DAG leaf. Assumes the
|
||||
// leaf is an operand/attribute matcher and asserts otherwise.
|
||||
std::string getConditionTemplate() const;
|
||||
|
||||
// Returns the transformation template inside this DAG leaf. Assumes the
|
||||
// leaf is an attribute matcher and asserts otherwise.
|
||||
std::string getTransformationTemplate() const;
|
||||
|
||||
private:
|
||||
const llvm::Init *def;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing TableGen DAG constructs
|
||||
// used inside Patterns. This class is lightweight and designed to be used like
|
||||
// values.
|
||||
|
@ -96,10 +139,9 @@ public:
|
|||
// Gets the `index`-th argument as a nested DAG construct if possible. Returns
|
||||
// null DagNode otherwise.
|
||||
DagNode getArgAsNestedDag(unsigned index) const;
|
||||
// Gets the `index`-th argument as a TableGen DefInit* if possible. Returns
|
||||
// nullptr otherwise.
|
||||
// TODO: This method is exposing raw TableGen object and should be changed.
|
||||
llvm::DefInit *getArgAsDefInit(unsigned index) const;
|
||||
|
||||
// Gets the `index`-th argument as a DAG leaf.
|
||||
DagLeaf getArgAsLeaf(unsigned index) const;
|
||||
|
||||
// Returns the specified name of the `index`-th argument.
|
||||
llvm::StringRef getArgName(unsigned index) const;
|
||||
|
@ -146,7 +188,7 @@ public:
|
|||
void ensureArgBoundInSourcePattern(llvm::StringRef name) const;
|
||||
|
||||
// Returns a reference to all the bound arguments in the source pattern.
|
||||
llvm::StringMap<DagArg> &getSourcePatternBoundArgs();
|
||||
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
|
||||
|
||||
// Returns the op that the root node of the source pattern matches.
|
||||
const Operator &getSourceRootOp();
|
||||
|
@ -159,8 +201,10 @@ private:
|
|||
// The TableGen definition of this pattern.
|
||||
const llvm::Record &def;
|
||||
|
||||
RecordOperatorMap *recordOpMap; // All operators
|
||||
llvm::StringMap<DagArg> boundArguments; // All bound arguments
|
||||
// All operators
|
||||
RecordOperatorMap *recordOpMap;
|
||||
// All bound arguments
|
||||
llvm::StringMap<Argument> boundArguments;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
|
|
|
@ -42,6 +42,7 @@ public:
|
|||
explicit TypeConstraint(const llvm::DefInit &init);
|
||||
|
||||
bool operator==(const TypeConstraint &that) { return def == that.def; }
|
||||
bool operator!=(const TypeConstraint &that) { return def != that.def; }
|
||||
|
||||
// Returns the predicate that can be used to check if a type satisfies this
|
||||
// type constraint.
|
||||
|
|
|
@ -133,3 +133,17 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
|||
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
||||
return def->getValueAsString("body");
|
||||
}
|
||||
|
||||
tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init)
|
||||
: def(init->getDef()) {
|
||||
assert(def->isSubClassOf("ConstantAttr") &&
|
||||
"must be subclass of TableGen 'ConstantAttr' class");
|
||||
}
|
||||
|
||||
tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
|
||||
return Attribute(def->getValueAsDef("attr"));
|
||||
}
|
||||
|
||||
StringRef tblgen::ConstantAttr::getConstantValue() const {
|
||||
return def->getValueAsString("value");
|
||||
}
|
||||
|
|
|
@ -28,8 +28,66 @@ using namespace mlir;
|
|||
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
bool tblgen::DagArg::isAttr() const {
|
||||
return arg.is<tblgen::NamedAttribute *>();
|
||||
bool tblgen::DagLeaf::isUnspecified() const {
|
||||
return !def || isa<llvm::UnsetInit>(def);
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isOperandMatcher() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
// Operand matchers specify a type constraint.
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isAttrMatcher() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
// Attribute matchers specify a type constraint.
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isAttrTransformer() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isConstantAttr() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
|
||||
}
|
||||
|
||||
tblgen::TypeConstraint tblgen::DagLeaf::getAsTypeConstraint() const {
|
||||
assert(isOperandMatcher() && "the DAG leaf must be operand");
|
||||
return TypeConstraint(*cast<llvm::DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
tblgen::AttrConstraint tblgen::DagLeaf::getAsAttrConstraint() const {
|
||||
assert(isAttrMatcher() && "the DAG leaf must be attribute");
|
||||
return AttrConstraint(cast<llvm::DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
|
||||
assert(isConstantAttr() && "the DAG leaf must be constant attribute");
|
||||
return ConstantAttr(cast<llvm::DefInit>(def));
|
||||
}
|
||||
|
||||
std::string tblgen::DagLeaf::getConditionTemplate() const {
|
||||
assert((isOperandMatcher() || isAttrMatcher()) &&
|
||||
"the DAG leaf must be operand/attribute matcher");
|
||||
if (isOperandMatcher()) {
|
||||
return getAsTypeConstraint().getConditionTemplate();
|
||||
}
|
||||
return getAsAttrConstraint().getConditionTemplate();
|
||||
}
|
||||
|
||||
std::string tblgen::DagLeaf::getTransformationTemplate() const {
|
||||
assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
|
||||
return cast<llvm::DefInit>(def)
|
||||
->getDef()
|
||||
->getValueAsString("attrTransform")
|
||||
.str();
|
||||
}
|
||||
|
||||
Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
|
||||
|
@ -56,8 +114,9 @@ tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
|
|||
return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
|
||||
}
|
||||
|
||||
llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const {
|
||||
return dyn_cast<llvm::DefInit>(node->getArg(index));
|
||||
tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
|
||||
assert(!isNestedDagArg(index));
|
||||
return DagLeaf(node->getArg(index));
|
||||
}
|
||||
|
||||
StringRef tblgen::DagNode::getArgName(unsigned index) const {
|
||||
|
@ -81,7 +140,7 @@ static void collectBoundArguments(const llvm::DagInit *tree,
|
|||
if (name.empty())
|
||||
continue;
|
||||
|
||||
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg);
|
||||
pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -131,7 +190,8 @@ void tblgen::Pattern::ensureArgBoundInSourcePattern(
|
|||
Twine("referencing unbound variable '") + name + "'");
|
||||
}
|
||||
|
||||
llvm::StringMap<tblgen::DagArg> &tblgen::Pattern::getSourcePatternBoundArgs() {
|
||||
llvm::StringMap<tblgen::Argument> &
|
||||
tblgen::Pattern::getSourcePatternBoundArgs() {
|
||||
return boundArguments;
|
||||
}
|
||||
|
||||
|
|
|
@ -3,24 +3,21 @@
|
|||
include "mlir/IR/op_base.td"
|
||||
|
||||
// Create a Type and Attribute.
|
||||
def YT : BuildableType<"buildYT">;
|
||||
def Y_Attr : TypeBasedAttr<YT, "Attribute", "attribute of Y type">;
|
||||
def Y_Const_Attr {
|
||||
Attr attr = Y_Attr;
|
||||
string value = "attrValue";
|
||||
}
|
||||
def T : BuildableType<"buildT">;
|
||||
def T_Attr : TypeBasedAttr<T, "Attribute", "attribute of T type">;
|
||||
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
|
||||
|
||||
// Define ops to rewrite.
|
||||
def T1: Type<CPred<"true">, "T1">;
|
||||
def U: Type<CPred<"true">, "U">;
|
||||
def X_AddOp : Op<"x.add"> {
|
||||
let arguments = (ins T1, T1);
|
||||
let arguments = (ins U, U);
|
||||
}
|
||||
def Y_AddOp : Op<"y.add"> {
|
||||
let arguments = (ins T1, T1, Y_Attr:$attrName);
|
||||
let arguments = (ins U, U, T_Attr:$attrName);
|
||||
}
|
||||
|
||||
// Define rewrite pattern.
|
||||
def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>;
|
||||
def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>;
|
||||
|
||||
// CHECK: struct GeneratedConvert0 : public RewritePattern
|
||||
// CHECK: RewritePattern("x.add", 1, context)
|
||||
|
|
|
@ -39,31 +39,11 @@
|
|||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
using mlir::tblgen::Argument;
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::DagNode;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::Operand;
|
||||
using mlir::tblgen::Operator;
|
||||
using mlir::tblgen::Pattern;
|
||||
using mlir::tblgen::RecordOperatorMap;
|
||||
using mlir::tblgen::Type;
|
||||
|
||||
namespace {
|
||||
|
||||
// Wrapper around DAG argument.
|
||||
struct DagArg {
|
||||
DagArg(Argument arg, Init *constraintInit)
|
||||
: arg(arg), constraintInit(constraintInit) {}
|
||||
bool isAttr();
|
||||
|
||||
Argument arg;
|
||||
Init *constraintInit;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); }
|
||||
|
||||
namespace {
|
||||
class PatternEmitter {
|
||||
|
@ -93,12 +73,19 @@ private:
|
|||
void emitReplaceWithNativeBuilder(DagNode resultTree);
|
||||
|
||||
// Emits the value of constant attribute to `os`.
|
||||
void emitAttributeValue(Record *constAttr);
|
||||
void emitConstantAttr(tblgen::ConstantAttr constAttr);
|
||||
|
||||
// Emits C++ statements for matching the op constrained by the given DAG
|
||||
// `tree`.
|
||||
void emitOpMatch(DagNode tree, int depth);
|
||||
|
||||
// Emits C++ statements for matching the `index`-th argument of the given DAG
|
||||
// `tree` as an operand.
|
||||
void emitOperandMatch(DagNode tree, int index, int depth, int indent);
|
||||
// Emits C++ statements for matching the `index`-th argument of the given DAG
|
||||
// `tree` as an attribute.
|
||||
void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
|
||||
|
||||
private:
|
||||
// Pattern instantiation location followed by the location of multiclass
|
||||
// prototypes used. This is intended to be used as a whole to
|
||||
|
@ -107,14 +94,13 @@ private:
|
|||
// Op's TableGen Record to wrapper object
|
||||
RecordOperatorMap *opMap;
|
||||
// Handy wrapper for pattern being emitted
|
||||
Pattern pattern;
|
||||
tblgen::Pattern pattern;
|
||||
raw_ostream &os;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
void PatternEmitter::emitAttributeValue(Record *constAttr) {
|
||||
Attribute attr(constAttr->getValueAsDef("attr"));
|
||||
auto value = constAttr->getValue("value");
|
||||
void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) {
|
||||
auto attr = constAttr.getAttribute();
|
||||
|
||||
if (!attr.isConstBuildable())
|
||||
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
|
||||
|
@ -122,7 +108,7 @@ void PatternEmitter::emitAttributeValue(Record *constAttr) {
|
|||
|
||||
// TODO(jpienaar): Verify the constants here
|
||||
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||
value->getValue()->getAsUnquotedString());
|
||||
constAttr.getConstantValue());
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
|
@ -137,13 +123,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
|
||||
op.getQualCppClassName());
|
||||
}
|
||||
if (tree.getNumArgs() != op.getNumArgs())
|
||||
PrintFatalError(loc, Twine("mismatch in number of arguments to op '") +
|
||||
op.getOperationName() +
|
||||
"' in pattern and op's definition");
|
||||
if (tree.getNumArgs() != op.getNumArgs()) {
|
||||
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
||||
"pattern vs. {2} in definition",
|
||||
op.getOperationName(), tree.getNumArgs(),
|
||||
op.getNumArgs()));
|
||||
}
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto opArg = op.getArg(i);
|
||||
|
||||
// Handle nested DAG construct first
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
os.indent(indent) << "{\n";
|
||||
os.indent(indent + 2) << formatv(
|
||||
|
@ -154,51 +144,79 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Verify arguments.
|
||||
if (auto defInit = tree.getArgAsDefInit(i)) {
|
||||
// Verify operands.
|
||||
if (auto *operand = opArg.dyn_cast<Operand *>()) {
|
||||
// Skip verification where not needed due to definition of op.
|
||||
if (operand->type == Type(defInit))
|
||||
goto StateCapture;
|
||||
// Next handle DAG leaf: operand or attribute
|
||||
if (auto *operand = opArg.dyn_cast<Operand *>()) {
|
||||
emitOperandMatch(tree, i, depth, indent);
|
||||
} else if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
|
||||
emitAttributeMatch(tree, i, depth, indent);
|
||||
} else {
|
||||
PrintFatalError(loc, "unhandled case when matching op");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!defInit->getDef()->isSubClassOf("Type"))
|
||||
PrintFatalError(loc, "type argument required for operand");
|
||||
void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
||||
int indent) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *operand = op.getArg(index).get<Operand *>();
|
||||
auto matcher = tree.getArgAsLeaf(index);
|
||||
|
||||
auto constraint = tblgen::TypeConstraint(*defInit);
|
||||
os.indent(indent)
|
||||
<< "if (!("
|
||||
<< formatv(constraint.getConditionTemplate().c_str(),
|
||||
formatv("op{0}->getOperand({1})->getType()", depth, i))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Verify attributes.
|
||||
if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
|
||||
auto constraint = tblgen::AttrConstraint(defInit);
|
||||
std::string condition = formatv(
|
||||
constraint.getConditionTemplate().c_str(),
|
||||
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
|
||||
namedAttr->attr.getStorageType(), namedAttr->getName()));
|
||||
os.indent(indent) << "if (!(" << condition
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
// If a constraint is specified, we need to generate C++ statements to
|
||||
// check the constraint.
|
||||
if (!matcher.isUnspecified()) {
|
||||
if (!matcher.isOperandMatcher()) {
|
||||
PrintFatalError(
|
||||
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
|
||||
op.getOperationName(), index + 1));
|
||||
}
|
||||
|
||||
StateCapture:
|
||||
auto name = tree.getArgName(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
if (opArg.is<Operand *>())
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getOperand(" << i << ");\n";
|
||||
if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getAttrOfType<"
|
||||
<< namedAttr->attr.getStorageType() << ">(\""
|
||||
<< namedAttr->getName() << "\");\n";
|
||||
// Only need to verify if the matcher's type is different from the one
|
||||
// of op definition.
|
||||
if (static_cast<tblgen::TypeConstraint>(operand->type) !=
|
||||
matcher.getAsTypeConstraint()) {
|
||||
os.indent(indent) << "if (!("
|
||||
<< formatv(matcher.getConditionTemplate().c_str(),
|
||||
formatv("op{0}->getOperand({1})->getType()",
|
||||
depth, index))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Capture the value
|
||||
auto name = tree.getArgName(index);
|
||||
if (!name.empty()) {
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getOperand(" << index << ");\n";
|
||||
}
|
||||
}
|
||||
|
||||
void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
||||
int indent) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
|
||||
auto matcher = tree.getArgAsLeaf(index);
|
||||
|
||||
if (!matcher.isUnspecified() && !matcher.isAttrMatcher()) {
|
||||
PrintFatalError(
|
||||
loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
|
||||
op.getOperationName(), index + 1));
|
||||
}
|
||||
|
||||
// If a constraint is specified, we need to generate C++ statements to
|
||||
// check the constraint.
|
||||
std::string condition =
|
||||
formatv(matcher.getConditionTemplate().c_str(),
|
||||
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
|
||||
namedAttr->attr.getStorageType(), namedAttr->getName()));
|
||||
os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n";
|
||||
|
||||
// Capture the value
|
||||
auto name = tree.getArgName(index);
|
||||
if (!name.empty()) {
|
||||
os.indent(indent) << "state->" << name << " = op" << depth
|
||||
<< "->getAttrOfType<" << namedAttr->attr.getStorageType()
|
||||
<< ">(\"" << namedAttr->getName() << "\");\n";
|
||||
}
|
||||
}
|
||||
|
||||
void PatternEmitter::emitMatchMethod(DagNode tree) {
|
||||
|
@ -234,11 +252,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
// Emit matched state.
|
||||
os << " struct MatchedState : public PatternState {\n";
|
||||
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
|
||||
if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) {
|
||||
os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
|
||||
auto fieldName = arg.first();
|
||||
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
|
||||
os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName
|
||||
<< ";\n";
|
||||
} else {
|
||||
os.indent(4) << "Value* " << arg.first() << ";\n";
|
||||
os.indent(4) << "Value* " << fieldName << ";\n";
|
||||
}
|
||||
}
|
||||
os << " };\n";
|
||||
|
@ -285,10 +304,10 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
|
|||
rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
|
||||
resultOp.getCppClassName());
|
||||
if (numOpArgs != resultTree.getNumArgs()) {
|
||||
PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") +
|
||||
Twine(numOpArgs) +
|
||||
") and arguments provided for rewrite (" +
|
||||
Twine(resultTree.getNumArgs()) + Twine(')'));
|
||||
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
|
||||
"{1} in pattern vs. {2} in definition",
|
||||
resultOp.getOperationName(),
|
||||
resultTree.getNumArgs(), numOpArgs));
|
||||
}
|
||||
|
||||
// Create the builder call for the result.
|
||||
|
@ -312,38 +331,33 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) {
|
|||
// Start each attribute on its own line.
|
||||
(os << ",\n").indent(6);
|
||||
|
||||
auto leaf = resultTree.getArgAsLeaf(i);
|
||||
// The argument in the result DAG pattern.
|
||||
auto argName = resultTree.getArgName(i);
|
||||
auto opName = resultOp.getArgName(i);
|
||||
auto *defInit = resultTree.getArgAsDefInit(i);
|
||||
auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr;
|
||||
if (!value) {
|
||||
pattern.ensureArgBoundInSourcePattern(argName);
|
||||
auto result = "s." + argName;
|
||||
os << "/*" << opName << "=*/";
|
||||
if (defInit) {
|
||||
auto transform = defInit->getDef();
|
||||
if (transform->isSubClassOf("tAttr")) {
|
||||
// TODO(jpienaar): move to helper class.
|
||||
os << formatv(
|
||||
transform->getValueAsString("attrTransform").str().c_str(),
|
||||
result);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
os << result;
|
||||
continue;
|
||||
auto patArgName = resultTree.getArgName(i);
|
||||
// The argument in the op definition.
|
||||
auto opArgName = resultOp.getArgName(i);
|
||||
|
||||
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
||||
pattern.ensureArgBoundInSourcePattern(patArgName);
|
||||
os << formatv("/*{0}=*/s.{1}", opArgName, patArgName);
|
||||
} else if (leaf.isAttrTransformer()) {
|
||||
pattern.ensureArgBoundInSourcePattern(patArgName);
|
||||
std::string result = std::string("s.") + patArgName.str();
|
||||
result = formatv(leaf.getTransformationTemplate().c_str(), result);
|
||||
os << formatv("/*{0}=*/{1}", opArgName, result);
|
||||
} else if (leaf.isConstantAttr()) {
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
||||
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
emitConstantAttr(leaf.getAsConstantAttr());
|
||||
// TODO(jpienaar): verify types
|
||||
} else {
|
||||
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
||||
|
||||
if (!argName.empty())
|
||||
os << "/*" << argName << "=*/";
|
||||
emitAttributeValue(defInit->getDef());
|
||||
// TODO(jpienaar): verify types
|
||||
}
|
||||
os << "\n );\n";
|
||||
}
|
||||
|
@ -367,7 +381,7 @@ void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
|
|||
auto name = resultTree.getArgName(i);
|
||||
pattern.ensureArgBoundInSourcePattern(name);
|
||||
const auto &val = boundedValues.find(name);
|
||||
if (val->second.isAttr() && !printingAttr) {
|
||||
if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
|
||||
os << "}, {";
|
||||
first = true;
|
||||
printingAttr = true;
|
||||
|
|
Loading…
Reference in New Issue