Revert "[DDR] Introduce implicit equality check for the source pattern operands with the same name."

This reverts commit 7271c1bcb9.

This broke the gcc-5 build:

/usr/include/c++/5/ext/new_allocator.h:120:4: error: no matching function for call to 'std::pair<const std::__cxx11::basic_string<char>, mlir::tblgen::SymbolInfoMap::SymbolInfo>::pair(llvm::StringRef&, mlir::tblgen::SymbolInfoMap::SymbolInfo)'
  { ::new((void *)__p) _Up(std::forward<_Args>(__args)...); }
    ^
In file included from /usr/include/c++/5/utility:70:0,
                 from llvm/include/llvm/Support/type_traits.h:18,
                 from llvm/include/llvm/Support/Casting.h:18,
                 from mlir/include/mlir/Support/LLVM.h:24,
                 from mlir/include/mlir/TableGen/Pattern.h:17,
                 from mlir/lib/TableGen/Pattern.cpp:14:
/usr/include/c++/5/bits/stl_pair.h:206:9: note: candidate: template<class ... _Args1, long unsigned int ..._Indexes1, class ... _Args2, long unsigned int ..._Indexes2> std::pair<_T1, _T2>::pair(std::tuple<_Args1 ...>&, std::tuple<_Args2 ...>&, std::_Index_tuple<_Indexes1 ...>, std::_Index_tuple<_Indexes2 ...>)
         pair(tuple<_Args1...>&, tuple<_Args2...>&,
         ^
This commit is contained in:
Mehdi Amini 2020-10-14 00:37:10 +00:00
parent 5fe53c4128
commit 0b793c4be0
5 changed files with 21 additions and 250 deletions

View File

@ -21,8 +21,6 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include <unordered_map>
namespace llvm {
class DagInit;
class Init;
@ -230,9 +228,6 @@ public:
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;
// Returns a variable name for the symbol named as `name`.
std::string getVarName(StringRef name) const;
private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
@ -290,12 +285,9 @@ public:
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
};
using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
using BaseT = llvm::StringMap<SymbolInfo>;
// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
@ -308,7 +300,7 @@ public:
const_iterator end() const { return symbolInfoMap.end(); }
// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound and symbols are not operands.
// Returns false if `symbol` is already bound.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
// Binds the given `symbol` to the results the given `op`. Returns false if
@ -325,18 +317,6 @@ public:
// Returns an iterator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;
// Returns an iterator to the information of the given symbol named as `key`,
// with index `argIndex` for operator `op`.
const_iterator findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const;
// Returns the bounds of a range that includes all the elements which
// bind to the `key`.
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
// Returns number of times symbol named as `key` was used.
int count(StringRef key) const;
// Returns the number of static values of the given `symbol` corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
@ -358,9 +338,6 @@ public:
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
// Assign alternative unique names to Operands that have equal names.
void assignUniqueAlternativeNames();
// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on success. Returns
// `symbol` itself if it does not contain an index.
@ -370,7 +347,7 @@ public:
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
private:
BaseT symbolInfoMap;
llvm::StringMap<SymbolInfo> symbolInfoMap;
// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.

View File

@ -208,10 +208,6 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
llvm_unreachable("unknown kind");
}
std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
}
std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
@ -223,9 +219,8 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
return std::string(
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
getVarName(name)));
return std::string(formatv(
"::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
}
case Kind::Value: {
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
@ -364,34 +359,16 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);
std::string key = symbol.str();
if (auto numberOfEntries = symbolInfoMap.count(key)) {
// Only non unique name for the operand is supported.
if (symInfo.kind != SymbolInfo::Kind::Operand) {
return false;
}
// Cannot add new operand if there is already non operand with the same
// name.
if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
return false;
}
}
symbolInfoMap.emplace(key, symInfo);
return true;
return symbolInfoMap.insert({symbol, symInfo}).second;
}
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
StringRef name = getValuePackName(symbol);
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
return symbolInfoMap.count(inserted->first) == 1;
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
}
bool SymbolInfoMap::bindValue(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
return symbolInfoMap.count(inserted->first) == 1;
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
}
bool SymbolInfoMap::contains(StringRef symbol) const {
@ -399,38 +376,10 @@ bool SymbolInfoMap::contains(StringRef symbol) const {
}
SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
std::string name = getValuePackName(key).str();
StringRef name = getValuePackName(key);
return symbolInfoMap.find(name);
}
SymbolInfoMap::const_iterator
SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const {
std::string name = getValuePackName(key).str();
auto range = symbolInfoMap.equal_range(name);
for (auto it = range.first; it != range.second; ++it) {
if (it->second.op == &op && it->second.argIndex == argIndex) {
return it;
}
}
return symbolInfoMap.end();
}
std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
std::string name = getValuePackName(key).str();
return symbolInfoMap.equal_range(name);
}
int SymbolInfoMap::count(StringRef key) const {
std::string name = getValuePackName(key).str();
return symbolInfoMap.count(name);
}
int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
@ -439,7 +388,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
return find(name)->second.getStaticValueCount();
return find(name)->getValue().getStaticValueCount();
}
std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
@ -448,13 +397,13 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto it = symbolInfoMap.find(name.str());
auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
return it->second.getValueAndRangeUse(name, index, fmt, separator);
return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
}
std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
@ -462,44 +411,13 @@ std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto it = symbolInfoMap.find(name.str());
auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
return it->second.getAllRangeUse(name, index, fmt, separator);
}
void SymbolInfoMap::assignUniqueAlternativeNames() {
llvm::StringSet<> usedNames;
for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;
auto operandName = symbolInfoIt->first;
int startSearchIndex = 0;
for (++startRange; startRange != endRange; ++startRange) {
// Current operand name is not unique, find a unique one
// and set the alternative name.
for (int i = startSearchIndex;; ++i) {
std::string alternativeName = operandName + std::to_string(i);
if (!usedNames.contains(alternativeName) &&
symbolInfoMap.count(alternativeName) == 0) {
usedNames.insert(alternativeName);
startRange->second.alternativeName = alternativeName;
startSearchIndex = i + 1;
break;
}
}
}
symbolInfoIt = endRange;
}
return it->getValue().getAllRangeUse(name, index, fmt, separator);
}
//===----------------------------------------------------------------------===//
@ -527,10 +445,6 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
infoMap.assignUniqueAlternativeNames();
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
}
void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {

View File

@ -619,32 +619,6 @@ def OpM : TEST_Op<"op_m"> {
let results = (outs I32);
}
def OpN : TEST_Op<"op_n"> {
let arguments = (ins I32, I32);
let results = (outs I32);
}
def OpO : TEST_Op<"op_o"> {
let arguments = (ins I32);
let results = (outs I32);
}
def OpP : TEST_Op<"op_p"> {
let arguments = (ins I32, I32, I32, I32, I32, I32);
let results = (outs I32);
}
// Test same operand name enforces equality condition check.
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
// Test when equality is enforced at different depth.
def TestNestedOpEqualArgsPattern :
Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
// Test multiple equal arguments check enforced.
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);

View File

@ -111,64 +111,6 @@ func @verifyManyArgs(%arg: i32) {
return
}
// CHECK-LABEL: verifyEqualArgs
func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
// def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;
// CHECK: "test.op_o"(%arg0) : (i32) -> i32
"test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)
// CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
"test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)
return
}
// CHECK-LABEL: verifyNestedOpEqualArgs
func @verifyNestedOpEqualArgs(
%arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
// def TestNestedOpEqualArgsPattern :
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
// CHECK: %arg1
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
// CHECK: test.op_p
// CHECK: test.op_n
%2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
return
}
// CHECK-LABEL: verifyMultipleEqualArgs
func @verifyMultipleEqualArgs(
%arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
// def TestMultipleEqualArgsPattern :
// Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
// CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32
// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32
// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32
// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
(i32, i32, i32, i32 , i32, i32) -> i32
return
}
//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//

View File

@ -89,11 +89,6 @@ private:
void emitMatchCheck(int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt);
// Emits C++ for checking a match with a corresponding match failure
// diagnostics.
void emitMatchCheck(int depth, const std::string &matchStr,
const std::string &failureStr);
//===--------------------------------------------------------------------===//
// Rewrite utilities
//===--------------------------------------------------------------------===//
@ -332,9 +327,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
res->second.getVarName(name), depth, argIndex - numPrevAttrs);
os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
argIndex - numPrevAttrs);
}
}
@ -399,15 +393,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
void PatternEmitter::emitMatchCheck(
int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) {
emitMatchCheck(depth, matchFmt.str(), failureFmt.str());
}
void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr,
const std::string &failureStr) {
os << "if (!(" << matchStr << "))";
os << "if (!(" << matchFmt.str() << "))";
os.scope("{\n", "\n}\n").os
<< "return rewriter.notifyMatchFailure(op" << depth
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr
<< ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str()
<< ";\n});";
}
@ -456,30 +445,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
constraint.getDescription()));
}
}
// Some of the operands could be bound to the same symbol name, we need
// to enforce equality constraint on those.
// TODO: we should be able to emit equality checks early
// and short circuit unnecessary work if vars are not equal.
for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;
auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
for (++startRange; startRange != endRange; ++startRange) {
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
emitMatchCheck(
depth,
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
secondOperand));
}
symbolInfoIt = endRange;
}
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
@ -553,9 +518,8 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Create local variables for storing the arguments and results bound
// to symbols.
for (const auto &symbolInfoPair : symbolInfoMap) {
const auto &symbol = symbolInfoPair.first;
const auto &info = symbolInfoPair.second;
StringRef symbol = symbolInfoPair.getKey();
auto &info = symbolInfoPair.getValue();
os << info.getVarDecl(symbol);
}
// TODO: capture ops with consistent numbering so that it can be
@ -1129,7 +1093,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
range);
} else {
os << formatv("tblgen_values.push_back(");
os << formatv("tblgen_values.push_back(", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(
childNodeNames.lookup(argIndex));