Support referencing a single value generated by a matched multi-result op

It's quite common that we want to put further constraints on the matched
multi-result op's specific results. This CL enables referencing symbols
bound to source op with the `__N` syntax.

PiperOrigin-RevId: 260122401
This commit is contained in:
Lei Zhang 2019-07-26 04:31:15 -07:00 committed by A. Unique TensorFlower
parent 54175c240a
commit 9f02e88946
6 changed files with 57 additions and 26 deletions

View File

@ -208,7 +208,11 @@ public:
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
// Returns a reference to all the bound ops in the source pattern.
llvm::StringSet<> &getSourcePatternBoundOps();
// The returned map contains pointers to the operators inside the
// `RecordOperatorMap` passed-in when constructing this pattern; callers
// should guarantee the lifetime of the returned map does not exceed that
// of the `RecordOperatorMap`.
llvm::StringMap<const Operator *> &getSourcePatternBoundOps();
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@ -238,13 +242,15 @@ private:
const llvm::Record &def;
// All operators.
// TODO(antiagainst): we need a proper context manager, like MLIRContext,
// for managing the lifetime of shared entities.
RecordOperatorMap *recordOpMap;
// All bound op arguments.
llvm::StringMap<Argument> boundArguments;
// All bound ops.
llvm::StringSet<> boundOps;
llvm::StringMap<const Operator *> boundOps;
};
} // end namespace tblgen

View File

@ -187,7 +187,8 @@ tblgen::Pattern::getSourcePatternBoundArgs() {
return boundArguments;
}
llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundOps() {
llvm::StringMap<const tblgen::Operator *> &
tblgen::Pattern::getSourcePatternBoundOps() {
return boundOps;
}
@ -263,7 +264,7 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) {
// results generated from this op. It should be remembered as bound results.
auto treeName = tree.getOpName();
if (!treeName.empty())
boundOps.insert(treeName);
boundOps.try_emplace(treeName, &op);
// TODO(jpienaar): Expand to multiple matches.
for (int i = 0; i != numTreeArgs; ++i) {

View File

@ -337,11 +337,12 @@ def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
def MultiResultOpEnum: I64EnumAttr<
"Multi-result op kinds", "", [
MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
MultiResultOpKind4
MultiResultOpKind4, MultiResultOpKind5
]>;
def ThreeResultOp : TEST_Op<"three_result"> {
@ -403,6 +404,15 @@ def : Pattern<(ThreeResultOp MultiResultOpKind4),
(OneResultOp MultiResultOpKind4),
(TwoResultOp:$res2__1 MultiResultOpKind4)]>;
// Test referencing a single value in the value pack
def HasNoUse: Constraint<
CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
// This rule only matches TwoResultOp if its second result has no use.
def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
[(AnotherOneResultOp MultiResultOpKind5),
(OneResultOp MultiResultOpKind5)],
[(HasNoUse $res__1)]>;
//===----------------------------------------------------------------------===//
// Test Directives
//===----------------------------------------------------------------------===//

View File

@ -30,11 +30,10 @@ def OpD : NS_Op<"op_d", []> {
}
def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
def getResult0 : NativeCodeCall<"$_self->getResult(0)">;
def : Pattern<(OpA:$res_a $operand, $attr),
[(OpC:$res_c (OpB:$res_b $operand)),
(OpD $res_b, $res_c, getResult0:$res_a, $attr)],
(OpD $res_b, $res_c, $res_a, $attr)],
[(hasOneUse $res_a)]>;
// CHECK-LABEL: GeneratedConvert0
@ -55,7 +54,7 @@ def : Pattern<(OpA:$res_a $operand, $attr),
// CHECK: s.operand = op0->getOperand(0);
// CHECK: attr = op0->getAttrOfType<IntegerAttr>("attr");
// CHECK: s.attr = attr;
// CHECK: if (!((s.res_a->hasOneUse()))) return matchFailure();
// CHECK: if (!((s.res_a->getResult(0)->hasOneUse()))) return matchFailure();
// Test bound results in result pattern
// ---

View File

@ -157,3 +157,14 @@ func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) {
%0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32)
return %0#0, %0#1, %0#2 : i32, f32, f32
}
// CHECK-LABEL: @constraintOnSourceOpResult
func @constraintOnSourceOpResult() -> (i32, f32, i32) {
// CHECK: %0:2 = "test.two_result"()
// CHECK: %1 = "test.another_one_result"()
// CHECK: %2 = "test.one_result"()
// CHECK: return %0#0, %0#1, %1
%0:2 = "test.two_result"() {kind = 5} : () -> (i32, f32)
%1:2 = "test.two_result"() {kind = 5} : () -> (i32, f32)
return %0#0, %0#1, %1#0 : i32, f32, i32
}

