forked from OSchip/llvm-project
Add attribute matching and transform to pattern rewrites.
Start simple with single predicate match & transform rules for attributes. * Its unclear whether modelling Attr predicates will be needed so start with allowing matching attributes with a single predicate. * The input and output attr type often differs and so add ability to specify a transform between the input and output format. PiperOrigin-RevId: 229580879
This commit is contained in:
parent
27d067e164
commit
a5827fc91d
|
@ -338,6 +338,9 @@ single output.
|
|||
(e.g., bias add rule only matches the case where both Tensors have F32
|
||||
elements).
|
||||
|
||||
1. Attributes can be transformed by transform rules to produce an attribute
|
||||
of a type different than the type matched.
|
||||
|
||||
TODO: Add constraints on the matching rules.
|
||||
|
||||
TODO: Describe the generation of benefit metric given pattern.
|
||||
|
|
|
@ -224,6 +224,9 @@ class Attr {
|
|||
// '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to
|
||||
// 'builder.getStringAttr("foo")'.
|
||||
code constBuilderCall = ?;
|
||||
|
||||
// TODO(jpienaar): Add predicate to verify the validity of Attr so
|
||||
// that verification can be generated.
|
||||
}
|
||||
|
||||
// A generic attribute that must be constructed around a specific type.
|
||||
|
@ -249,7 +252,12 @@ class IntegerAttrBase<BuildableType t> : TypeBasedAttr<t, "IntegerAttr">;
|
|||
def BoolAttr : Attr {
|
||||
let storageType = [{ BoolAttr }];
|
||||
let returnType = [{ bool }];
|
||||
let constBuilderCall = [{ {0}.getBoolAttr({1})" }];
|
||||
let constBuilderCall = [{ {0}.getBoolAttr({1}) }];
|
||||
}
|
||||
def ArrayAttr : Attr {
|
||||
let storageType = [{ ArrayAttr }];
|
||||
let returnType = [{ ArrayAttr }];
|
||||
code convertFromStorage = "{0}";
|
||||
}
|
||||
def ElementsAttr : Attr {
|
||||
let storageType = [{ ElementsAttr }];
|
||||
|
@ -407,4 +415,21 @@ class Pattern<dag patternToMatch, list<dag> resultOps> {
|
|||
// Form of a pattern which produces a single result.
|
||||
class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
|
||||
|
||||
// Attribute matcher. This is the base class to specify a predicate
|
||||
// that has to match. Used on the input attributes of a rewrite rule.
|
||||
class mAttr<CPred pred> {
|
||||
// Code to match the attribute.
|
||||
// Format: {0} represents the attribute.
|
||||
CPred predicate = pred;
|
||||
}
|
||||
|
||||
// Attribute transforms. This is the base class to specify a
|
||||
// transformation of a matched attribute. Used on the output of a rewrite
|
||||
// rule.
|
||||
class tAttr<code transform> {
|
||||
// Code to transform the attribute.
|
||||
// Format: {0} represents the attribute.
|
||||
code attrTransform = transform;
|
||||
}
|
||||
|
||||
#endif // OP_BASE
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace tblgen {
|
|||
class Attribute {
|
||||
public:
|
||||
explicit Attribute(const llvm::Record &def);
|
||||
explicit Attribute(const llvm::Record *def) : Attribute(*def) {}
|
||||
explicit Attribute(const llvm::Record *def);
|
||||
explicit Attribute(const llvm::DefInit *init);
|
||||
|
||||
// Returns true if this attribute is a derived attribute (i.e., a subclass
|
||||
|
@ -79,7 +79,7 @@ public:
|
|||
|
||||
private:
|
||||
// The TableGen definition of this attribute.
|
||||
const llvm::Record &def;
|
||||
const llvm::Record *def;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
|
|
|
@ -35,25 +35,27 @@ static StringRef getValueAsString(const llvm::Init *init) {
|
|||
return {};
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::Record &def) : def(def) {
|
||||
assert(def.isSubClassOf("Attr") &&
|
||||
tblgen::Attribute::Attribute(const llvm::Record *def) : def(def) {
|
||||
assert(def->isSubClassOf("Attr") &&
|
||||
"must be subclass of TableGen 'Attr' class");
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::Record &def) : Attribute(&def) {}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::DefInit *init)
|
||||
: Attribute(*init->getDef()) {}
|
||||
|
||||
bool tblgen::Attribute::isDerivedAttr() const {
|
||||
return def.isSubClassOf("DerivedAttr");
|
||||
return def->isSubClassOf("DerivedAttr");
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::hasStorageType() const {
|
||||
const auto *init = def.getValueInit("storageType");
|
||||
const auto *init = def->getValueInit("storageType");
|
||||
return !getValueAsString(init).empty();
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getStorageType() const {
|
||||
const auto *init = def.getValueInit("storageType");
|
||||
const auto *init = def->getValueInit("storageType");
|
||||
auto type = getValueAsString(init);
|
||||
if (type.empty())
|
||||
return "Attribute";
|
||||
|
@ -61,30 +63,30 @@ StringRef tblgen::Attribute::getStorageType() const {
|
|||
}
|
||||
|
||||
StringRef tblgen::Attribute::getReturnType() const {
|
||||
const auto *init = def.getValueInit("returnType");
|
||||
const auto *init = def->getValueInit("returnType");
|
||||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getConvertFromStorageCall() const {
|
||||
const auto *init = def.getValueInit("convertFromStorage");
|
||||
const auto *init = def->getValueInit("convertFromStorage");
|
||||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::isConstBuildable() const {
|
||||
const auto *init = def.getValueInit("constBuilderCall");
|
||||
const auto *init = def->getValueInit("constBuilderCall");
|
||||
return !getValueAsString(init).empty();
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getConstBuilderTemplate() const {
|
||||
const auto *init = def.getValueInit("constBuilderCall");
|
||||
const auto *init = def->getValueInit("constBuilderCall");
|
||||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getTableGenDefName() const {
|
||||
return def.getName();
|
||||
return def->getName();
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
||||
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
||||
return def.getValueAsString("body");
|
||||
return def->getValueAsString("body");
|
||||
}
|
||||
|
|
|
@ -190,7 +190,7 @@ void OpEmitter::emitAttrGetters() {
|
|||
OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
|
||||
|
||||
// Return the queried attribute with the correct return type.
|
||||
std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
|
||||
std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()",
|
||||
attr.getStorageType(), name);
|
||||
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal)
|
||||
<< ";\n }\n";
|
||||
|
|
|
@ -44,21 +44,19 @@ using mlir::tblgen::Type;
|
|||
|
||||
namespace {
|
||||
|
||||
// Wrapper around dag argument.
|
||||
// Wrapper around DAG argument.
|
||||
struct DagArg {
|
||||
DagArg(Init *init) : init(init) {}
|
||||
DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit)
|
||||
: arg(arg), constraintInit(constraintInit) {}
|
||||
bool isAttr();
|
||||
|
||||
Init *init;
|
||||
mlir::tblgen::Operator::Argument arg;
|
||||
Init *constraintInit;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
bool DagArg::isAttr() {
|
||||
if (auto defInit = dyn_cast<DefInit>(init))
|
||||
return defInit->getDef()->isSubClassOf("Attr");
|
||||
return false;
|
||||
}
|
||||
bool DagArg::isAttr() { return arg.is<Operator::NamedAttribute *>(); }
|
||||
|
||||
namespace {
|
||||
class Pattern {
|
||||
|
@ -80,9 +78,18 @@ private:
|
|||
// Collect bound arguments.
|
||||
void collectBoundArguments(DagInit *tree);
|
||||
|
||||
// Helper function to match patterns.
|
||||
void matchOp(DagInit *tree, int depth);
|
||||
|
||||
// Returns the Operator stored for the given record.
|
||||
Operator &getOperator(const llvm::Record *record);
|
||||
|
||||
// Map from bound argument name to DagArg.
|
||||
StringMap<DagArg> boundArguments;
|
||||
|
||||
// Map from Record* to Operator.
|
||||
DenseMap<const llvm::Record *, Operator> opMap;
|
||||
|
||||
// Number of the operations in the input pattern.
|
||||
int numberOfOpsMatched = 0;
|
||||
|
||||
|
@ -91,6 +98,11 @@ private:
|
|||
};
|
||||
} // end namespace
|
||||
|
||||
// Returns the Operator stored for the given record.
|
||||
auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
|
||||
return opMap.try_emplace(record, record).first->second;
|
||||
}
|
||||
|
||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||
Attribute attr(constAttr->getValueAsDef("attr"));
|
||||
auto value = constAttr->getValue("value");
|
||||
|
@ -107,6 +119,7 @@ void Pattern::emitAttributeValue(Record *constAttr) {
|
|||
|
||||
void Pattern::collectBoundArguments(DagInit *tree) {
|
||||
++numberOfOpsMatched;
|
||||
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
|
||||
// TODO(jpienaar): Expand to multiple matches.
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
|
@ -117,14 +130,13 @@ void Pattern::collectBoundArguments(DagInit *tree) {
|
|||
auto name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
boundArguments.try_emplace(name, arg);
|
||||
boundArguments.try_emplace(name, op.getArg(i), arg);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
static void matchOp(Record *pattern, DagInit *tree, int depth,
|
||||
raw_ostream &os) {
|
||||
Operator op(cast<DefInit>(tree->getOperator())->getDef());
|
||||
void Pattern::matchOp(DagInit *tree, int depth) {
|
||||
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
|
||||
int indent = 4 + 2 * depth;
|
||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
|
@ -148,7 +160,7 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
os.indent(indent + 2) << formatv(
|
||||
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
|
||||
depth + 1, depth, i);
|
||||
matchOp(pattern, argTree, depth + 1, os);
|
||||
matchOp(argTree, depth + 1);
|
||||
os.indent(indent) << "}\n";
|
||||
continue;
|
||||
}
|
||||
|
@ -174,7 +186,19 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
|
|||
}
|
||||
|
||||
// TODO(jpienaar): Verify attributes.
|
||||
if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
|
||||
if (auto *namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
|
||||
// TODO(jpienaar): move to helper class.
|
||||
if (defInit->getDef()->isSubClassOf("mAttr")) {
|
||||
auto pred =
|
||||
tblgen::Pred(defInit->getDef()->getValueInit("predicate"));
|
||||
os.indent(indent)
|
||||
<< "if (!("
|
||||
<< formatv(pred.getCondition().str().c_str(),
|
||||
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
|
||||
namedAttr->attr.getStorageType(),
|
||||
namedAttr->getName()))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -202,7 +226,7 @@ void Pattern::emitMatcher(DagInit *tree) {
|
|||
if (op0->getNumResults() != 1) return matchFailure();
|
||||
auto state = std::make_unique<MatchedState>();)"
|
||||
<< "\n";
|
||||
matchOp(pattern, tree, 0, os);
|
||||
matchOp(tree, 0);
|
||||
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
||||
}
|
||||
|
||||
|
@ -224,9 +248,9 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
// Emit matched state.
|
||||
os << " struct MatchedState : public PatternState {\n";
|
||||
for (auto &arg : boundArguments) {
|
||||
if (arg.second.isAttr()) {
|
||||
DefInit *defInit = cast<DefInit>(arg.second.init);
|
||||
os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first()
|
||||
if (auto namedAttr =
|
||||
arg.second.arg.dyn_cast<Operator::NamedAttribute *>()) {
|
||||
os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
|
||||
<< ";\n";
|
||||
} else {
|
||||
os.indent(4) << "Value* " << arg.first() << ";\n";
|
||||
|
@ -247,7 +271,7 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
}
|
||||
|
||||
DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
|
||||
Operator resultOp(*resultRoot->getDef());
|
||||
Operator &resultOp = getOperator(resultRoot->getDef());
|
||||
auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
|
||||
|
||||
os << formatv(R"(
|
||||
|
@ -296,8 +320,19 @@ void Pattern::emit(StringRef rewriteName) {
|
|||
if (boundArguments.find(name) == boundArguments.end())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
os << "/*" << opName << "=*/"
|
||||
<< "s." << name;
|
||||
auto result = "s." + name;
|
||||
os << "/*" << opName << "=*/";
|
||||
if (defInit) {
|
||||
auto transform = defInit->getDef();
|
||||
if (transform->isSubClassOf("tAttr")) {
|
||||
// TODO(jpienaar): move to helper class.
|
||||
os << formatv(
|
||||
transform->getValueAsString("attrTransform").str().c_str(),
|
||||
result);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
os << result;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue