Add attribute matching and transform to pattern rewrites.

Start simple with single predicate match & transform rules for attributes.
* Its unclear whether modelling Attr predicates will be needed so start with allowing matching attributes with a single predicate.
*  The input and output attr type often differs and so add ability to specify a transform between the input and output format.

PiperOrigin-RevId: 229580879
This commit is contained in:
Jacques Pienaar 2019-01-16 10:23:21 -08:00 committed by jpienaar
parent 27d067e164
commit a5827fc91d
6 changed files with 101 additions and 36 deletions

View File

@ -338,6 +338,9 @@ single output.
(e.g., bias add rule only matches the case where both Tensors have F32
elements).
1. Attributes can be transformed by transform rules to produce an attribute
of a type different than the type matched.
TODO: Add constraints on the matching rules.
TODO: Describe the generation of benefit metric given pattern.

View File

@ -224,6 +224,9 @@ class Attr {
// '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to
// 'builder.getStringAttr("foo")'.
code constBuilderCall = ?;
// TODO(jpienaar): Add predicate to verify the validity of Attr so
// that verification can be generated.
}
// A generic attribute that must be constructed around a specific type.
@ -249,7 +252,12 @@ class IntegerAttrBase<BuildableType t> : TypeBasedAttr<t, "IntegerAttr">;
def BoolAttr : Attr {
let storageType = [{ BoolAttr }];
let returnType = [{ bool }];
let constBuilderCall = [{ {0}.getBoolAttr({1})" }];
let constBuilderCall = [{ {0}.getBoolAttr({1}) }];
}
def ArrayAttr : Attr {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayAttr }];
code convertFromStorage = "{0}";
}
def ElementsAttr : Attr {
let storageType = [{ ElementsAttr }];
@ -407,4 +415,21 @@ class Pattern<dag patternToMatch, list<dag> resultOps> {
// Form of a pattern which produces a single result.
class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
// Attribute matcher. This is the base class to specify a predicate
// that has to match. Used on the input attributes of a rewrite rule.
class mAttr<CPred pred> {
// Code to match the attribute.
// Format: {0} represents the attribute.
CPred predicate = pred;
}
// Attribute transforms. This is the base class to specify a
// transformation of a matched attribute. Used on the output of a rewrite
// rule.
class tAttr<code transform> {
// Code to transform the attribute.
// Format: {0} represents the attribute.
code attrTransform = transform;
}
#endif // OP_BASE

View File

@ -40,7 +40,7 @@ namespace tblgen {
class Attribute {
public:
explicit Attribute(const llvm::Record &def);
explicit Attribute(const llvm::Record *def) : Attribute(*def) {}
explicit Attribute(const llvm::Record *def);
explicit Attribute(const llvm::DefInit *init);
// Returns true if this attribute is a derived attribute (i.e., a subclass
@ -79,7 +79,7 @@ public:
private:
// The TableGen definition of this attribute.
const llvm::Record &def;
const llvm::Record *def;
};
} // end namespace tblgen

View File

@ -35,25 +35,27 @@ static StringRef getValueAsString(const llvm::Init *init) {
return {};
}
tblgen::Attribute::Attribute(const llvm::Record &def) : def(def) {
assert(def.isSubClassOf("Attr") &&
tblgen::Attribute::Attribute(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Attr") &&
"must be subclass of TableGen 'Attr' class");
}
tblgen::Attribute::Attribute(const llvm::Record &def) : Attribute(&def) {}
tblgen::Attribute::Attribute(const llvm::DefInit *init)
: Attribute(*init->getDef()) {}
bool tblgen::Attribute::isDerivedAttr() const {
return def.isSubClassOf("DerivedAttr");
return def->isSubClassOf("DerivedAttr");
}
bool tblgen::Attribute::hasStorageType() const {
const auto *init = def.getValueInit("storageType");
const auto *init = def->getValueInit("storageType");
return !getValueAsString(init).empty();
}
StringRef tblgen::Attribute::getStorageType() const {
const auto *init = def.getValueInit("storageType");
const auto *init = def->getValueInit("storageType");
auto type = getValueAsString(init);
if (type.empty())
return "Attribute";
@ -61,30 +63,30 @@ StringRef tblgen::Attribute::getStorageType() const {
}
StringRef tblgen::Attribute::getReturnType() const {
const auto *init = def.getValueInit("returnType");
const auto *init = def->getValueInit("returnType");
return getValueAsString(init);
}
StringRef tblgen::Attribute::getConvertFromStorageCall() const {
const auto *init = def.getValueInit("convertFromStorage");
const auto *init = def->getValueInit("convertFromStorage");
return getValueAsString(init);
}
bool tblgen::Attribute::isConstBuildable() const {
const auto *init = def.getValueInit("constBuilderCall");
const auto *init = def->getValueInit("constBuilderCall");
return !getValueAsString(init).empty();
}
StringRef tblgen::Attribute::getConstBuilderTemplate() const {
const auto *init = def.getValueInit("constBuilderCall");
const auto *init = def->getValueInit("constBuilderCall");
return getValueAsString(init);
}
StringRef tblgen::Attribute::getTableGenDefName() const {
return def.getName();
return def->getName();
}
StringRef tblgen::Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def.getValueAsString("body");
return def->getValueAsString("body");
}

View File

@ -190,7 +190,7 @@ void OpEmitter::emitAttrGetters() {
OUT(2) << attr.getReturnType() << ' ' << name << "() const {\n";
// Return the queried attribute with the correct return type.
std::string attrVal = formatv("this->getAttrOfType<{0}>(\"{1}\")",
std::string attrVal = formatv("this->getAttr(\"{1}\").dyn_cast<{0}>()",
attr.getStorageType(), name);
OUT(4) << "return " << formatv(attr.getConvertFromStorageCall(), attrVal)
<< ";\n }\n";

View File

@ -44,21 +44,19 @@ using mlir::tblgen::Type;
namespace {
// Wrapper around dag argument.
// Wrapper around DAG argument.
struct DagArg {
DagArg(Init *init) : init(init) {}
DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit)
: arg(arg), constraintInit(constraintInit) {}
bool isAttr();
Init *init;
mlir::tblgen::Operator::Argument arg;
Init *constraintInit;
};
} // end namespace
bool DagArg::isAttr() {
if (auto defInit = dyn_cast<DefInit>(init))
return defInit->getDef()->isSubClassOf("Attr");
return false;
}
bool DagArg::isAttr() { return arg.is<Operator::NamedAttribute *>(); }
namespace {
class Pattern {
@ -80,9 +78,18 @@ private:
// Collect bound arguments.
void collectBoundArguments(DagInit *tree);
// Helper function to match patterns.
void matchOp(DagInit *tree, int depth);
// Returns the Operator stored for the given record.
Operator &getOperator(const llvm::Record *record);
// Map from bound argument name to DagArg.
StringMap<DagArg> boundArguments;
// Map from Record* to Operator.
DenseMap<const llvm::Record *, Operator> opMap;
// Number of the operations in the input pattern.
int numberOfOpsMatched = 0;
@ -91,6 +98,11 @@ private:
};
} // end namespace
// Returns the Operator stored for the given record.
auto Pattern::getOperator(const llvm::Record *record) -> Operator & {
return opMap.try_emplace(record, record).first->second;
}
void Pattern::emitAttributeValue(Record *constAttr) {
Attribute attr(constAttr->getValueAsDef("attr"));
auto value = constAttr->getValue("value");
@ -107,6 +119,7 @@ void Pattern::emitAttributeValue(Record *constAttr) {
void Pattern::collectBoundArguments(DagInit *tree) {
++numberOfOpsMatched;
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
// TODO(jpienaar): Expand to multiple matches.
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
auto arg = tree->getArg(i);
@ -117,14 +130,13 @@ void Pattern::collectBoundArguments(DagInit *tree) {
auto name = tree->getArgNameStr(i);
if (name.empty())
continue;
boundArguments.try_emplace(name, arg);
boundArguments.try_emplace(name, op.getArg(i), arg);
}
}
// Helper function to match patterns.
static void matchOp(Record *pattern, DagInit *tree, int depth,
raw_ostream &os) {
Operator op(cast<DefInit>(tree->getOperator())->getDef());
void Pattern::matchOp(DagInit *tree, int depth) {
Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef());
int indent = 4 + 2 * depth;
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
@ -148,7 +160,7 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
os.indent(indent + 2) << formatv(
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
depth + 1, depth, i);
matchOp(pattern, argTree, depth + 1, os);
matchOp(argTree, depth + 1);
os.indent(indent) << "}\n";
continue;
}
@ -174,7 +186,19 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
}
// TODO(jpienaar): Verify attributes.
if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
if (auto *namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) {
// TODO(jpienaar): move to helper class.
if (defInit->getDef()->isSubClassOf("mAttr")) {
auto pred =
tblgen::Pred(defInit->getDef()->getValueInit("predicate"));
os.indent(indent)
<< "if (!("
<< formatv(pred.getCondition().str().c_str(),
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
namedAttr->attr.getStorageType(),
namedAttr->getName()))
<< ")) return matchFailure();\n";
}
}
}
@ -202,7 +226,7 @@ void Pattern::emitMatcher(DagInit *tree) {
if (op0->getNumResults() != 1) return matchFailure();
auto state = std::make_unique<MatchedState>();)"
<< "\n";
matchOp(pattern, tree, 0, os);
matchOp(tree, 0);
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
}
@ -224,9 +248,9 @@ void Pattern::emit(StringRef rewriteName) {
// Emit matched state.
os << " struct MatchedState : public PatternState {\n";
for (auto &arg : boundArguments) {
if (arg.second.isAttr()) {
DefInit *defInit = cast<DefInit>(arg.second.init);
os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first()
if (auto namedAttr =
arg.second.arg.dyn_cast<Operator::NamedAttribute *>()) {
os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first()
<< ";\n";
} else {
os.indent(4) << "Value* " << arg.first() << ";\n";
@ -247,7 +271,7 @@ void Pattern::emit(StringRef rewriteName) {
}
DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
Operator resultOp(*resultRoot->getDef());
Operator &resultOp = getOperator(resultRoot->getDef());
auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
os << formatv(R"(
@ -296,8 +320,19 @@ void Pattern::emit(StringRef rewriteName) {
if (boundArguments.find(name) == boundArguments.end())
PrintFatalError(pattern->getLoc(),
Twine("referencing unbound variable '") + name + "'");
os << "/*" << opName << "=*/"
<< "s." << name;
auto result = "s." + name;
os << "/*" << opName << "=*/";
if (defInit) {
auto transform = defInit->getDef();
if (transform->isSubClassOf("tAttr")) {
// TODO(jpienaar): move to helper class.
os << formatv(
transform->getValueAsString("attrTransform").str().c_str(),
result);
continue;
}
}
os << result;
continue;
}