NFC: Refactoring PatternSymbolResolver into SymbolInfoMap

In declarative rewrite rules, a symbol can be bound to op arguments or
results in the source pattern, and it can be bound to op results in the
result pattern. This means given a symbol in the pattern, it can stands
for different things: op operand, op attribute, single op result,
op result pack. We need a better way to model this complexity so that
we can handle according to the specific kind a symbol corresponds to.

Created SymbolInfo class for maintaining the information regarding a
symbol. Also created a companion SymbolInfoMap class for a map of
such symbols, providing insertion and querying depending on use cases.

PiperOrigin-RevId: 262675515
This commit is contained in:
Lei Zhang 2019-08-09 19:03:58 -07:00 committed by A. Unique TensorFlower
parent 41968fb475
commit ac68637ba9
3 changed files with 429 additions and 310 deletions

View File

@ -180,6 +180,154 @@ private:
const llvm::DagInit *node; // nullptr means null DagNode
};
// A class for maintaining information for symbols bound in patterns and
// provides methods for resolving them according to specific use cases.
//
// Symbols can be bound to
//
// * Op arguments and op results in the source pattern and
// * Op results in result patterns.
//
// Symbols can be referenced in result patterns and additional constraints to
// the pattern.
//
// For example, in
//
// ```
// def : Pattern<
// (SrcOp:$results1 $arg0, %arg1),
// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
// ```
//
// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
//
// If a symbol binds to a multi-result op and it does not have the `__N`
// suffix, the symbol is expanded to represent all results generated by the
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
// only the N-th *static* result as declared in ODS, and that can still
// corresponds to multiple *dynamic* values if the N-th *static* result is
// variadic.
//
// This class keeps track of such symbols and resolves them into their bound
// values in a suitable way.
class SymbolInfoMap {
public:
explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
// Class for information regarding a symbol.
class SymbolInfo {
public:
// Returns a string for defining a variable named as `name` to store the
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;
private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
// What kind of entity this symbol represents:
// * Attr: op attribute
// * Operand: op operand
// * Result: op result
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
enum class Kind : uint8_t { Attr, Operand, Result, Value };
// Creates a SymbolInfo instance. `index` is only used for `Attr` and
// `Operand` so should be negative for `Result` and `Value` kind.
SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
// Static methods for creating SymbolInfo.
static SymbolInfo getAttr(const Operator *op, int index) {
return SymbolInfo(op, Kind::Attr, index);
}
static SymbolInfo getOperand(const Operator *op, int index) {
return SymbolInfo(op, Kind::Operand, index);
}
static SymbolInfo getResult(const Operator *op) {
return SymbolInfo(op, Kind::Result, llvm::None);
}
static SymbolInfo getValue() {
return SymbolInfo(nullptr, Kind::Value, llvm::None);
}
// Returns the number of static values this symbol corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol
// only represents one static value, but symbols bound to op results can
// represent more than one if the op is a multi-result op.
int getStaticValueCount() const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values). `name` is the
// name of the C++ variable that this symbol bounds to. `index` should only
// be used for indexing results.
std::string getValueAndRangeUse(StringRef name, int index) const;
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
};
using BaseT = llvm::StringMap<SymbolInfo>;
// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
iterator begin() { return symbolInfoMap.begin(); }
iterator end() { return symbolInfoMap.end(); }
// Const iterators for accessing all symbols.
using const_iterator = BaseT::const_iterator;
const_iterator begin() const { return symbolInfoMap.begin(); }
const_iterator end() const { return symbolInfoMap.end(); }
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
// Binds the given `symbol` to the results the given `op`. Returns false if
// `symbol` is already bound.
bool bindOpResult(StringRef symbol, const Operator &op);
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
// is already bound.
bool bindValue(StringRef symbol);
// Returns true if the given `symbol` is bound.
bool contains(StringRef symbol) const;
// Returns an interator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;
// Returns the number of static values of the given `symbol` corresponds to.
// A static value is a operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
// more than one if the op is a multi-result op.
int getStaticValueCount(StringRef symbol) const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values).
std::string getValueAndRangeUse(StringRef symbol) const;
// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on sucess. Returns `symbol`
// itself if it does not contain an index.
//
// We can use `name__N` to access the `N`-th value in the value pack bound to
// `name`. `name` is typically the results of an multi-result op.
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
private:
llvm::StringMap<SymbolInfo> symbolInfoMap;
// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.
ArrayRef<llvm::SMLoc> loc;
};
// Wrapper class providing helper methods for accessing MLIR Pattern defined
// in TableGen. This class should closely reflect what is defined as class
// `Pattern` in TableGen. This class contains maps so it is not intended to be
@ -198,24 +346,11 @@ public:
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
// Checks whether an argument or op with the given `name` is bound in
// source pattern. Prints fatal error if not; does nothing otherwise.
void ensureBoundInSourcePattern(StringRef name) const;
// Collects all symbols bound in the source pattern into `infoMap`.
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
// Returns a reference to all the bound arguments in the source pattern.
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
// The returned map contains pointers to the operators inside the
// `RecordOperatorMap` passed-in when constructing this pattern; callers
// should guarantee the lifetime of the returned map does not exceed that
// of the `RecordOperatorMap`.
using SymbolOperatorMap = llvm::StringMap<const Operator *>;
// Returns a reference to all the bound ops in the source pattern.
SymbolOperatorMap &getSourcePatternBoundOps();
// Returns a reference to all the bound ops in the result patterns.
SymbolOperatorMap &getResultPatternBoundOps();
// Collects all symbols bound in result patterns into `infoMap`.
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@ -238,8 +373,8 @@ public:
private:
// Recursively collects all bound symbols inside the DAG tree rooted
// at `tree` and updates the given `symOpMap`.
void collectBoundSymbols(DagNode tree, SymbolOperatorMap &symOpMap,
// at `tree` and updates the given `infoMap`.
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern);
// The TableGen definition of this pattern.
@ -249,15 +384,6 @@ private:
// TODO(antiagainst): we need a proper context manager, like MLIRContext,
// for managing the lifetime of shared entities.
RecordOperatorMap *recordOpMap;
// All source pattern bound op arguments.
llvm::StringMap<Argument> srcBoundArguments;
// All source pattern bound ops.
SymbolOperatorMap srcBoundOps;
// All result pattern bound ops.
SymbolOperatorMap resBoundOps;
};
} // end namespace tblgen

View File

@ -31,6 +31,10 @@ using namespace mlir;
using llvm::formatv;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
// DagLeaf
//===----------------------------------------------------------------------===//
bool tblgen::DagLeaf::isUnspecified() const {
return dyn_cast_or_null<llvm::UnsetInit>(def);
}
@ -88,6 +92,10 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
return false;
}
//===----------------------------------------------------------------------===//
// DagNode
//===----------------------------------------------------------------------===//
bool tblgen::DagNode::isNativeCodeCall() const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
return defInit->getDef()->isSubClassOf("NativeCodeCall");
@ -151,14 +159,158 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue";
}
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {
collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
for (int i = 0, e = getNumResultPatterns(); i < e; ++i)
collectBoundSymbols(getResultPattern(i), resBoundOps,
/*isSrcPattern=*/false);
//===----------------------------------------------------------------------===//
// SymbolInfoMap
//===----------------------------------------------------------------------===//
StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
int *index) {
StringRef name, indexStr;
int idx = -1;
std::tie(name, indexStr) = symbol.rsplit("__");
if (indexStr.consumeInteger(10, idx)) {
// The second part is not an index; we return the whole symbol as-is.
return symbol;
}
if (index) {
*index = idx;
}
return name;
}
tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
SymbolInfo::Kind kind,
Optional<int> index)
: op(op), kind(kind), argIndex(index) {}
int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
switch (kind) {
case Kind::Attr:
case Kind::Operand:
case Kind::Value:
return 1;
case Kind::Result:
return op->getNumResults();
}
}
std::string
tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
switch (kind) {
case Kind::Attr: {
auto type =
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
return formatv("{0} {1};\n", type, name);
}
case Kind::Operand:
case Kind::Value: {
return formatv("Value *{0};\n", name);
}
case Kind::Result: {
// Use the op itself for the results.
return formatv("{0} {1};\n", op->getQualCppClassName(), name);
}
}
}
std::string
tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
int index) const {
switch (kind) {
case Kind::Attr:
case Kind::Operand: {
assert(index < 0 && "only allowed for symbol bound to result");
return name;
}
case Kind::Result: {
// TODO(b/133341698): The following is incorrect for variadic results. We
// should use getODSResults().
if (index >= 0) {
return formatv("{0}.getOperation()->getResult({1})", name, index);
}
// If referencing multiple results, compose a comma-separated list.
SmallVector<std::string, 4> values;
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
}
return llvm::join(values, ", ");
}
case Kind::Value: {
assert(index < 0 && "only allowed for symbol bound to result");
assert(op == nullptr);
return name;
}
}
}
bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
int argIndex) {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
auto error = formatv(
"symbol '{0}' with trailing index cannot bind to op argument", symbol);
PrintFatalError(loc, error);
}
auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);
return symbolInfoMap.insert({symbol, symInfo}).second;
}
bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
StringRef name = getValuePackName(symbol);
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
}
bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
}
bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end();
}
tblgen::SymbolInfoMap::const_iterator
tblgen::SymbolInfoMap::find(StringRef key) const {
StringRef name = getValuePackName(key);
return symbolInfoMap.find(name);
}
int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
// If there is a trailing index inside symbol, it references just one
// static value.
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
return find(name)->getValue().getStaticValueCount();
}
std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
return it->getValue().getValueAndRangeUse(name, index);
}
//===----------------------------------------------------------------------===//
// Pattern
//==----------------------------------------------------------------------===//
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {}
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
}
@ -173,26 +325,17 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
if (srcBoundArguments.find(name) == srcBoundArguments.end() &&
srcBoundOps.find(name) == srcBoundOps.end())
PrintFatalError(def.getLoc(),
Twine("referencing unbound variable '") + name + "'");
void tblgen::Pattern::collectSourcePatternBoundSymbols(
tblgen::SymbolInfoMap &infoMap) {
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
}
llvm::StringMap<tblgen::Argument> &
tblgen::Pattern::getSourcePatternBoundArgs() {
return srcBoundArguments;
}
llvm::StringMap<const tblgen::Operator *> &
tblgen::Pattern::getSourcePatternBoundOps() {
return srcBoundOps;
}
llvm::StringMap<const tblgen::Operator *> &
tblgen::Pattern::getResultPatternBoundOps() {
return resBoundOps;
void tblgen::Pattern::collectResultPatternBoundSymbols(
tblgen::SymbolInfoMap &infoMap) {
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
auto pattern = getResultPattern(i);
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
}
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
@ -251,8 +394,7 @@ tblgen::Pattern::getLocation() const {
return result;
}
void tblgen::Pattern::collectBoundSymbols(DagNode tree,
SymbolOperatorMap &symOpMap,
void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) {
auto treeName = tree.getSymbol();
if (!tree.isOperation()) {
@ -270,27 +412,34 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree,
auto numTreeArgs = tree.getNumArgs();
if (numOpArgs != numTreeArgs) {
PrintFatalError(def.getLoc(),
formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs));
auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs);
PrintFatalError(def.getLoc(), err);
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
if (!treeName.empty())
symOpMap.try_emplace(treeName, &op);
if (!treeName.empty()) {
if (!infoMap.bindOpResult(treeName, op))
PrintFatalError(def.getLoc(),
formatv("symbol '{0}' bound more than once", treeName));
}
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, symOpMap, isSrcPattern);
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
} else if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
if (!treeArgName.empty())
srcBoundArguments.try_emplace(treeArgName, op.getArg(i));
if (!treeArgName.empty()) {
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
PrintFatalError(def.getLoc(), err);
}
}
}
}
}

