Fix support for auxiliary ops in declarative rewrite rules

We allow to generate more ops than what are needed for replacing
the matched root op. Only the last N static values generated are
used as replacement; the others serve as auxiliary ops/values for
building the replacement.

With the introduction of multi-result op support, an op, if used
as a whole, may be used to replace multiple static values of
the matched root op. We need to consider this when calculating
the result range an generated op is to replace.

For example, we can have the following pattern:

```tblgen
def : Pattern<(ThreeResultOp ...),
              [(OneResultOp ...), (OneResultOp ...), (OneResultOp ...)]>;

// Two op to replace all three results
def : Pattern<(ThreeResultOp ...),
              [(TwoResultOp ...), (OneResultOp ...)]>;

// One op to replace all three results
def : Pat<(ThreeResultOp ...), (ThreeResultOp ...)>;

def : Pattern<(ThreeResultOp ...),
              [(AuxiliaryOp ...), (ThreeResultOp ...)]>;
```
PiperOrigin-RevId: 261017235
This commit is contained in:
Lei Zhang 2019-07-31 16:03:13 -07:00 committed by A. Unique TensorFlower
parent e44ba1f8bf
commit e032d0dc63
6 changed files with 215 additions and 181 deletions

View File

@ -102,7 +102,7 @@ public:
// Returns the native code call template inside this DAG leaf.
// Precondition: isNativeCodeCall()
llvm::StringRef getNativeCodeTemplate() const;
StringRef getNativeCodeTemplate() const;
private:
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
@ -134,8 +134,8 @@ public:
// DagNode.
operator bool() const { return node != nullptr; }
// Returns the operation referenced by this DAG node.
llvm::StringRef getOpName() const;
// Returns the symbol bound to this DAG node.
StringRef getSymbol() const;
// Returns the operator wrapper object corresponding to the dialect op matched
// by this DAG. The operator wrapper will be queried from the given `mapper`
@ -160,7 +160,7 @@ public:
DagLeaf getArgAsLeaf(unsigned index) const;
// Returns the specified name of the `index`-th argument.
llvm::StringRef getArgName(unsigned index) const;
StringRef getArgName(unsigned index) const;
// Returns true if this DAG construct means to replace with an existing SSA
// value.
@ -177,7 +177,7 @@ public:
// Returns the native code call template inside this DAG node.
// Precondition: isNativeCodeCall()
llvm::StringRef getNativeCodeTemplate() const;
StringRef getNativeCodeTemplate() const;
private:
const llvm::DagInit *node; // nullptr means null DagNode
@ -194,25 +194,31 @@ public:
// Returns the source pattern to match.
DagNode getSourcePattern() const;
// Returns the number of results generated by applying this rewrite pattern.
int getNumResults() const;
// Returns the number of result patterns generated by applying this rewrite
// rule.
int getNumResultPatterns() const;
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
// Checks whether an argument or op with the given `name` is bound in
// source pattern. Prints fatal error if not; does nothing otherwise.
void ensureBoundInSourcePattern(llvm::StringRef name) const;
void ensureBoundInSourcePattern(StringRef name) const;
// Returns a reference to all the bound arguments in the source pattern.
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
// Returns a reference to all the bound ops in the source pattern.
// The returned map contains pointers to the operators inside the
// `RecordOperatorMap` passed-in when constructing this pattern; callers
// should guarantee the lifetime of the returned map does not exceed that
// of the `RecordOperatorMap`.
llvm::StringMap<const Operator *> &getSourcePatternBoundOps();
using SymbolOperatorMap = llvm::StringMap<const Operator *>;
// Returns a reference to all the bound ops in the source pattern.
SymbolOperatorMap &getSourcePatternBoundOps();
// Returns a reference to all the bound ops in the result patterns.
SymbolOperatorMap &getResultPatternBoundOps();
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@ -234,9 +240,10 @@ public:
std::vector<IdentifierLine> getLocation() const;
private:
// Recursively collects all bound arguments inside the DAG tree rooted
// at `tree`.
void collectBoundArguments(DagNode tree);
// Recursively collects all bound symbols inside the DAG tree rooted
// at `tree` and updates the given `symOpMap`.
void collectBoundSymbols(DagNode tree, SymbolOperatorMap &symOpMap,
bool isSrcPattern);
// The TableGen definition of this pattern.
const llvm::Record &def;
@ -246,11 +253,14 @@ private:
// for managing the lifetime of shared entities.
RecordOperatorMap *recordOpMap;
// All bound op arguments.
llvm::StringMap<Argument> boundArguments;
// All source pattern bound op arguments.
llvm::StringMap<Argument> srcBoundArguments;
// All bound ops.
llvm::StringMap<const Operator *> boundOps;
// All source pattern bound ops.
SymbolOperatorMap srcBoundOps;
// All result pattern bound ops.
SymbolOperatorMap resBoundOps;
};
} // end namespace tblgen

