forked from OSchip/llvm-project
[mlir][ODS] Get rid of limitations in rewriters generator
Do not limit the number of arguments in rewriter pattern. Introduce separate `FmtStrVecObject` class to handle format of variadic `std::string` array. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D97839
This commit is contained in:
parent
209a626ede
commit
02834e1bd9
|
@ -186,6 +186,20 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class FmtStrVecObject : public FmtObjectBase {
|
||||
public:
|
||||
using StrFormatAdapter =
|
||||
decltype(llvm::detail::build_format_adapter(std::declval<std::string>()));
|
||||
|
||||
FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
|
||||
ArrayRef<std::string> params);
|
||||
FmtStrVecObject(FmtStrVecObject const &that) = delete;
|
||||
FmtStrVecObject(FmtStrVecObject &&that);
|
||||
|
||||
private:
|
||||
SmallVector<StrFormatAdapter, 16> parameters;
|
||||
};
|
||||
|
||||
/// Formats text by substituting placeholders in format string with replacement
|
||||
/// parameters.
|
||||
///
|
||||
|
@ -234,6 +248,11 @@ inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals)
|
|||
llvm::detail::build_format_adapter(std::forward<Ts>(vals))...));
|
||||
}
|
||||
|
||||
inline FmtStrVecObject tgfmt(StringRef fmt, const FmtContext *ctx,
|
||||
ArrayRef<std::string> params) {
|
||||
return FmtStrVecObject(fmt, ctx, params);
|
||||
}
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -173,3 +173,22 @@ void FmtObjectBase::format(raw_ostream &s) const {
|
|||
adapters[repl.index]->format(s, /*Options=*/"");
|
||||
}
|
||||
}
|
||||
|
||||
FmtStrVecObject::FmtStrVecObject(StringRef fmt, const FmtContext *ctx,
|
||||
ArrayRef<std::string> params)
|
||||
: FmtObjectBase(fmt, ctx, params.size()) {
|
||||
parameters.reserve(params.size());
|
||||
for (std::string p : params)
|
||||
parameters.push_back(llvm::detail::build_format_adapter(std::move(p)));
|
||||
|
||||
adapters.reserve(parameters.size());
|
||||
for (auto &p : parameters)
|
||||
adapters.push_back(&p);
|
||||
}
|
||||
|
||||
FmtStrVecObject::FmtStrVecObject(FmtStrVecObject &&that)
|
||||
: FmtObjectBase(std::move(that)), parameters(std::move(that.parameters)) {
|
||||
adapters.reserve(parameters.size());
|
||||
for (auto &p : parameters)
|
||||
adapters.push_back(&p);
|
||||
}
|
||||
|
|
|
@ -58,3 +58,30 @@ def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)),
|
|||
def test3 : Pat<(BOp $attr, (AOp:$a $input)),
|
||||
(BOp $attr, (AOp $input), (location $a))>;
|
||||
|
||||
def DOp : NS_Op<"d_op", []> {
|
||||
let arguments = (ins
|
||||
AnyInteger:$v1,
|
||||
AnyInteger:$v2,
|
||||
AnyInteger:$v3,
|
||||
AnyInteger:$v4,
|
||||
AnyInteger:$v5,
|
||||
AnyInteger:$v6,
|
||||
AnyInteger:$v7,
|
||||
AnyInteger:$v8,
|
||||
AnyInteger:$v9,
|
||||
AnyInteger:$v10
|
||||
);
|
||||
|
||||
let results = (outs AnyInteger);
|
||||
}
|
||||
|
||||
def NativeBuilder :
|
||||
NativeCodeCall<[{
|
||||
nativeCall($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
}]>;
|
||||
|
||||
// Check Pattern with large number of DAG arguments passed to NativeCodeCall
|
||||
// CHECK: struct test4 : public ::mlir::RewritePattern {
|
||||
// CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin()))
|
||||
def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10),
|
||||
(NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>;
|
||||
|
|
|
@ -251,12 +251,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
|||
|
||||
// TODO(suderman): iterate through arguments, determine their types, output
|
||||
// names.
|
||||
SmallVector<std::string, 8> capture(8);
|
||||
if (tree.getNumArgs() > 8) {
|
||||
PrintFatalError(loc,
|
||||
"unsupported NativeCodeCall matcher argument numbers: " +
|
||||
Twine(tree.getNumArgs()));
|
||||
}
|
||||
SmallVector<std::string, 8> capture;
|
||||
capture.push_back(opName.str());
|
||||
|
||||
raw_indented_ostream::DelimitedScope scope(os);
|
||||
|
||||
|
@ -274,7 +270,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
|||
}
|
||||
}
|
||||
|
||||
capture[i] = std::move(argName);
|
||||
capture.push_back(std::move(argName));
|
||||
}
|
||||
|
||||
bool hasLocationDirective;
|
||||
|
@ -282,21 +278,20 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
|||
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||
|
||||
auto fmt = tree.getNativeCodeTemplate();
|
||||
auto nativeCodeCall = std::string(tgfmt(
|
||||
fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
|
||||
capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
|
||||
auto nativeCodeCall =
|
||||
std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture));
|
||||
|
||||
os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
auto name = tree.getArgName(i);
|
||||
if (!name.empty() && name != "_") {
|
||||
os << formatv("{0} = {1};\n", name, capture[i]);
|
||||
os << formatv("{0} = {1};\n", name, capture[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
std::string argName = capture[i];
|
||||
std::string argName = capture[i + 1];
|
||||
|
||||
// Handle nested DAG construct first
|
||||
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
|
||||
|
@ -915,29 +910,26 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
|
|||
LLVM_DEBUG(llvm::dbgs() << '\n');
|
||||
|
||||
auto fmt = tree.getNativeCodeTemplate();
|
||||
// TODO: replace formatv arguments with the exact specified args.
|
||||
SmallVector<std::string, 8> attrs(8);
|
||||
if (tree.getNumArgs() > 8) {
|
||||
PrintFatalError(loc,
|
||||
"unsupported NativeCodeCall replace argument numbers: " +
|
||||
Twine(tree.getNumArgs()));
|
||||
}
|
||||
|
||||
SmallVector<std::string, 16> attrs;
|
||||
|
||||
bool hasLocationDirective;
|
||||
std::string locToUse;
|
||||
std::tie(hasLocationDirective, locToUse) = getLocation(tree);
|
||||
|
||||
for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
|
||||
if (tree.isNestedDagArg(i)) {
|
||||
attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
|
||||
attrs.push_back(
|
||||
handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
|
||||
} else {
|
||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||
attrs.push_back(
|
||||
handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
|
||||
<< " replacement: " << attrs[i] << "\n");
|
||||
}
|
||||
return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
|
||||
attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
|
||||
attrs[6], attrs[7]));
|
||||
|
||||
return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs));
|
||||
}
|
||||
|
||||
int PatternEmitter::getNodeValueCount(DagNode node) {
|
||||
|
|
Loading…
Reference in New Issue