View File

@ -51,166 +51,6 @@ template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
};
} // end namespace llvm
// Gets the dynamic value pack's name by removing the index suffix from
// `symbol`. Returns `symbol` itself if it does not contain an index.
//
// We can use `name__<index>` to access the `<index>`-th value in the dynamic
// value pack bound to `name`. `name` is typically the results of an
// multi-result op.
static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
StringRef name, indexStr;
unsigned idx = 0;
std::tie(name, indexStr) = symbol.rsplit("__");
if (indexStr.consumeInteger(10, idx)) {
// The second part is not an index.
return symbol;
}
if (index)
*index = idx;
return name;
}
// Formats all values from a dynamic value pack `symbol` according to the given
// `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
// and `{1}` as a placeholder for the value index, which will be offsetted by
// `offset`. The `symbol` value pack has a total of `count` values.
//
// This extracts one value from the pack if `symbol` contains an index,
// otherwise it extracts all values sequentially and returns them as a
// comma-separated list.
static std::string formatValuePack(const char *fmt, StringRef symbol,
unsigned count, unsigned offset) {
auto getNthValue = [fmt, offset](StringRef results,
unsigned index) -> std::string {
return formatv(fmt, results, index + offset);
};
unsigned index = 0;
StringRef name = getValuePackName(symbol, &index);
if (name != symbol) {
// The symbol contains an index.
return getNthValue(name, index);
}
// The symbol does not contain an index. Treat the symbol as a whole.
SmallVector<std::string, 4> values;
values.reserve(count);
for (unsigned i = 0; i < count; ++i)
values.emplace_back(getNthValue(symbol, i));
return llvm::join(values, ", ");
}
//===----------------------------------------------------------------------===//
// PatternSymbolResolver
//===----------------------------------------------------------------------===//
namespace {
// A class for resolving symbols bound in patterns.
//
// Symbols can be bound to op arguments and ops in the source pattern and ops
// in result patterns. For example, in
//
// ```
// def : Pattern<(SrcOp:$op1 $arg0, %arg1),
// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
// ```
//
// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
// `$op2` is bound to `ResOp1`.
//
// If a symbol binds to a multi-result op and it does not have the `__N`
// suffix, the symbol is expanded to the whole value pack generated by the
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
// only the N-th result.
//
// This class keeps track of such symbols and translates them into their bound
// values.
//
// Note that we also generate local variables for unnamed DAG nodes, like
// `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
// generated local variable will be implicitly named. Those implicit names are
// not tracked in this class.
class PatternSymbolResolver {
public:
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringMap<const Operator *> &srcOperations);
// Marks the given `symbol` as bound to a value pack with `numValues` and
// returns true on success. Returns false if the `symbol` is already bound.
bool add(StringRef symbol, int numValues);
// Queries the substitution for the given `symbol`. Returns empty string if
// symbol not found. If the symbol represents a value pack, returns all the
// values separated via comma.
std::string query(StringRef symbol) const;
// Returns how many static values the given `symbol` correspond to. Returns a
// negative value if the given symbol is not bound.
//
// Normally a symbol would correspond to just one value; for symbols bound to
// multi-result ops, it can be more than one.
int getValueCount(StringRef symbol) const;
private:
// Symbols bound to arguments in source pattern.
const StringMap<Argument> &sourceArguments;
// Symbols bound to ops (for their results) in source pattern.
const StringMap<const Operator *> &sourceOps;
// Symbols bound to ops (for their results) in result patterns.
// Key: symbol; value: number of values inside the pack
StringMap<int> resultOps;
};
} // end anonymous namespace
PatternSymbolResolver::PatternSymbolResolver(
const StringMap<Argument> &srcArgs,
const StringMap<const Operator *> &srcOperations)
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
StringRef name = getValuePackName(symbol);
return resultOps.try_emplace(name, numValues).second;
}
std::string PatternSymbolResolver::query(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
// Handle symbols bound to generated ops
auto resOpIt = resultOps.find(name);
if (resOpIt != resultOps.end())
return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
resOpIt->second, /*offset=*/0);
// Handle symbols bound to matched op arguments
auto srcArgIt = sourceArguments.find(symbol);
if (srcArgIt != sourceArguments.end())
return symbol;
// Handle symbols bound to matched op results
auto srcOpIt = sourceOps.find(name);
if (srcOpIt != sourceOps.end())
return formatValuePack("{0}->getResult({1})", symbol,
srcOpIt->second->getNumResults(), /*offset=*/0);
return {};
}
int PatternSymbolResolver::getValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
// Handle symbols bound to generated ops
auto resOpIt = resultOps.find(name);
if (resOpIt != resultOps.end())
return name == symbol ? resOpIt->second : 1;
// Handle symbols bound to matched op arguments
if (sourceArguments.count(symbol))
return 1;
// Handle symbols bound to matched op results
auto srcOpIt = sourceOps.find(name);
if (srcOpIt != sourceOps.end())
return name == symbol ? srcOpIt->second->getNumResults() : 1;
return -1;
}
//===----------------------------------------------------------------------===//
// PatternEmitter
//===----------------------------------------------------------------------===//
@ -286,17 +126,13 @@ private:
// Collects all of the operations within the given dag tree.
void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
// Returns a unique name for a value of the given `op`.
std::string getUniqueValueName(const Operator *op);
// Returns a unique symbol for a local variable of the given `op`.
std::string getUniqueSymbol(const Operator *op);
//===--------------------------------------------------------------------===//
// Symbol utilities
//===--------------------------------------------------------------------===//
// Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
// is already bound.
void addSymbol(StringRef symbol, int numValues);
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
std::string resolveSymbol(StringRef symbol);
@ -308,13 +144,19 @@ private:
// prototypes used. This is intended to be used as a whole to
// PrintFatalError() on errors.
ArrayRef<llvm::SMLoc> loc;
// Op's TableGen Record to wrapper object
// Op's TableGen Record to wrapper object.
RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted
// Handy wrapper for pattern being emitted.
Pattern pattern;
PatternSymbolResolver symbolResolver;
// The next unused ID for newly created values
// Map for all bound symbols' info.
SymbolInfoMap symbolInfoMap;
// The next unused ID for newly created values.
unsigned nextValueId;
raw_ostream &os;
// Format contexts containing placeholder substitutations.
@ -328,9 +170,7 @@ private:
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
symbolResolver(pattern.getSourcePatternBoundArgs(),
pattern.getSourcePatternBoundOps()),
nextValueId(0), os(os) {
symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
fmtCtx.withBuilder("rewriter");
}
@ -354,13 +194,14 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
}
int indent = 4 + 2 * depth;
os.indent(indent) << formatv(
"auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
depth, op.getQualCppClassName());
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
os.indent(indent) << formatv(
"if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
op.getQualCppClassName());
os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
depth);
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@ -372,7 +213,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
os.indent(indent) << formatv("{0} = op{1};\n", name, depth);
os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
@ -381,7 +222,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2)
<< formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
<< formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent + 2)
@ -569,21 +410,17 @@ void PatternEmitter::emit(StringRef rewriteName) {
PatternRewriter &rewriter) const override {
)";
// Register all symbols bound in the source pattern.
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
os.indent(4) << "// Variables for capturing values and attributes used for "
"creating ops\n";
// Create local variables for storing the arguments bound to symbols.
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
auto fieldName = arg.first();
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(),
fieldName);
} else {
os.indent(4) << "Value *" << fieldName << ";\n";
}
}
// Create local variables for storing the ops bound to symbols.
for (const auto &result : pattern.getSourcePatternBoundOps()) {
os.indent(4) << formatv("Operation *{0};\n", result.getKey());
// Create local variables for storing the arguments and results bound
// to symbols.
for (const auto &symbolInfoPair : symbolInfoMap) {
StringRef symbol = symbolInfoPair.getKey();
auto &info = symbolInfoPair.getValue();
os.indent(4) << info.getVarDecl(symbol);
}
// TODO(jpienaar): capture ops with consistent numbering so that it can be
// reused for fused loc.
@ -609,20 +446,22 @@ void PatternEmitter::emitRewriteLogic() {
int numResultPatterns = pattern.getNumResultPatterns();
// First register all symbols bound to ops generated in result patterns.
for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
}
pattern.collectResultPatternBoundSymbols(symbolInfoMap);
// Only the last N static values generated are used to replace the matched
// root N-result op. We need to calculate the starting index (of the results
// of the matched op) each result pattern is to replace.
SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
int replStartIndex = -1;
// If we don't need to replace any value at all, set the replacement starting
// index as the number of result patterns so we skip all of them when trying
// to replace the matched op's results.
int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
for (int i = numResultPatterns - 1; i >= 0; --i) {
auto numValues = getNodeValueCount(pattern.getResultPattern(i));
offsets[i] = offsets[i + 1] - numValues;
if (offsets[i] == 0) {
replStartIndex = i;
if (replStartIndex == -1)
replStartIndex = i;
} else if (offsets[i] < 0 && offsets[i + 1] > 0) {
auto error = formatv(
"cannot use the same multi-result op '{0}' to generate both "
@ -652,31 +491,36 @@ void PatternEmitter::emitRewriteLogic() {
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op0, {";
interleave(
ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
[&](const std::string &name) { os << name; }, [&]() { os << ", "; });
interleaveComma(
ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os,
[&](const std::string &symbol) { os << resolveSymbol(symbol); });
os << "});\n";
}
std::string PatternEmitter::getUniqueValueName(const Operator *op) {
return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
}
std::string PatternEmitter::handleResultPattern(DagNode resultTree,
int resultIndex, int depth) {
if (resultTree.isNativeCodeCall())
return handleReplaceWithNativeCodeCall(resultTree);
if (resultTree.isNativeCodeCall()) {
auto symbol = handleReplaceWithNativeCodeCall(resultTree);
symbolInfoMap.bindValue(symbol);
return symbol;
}
if (resultTree.isReplaceWithValue())
if (resultTree.isReplaceWithValue()) {
return handleReplaceWithValue(resultTree);
}
// Create the op and get the local variable for it.
auto results = handleOpCreation(resultTree, resultIndex, depth);
// We need to get all the values out of this local variable if we've created a
// multi-result op.
const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
return formatValuePack("{0}.getOperation()->getResult({1})", results,
numResults, /*offset=*/0);
// Normal op creation.
auto symbol = handleOpCreation(resultTree, resultIndex, depth);
if (resultTree.getSymbol().empty()) {
// This is an op not explicitly bound to a symbol in the rewrite rule.
// Register the auto-generated symbol for it.
symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
}
return symbol;
}
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
@ -709,7 +553,6 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
std::string val = std::to_string(enumCase.getValue());
return handleConstantAttr(enumCase, val);
}
pattern.ensureBoundInSourcePattern(argName);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return argName;
}
@ -734,27 +577,23 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
attrs[5], attrs[6], attrs[7]);
}
void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
if (!symbolResolver.add(symbol, numValues))
PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
}
std::string PatternEmitter::resolveSymbol(StringRef symbol) {
auto subst = symbolResolver.query(symbol);
if (subst.empty())
auto subst = symbolInfoMap.getValueAndRangeUse(symbol);
if (subst.empty()) {
PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
}
return subst;
}
int PatternEmitter::getNodeValueCount(DagNode node) {
if (node.isOperation()) {
// First to see whether this op is bound and we just want a specific result
// of it with `__N` suffix in symbol.
int count = symbolResolver.getValueCount(node.getSymbol());
if (count >= 0)
return count;
// No symbol. Then we are using all the results.
// If the op is bound to a symbol in the rewrite rule, query its result
// count from the symbol info map.
auto symbol = node.getSymbol();
if (!symbol.empty()) {
return symbolInfoMap.getStaticValueCount(symbol);
}
// Otherwise this is an unbound op; we will use all its results.
return pattern.getDialectOp(node).getNumResults();
}
// TODO(antiagainst): This considers all NativeCodeCall as returning one
@ -799,10 +638,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// Use the specified name for this op if available. Generate one otherwise.
std::string resultValue = tree.getSymbol();
if (resultValue.empty())
resultValue = getUniqueValueName(&resultOp);
resultValue = getUniqueSymbol(&resultOp);
// Strip the index to get the name for the value pack. This will be used to
// name the local variable for the op.
StringRef valuePackName = getValuePackName(resultValue);
StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue);
// Then we build the new op corresponding to this DAG node.
@ -826,20 +665,25 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// here.
// We need to specify the types for all results.
auto resultTypes =
formatValuePack("op0->getResult({1})->getType()", valuePackName,
resultOp.getNumResults(), resultIndex);
SmallVector<std::string, 4> resultTypes;
int numResults = resultOp.getNumResults();
resultTypes.reserve(numResults);
for (int i = 0; i < numResults; ++i) {
resultTypes.push_back(
formatv("op0->getResult({0})->getType()", resultIndex + i));
}
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
valuePackName, resultOp.getQualCppClassName())
<< (resultTypes.empty() ? "" : ", ") << resultTypes;
<< (resultTypes.empty() ? "" : ", ")
<< llvm::join(resultTypes, ", ");
}
// Create the builder call for the result.
// Add operands.
int i = 0;
for (int e = resultOp.getNumOperands(); i < e; ++i) {
const auto &operand = resultOp.getOperand(i);
int argIndex = 0;
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
const auto &operand = resultOp.getOperand(argIndex);
// Start each operand on its own line.
(os << ",\n").indent(6);
@ -847,11 +691,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
if (tree.isNestedDagArg(i)) {
os << childNodeNames[i];
if (tree.isNestedDagArg(argIndex)) {
os << childNodeNames[argIndex];
} else {
DagLeaf leaf = tree.getArgAsLeaf(i);
auto symbol = resolveSymbol(tree.getArgName(i));
DagLeaf leaf = tree.getArgAsLeaf(argIndex);
auto symbol = resolveSymbol(tree.getArgName(argIndex));
if (leaf.isNativeCodeCall()) {
os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
} else {
@ -862,26 +706,26 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
// Add attributes.
for (int e = tree.getNumArgs(); i != e; ++i) {
for (; argIndex != numOpArgs; ++argIndex) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(i);
if (auto subTree = tree.getArgAsNestedDag(i)) {
auto opArgName = resultOp.getArgName(argIndex);
if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
handleReplaceWithNativeCodeCall(subTree));
} else {
auto leaf = tree.getArgAsLeaf(i);
auto leaf = tree.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = tree.getArgName(i);
auto patArgName = tree.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
auto argument = resultOp.getArg(argIndex);
if (!argument.is<NamedAttribute *>())
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
} else {