View File

@ -87,8 +87,8 @@ static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
// This extracts one value from the pack if `symbol` contains an index,
// otherwise it extracts all values sequentially and returns them as a
// comma-separated list.
static std::string formtValuePack(const char *fmt, StringRef symbol,
unsigned count, unsigned offset) {
static std::string formatValuePack(const char *fmt, StringRef symbol,
unsigned count, unsigned offset) {
auto getNthValue = [fmt, offset](StringRef results,
unsigned index) -> std::string {
return formatv(fmt, results, index + offset);
@ -142,28 +142,31 @@ namespace {
class PatternSymbolResolver {
public:
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcOperations);
const StringMap<const Operator *> &srcOperations);
// Marks the given `symbol` as bound to a value pack with `numValues` and
// returns true on success. Returns false if the `symbol` is already bound.
bool add(StringRef symbol, int numValues);
// Queries the substitution for the given `symbol`.
// Queries the substitution for the given `symbol`. Returns empty string if
// symbol not found. If the symbol represents a value pack, returns all the
// values separated via comma.
std::string query(StringRef symbol) const;
private:
// Symbols bound to arguments in source pattern.
const StringMap<Argument> &sourceArguments;
// Symbols bound to ops (for their results) in source pattern.
const StringSet<> &sourceOps;
const StringMap<const Operator *> &sourceOps;
// Symbols bound to ops (for their results) in result patterns.
// Key: symbol; value: number of values inside the pack
StringMap<int> resultOps;
};
} // end anonymous namespace
PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcOperations)
PatternSymbolResolver::PatternSymbolResolver(
const StringMap<Argument> &srcArgs,
const StringMap<const Operator *> &srcOperations)
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
@ -176,8 +179,8 @@ std::string PatternSymbolResolver::query(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
auto it = resultOps.find(name);
if (it != resultOps.end())
return formtValuePack("{0}.getOperation()->getResult({1})", symbol,
it->second, 0);
return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
it->second, /*offset=*/0);
}
{
auto it = sourceArguments.find(symbol);
@ -185,9 +188,12 @@ std::string PatternSymbolResolver::query(StringRef symbol) const {
return getBoundSymbol(symbol).str();
}
{
auto it = sourceOps.find(symbol);
StringRef name = getValuePackName(symbol);
auto it = sourceOps.find(name);
if (it != sourceOps.end())
return getBoundSymbol(symbol).str();
return formatValuePack("{0}->getResult({1})",
getBoundSymbol(symbol).str(),
it->second->getNumResults(), /*offset=*/0);
}
return {};
}
@ -490,9 +496,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
auto cmd = "if (!({0})) return matchFailure();\n";
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("(*{0}->result_type_begin())",
resolveSymbol(entities.front()));
// TODO(jpienaar): Verify op only has one result.
auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
os.indent(4) << formatv(cmd,
tgfmt(condition, &matchCtx.withSelf(self.str())));
} else if (isa<AttrConstraint>(constraint)) {
@ -650,8 +654,8 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
// We need to get all the values out of this local variable if we've created a
// multi-result op.
const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
return formtValuePack("{0}.getOperation()->getResult({1})", results,
numResults, 0);
return formatValuePack("{0}.getOperation()->getResult({1})", results,
numResults, /*offset=*/0);
}
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
@ -799,8 +803,8 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
// We need to specify the types for all results.
auto resultTypes =
formtValuePack("op->getResult({1})->getType()", valuePackName,
resultOp.getNumResults(), resultIndex);
formatValuePack("op->getResult({1})->getType()", valuePackName,
resultOp.getNumResults(), resultIndex);
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc, {2}",
valuePackName, resultOp.getQualCppClassName(),