Replace the verifyUnusedValue directive with HasNoUseOf constraint

verifyUnusedValue is a bit strange given that it is specified in a
result pattern but used to generate match statements. Now we are
able to support multi-result ops better, we can retire it and replace
it with a HasNoUseOf constraint. This reduces the number of mechanisms.

PiperOrigin-RevId: 261166863
This commit is contained in:
Lei Zhang 2019-08-01 11:50:47 -07:00 committed by A. Unique TensorFlower
parent 88b175eea5
commit c72d849eb9
8 changed files with 28 additions and 101 deletions

View File

@ -1173,9 +1173,17 @@ class Results<dag rets> {
dag results = rets; dag results = rets;
} }
//===----------------------------------------------------------------------===//
// Common value constraints
//===----------------------------------------------------------------------===//
def HasNoUseOf: Constraint<
CPred<"$_self->use_begin() == $_self->use_end()">, "has no use">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Common op type constraints // Common op type constraints
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// These traits are for verifying properties of an op that require knowledge of // These traits are for verifying properties of an op that require knowledge of
// multiple arguments or results. For verifying properties of a single argument // multiple arguments or results. For verifying properties of a single argument
// or result, prefer operand type constraints. // or result, prefer operand type constraints.
@ -1407,10 +1415,4 @@ class NativeCodeCall<string expr> {
// so to replace the matched DAG with an existing SSA value. // so to replace the matched DAG with an existing SSA value.
def replaceWithValue; def replaceWithValue;
// Directive used in result pattern to indicate that no replacement is generated
// for the current result. Predicates are generated to make sure the
// corresponding result in source pattern is unused.
// syntax: (verifyUnusedValue)
def verifyUnusedValue;
#endif // OP_BASE #endif // OP_BASE

View File

@ -74,9 +74,13 @@ private:
// An constraint and the concrete entities to place the constraint on. // An constraint and the concrete entities to place the constraint on.
struct AppliedConstraint { struct AppliedConstraint {
AppliedConstraint(Constraint &&c, std::vector<std::string> &&e); AppliedConstraint(Constraint &&constraint, StringRef self,
std::vector<std::string> &&entities);
Constraint constraint; Constraint constraint;
// The symbol to replace `$_self` special placeholder in the constraint.
std::string self;
// The symbols to replace `$N` positional placeholders in the constraint.
std::vector<std::string> entities; std::vector<std::string> entities;
}; };

View File

@ -166,9 +166,6 @@ public:
// value. // value.
bool isReplaceWithValue() const; bool isReplaceWithValue() const;
// Returns true if this DAG node is the `verifyUnusedValue` directive.
bool isVerifyUnusedValue() const;
// Returns true if this DAG node is wrapping native code call. // Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const; bool isNativeCodeCall() const;

View File

@ -63,6 +63,6 @@ llvm::StringRef Constraint::getDescription() const {
return doc; return doc;
} }
AppliedConstraint::AppliedConstraint(Constraint &&c, AppliedConstraint::AppliedConstraint(Constraint &&constraint, StringRef self,
std::vector<std::string> &&e) std::vector<std::string> &&entities)
: constraint(c), entities(std::move(e)) {} : constraint(constraint), self(self), entities(std::move(entities)) {}

View File

@ -95,7 +95,7 @@ bool tblgen::DagNode::isNativeCodeCall() const {
} }
bool tblgen::DagNode::isOperation() const { bool tblgen::DagNode::isOperation() const {
return !(isNativeCodeCall() || isVerifyUnusedValue() || isReplaceWithValue()); return !(isNativeCodeCall() || isReplaceWithValue());
} }
llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
@ -151,11 +151,6 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue"; return dagOpDef->getName() == "replaceWithValue";
} }
bool tblgen::DagNode::isVerifyUnusedValue() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "verifyUnusedValue";
}
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) { : def(*def), recordOpMap(mapper) {
collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true); collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
@ -225,7 +220,7 @@ std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
entities.push_back(argName->getValue()); entities.push_back(argName->getValue());
ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
std::move(entities)); dagInit->getNameStr(), std::move(entities));
} }
return ret; return ret;
} }

View File

@ -475,13 +475,11 @@ def : Pattern<(ThreeResultOp MultiResultOpKind4),
(TwoResultOp:$res2__1 MultiResultOpKind4)]>; (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
// Test referencing a single value in the value pack // Test referencing a single value in the value pack
def HasNoUse: Constraint<
CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
// This rule only matches TwoResultOp if its second result has no use. // This rule only matches TwoResultOp if its second result has no use.
def : Pattern<(TwoResultOp:$res MultiResultOpKind5), def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
[(OneResultOp2 MultiResultOpKind5), [(OneResultOp2 MultiResultOpKind5),
(OneResultOp1 MultiResultOpKind5)], (OneResultOp1 MultiResultOpKind5)],
[(HasNoUse $res__1)]>; [(HasNoUseOf:$res__1)]>;
// Test using auxiliary ops for replacing multi-result op // Test using auxiliary ops for replacing multi-result op
def : Pattern< def : Pattern<
@ -494,20 +492,6 @@ def : Pattern<
(AnotherTwoResultOp MultiResultOpKind6) (AnotherTwoResultOp MultiResultOpKind6)
]>; ]>;
//===----------------------------------------------------------------------===//
// Test Directives
//===----------------------------------------------------------------------===//
// Test 'verifyUnusedValue'
def VUVTwoResultOp : TEST_Op<"vuv_two_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1, I32:$r2);
}
def VUVFoldTwoResultOp : Pattern<(VUVTwoResultOp $input), [
(verifyUnusedValue),
(replaceWithValue $input)
]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test Legalization // Test Legalization
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1,20 +0,0 @@
// RUN: mlir-opt -test-patterns %s | FileCheck %s
//===----------------------------------------------------------------------===//
// Test 'verifyUnusedValue'
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @match_success_on_unused_first_result
func @match_success_on_unused_first_result(%arg0 : i32) -> i32 {
// CHECK-NEXT: return {{.*}} : i32
%result:2 = "test.vuv_two_result_op"(%arg0) : (i32) -> (i32, i32)
return %result#1 : i32
}
// CHECK-LABEL: @match_fail_on_used_first_result
func @match_fail_on_used_first_result(%arg0 : i32) -> i32 {
// CHECK-NEXT: "test.vuv_two_result_op"(%arg0) : (i32) -> (i32, i32)
%result:2 = "test.vuv_two_result_op"(%arg0) : (i32) -> (i32, i32)
"foo.unknown_op"(%result#0) : (i32) -> ()
return %result#1 : i32
}

View File

@ -274,10 +274,6 @@ private:
// replacement. // replacement.
std::string handleReplaceWithValue(DagNode tree); std::string handleReplaceWithValue(DagNode tree);
// Handles the `verifyUnusedValue` directive: emitting C++ statements to check
// the `index`-th result of the source op is not used.
void handleVerifyUnusedValue(DagNode tree, int index);
// Emits the C++ statement to build a new op out of the given DAG `tree` and // Emits the C++ statement to build a new op out of the given DAG `tree` and
// returns the variable name that this op is assigned to. If the root op in // returns the variable name that this op is assigned to. If the root op in
// DAG `tree` has a specified name, the created op will be assigned to a // DAG `tree` has a specified name, the created op will be assigned to a
@ -502,15 +498,6 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
s.autogeneratedRewritePatternOps[0] = op0; s.autogeneratedRewritePatternOps[0] = op0;
)"; )";
// The rewrite pattern may specify that certain outputs should be unused in
// the source IR. Check it here.
for (int i = 0, e = pattern.getNumResultPatterns(); i < e; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
if (resultTree.isVerifyUnusedValue()) {
handleVerifyUnusedValue(resultTree, i);
}
}
emitOpMatch(tree, 0); emitOpMatch(tree, 0);
for (auto &appliedConstraint : pattern.getConstraints()) { for (auto &appliedConstraint : pattern.getConstraints()) {
@ -528,7 +515,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
PrintFatalError( PrintFatalError(
loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
} else { } else {
// TODO(fengliuai): replace formatv arguments with the exact specified // TODO(b/138794486): replace formatv arguments with the exact specified
// args. // args.
if (entities.size() > 4) { if (entities.size() > 4) {
PrintFatalError(loc, "only support up to 4-entity constraints now"); PrintFatalError(loc, "only support up to 4-entity constraints now");
@ -537,10 +524,14 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
int i = 0; int i = 0;
for (int e = entities.size(); i < e; ++i) for (int e = entities.size(); i < e; ++i)
names.push_back(resolveSymbol(entities[i])); names.push_back(resolveSymbol(entities[i]));
std::string self = appliedConstraint.self;
if (!self.empty())
self = resolveSymbol(self);
for (; i < 4; ++i) for (; i < 4; ++i)
names.push_back("<unused>"); names.push_back("<unused>");
os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0], os.indent(4) << formatv(cmd,
names[1], names[2], names[3])); tgfmt(condition, &matchCtx.withSelf(self),
names[0], names[1], names[2], names[3]));
} }
} }
@ -678,25 +669,6 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
if (resultTree.isNativeCodeCall()) if (resultTree.isNativeCodeCall())
return emitReplaceWithNativeCodeCall(resultTree); return emitReplaceWithNativeCodeCall(resultTree);
if (resultTree.isVerifyUnusedValue()) {
if (depth > 0) {
// TODO: Revisit this when we have use cases of matching an intermediate
// multi-result op with no uses of its certain results.
PrintFatalError(loc, "verifyUnusedValue directive can only be used to "
"verify top-level result");
}
if (!resultTree.getSymbol().empty()) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
}
// The C++ statements to check that this result value is unused are already
// emitted in the match() method. So returning a nullptr here directly
// should be safe because the C++ RewritePattern harness will use it to
// replace nothing.
return "nullptr";
}
if (resultTree.isReplaceWithValue()) if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree); return handleReplaceWithValue(resultTree);
@ -718,19 +690,12 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
} }
if (!tree.getSymbol().empty()) { if (!tree.getSymbol().empty()) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue"); PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
} }
return resolveSymbol(tree.getArgName(0)); return resolveSymbol(tree.getArgName(0));
} }
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
assert(tree.isVerifyUnusedValue());
os.indent(4) << "if (!op0->getResult(" << index
<< ")->use_empty()) return matchFailure();\n";
}
std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) { std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
if (leaf.isConstantAttr()) { if (leaf.isConstantAttr()) {
auto constAttr = leaf.getAsConstantAttr(); auto constAttr = leaf.getAsConstantAttr();
@ -759,7 +724,7 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) { std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
auto fmt = tree.getNativeCodeTemplate(); auto fmt = tree.getNativeCodeTemplate();
// TODO(fengliuai): replace formatv arguments with the exact specified args. // TODO(b/138794486): replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8); SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) { if (tree.getNumArgs() > 8) {
PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +