forked from OSchip/llvm-project
NFC: Refactoring PatternSymbolResolver into SymbolInfoMap
In declarative rewrite rules, a symbol can be bound to op arguments or results in the source pattern, and it can be bound to op results in the result pattern. This means given a symbol in the pattern, it can stands for different things: op operand, op attribute, single op result, op result pack. We need a better way to model this complexity so that we can handle according to the specific kind a symbol corresponds to. Created SymbolInfo class for maintaining the information regarding a symbol. Also created a companion SymbolInfoMap class for a map of such symbols, providing insertion and querying depending on use cases. PiperOrigin-RevId: 262675515
This commit is contained in:
parent
41968fb475
commit
ac68637ba9
|
@ -180,6 +180,154 @@ private:
|
|||
const llvm::DagInit *node; // nullptr means null DagNode
|
||||
};
|
||||
|
||||
// A class for maintaining information for symbols bound in patterns and
|
||||
// provides methods for resolving them according to specific use cases.
|
||||
//
|
||||
// Symbols can be bound to
|
||||
//
|
||||
// * Op arguments and op results in the source pattern and
|
||||
// * Op results in result patterns.
|
||||
//
|
||||
// Symbols can be referenced in result patterns and additional constraints to
|
||||
// the pattern.
|
||||
//
|
||||
// For example, in
|
||||
//
|
||||
// ```
|
||||
// def : Pattern<
|
||||
// (SrcOp:$results1 $arg0, %arg1),
|
||||
// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
|
||||
// ```
|
||||
//
|
||||
// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
|
||||
// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
|
||||
// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
|
||||
//
|
||||
// If a symbol binds to a multi-result op and it does not have the `__N`
|
||||
// suffix, the symbol is expanded to represent all results generated by the
|
||||
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
|
||||
// only the N-th *static* result as declared in ODS, and that can still
|
||||
// corresponds to multiple *dynamic* values if the N-th *static* result is
|
||||
// variadic.
|
||||
//
|
||||
// This class keeps track of such symbols and resolves them into their bound
|
||||
// values in a suitable way.
|
||||
class SymbolInfoMap {
|
||||
public:
|
||||
explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
|
||||
|
||||
// Class for information regarding a symbol.
|
||||
class SymbolInfo {
|
||||
public:
|
||||
// Returns a string for defining a variable named as `name` to store the
|
||||
// value bound by this symbol.
|
||||
std::string getVarDecl(StringRef name) const;
|
||||
|
||||
private:
|
||||
// Allow SymbolInfoMap to access private methods.
|
||||
friend class SymbolInfoMap;
|
||||
|
||||
// What kind of entity this symbol represents:
|
||||
// * Attr: op attribute
|
||||
// * Operand: op operand
|
||||
// * Result: op result
|
||||
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
|
||||
enum class Kind : uint8_t { Attr, Operand, Result, Value };
|
||||
|
||||
// Creates a SymbolInfo instance. `index` is only used for `Attr` and
|
||||
// `Operand` so should be negative for `Result` and `Value` kind.
|
||||
SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
|
||||
|
||||
// Static methods for creating SymbolInfo.
|
||||
static SymbolInfo getAttr(const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Attr, index);
|
||||
}
|
||||
static SymbolInfo getOperand(const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Operand, index);
|
||||
}
|
||||
static SymbolInfo getResult(const Operator *op) {
|
||||
return SymbolInfo(op, Kind::Result, llvm::None);
|
||||
}
|
||||
static SymbolInfo getValue() {
|
||||
return SymbolInfo(nullptr, Kind::Value, llvm::None);
|
||||
}
|
||||
|
||||
// Returns the number of static values this symbol corresponds to.
|
||||
// A static value is an operand/result declared in ODS. Normally a symbol
|
||||
// only represents one static value, but symbols bound to op results can
|
||||
// represent more than one if the op is a multi-result op.
|
||||
int getStaticValueCount() const;
|
||||
|
||||
// Returns a string containing the C++ expression for referencing this
|
||||
// symbol as a value (if this symbol represents one static value) or a value
|
||||
// range (if this symbol represents multiple static values). `name` is the
|
||||
// name of the C++ variable that this symbol bounds to. `index` should only
|
||||
// be used for indexing results.
|
||||
std::string getValueAndRangeUse(StringRef name, int index) const;
|
||||
|
||||
const Operator *op; // The op where the bound entity belongs
|
||||
Kind kind; // The kind of the bound entity
|
||||
// The argument index (for `Attr` and `Operand` only)
|
||||
Optional<int> argIndex;
|
||||
};
|
||||
|
||||
using BaseT = llvm::StringMap<SymbolInfo>;
|
||||
|
||||
// Iterators for accessing all symbols.
|
||||
using iterator = BaseT::iterator;
|
||||
iterator begin() { return symbolInfoMap.begin(); }
|
||||
iterator end() { return symbolInfoMap.end(); }
|
||||
|
||||
// Const iterators for accessing all symbols.
|
||||
using const_iterator = BaseT::const_iterator;
|
||||
const_iterator begin() const { return symbolInfoMap.begin(); }
|
||||
const_iterator end() const { return symbolInfoMap.end(); }
|
||||
|
||||
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
|
||||
// Returns false if `symbol` is already bound.
|
||||
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
|
||||
|
||||
// Binds the given `symbol` to the results the given `op`. Returns false if
|
||||
// `symbol` is already bound.
|
||||
bool bindOpResult(StringRef symbol, const Operator &op);
|
||||
|
||||
// Registers the given `symbol` as bound to a value. Returns false if `symbol`
|
||||
// is already bound.
|
||||
bool bindValue(StringRef symbol);
|
||||
|
||||
// Returns true if the given `symbol` is bound.
|
||||
bool contains(StringRef symbol) const;
|
||||
|
||||
// Returns an interator to the information of the given symbol named as `key`.
|
||||
const_iterator find(StringRef key) const;
|
||||
|
||||
// Returns the number of static values of the given `symbol` corresponds to.
|
||||
// A static value is a operand/result declared in ODS. Normally a symbol only
|
||||
// represents one static value, but symbols bound to op results can represent
|
||||
// more than one if the op is a multi-result op.
|
||||
int getStaticValueCount(StringRef symbol) const;
|
||||
|
||||
// Returns a string containing the C++ expression for referencing this
|
||||
// symbol as a value (if this symbol represents one static value) or a value
|
||||
// range (if this symbol represents multiple static values).
|
||||
std::string getValueAndRangeUse(StringRef symbol) const;
|
||||
|
||||
// Splits the given `symbol` into a value pack name and an index. Returns the
|
||||
// value pack name and writes the index to `index` on sucess. Returns `symbol`
|
||||
// itself if it does not contain an index.
|
||||
//
|
||||
// We can use `name__N` to access the `N`-th value in the value pack bound to
|
||||
// `name`. `name` is typically the results of an multi-result op.
|
||||
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
|
||||
|
||||
private:
|
||||
llvm::StringMap<SymbolInfo> symbolInfoMap;
|
||||
|
||||
// Pattern instantiation location. This is intended to be used as parameter
|
||||
// to PrintFatalError() to report errors.
|
||||
ArrayRef<llvm::SMLoc> loc;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR Pattern defined
|
||||
// in TableGen. This class should closely reflect what is defined as class
|
||||
// `Pattern` in TableGen. This class contains maps so it is not intended to be
|
||||
|
@ -198,24 +346,11 @@ public:
|
|||
// Returns the DAG tree root node of the `index`-th result pattern.
|
||||
DagNode getResultPattern(unsigned index) const;
|
||||
|
||||
// Checks whether an argument or op with the given `name` is bound in
|
||||
// source pattern. Prints fatal error if not; does nothing otherwise.
|
||||
void ensureBoundInSourcePattern(StringRef name) const;
|
||||
// Collects all symbols bound in the source pattern into `infoMap`.
|
||||
void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
|
||||
|
||||
// Returns a reference to all the bound arguments in the source pattern.
|
||||
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
|
||||
|
||||
// The returned map contains pointers to the operators inside the
|
||||
// `RecordOperatorMap` passed-in when constructing this pattern; callers
|
||||
// should guarantee the lifetime of the returned map does not exceed that
|
||||
// of the `RecordOperatorMap`.
|
||||
using SymbolOperatorMap = llvm::StringMap<const Operator *>;
|
||||
|
||||
// Returns a reference to all the bound ops in the source pattern.
|
||||
SymbolOperatorMap &getSourcePatternBoundOps();
|
||||
|
||||
// Returns a reference to all the bound ops in the result patterns.
|
||||
SymbolOperatorMap &getResultPatternBoundOps();
|
||||
// Collects all symbols bound in result patterns into `infoMap`.
|
||||
void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
|
||||
|
||||
// Returns the op that the root node of the source pattern matches.
|
||||
const Operator &getSourceRootOp();
|
||||
|
@ -238,8 +373,8 @@ public:
|
|||
|
||||
private:
|
||||
// Recursively collects all bound symbols inside the DAG tree rooted
|
||||
// at `tree` and updates the given `symOpMap`.
|
||||
void collectBoundSymbols(DagNode tree, SymbolOperatorMap &symOpMap,
|
||||
// at `tree` and updates the given `infoMap`.
|
||||
void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
bool isSrcPattern);
|
||||
|
||||
// The TableGen definition of this pattern.
|
||||
|
@ -249,15 +384,6 @@ private:
|
|||
// TODO(antiagainst): we need a proper context manager, like MLIRContext,
|
||||
// for managing the lifetime of shared entities.
|
||||
RecordOperatorMap *recordOpMap;
|
||||
|
||||
// All source pattern bound op arguments.
|
||||
llvm::StringMap<Argument> srcBoundArguments;
|
||||
|
||||
// All source pattern bound ops.
|
||||
SymbolOperatorMap srcBoundOps;
|
||||
|
||||
// All result pattern bound ops.
|
||||
SymbolOperatorMap resBoundOps;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
|
|
|
@ -31,6 +31,10 @@ using namespace mlir;
|
|||
using llvm::formatv;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DagLeaf
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool tblgen::DagLeaf::isUnspecified() const {
|
||||
return dyn_cast_or_null<llvm::UnsetInit>(def);
|
||||
}
|
||||
|
@ -88,6 +92,10 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DagNode
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool tblgen::DagNode::isNativeCodeCall() const {
|
||||
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
|
||||
return defInit->getDef()->isSubClassOf("NativeCodeCall");
|
||||
|
@ -151,14 +159,158 @@ bool tblgen::DagNode::isReplaceWithValue() const {
|
|||
return dagOpDef->getName() == "replaceWithValue";
|
||||
}
|
||||
|
||||
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
|
||||
: def(*def), recordOpMap(mapper) {
|
||||
collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
|
||||
for (int i = 0, e = getNumResultPatterns(); i < e; ++i)
|
||||
collectBoundSymbols(getResultPattern(i), resBoundOps,
|
||||
/*isSrcPattern=*/false);
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolInfoMap
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
|
||||
int *index) {
|
||||
StringRef name, indexStr;
|
||||
int idx = -1;
|
||||
std::tie(name, indexStr) = symbol.rsplit("__");
|
||||
|
||||
if (indexStr.consumeInteger(10, idx)) {
|
||||
// The second part is not an index; we return the whole symbol as-is.
|
||||
return symbol;
|
||||
}
|
||||
if (index) {
|
||||
*index = idx;
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
|
||||
SymbolInfo::Kind kind,
|
||||
Optional<int> index)
|
||||
: op(op), kind(kind), argIndex(index) {}
|
||||
|
||||
int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
|
||||
switch (kind) {
|
||||
case Kind::Attr:
|
||||
case Kind::Operand:
|
||||
case Kind::Value:
|
||||
return 1;
|
||||
case Kind::Result:
|
||||
return op->getNumResults();
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
|
||||
switch (kind) {
|
||||
case Kind::Attr: {
|
||||
auto type =
|
||||
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
|
||||
return formatv("{0} {1};\n", type, name);
|
||||
}
|
||||
case Kind::Operand:
|
||||
case Kind::Value: {
|
||||
return formatv("Value *{0};\n", name);
|
||||
}
|
||||
case Kind::Result: {
|
||||
// Use the op itself for the results.
|
||||
return formatv("{0} {1};\n", op->getQualCppClassName(), name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
|
||||
int index) const {
|
||||
switch (kind) {
|
||||
case Kind::Attr:
|
||||
case Kind::Operand: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
return name;
|
||||
}
|
||||
case Kind::Result: {
|
||||
// TODO(b/133341698): The following is incorrect for variadic results. We
|
||||
// should use getODSResults().
|
||||
if (index >= 0) {
|
||||
return formatv("{0}.getOperation()->getResult({1})", name, index);
|
||||
}
|
||||
|
||||
// If referencing multiple results, compose a comma-separated list.
|
||||
SmallVector<std::string, 4> values;
|
||||
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||
values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
|
||||
}
|
||||
return llvm::join(values, ", ");
|
||||
}
|
||||
case Kind::Value: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
assert(op == nullptr);
|
||||
return name;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
|
||||
int argIndex) {
|
||||
StringRef name = getValuePackName(symbol);
|
||||
if (name != symbol) {
|
||||
auto error = formatv(
|
||||
"symbol '{0}' with trailing index cannot bind to op argument", symbol);
|
||||
PrintFatalError(loc, error);
|
||||
}
|
||||
|
||||
auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
|
||||
? SymbolInfo::getAttr(&op, argIndex)
|
||||
: SymbolInfo::getOperand(&op, argIndex);
|
||||
|
||||
return symbolInfoMap.insert({symbol, symInfo}).second;
|
||||
}
|
||||
|
||||
bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
|
||||
StringRef name = getValuePackName(symbol);
|
||||
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
|
||||
}
|
||||
|
||||
bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
|
||||
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
|
||||
}
|
||||
|
||||
bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
|
||||
return find(symbol) != symbolInfoMap.end();
|
||||
}
|
||||
|
||||
tblgen::SymbolInfoMap::const_iterator
|
||||
tblgen::SymbolInfoMap::find(StringRef key) const {
|
||||
StringRef name = getValuePackName(key);
|
||||
return symbolInfoMap.find(name);
|
||||
}
|
||||
|
||||
int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
|
||||
StringRef name = getValuePackName(symbol);
|
||||
if (name != symbol) {
|
||||
// If there is a trailing index inside symbol, it references just one
|
||||
// static value.
|
||||
return 1;
|
||||
}
|
||||
// Otherwise, find how many it represents by querying the symbol's info.
|
||||
return find(name)->getValue().getStaticValueCount();
|
||||
}
|
||||
|
||||
std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
|
||||
int index = -1;
|
||||
StringRef name = getValuePackName(symbol, &index);
|
||||
|
||||
auto it = symbolInfoMap.find(name);
|
||||
if (it == symbolInfoMap.end()) {
|
||||
auto error = formatv("referencing unbound symbol '{0}'", symbol);
|
||||
PrintFatalError(loc, error);
|
||||
}
|
||||
|
||||
return it->getValue().getValueAndRangeUse(name, index);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern
|
||||
//==----------------------------------------------------------------------===//
|
||||
|
||||
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
|
||||
: def(*def), recordOpMap(mapper) {}
|
||||
|
||||
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
|
||||
return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
|
||||
}
|
||||
|
@ -173,26 +325,17 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
|
|||
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||
}
|
||||
|
||||
void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
|
||||
if (srcBoundArguments.find(name) == srcBoundArguments.end() &&
|
||||
srcBoundOps.find(name) == srcBoundOps.end())
|
||||
PrintFatalError(def.getLoc(),
|
||||
Twine("referencing unbound variable '") + name + "'");
|
||||
void tblgen::Pattern::collectSourcePatternBoundSymbols(
|
||||
tblgen::SymbolInfoMap &infoMap) {
|
||||
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
|
||||
}
|
||||
|
||||
llvm::StringMap<tblgen::Argument> &
|
||||
tblgen::Pattern::getSourcePatternBoundArgs() {
|
||||
return srcBoundArguments;
|
||||
}
|
||||
|
||||
llvm::StringMap<const tblgen::Operator *> &
|
||||
tblgen::Pattern::getSourcePatternBoundOps() {
|
||||
return srcBoundOps;
|
||||
}
|
||||
|
||||
llvm::StringMap<const tblgen::Operator *> &
|
||||
tblgen::Pattern::getResultPatternBoundOps() {
|
||||
return resBoundOps;
|
||||
void tblgen::Pattern::collectResultPatternBoundSymbols(
|
||||
tblgen::SymbolInfoMap &infoMap) {
|
||||
for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
|
||||
auto pattern = getResultPattern(i);
|
||||
collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
|
||||
|
@ -251,8 +394,7 @@ tblgen::Pattern::getLocation() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
void tblgen::Pattern::collectBoundSymbols(DagNode tree,
|
||||
SymbolOperatorMap &symOpMap,
|
||||
void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
bool isSrcPattern) {
|
||||
auto treeName = tree.getSymbol();
|
||||
if (!tree.isOperation()) {
|
||||
|
@ -270,27 +412,34 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree,
|
|||
auto numTreeArgs = tree.getNumArgs();
|
||||
|
||||
if (numOpArgs != numTreeArgs) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
formatv("op '{0}' argument number mismatch: "
|
||||
"{1} in pattern vs. {2} in definition",
|
||||
op.getOperationName(), numTreeArgs, numOpArgs));
|
||||
auto err = formatv("op '{0}' argument number mismatch: "
|
||||
"{1} in pattern vs. {2} in definition",
|
||||
op.getOperationName(), numTreeArgs, numOpArgs);
|
||||
PrintFatalError(def.getLoc(), err);
|
||||
}
|
||||
|
||||
// The name attached to the DAG node's operator is for representing the
|
||||
// results generated from this op. It should be remembered as bound results.
|
||||
if (!treeName.empty())
|
||||
symOpMap.try_emplace(treeName, &op);
|
||||
if (!treeName.empty()) {
|
||||
if (!infoMap.bindOpResult(treeName, op))
|
||||
PrintFatalError(def.getLoc(),
|
||||
formatv("symbol '{0}' bound more than once", treeName));
|
||||
}
|
||||
|
||||
for (int i = 0; i != numTreeArgs; ++i) {
|
||||
if (auto treeArg = tree.getArgAsNestedDag(i)) {
|
||||
// This DAG node argument is a DAG node itself. Go inside recursively.
|
||||
collectBoundSymbols(treeArg, symOpMap, isSrcPattern);
|
||||
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
|
||||
} else if (isSrcPattern) {
|
||||
// We can only bind symbols to op arguments in source pattern. Those
|
||||
// symbols are referenced in result patterns.
|
||||
auto treeArgName = tree.getArgName(i);
|
||||
if (!treeArgName.empty())
|
||||
srcBoundArguments.try_emplace(treeArgName, op.getArg(i));
|
||||
if (!treeArgName.empty()) {
|
||||
if (!infoMap.bindOpArgument(treeArgName, op, i)) {
|
||||
auto err = formatv("symbol '{0}' bound more than once", treeArgName);
|
||||
PrintFatalError(def.getLoc(), err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,166 +51,6 @@ template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
|
|||
};
|
||||
} // end namespace llvm
|
||||
|
||||
// Gets the dynamic value pack's name by removing the index suffix from
|
||||
// `symbol`. Returns `symbol` itself if it does not contain an index.
|
||||
//
|
||||
// We can use `name__<index>` to access the `<index>`-th value in the dynamic
|
||||
// value pack bound to `name`. `name` is typically the results of an
|
||||
// multi-result op.
|
||||
static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
|
||||
StringRef name, indexStr;
|
||||
unsigned idx = 0;
|
||||
std::tie(name, indexStr) = symbol.rsplit("__");
|
||||
if (indexStr.consumeInteger(10, idx)) {
|
||||
// The second part is not an index.
|
||||
return symbol;
|
||||
}
|
||||
if (index)
|
||||
*index = idx;
|
||||
return name;
|
||||
}
|
||||
|
||||
// Formats all values from a dynamic value pack `symbol` according to the given
|
||||
// `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
|
||||
// and `{1}` as a placeholder for the value index, which will be offsetted by
|
||||
// `offset`. The `symbol` value pack has a total of `count` values.
|
||||
//
|
||||
// This extracts one value from the pack if `symbol` contains an index,
|
||||
// otherwise it extracts all values sequentially and returns them as a
|
||||
// comma-separated list.
|
||||
static std::string formatValuePack(const char *fmt, StringRef symbol,
|
||||
unsigned count, unsigned offset) {
|
||||
auto getNthValue = [fmt, offset](StringRef results,
|
||||
unsigned index) -> std::string {
|
||||
return formatv(fmt, results, index + offset);
|
||||
};
|
||||
|
||||
unsigned index = 0;
|
||||
StringRef name = getValuePackName(symbol, &index);
|
||||
if (name != symbol) {
|
||||
// The symbol contains an index.
|
||||
return getNthValue(name, index);
|
||||
}
|
||||
|
||||
// The symbol does not contain an index. Treat the symbol as a whole.
|
||||
SmallVector<std::string, 4> values;
|
||||
values.reserve(count);
|
||||
for (unsigned i = 0; i < count; ++i)
|
||||
values.emplace_back(getNthValue(symbol, i));
|
||||
return llvm::join(values, ", ");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternSymbolResolver
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// A class for resolving symbols bound in patterns.
|
||||
//
|
||||
// Symbols can be bound to op arguments and ops in the source pattern and ops
|
||||
// in result patterns. For example, in
|
||||
//
|
||||
// ```
|
||||
// def : Pattern<(SrcOp:$op1 $arg0, %arg1),
|
||||
// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
|
||||
// ```
|
||||
//
|
||||
// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
|
||||
// `$op2` is bound to `ResOp1`.
|
||||
//
|
||||
// If a symbol binds to a multi-result op and it does not have the `__N`
|
||||
// suffix, the symbol is expanded to the whole value pack generated by the
|
||||
// multi-result op. If the symbol has a `__N` suffix, then it will expand to
|
||||
// only the N-th result.
|
||||
//
|
||||
// This class keeps track of such symbols and translates them into their bound
|
||||
// values.
|
||||
//
|
||||
// Note that we also generate local variables for unnamed DAG nodes, like
|
||||
// `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
|
||||
// generated local variable will be implicitly named. Those implicit names are
|
||||
// not tracked in this class.
|
||||
class PatternSymbolResolver {
|
||||
public:
|
||||
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
||||
const StringMap<const Operator *> &srcOperations);
|
||||
|
||||
// Marks the given `symbol` as bound to a value pack with `numValues` and
|
||||
// returns true on success. Returns false if the `symbol` is already bound.
|
||||
bool add(StringRef symbol, int numValues);
|
||||
|
||||
// Queries the substitution for the given `symbol`. Returns empty string if
|
||||
// symbol not found. If the symbol represents a value pack, returns all the
|
||||
// values separated via comma.
|
||||
std::string query(StringRef symbol) const;
|
||||
|
||||
// Returns how many static values the given `symbol` correspond to. Returns a
|
||||
// negative value if the given symbol is not bound.
|
||||
//
|
||||
// Normally a symbol would correspond to just one value; for symbols bound to
|
||||
// multi-result ops, it can be more than one.
|
||||
int getValueCount(StringRef symbol) const;
|
||||
|
||||
private:
|
||||
// Symbols bound to arguments in source pattern.
|
||||
const StringMap<Argument> &sourceArguments;
|
||||
// Symbols bound to ops (for their results) in source pattern.
|
||||
const StringMap<const Operator *> &sourceOps;
|
||||
// Symbols bound to ops (for their results) in result patterns.
|
||||
// Key: symbol; value: number of values inside the pack
|
||||
StringMap<int> resultOps;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
PatternSymbolResolver::PatternSymbolResolver(
|
||||
const StringMap<Argument> &srcArgs,
|
||||
const StringMap<const Operator *> &srcOperations)
|
||||
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
|
||||
|
||||
bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
|
||||
StringRef name = getValuePackName(symbol);
|
||||
return resultOps.try_emplace(name, numValues).second;
|
||||
}
|
||||
|
||||
std::string PatternSymbolResolver::query(StringRef symbol) const {
|
||||
StringRef name = getValuePackName(symbol);
|
||||
// Handle symbols bound to generated ops
|
||||
auto resOpIt = resultOps.find(name);
|
||||
if (resOpIt != resultOps.end())
|
||||
return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
|
||||
resOpIt->second, /*offset=*/0);
|
||||
|
||||
// Handle symbols bound to matched op arguments
|
||||
auto srcArgIt = sourceArguments.find(symbol);
|
||||
if (srcArgIt != sourceArguments.end())
|
||||
return symbol;
|
||||
|
||||
// Handle symbols bound to matched op results
|
||||
auto srcOpIt = sourceOps.find(name);
|
||||
if (srcOpIt != sourceOps.end())
|
||||
return formatValuePack("{0}->getResult({1})", symbol,
|
||||
srcOpIt->second->getNumResults(), /*offset=*/0);
|
||||
return {};
|
||||
}
|
||||
|
||||
int PatternSymbolResolver::getValueCount(StringRef symbol) const {
|
||||
StringRef name = getValuePackName(symbol);
|
||||
// Handle symbols bound to generated ops
|
||||
auto resOpIt = resultOps.find(name);
|
||||
if (resOpIt != resultOps.end())
|
||||
return name == symbol ? resOpIt->second : 1;
|
||||
|
||||
// Handle symbols bound to matched op arguments
|
||||
if (sourceArguments.count(symbol))
|
||||
return 1;
|
||||
|
||||
// Handle symbols bound to matched op results
|
||||
auto srcOpIt = sourceOps.find(name);
|
||||
if (srcOpIt != sourceOps.end())
|
||||
return name == symbol ? srcOpIt->second->getNumResults() : 1;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternEmitter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -286,17 +126,13 @@ private:
|
|||
// Collects all of the operations within the given dag tree.
|
||||
void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
|
||||
|
||||
// Returns a unique name for a value of the given `op`.
|
||||
std::string getUniqueValueName(const Operator *op);
|
||||
// Returns a unique symbol for a local variable of the given `op`.
|
||||
std::string getUniqueSymbol(const Operator *op);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Symbol utilities
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
|
||||
// is already bound.
|
||||
void addSymbol(StringRef symbol, int numValues);
|
||||
|
||||
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
|
||||
std::string resolveSymbol(StringRef symbol);
|
||||
|
||||
|
@ -308,13 +144,19 @@ private:
|
|||
// prototypes used. This is intended to be used as a whole to
|
||||
// PrintFatalError() on errors.
|
||||
ArrayRef<llvm::SMLoc> loc;
|
||||
// Op's TableGen Record to wrapper object
|
||||
|
||||
// Op's TableGen Record to wrapper object.
|
||||
RecordOperatorMap *opMap;
|
||||
// Handy wrapper for pattern being emitted
|
||||
|
||||
// Handy wrapper for pattern being emitted.
|
||||
Pattern pattern;
|
||||
PatternSymbolResolver symbolResolver;
|
||||
// The next unused ID for newly created values
|
||||
|
||||
// Map for all bound symbols' info.
|
||||
SymbolInfoMap symbolInfoMap;
|
||||
|
||||
// The next unused ID for newly created values.
|
||||
unsigned nextValueId;
|
||||
|
||||
raw_ostream &os;
|
||||
|
||||
// Format contexts containing placeholder substitutations.
|
||||
|
@ -328,9 +170,7 @@ private:
|
|||
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
||||
raw_ostream &os)
|
||||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
|
||||
symbolResolver(pattern.getSourcePatternBoundArgs(),
|
||||
pattern.getSourcePatternBoundOps()),
|
||||
nextValueId(0), os(os) {
|
||||
symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
|
||||
fmtCtx.withBuilder("rewriter");
|
||||
}
|
||||
|
||||
|
@ -354,13 +194,14 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
}
|
||||
|
||||
int indent = 4 + 2 * depth;
|
||||
os.indent(indent) << formatv(
|
||||
"auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
|
||||
depth, op.getQualCppClassName());
|
||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
// Skip if there is no defining operation (e.g., arguments to function).
|
||||
os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
|
||||
os.indent(indent) << formatv(
|
||||
"if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
|
||||
op.getQualCppClassName());
|
||||
os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
|
||||
depth);
|
||||
}
|
||||
if (tree.getNumArgs() != op.getNumArgs()) {
|
||||
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
|
||||
|
@ -372,7 +213,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
// If the operand's name is set, set to that variable.
|
||||
auto name = tree.getSymbol();
|
||||
if (!name.empty())
|
||||
os.indent(indent) << formatv("{0} = op{1};\n", name, depth);
|
||||
os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto opArg = op.getArg(i);
|
||||
|
@ -381,7 +222,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
os.indent(indent) << "{\n";
|
||||
os.indent(indent + 2)
|
||||
<< formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
|
||||
<< formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
|
||||
depth + 1, depth, i);
|
||||
emitOpMatch(argTree, depth + 1);
|
||||
os.indent(indent + 2)
|
||||
|
@ -569,21 +410,17 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
PatternRewriter &rewriter) const override {
|
||||
)";
|
||||
|
||||
// Register all symbols bound in the source pattern.
|
||||
pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
|
||||
|
||||
os.indent(4) << "// Variables for capturing values and attributes used for "
|
||||
"creating ops\n";
|
||||
// Create local variables for storing the arguments bound to symbols.
|
||||
for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
|
||||
auto fieldName = arg.first();
|
||||
if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
|
||||
os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(),
|
||||
fieldName);
|
||||
} else {
|
||||
os.indent(4) << "Value *" << fieldName << ";\n";
|
||||
}
|
||||
}
|
||||
// Create local variables for storing the ops bound to symbols.
|
||||
for (const auto &result : pattern.getSourcePatternBoundOps()) {
|
||||
os.indent(4) << formatv("Operation *{0};\n", result.getKey());
|
||||
// Create local variables for storing the arguments and results bound
|
||||
// to symbols.
|
||||
for (const auto &symbolInfoPair : symbolInfoMap) {
|
||||
StringRef symbol = symbolInfoPair.getKey();
|
||||
auto &info = symbolInfoPair.getValue();
|
||||
os.indent(4) << info.getVarDecl(symbol);
|
||||
}
|
||||
// TODO(jpienaar): capture ops with consistent numbering so that it can be
|
||||
// reused for fused loc.
|
||||
|
@ -609,20 +446,22 @@ void PatternEmitter::emitRewriteLogic() {
|
|||
int numResultPatterns = pattern.getNumResultPatterns();
|
||||
|
||||
// First register all symbols bound to ops generated in result patterns.
|
||||
for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
|
||||
addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
|
||||
}
|
||||
pattern.collectResultPatternBoundSymbols(symbolInfoMap);
|
||||
|
||||
// Only the last N static values generated are used to replace the matched
|
||||
// root N-result op. We need to calculate the starting index (of the results
|
||||
// of the matched op) each result pattern is to replace.
|
||||
SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
|
||||
int replStartIndex = -1;
|
||||
// If we don't need to replace any value at all, set the replacement starting
|
||||
// index as the number of result patterns so we skip all of them when trying
|
||||
// to replace the matched op's results.
|
||||
int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
|
||||
for (int i = numResultPatterns - 1; i >= 0; --i) {
|
||||
auto numValues = getNodeValueCount(pattern.getResultPattern(i));
|
||||
offsets[i] = offsets[i + 1] - numValues;
|
||||
if (offsets[i] == 0) {
|
||||
replStartIndex = i;
|
||||
if (replStartIndex == -1)
|
||||
replStartIndex = i;
|
||||
} else if (offsets[i] < 0 && offsets[i + 1] > 0) {
|
||||
auto error = formatv(
|
||||
"cannot use the same multi-result op '{0}' to generate both "
|
||||
|
@ -652,31 +491,36 @@ void PatternEmitter::emitRewriteLogic() {
|
|||
|
||||
// Emit the final replaceOp() statement
|
||||
os.indent(4) << "rewriter.replaceOp(op0, {";
|
||||
interleave(
|
||||
ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
|
||||
[&](const std::string &name) { os << name; }, [&]() { os << ", "; });
|
||||
interleaveComma(
|
||||
ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os,
|
||||
[&](const std::string &symbol) { os << resolveSymbol(symbol); });
|
||||
os << "});\n";
|
||||
}
|
||||
|
||||
std::string PatternEmitter::getUniqueValueName(const Operator *op) {
|
||||
return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
|
||||
std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
|
||||
return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleResultPattern(DagNode resultTree,
|
||||
int resultIndex, int depth) {
|
||||
if (resultTree.isNativeCodeCall())
|
||||
return handleReplaceWithNativeCodeCall(resultTree);
|
||||
if (resultTree.isNativeCodeCall()) {
|
||||
auto symbol = handleReplaceWithNativeCodeCall(resultTree);
|
||||
symbolInfoMap.bindValue(symbol);
|
||||
return symbol;
|
||||
}
|
||||
|
||||
if (resultTree.isReplaceWithValue())
|
||||
if (resultTree.isReplaceWithValue()) {
|
||||
return handleReplaceWithValue(resultTree);
|
||||
}
|
||||
|
||||
// Create the op and get the local variable for it.
|
||||
auto results = handleOpCreation(resultTree, resultIndex, depth);
|
||||
// We need to get all the values out of this local variable if we've created a
|
||||
// multi-result op.
|
||||
const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
|
||||
return formatValuePack("{0}.getOperation()->getResult({1})", results,
|
||||
numResults, /*offset=*/0);
|
||||
// Normal op creation.
|
||||
auto symbol = handleOpCreation(resultTree, resultIndex, depth);
|
||||
if (resultTree.getSymbol().empty()) {
|
||||
// This is an op not explicitly bound to a symbol in the rewrite rule.
|
||||
// Register the auto-generated symbol for it.
|
||||
symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
|
||||
}
|
||||
return symbol;
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
|
||||
|
@ -709,7 +553,6 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
|
|||
std::string val = std::to_string(enumCase.getValue());
|
||||
return handleConstantAttr(enumCase, val);
|
||||
}
|
||||
pattern.ensureBoundInSourcePattern(argName);
|
||||
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
|
||||
return argName;
|
||||
}
|
||||
|
@ -734,27 +577,23 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
|
|||
attrs[5], attrs[6], attrs[7]);
|
||||
}
|
||||
|
||||
void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
|
||||
if (!symbolResolver.add(symbol, numValues))
|
||||
PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
|
||||
}
|
||||
|
||||
std::string PatternEmitter::resolveSymbol(StringRef symbol) {
|
||||
auto subst = symbolResolver.query(symbol);
|
||||
if (subst.empty())
|
||||
auto subst = symbolInfoMap.getValueAndRangeUse(symbol);
|
||||
if (subst.empty()) {
|
||||
PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
|
||||
}
|
||||
return subst;
|
||||
}
|
||||
|
||||
int PatternEmitter::getNodeValueCount(DagNode node) {
|
||||
if (node.isOperation()) {
|
||||
// First to see whether this op is bound and we just want a specific result
|
||||
// of it with `__N` suffix in symbol.
|
||||
int count = symbolResolver.getValueCount(node.getSymbol());
|
||||
if (count >= 0)
|
||||
return count;
|
||||
|
||||
// No symbol. Then we are using all the results.
|
||||
// If the op is bound to a symbol in the rewrite rule, query its result
|
||||
// count from the symbol info map.
|
||||
auto symbol = node.getSymbol();
|
||||
if (!symbol.empty()) {
|
||||
return symbolInfoMap.getStaticValueCount(symbol);
|
||||
}
|
||||
// Otherwise this is an unbound op; we will use all its results.
|
||||
return pattern.getDialectOp(node).getNumResults();
|
||||
}
|
||||
// TODO(antiagainst): This considers all NativeCodeCall as returning one
|
||||
|
@ -799,10 +638,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
// Use the specified name for this op if available. Generate one otherwise.
|
||||
std::string resultValue = tree.getSymbol();
|
||||
if (resultValue.empty())
|
||||
resultValue = getUniqueValueName(&resultOp);
|
||||
resultValue = getUniqueSymbol(&resultOp);
|
||||
// Strip the index to get the name for the value pack. This will be used to
|
||||
// name the local variable for the op.
|
||||
StringRef valuePackName = getValuePackName(resultValue);
|
||||
StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue);
|
||||
|
||||
// Then we build the new op corresponding to this DAG node.
|
||||
|
||||
|
@ -826,20 +665,25 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
// here.
|
||||
|
||||
// We need to specify the types for all results.
|
||||
auto resultTypes =
|
||||
formatValuePack("op0->getResult({1})->getType()", valuePackName,
|
||||
resultOp.getNumResults(), resultIndex);
|
||||
SmallVector<std::string, 4> resultTypes;
|
||||
int numResults = resultOp.getNumResults();
|
||||
resultTypes.reserve(numResults);
|
||||
for (int i = 0; i < numResults; ++i) {
|
||||
resultTypes.push_back(
|
||||
formatv("op0->getResult({0})->getType()", resultIndex + i));
|
||||
}
|
||||
|
||||
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
|
||||
valuePackName, resultOp.getQualCppClassName())
|
||||
<< (resultTypes.empty() ? "" : ", ") << resultTypes;
|
||||
<< (resultTypes.empty() ? "" : ", ")
|
||||
<< llvm::join(resultTypes, ", ");
|
||||
}
|
||||
|
||||
// Create the builder call for the result.
|
||||
// Add operands.
|
||||
int i = 0;
|
||||
for (int e = resultOp.getNumOperands(); i < e; ++i) {
|
||||
const auto &operand = resultOp.getOperand(i);
|
||||
int argIndex = 0;
|
||||
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
|
||||
const auto &operand = resultOp.getOperand(argIndex);
|
||||
|
||||
// Start each operand on its own line.
|
||||
(os << ",\n").indent(6);
|
||||
|
@ -847,11 +691,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
if (!operand.name.empty())
|
||||
os << "/*" << operand.name << "=*/";
|
||||
|
||||
if (tree.isNestedDagArg(i)) {
|
||||
os << childNodeNames[i];
|
||||
if (tree.isNestedDagArg(argIndex)) {
|
||||
os << childNodeNames[argIndex];
|
||||
} else {
|
||||
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||
auto symbol = resolveSymbol(tree.getArgName(i));
|
||||
DagLeaf leaf = tree.getArgAsLeaf(argIndex);
|
||||
auto symbol = resolveSymbol(tree.getArgName(argIndex));
|
||||
if (leaf.isNativeCodeCall()) {
|
||||
os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
|
||||
} else {
|
||||
|
@ -862,26 +706,26 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
}
|
||||
|
||||
// Add attributes.
|
||||
for (int e = tree.getNumArgs(); i != e; ++i) {
|
||||
for (; argIndex != numOpArgs; ++argIndex) {
|
||||
// Start each attribute on its own line.
|
||||
(os << ",\n").indent(6);
|
||||
// The argument in the op definition.
|
||||
auto opArgName = resultOp.getArgName(i);
|
||||
if (auto subTree = tree.getArgAsNestedDag(i)) {
|
||||
auto opArgName = resultOp.getArgName(argIndex);
|
||||
if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
|
||||
if (!subTree.isNativeCodeCall())
|
||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
os << formatv("/*{0}=*/{1}", opArgName,
|
||||
handleReplaceWithNativeCodeCall(subTree));
|
||||
} else {
|
||||
auto leaf = tree.getArgAsLeaf(i);
|
||||
auto leaf = tree.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
auto patArgName = tree.getArgName(i);
|
||||
auto patArgName = tree.getArgName(argIndex);
|
||||
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
auto argument = resultOp.getArg(argIndex);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue