[mlir][ODS] Get rid of limitations in rewriters generator

Do not limit the number of arguments in rewriter pattern.

Introduce separate `FmtStrVecObject` class to handle
format of variadic `std::string` array.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D97839
This commit is contained in:
Vladislav Vinogradov 2021-03-03 12:04:08 +03:00
parent 209a626ede
commit 02834e1bd9
4 changed files with 81 additions and 24 deletions

View File

@ -186,6 +186,20 @@ public:
}
};
class FmtStrVecObject : public FmtObjectBase {
public:
using StrFormatAdapter =
decltype(llvm::detail::build_format_adapter(std::declval<std::string>()));
FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
ArrayRef<std::string> params);
FmtStrVecObject(FmtStrVecObject const &that) = delete;
FmtStrVecObject(FmtStrVecObject &&that);
private:
SmallVector<StrFormatAdapter, 16> parameters;
};
/// Formats text by substituting placeholders in format string with replacement
/// parameters.
///
@ -234,6 +248,11 @@ inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
}
inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx,
ArrayRef<std::string> params) {
return FmtStrVecObject(fmt, ctx, params);
}
} // end namespace tblgen
} // end namespace mlir

View File

@ -173,3 +173,22 @@ void FmtObjectBase::format(raw_ostream &s) const {
adapters[repl.index]->format(s, /*Options=*/"");
}
}
FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
ArrayRef<std::string> params)
: FmtObjectBase(fmt, ctx, params.size()) {
parameters.reserve(params.size());
for (std::string p : params)
parameters.push_back(llvm::detail::build_format_adapter(std::move(p)));
adapters.reserve(parameters.size());
for (auto &p : parameters)
adapters.push_back(&p);
}
FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that)
: FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
adapters.reserve(parameters.size());
for (auto &p : parameters)
adapters.push_back(&p);
}

View File

@ -58,3 +58,30 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)),
def test3 : Pat<(BOp $attr, (AOp:$a $input)),
(BOp $attr, (AOp $input), (location $a))>;
def DOp : NS_Op<"d_op", []> {
let arguments = (ins
AnyInteger:$v1,
AnyInteger:$v2,
AnyInteger:$v3,
AnyInteger:$v4,
AnyInteger:$v5,
AnyInteger:$v6,
AnyInteger:$v7,
AnyInteger:$v8,
AnyInteger:$v9,
AnyInteger:$v10
);
let results = (outs AnyInteger);
}
def NativeBuilder :
NativeCodeCall<[{
nativeCall($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9)
}]>;
// Check Pattern with large number of DAG arguments passed to NativeCodeCall
// CHECK: struct test4 : public ::mlir::RewritePattern {
// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
(NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;

View File

@ -251,12 +251,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
// TODO(suderman): iterate through arguments, determine their types, output
// names.
SmallVector<std::string, 8> capture(8);
if (tree.getNumArgs() > 8) {
PrintFatalError(loc,
"unsupported NativeCodeCall matcher argument numbers: " +
Twine(tree.getNumArgs()));
}
SmallVector<std::string, 8> capture;
capture.push_back(opName.str());
raw_indented_ostream::DelimitedScope scope(os);
@ -274,7 +270,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
}
}
capture[i] = std::move(argName);
capture.push_back(std::move(argName));
}
bool hasLocationDirective;
@ -282,21 +278,20 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
auto fmt = tree.getNativeCodeTemplate();
auto nativeCodeCall = std::string(tgfmt(
fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
auto nativeCodeCall =
std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture));
os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto name = tree.getArgName(i);
if (!name.empty() && name != "_") {
os << formatv("{0} = {1};\n", name, capture[i]);
os << formatv("{0} = {1};\n", name, capture[i + 1]);
}
}
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = capture[i];
std::string argName = capture[i + 1];
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@ -915,29 +910,26 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
LLVM_DEBUG(llvm::dbgs() << '\n');
auto fmt = tree.getNativeCodeTemplate();
// TODO: replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) {
PrintFatalError(loc,
"unsupported NativeCodeCall replace argument numbers: " +
Twine(tree.getNumArgs()));
}
SmallVector<std::string, 16> attrs;
bool hasLocationDirective;
std::string locToUse;
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
if (tree.isNestedDagArg(i)) {
attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
attrs.push_back(
handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
} else {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
attrs.push_back(
handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
}
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
<< " replacement: " << attrs[i] << "\n");
}
return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
attrs[6], attrs[7]));
return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs));
}
int PatternEmitter::getNodeValueCount(DagNode node) {