Fix RewriterGen to support using NativeCodeCall as auxiliary pattern

NativeCodeCall is handled differently than normal op creation in RewriterGen
(because its flexibility). It will only be materialized to output stream if
it is used. But when using it for auxiliary patterns, we still want the side
effect even if it is not replacing matched root op's results.

PiperOrigin-RevId: 275265467
This commit is contained in:
Lei Zhang 2019-10-17 08:39:13 -07:00 committed by A. Unique TensorFlower
parent 1358df19ca
commit 603117b2d6
4 changed files with 34 additions and 7 deletions

View File

@ -458,6 +458,15 @@ def : Pat<(OpNativeCodeCall1 $input1, $input2,
ConstBoolAttrFalse, $attr1, $attr2),
(UseOpResult $input2)>;
def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> {
let arguments = (ins I32:$input);
let results = (outs I32);
}
// Test that NativeCodeCall is not ignored if it is not used to directly
// replace the matched root op.
def : Pattern<(OpNativeCodeCall3 $input),
[(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>;
// Test AllAttrConstraintsOf.
def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> {
let arguments = (ins I64ArrayAttr:$attr);

View File

@ -26,6 +26,10 @@ static Value *chooseOperand(Value *input1, Value *input2, BoolAttr choice) {
return choice.getValue() ? input1 : input2;
}
static void createOpI(PatternRewriter &rewriter, Value *input) {
rewriter.create<OpI>(rewriter.getUnknownLoc(), input);
}
namespace {
#include "TestPatterns.inc"
} // end anonymous namespace

View File

@ -44,6 +44,14 @@ func @verifyNativeCodeCall(%arg0: i32, %arg1: i32) -> (i32, i32) {
return %0, %1: i32, i32
}
// CHECK-LABEL: verifyAuxiliaryNativeCodeCall
func @verifyAuxiliaryNativeCodeCall(%arg0: i32) -> (i32) {
// CHECK: test.op_i
// CHECK: test.op_k
%0 = "test.native_code_call3"(%arg0) : (i32) -> (i32)
return %0 : i32
}
// CHECK-LABEL: verifyAllAttrConstraintOf
func @verifyAllAttrConstraintOf() -> (i32, i32, i32) {
// CHECK: "test.all_attr_constraint_of2"

View File

@ -530,17 +530,23 @@ void PatternEmitter::emitRewriteLogic() {
}
os << "}); (void)loc;\n";
// Process each result pattern and record the result symbol.
llvm::SmallVector<std::string, 2> resultValues;
for (int i = 0; i < numResultPatterns; ++i) {
// Process auxiliary result patterns.
for (int i = 0; i < replStartIndex; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
auto val = handleResultPattern(resultTree, offsets[i], 0);
// Normal op creation will be streamed to `os` by the above call; but
// NativeCodeCall will only be materialized to `os` if it is used. Here
// we are handling auxiliary patterns so we want the side effect even if
// NativeCodeCall is not replacing matched root op's results.
if (resultTree.isNativeCodeCall())
os.indent(4) << val << ";\n";
}
// Process replacement result patterns.
os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
// Only use the last portion for replacing the matched root op's results.
auto range = llvm::makeArrayRef(resultValues).drop_front(replStartIndex);
for (const auto &val : range) {
for (int i = replStartIndex; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os.indent(4) << "\n";
// Resolve each symbol for all range use so that we can loop over them.
os << symbolInfoMap.getAllRangeUse(