View File

@ -105,7 +105,7 @@ llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
->getValueAsString("expression");
}
llvm::StringRef tblgen::DagNode::getOpName() const {
llvm::StringRef tblgen::DagNode::getSymbol() const {
return node->getNameStr();
}
@ -158,14 +158,17 @@ bool tblgen::DagNode::isVerifyUnusedValue() const {
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) {
collectBoundArguments(getSourcePattern());
collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
for (int i = 0, e = getNumResultPatterns(); i < e; ++i)
collectBoundSymbols(getResultPattern(i), resBoundOps,
/*isSrcPattern=*/false);
}
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
}
int tblgen::Pattern::getNumResults() const {
int tblgen::Pattern::getNumResultPatterns() const {
auto *results = def.getValueAsListInit("resultPatterns");
return results->size();
}
@ -176,20 +179,25 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
}
void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
if (boundArguments.find(name) == boundArguments.end() &&
boundOps.find(name) == boundOps.end())
if (srcBoundArguments.find(name) == srcBoundArguments.end() &&
srcBoundOps.find(name) == srcBoundOps.end())
PrintFatalError(def.getLoc(),
Twine("referencing unbound variable '") + name + "'");
}
llvm::StringMap<tblgen::Argument> &
tblgen::Pattern::getSourcePatternBoundArgs() {
return boundArguments;
return srcBoundArguments;
}
llvm::StringMap<const tblgen::Operator *> &
tblgen::Pattern::getSourcePatternBoundOps() {
return boundOps;
return srcBoundOps;
}
llvm::StringMap<const tblgen::Operator *> &
tblgen::Pattern::getResultPatternBoundOps() {
return resBoundOps;
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
@ -248,7 +256,20 @@ tblgen::Pattern::getLocation() const {
return result;
}
void tblgen::Pattern::collectBoundArguments(DagNode tree) {
void tblgen::Pattern::collectBoundSymbols(DagNode tree,
SymbolOperatorMap &symOpMap,
bool isSrcPattern) {
auto treeName = tree.getSymbol();
if (!tree.isOperation()) {
if (!treeName.empty()) {
PrintFatalError(
def.getLoc(),
formatv("binding symbol '{0}' to non-operation unsupported right now",
treeName));
}
return;
}
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
auto numTreeArgs = tree.getNumArgs();
@ -262,19 +283,19 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) {
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
auto treeName = tree.getOpName();
if (!treeName.empty())
boundOps.try_emplace(treeName, &op);
symOpMap.try_emplace(treeName, &op);
// TODO(jpienaar): Expand to multiple matches.
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundArguments(treeArg);
} else {
collectBoundSymbols(treeArg, symOpMap, isSrcPattern);
} else if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
if (!treeArgName.empty())
boundArguments.try_emplace(treeArgName, op.getArg(i));
srcBoundArguments.try_emplace(treeArgName, op.getArg(i));
}
}
}

View File

@ -360,11 +360,12 @@ def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
def MultiResultOpEnum: I64EnumAttr<
"Multi-result op kinds", "", [
MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
MultiResultOpKind4, MultiResultOpKind5
MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
]>;
def ThreeResultOp : TEST_Op<"three_result"> {
@ -398,16 +399,21 @@ def AnotherTwoResultOp : TEST_Op<"another_two_result"> {
let results = (outs F32:$result1, F32:$result2);
}
def OneResultOp : TEST_Op<"one_result"> {
def OneResultOp1 : TEST_Op<"one_result1"> {
let arguments = (ins MultiResultOpEnum:$kind);
let results = (outs F32:$result1);
}
def AnotherOneResultOp : TEST_Op<"another_one_result"> {
def OneResultOp2 : TEST_Op<"one_result2"> {
let arguments = (ins MultiResultOpEnum:$kind);
let results = (outs I32:$result1);
}
def OneResultOp3 : TEST_Op<"one_result3"> {
let arguments = (ins F32:$input);
let results = (outs I32:$result1);
}
// Test using multi-result op as a whole
def : Pat<(ThreeResultOp MultiResultOpKind1),
(AnotherThreeResultOp MultiResultOpKind1)>;
@ -415,15 +421,15 @@ def : Pat<(ThreeResultOp MultiResultOpKind1),
// Test using multi-result op as a whole for partial replacement
def : Pattern<(ThreeResultOp MultiResultOpKind2),
[(TwoResultOp MultiResultOpKind2),
(OneResultOp MultiResultOpKind2)]>;
(OneResultOp1 MultiResultOpKind2)]>;
def : Pattern<(ThreeResultOp MultiResultOpKind3),
[(AnotherOneResultOp MultiResultOpKind3),
[(OneResultOp2 MultiResultOpKind3),
(AnotherTwoResultOp MultiResultOpKind3)]>;
// Test using results separately in a multi-result op
def : Pattern<(ThreeResultOp MultiResultOpKind4),
[(TwoResultOp:$res1__0 MultiResultOpKind4),
(OneResultOp MultiResultOpKind4),
(OneResultOp1 MultiResultOpKind4),
(TwoResultOp:$res2__1 MultiResultOpKind4)]>;
// Test referencing a single value in the value pack
@ -431,10 +437,21 @@ def HasNoUse: Constraint<
CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
// This rule only matches TwoResultOp if its second result has no use.
def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
[(AnotherOneResultOp MultiResultOpKind5),
(OneResultOp MultiResultOpKind5)],
[(OneResultOp2 MultiResultOpKind5),
(OneResultOp1 MultiResultOpKind5)],
[(HasNoUse $res__1)]>;
// Test using auxiliary ops for replacing multi-result op
def : Pattern<
(ThreeResultOp MultiResultOpKind6), [
// Auxiliary op generated to help building the final result but not
// directly used to replace the source op's results.
(TwoResultOp:$interm MultiResultOpKind6),
(OneResultOp3 $interm__1),
(AnotherTwoResultOp MultiResultOpKind6)
]>;
//===----------------------------------------------------------------------===//
// Test Directives
//===----------------------------------------------------------------------===//

View File

@ -1,89 +0,0 @@
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --dump-input-on-failure
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "";
}
class NS_Op<string mnemonic, list<OpTrait> traits> :
Op<Test_Dialect, mnemonic, traits>;
def ThreeResultOp : NS_Op<"three_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1, I32:$r2, I32:$r3);
}
def TwoResultOp : NS_Op<"two_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1, I32:$r2);
}
def OneResultOp : NS_Op<"one_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1);
}
def b : Pattern<(ThreeResultOp $input), [
(OneResultOp (OneResultOp:$interm $input)),
(OneResultOp $interm),
(OneResultOp (OneResultOp $interm))
]>;
// CHECK-LABEL: struct b
// CHECK: void rewrite(
// CHECK: auto interm = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/s.input
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/interm
// CHECK: auto vOneResultOp1 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/interm
// CHECK: auto vOneResultOp2 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/interm
// CHECK: auto vOneResultOp3 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/vOneResultOp2
// CHECK: rewriter.replaceOp(op, {vOneResultOp0.getOperation()->getResult(0), vOneResultOp1.getOperation()->getResult(0), vOneResultOp3.getOperation()->getResult(0)});
// Test more result patterns than needed for replacement
// ---
def AdditionalOp : NS_Op<"additional_one_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1);
}
def c : Pattern<(TwoResultOp $input), [
// Additional op generated to help build the final result but not
// directly used to replace the source op
(AdditionalOp:$interm $input),
(OneResultOp $interm),
(OneResultOp $input)
]>;
// CHECK-LABEL: struct c
// CHECK: auto interm = rewriter.create<AdditionalOp>(
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/interm
// CHECK: auto vOneResultOp1 = rewriter.create<OneResultOp>(
// CHECK: rewriter.replaceOp(op, {vOneResultOp0.getOperation()->getResult(0), vOneResultOp1.getOperation()->getResult(0)});
// Test UnitAttr in rewrite patterns
// ---
def UnitAttrOp : NS_Op<"unit_attr_op", []> {
let arguments = (ins AnyAttr:$value);
let results = (outs NoneType:$output);
}
def NoneOp : NS_Op<"none_op", []> {
let results = (outs NoneType:$output);
}
def d : Pat<(UnitAttrOp UnitAttr:$ignore),
(NoneOp)>;
// CHECK-LABEL: struct d
// CHECK: PatternMatchResult match(Operation *op0) const override {
// CHECK: if (!((attr.isa<UnitAttr>()))) return matchFailure();

View File

@ -147,7 +147,7 @@ func @useMultiResultOpToReplaceWhole() -> (i32, f32, f32) {
// CHECK-LABEL: @useMultiResultOpToReplacePartial1
func @useMultiResultOpToReplacePartial1() -> (i32, f32, f32) {
// CHECK: %0:2 = "test.two_result"()
// CHECK: %1 = "test.one_result"()
// CHECK: %1 = "test.one_result1"()
// CHECK: return %0#0, %0#1, %1
%0:3 = "test.three_result"() {kind = 2} : () -> (i32, f32, f32)
return %0#0, %0#1, %0#2 : i32, f32, f32
@ -155,7 +155,7 @@ func @useMultiResultOpToReplacePartial1() -> (i32, f32, f32) {
// CHECK-LABEL: @useMultiResultOpToReplacePartial2
func @useMultiResultOpToReplacePartial2() -> (i32, f32, f32) {
// CHECK: %0 = "test.another_one_result"()
// CHECK: %0 = "test.one_result2"()
// CHECK: %1:2 = "test.another_two_result"()
// CHECK: return %0, %1#0, %1#1
%0:3 = "test.three_result"() {kind = 3} : () -> (i32, f32, f32)
@ -165,7 +165,7 @@ func @useMultiResultOpToReplacePartial2() -> (i32, f32, f32) {
// CHECK-LABEL: @useMultiResultOpResultsSeparately
func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) {
// CHECK: %0:2 = "test.two_result"()
// CHECK: %1 = "test.one_result"()
// CHECK: %1 = "test.one_result1"()
// CHECK: %2:2 = "test.two_result"()
// CHECK: return %0#0, %1, %2#1
%0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32)
@ -175,10 +175,23 @@ func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) {
// CHECK-LABEL: @constraintOnSourceOpResult
func @constraintOnSourceOpResult() -> (i32, f32, i32) {
// CHECK: %0:2 = "test.two_result"()
// CHECK: %1 = "test.another_one_result"()
// CHECK: %2 = "test.one_result"()
// CHECK: %1 = "test.one_result2"()
// CHECK: %2 = "test.one_result1"()
// CHECK: return %0#0, %0#1, %1
%0:2 = "test.two_result"() {kind = 5} : () -> (i32, f32)
%1:2 = "test.two_result"() {kind = 5} : () -> (i32, f32)
return %0#0, %0#1, %1#0 : i32, f32, i32
}
// CHECK-LABEL: @useAuxiliaryOpToReplaceMultiResultOp
func @useAuxiliaryOpToReplaceMultiResultOp() -> (i32, f32, f32) {
// An auxiliary op is generated to help building the op for replacing the
// matched op.
// CHECK: %0:2 = "test.two_result"()
// CHECK: %1 = "test.one_result3"(%0#1)
// CHECK: %2:2 = "test.another_two_result"()
// CHECK: return %1, %2#0, %2#1
%0:3 = "test.three_result"() {kind = 6} : () -> (i32, f32, f32)
return %0#0, %0#1, %0#2 : i32, f32, f32
}

View File

@ -153,6 +153,13 @@ public:
// values separated via comma.
std::string query(StringRef symbol) const;
// Returns how many static values the given `symbol` correspond to. Returns a
// negative value if the given symbol is not bound.
//
// Normally a symbol would correspond to just one value; for symbols bound to
// multi-result ops, it can be more than one.
int getValueCount(StringRef symbol) const;
private:
// Symbols bound to arguments in source pattern.
const StringMap<Argument> &sourceArguments;
@ -175,29 +182,44 @@ bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
}
std::string PatternSymbolResolver::query(StringRef symbol) const {
{
StringRef name = getValuePackName(symbol);
auto it = resultOps.find(name);
if (it != resultOps.end())
return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
it->second, /*offset=*/0);
}
{
auto it = sourceArguments.find(symbol);
if (it != sourceArguments.end())
return getBoundSymbol(symbol).str();
}
{
StringRef name = getValuePackName(symbol);
auto it = sourceOps.find(name);
if (it != sourceOps.end())
return formatValuePack("{0}->getResult({1})",
getBoundSymbol(symbol).str(),
it->second->getNumResults(), /*offset=*/0);
}
StringRef name = getValuePackName(symbol);
// Handle symbols bound to generated ops
auto resOpIt = resultOps.find(name);
if (resOpIt != resultOps.end())
return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
resOpIt->second, /*offset=*/0);
// Handle symbols bound to matched op arguments
auto srcArgIt = sourceArguments.find(symbol);
if (srcArgIt != sourceArguments.end())
return getBoundSymbol(symbol).str();
// Handle symbols bound to matched op results
auto srcOpIt = sourceOps.find(name);
if (srcOpIt != sourceOps.end())
return formatValuePack("{0}->getResult({1})", getBoundSymbol(symbol).str(),
srcOpIt->second->getNumResults(), /*offset=*/0);
return {};
}
int PatternSymbolResolver::getValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
// Handle symbols bound to generated ops
auto resOpIt = resultOps.find(name);
if (resOpIt != resultOps.end())
return name == symbol ? resOpIt->second : 1;
// Handle symbols bound to matched op arguments
if (sourceArguments.count(symbol))
return 1;
// Handle symbols bound to matched op results
auto srcOpIt = sourceOps.find(name);
if (srcOpIt != sourceOps.end())
return name == symbol ? srcOpIt->second->getNumResults() : 1;
return -1;
}
//===----------------------------------------------------------------------===//
// PatternEmitter
//===----------------------------------------------------------------------===//
@ -269,15 +291,18 @@ private:
// Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern.
std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
// Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
// is already bound.
void addSymbol(DagNode node);
void addSymbol(StringRef symbol, int numValues);
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
std::string resolveSymbol(StringRef symbol);
// Returns how many static values the given DAG `node` correspond to.
int getNodeValueCount(DagNode node);
private:
// Pattern instantiation location followed by the location of multiclass
// prototypes used. This is intended to be used as a whole to
@ -349,7 +374,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
}
// If the operand's name is set, set to that variable.
auto name = tree.getOpName();
auto name = tree.getSymbol();
if (!name.empty())
os.indent(indent) << formatv("{0} = op{1};\n", getBoundSymbol(name), depth);
@ -479,7 +504,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
// The rewrite pattern may specify that certain outputs should be unused in
// the source IR. Check it here.
for (int i = 0, e = pattern.getNumResults(); i < e; ++i) {
for (int i = 0, e = pattern.getNumResultPatterns(); i < e; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
if (resultTree.isVerifyUnusedValue()) {
handleVerifyUnusedValue(resultTree, i);
@ -543,7 +568,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Collect the set of result operations.
llvm::SmallPtrSet<const Operator *, 4> results;
for (unsigned i = 0, e = pattern.getNumResults(); i != e; ++i)
for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
collectOps(pattern.getResultPattern(i), results);
// Emit RewritePattern for Pattern.
@ -587,7 +612,36 @@ void PatternEmitter::emit(StringRef rewriteName) {
void PatternEmitter::emitRewriteMethod() {
const Operator &rootOp = pattern.getSourceRootOp();
int numExpectedResults = rootOp.getNumResults();
int numProvidedResults = pattern.getNumResults();
int numResultPatterns = pattern.getNumResultPatterns();
// First register all symbols bound to ops generated in result patterns.
for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
}
// Only the last N static values generated are used to replace the matched
// root N-result op. We need to calculate the starting index (of the results
// of the matched op) each result pattern is to replace.
SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
int replStartIndex = -1;
for (int i = numResultPatterns - 1; i >= 0; --i) {
auto numValues = getNodeValueCount(pattern.getResultPattern(i));
offsets[i] = offsets[i + 1] - numValues;
if (offsets[i] == 0) {
replStartIndex = i;
} else if (offsets[i] < 0 && offsets[i + 1] > 0) {
auto error = formatv(
"cannot use the same multi-result op '{0}' to generate both "
"auxiliary values and values to be used for replacing the matched op",
pattern.getResultPattern(i).getSymbol());
PrintFatalError(loc, error);
}
}
if (offsets.front() > 0) {
const char error[] = "no enough values generated to replace the matched op";
PrintFatalError(loc, error);
}
os << R"(
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
@ -602,18 +656,15 @@ void PatternEmitter::emitRewriteMethod() {
// Collect the replacement value for each result
llvm::SmallVector<std::string, 2> resultValues;
for (int i = 0; i < numProvidedResults; ++i) {
for (int i = 0; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleRewritePattern(resultTree, i, 0));
// Keep track of bound symbols at the top-level DAG nodes
addSymbol(resultTree);
resultValues.push_back(handleRewritePattern(resultTree, offsets[i], 0));
}
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op, {";
interleave(
// We only use the last numExpectedResults ones to replace the root op.
ArrayRef<std::string>(resultValues).take_back(numExpectedResults),
ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
[&](const std::string &name) { os << name; }, [&]() { os << ", "; });
os << "});\n }\n";
}
@ -635,7 +686,7 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
"verify top-level result");
}
if (!resultTree.getOpName().empty()) {
if (!resultTree.getSymbol().empty()) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
}
@ -666,7 +717,7 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
loc, "replaceWithValue directive must take exactly one argument");
}
if (!tree.getOpName().empty()) {
if (!tree.getSymbol().empty()) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
}
@ -680,8 +731,7 @@ void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
<< ")->use_empty()) return matchFailure();\n";
}
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
llvm::StringRef argName) {
std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
if (leaf.isConstantAttr()) {
auto constAttr = leaf.getAsConstantAttr();
return handleConstantAttr(constAttr.getAttribute(),
@ -722,12 +772,8 @@ std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
attrs[4], attrs[5], attrs[6], attrs[7]);
}
void PatternEmitter::addSymbol(DagNode node) {
StringRef symbol = node.getOpName();
// Skip empty-named symbols, which happen for unbound ops in result patterns.
if (symbol.empty())
return;
if (!symbolResolver.add(symbol, pattern.getDialectOp(node).getNumResults()))
void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
if (!symbolResolver.add(symbol, numValues))
PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
}
@ -738,6 +784,22 @@ std::string PatternEmitter::resolveSymbol(StringRef symbol) {
return subst;
}
int PatternEmitter::getNodeValueCount(DagNode node) {
if (node.isOperation()) {
// First to see whether this op is bound and we just want a specific result
// of it with `__N` suffix in symbol.
int count = symbolResolver.getValueCount(node.getSymbol());
if (count >= 0)
return count;
// No symbol. Then we are using all the results.
return pattern.getDialectOp(node).getNumResults();
}
// TODO(antiagainst): This considers all NativeCodeCall as returning one
// value. Enhance if multi-value ones are needed.
return 1;
}
std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
int depth) {
Operator &resultOp = tree.getDialectOp(opMap);
@ -769,13 +831,11 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i)) {
childNodeNames[i] = handleRewritePattern(child, i, depth + 1);
// Keep track of bound symbols at the middle-level DAG nodes
addSymbol(child);
}
}
// Use the specified name for this op if available. Generate one otherwise.
std::string resultValue = tree.getOpName();
std::string resultValue = tree.getSymbol();
if (resultValue.empty())
resultValue = getUniqueValueName(&resultOp);
// Strip the index to get the name for the value pack. This will be used to
@ -794,12 +854,14 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
bool usePartialResults = valuePackName != resultValue;
if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
usePartialResults || depth > 0) {
usePartialResults || depth > 0 || resultIndex < 0) {
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
valuePackName, resultOp.getQualCppClassName());
} else {
// If depth == 0 we can use the equivalence of the source and target root
// ops in the pattern to determine the return type.
// If depth == 0 and resultIndex >= 0, it means we are replacing the values
// generated from the source pattern root op. Then we can use the source
// pattern's value types to determine the value type of the generated op
// here.
// We need to specify the types for all results.
auto resultTypes =