forked from OSchip/llvm-project
Support referencing a single value generated by a matched multi-result op
It's quite common that we want to put further constraints on the matched multi-result op's specific results. This CL enables referencing symbols bound to source op with the `__N` syntax. PiperOrigin-RevId: 260122401
This commit is contained in:
parent
54175c240a
commit
9f02e88946
|
@ -208,7 +208,11 @@ public:
|
|||
llvm::StringMap<Argument> &getSourcePatternBoundArgs();
|
||||
|
||||
// Returns a reference to all the bound ops in the source pattern.
|
||||
llvm::StringSet<> &getSourcePatternBoundOps();
|
||||
// The returned map contains pointers to the operators inside the
|
||||
// `RecordOperatorMap` passed-in when constructing this pattern; callers
|
||||
// should guarantee the lifetime of the returned map does not exceed that
|
||||
// of the `RecordOperatorMap`.
|
||||
llvm::StringMap<const Operator *> &getSourcePatternBoundOps();
|
||||
|
||||
// Returns the op that the root node of the source pattern matches.
|
||||
const Operator &getSourceRootOp();
|
||||
|
@ -238,13 +242,15 @@ private:
|
|||
const llvm::Record &def;
|
||||
|
||||
// All operators.
|
||||
// TODO(antiagainst): we need a proper context manager, like MLIRContext,
|
||||
// for managing the lifetime of shared entities.
|
||||
RecordOperatorMap *recordOpMap;
|
||||
|
||||
// All bound op arguments.
|
||||
llvm::StringMap<Argument> boundArguments;
|
||||
|
||||
// All bound ops.
|
||||
llvm::StringSet<> boundOps;
|
||||
llvm::StringMap<const Operator *> boundOps;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
|
|
|
@ -187,7 +187,8 @@ tblgen::Pattern::getSourcePatternBoundArgs() {
|
|||
return boundArguments;
|
||||
}
|
||||
|
||||
llvm::StringSet<> &tblgen::Pattern::getSourcePatternBoundOps() {
|
||||
llvm::StringMap<const tblgen::Operator *> &
|
||||
tblgen::Pattern::getSourcePatternBoundOps() {
|
||||
return boundOps;
|
||||
}
|
||||
|
||||
|
@ -263,7 +264,7 @@ void tblgen::Pattern::collectBoundArguments(DagNode tree) {
|
|||
// results generated from this op. It should be remembered as bound results.
|
||||
auto treeName = tree.getOpName();
|
||||
if (!treeName.empty())
|
||||
boundOps.insert(treeName);
|
||||
boundOps.try_emplace(treeName, &op);
|
||||
|
||||
// TODO(jpienaar): Expand to multiple matches.
|
||||
for (int i = 0; i != numTreeArgs; ++i) {
|
||||
|
|
|
@ -337,11 +337,12 @@ def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
|
|||
def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
|
||||
def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
|
||||
def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
|
||||
def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
|
||||
|
||||
def MultiResultOpEnum: I64EnumAttr<
|
||||
"Multi-result op kinds", "", [
|
||||
MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
|
||||
MultiResultOpKind4
|
||||
MultiResultOpKind4, MultiResultOpKind5
|
||||
]>;
|
||||
|
||||
def ThreeResultOp : TEST_Op<"three_result"> {
|
||||
|
@ -403,6 +404,15 @@ def : Pattern<(ThreeResultOp MultiResultOpKind4),
|
|||
(OneResultOp MultiResultOpKind4),
|
||||
(TwoResultOp:$res2__1 MultiResultOpKind4)]>;
|
||||
|
||||
// Test referencing a single value in the value pack
|
||||
def HasNoUse: Constraint<
|
||||
CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
|
||||
// This rule only matches TwoResultOp if its second result has no use.
|
||||
def : Pattern<(TwoResultOp:$res MultiResultOpKind5),
|
||||
[(AnotherOneResultOp MultiResultOpKind5),
|
||||
(OneResultOp MultiResultOpKind5)],
|
||||
[(HasNoUse $res__1)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Directives
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -30,11 +30,10 @@ def OpD : NS_Op<"op_d", []> {
|
|||
}
|
||||
|
||||
def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
|
||||
def getResult0 : NativeCodeCall<"$_self->getResult(0)">;
|
||||
|
||||
def : Pattern<(OpA:$res_a $operand, $attr),
|
||||
[(OpC:$res_c (OpB:$res_b $operand)),
|
||||
(OpD $res_b, $res_c, getResult0:$res_a, $attr)],
|
||||
(OpD $res_b, $res_c, $res_a, $attr)],
|
||||
[(hasOneUse $res_a)]>;
|
||||
|
||||
// CHECK-LABEL: GeneratedConvert0
|
||||
|
@ -55,7 +54,7 @@ def : Pattern<(OpA:$res_a $operand, $attr),
|
|||
// CHECK: s.operand = op0->getOperand(0);
|
||||
// CHECK: attr = op0->getAttrOfType<IntegerAttr>("attr");
|
||||
// CHECK: s.attr = attr;
|
||||
// CHECK: if (!((s.res_a->hasOneUse()))) return matchFailure();
|
||||
// CHECK: if (!((s.res_a->getResult(0)->hasOneUse()))) return matchFailure();
|
||||
|
||||
// Test bound results in result pattern
|
||||
// ---
|
||||
|
|
|
@ -157,3 +157,14 @@ func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) {
|
|||
%0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32)
|
||||
return %0#0, %0#1, %0#2 : i32, f32, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @constraintOnSourceOpResult
|
||||
func @constraintOnSourceOpResult() -> (i32, f32, i32) {
|
||||
// CHECK: %0:2 = "test.two_result"()
|
||||
// CHECK: %1 = "test.another_one_result"()
|
||||
// CHECK: %2 = "test.one_result"()
|
||||
// CHECK: return %0#0, %0#1, %1
|
||||
%0:2 = "test.two_result"() {kind = 5} : () -> (i32, f32)
|
||||
%1:2 = "test.two_result"() {kind = 5} : () -> (i32, f32)
|
||||
return %0#0, %0#1, %1#0 : i32, f32, i32
|
||||
}
|
||||
|
|
|
@ -87,8 +87,8 @@ static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
|
|||
// This extracts one value from the pack if `symbol` contains an index,
|
||||
// otherwise it extracts all values sequentially and returns them as a
|
||||
// comma-separated list.
|
||||
static std::string formtValuePack(const char *fmt, StringRef symbol,
|
||||
unsigned count, unsigned offset) {
|
||||
static std::string formatValuePack(const char *fmt, StringRef symbol,
|
||||
unsigned count, unsigned offset) {
|
||||
auto getNthValue = [fmt, offset](StringRef results,
|
||||
unsigned index) -> std::string {
|
||||
return formatv(fmt, results, index + offset);
|
||||
|
@ -142,28 +142,31 @@ namespace {
|
|||
class PatternSymbolResolver {
|
||||
public:
|
||||
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
||||
const StringSet<> &srcOperations);
|
||||
const StringMap<const Operator *> &srcOperations);
|
||||
|
||||
// Marks the given `symbol` as bound to a value pack with `numValues` and
|
||||
// returns true on success. Returns false if the `symbol` is already bound.
|
||||
bool add(StringRef symbol, int numValues);
|
||||
|
||||
// Queries the substitution for the given `symbol`.
|
||||
// Queries the substitution for the given `symbol`. Returns empty string if
|
||||
// symbol not found. If the symbol represents a value pack, returns all the
|
||||
// values separated via comma.
|
||||
std::string query(StringRef symbol) const;
|
||||
|
||||
private:
|
||||
// Symbols bound to arguments in source pattern.
|
||||
const StringMap<Argument> &sourceArguments;
|
||||
// Symbols bound to ops (for their results) in source pattern.
|
||||
const StringSet<> &sourceOps;
|
||||
const StringMap<const Operator *> &sourceOps;
|
||||
// Symbols bound to ops (for their results) in result patterns.
|
||||
// Key: symbol; value: number of values inside the pack
|
||||
StringMap<int> resultOps;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
|
||||
const StringSet<> &srcOperations)
|
||||
PatternSymbolResolver::PatternSymbolResolver(
|
||||
const StringMap<Argument> &srcArgs,
|
||||
const StringMap<const Operator *> &srcOperations)
|
||||
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
|
||||
|
||||
bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
|
||||
|
@ -176,8 +179,8 @@ std::string PatternSymbolResolver::query(StringRef symbol) const {
|
|||
StringRef name = getValuePackName(symbol);
|
||||
auto it = resultOps.find(name);
|
||||
if (it != resultOps.end())
|
||||
return formtValuePack("{0}.getOperation()->getResult({1})", symbol,
|
||||
it->second, 0);
|
||||
return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
|
||||
it->second, /*offset=*/0);
|
||||
}
|
||||
{
|
||||
auto it = sourceArguments.find(symbol);
|
||||
|
@ -185,9 +188,12 @@ std::string PatternSymbolResolver::query(StringRef symbol) const {
|
|||
return getBoundSymbol(symbol).str();
|
||||
}
|
||||
{
|
||||
auto it = sourceOps.find(symbol);
|
||||
StringRef name = getValuePackName(symbol);
|
||||
auto it = sourceOps.find(name);
|
||||
if (it != sourceOps.end())
|
||||
return getBoundSymbol(symbol).str();
|
||||
return formatValuePack("{0}->getResult({1})",
|
||||
getBoundSymbol(symbol).str(),
|
||||
it->second->getNumResults(), /*offset=*/0);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
@ -490,9 +496,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
auto cmd = "if (!({0})) return matchFailure();\n";
|
||||
|
||||
if (isa<TypeConstraint>(constraint)) {
|
||||
auto self = formatv("(*{0}->result_type_begin())",
|
||||
resolveSymbol(entities.front()));
|
||||
// TODO(jpienaar): Verify op only has one result.
|
||||
auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
|
||||
os.indent(4) << formatv(cmd,
|
||||
tgfmt(condition, &matchCtx.withSelf(self.str())));
|
||||
} else if (isa<AttrConstraint>(constraint)) {
|
||||
|
@ -650,8 +654,8 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
|
|||
// We need to get all the values out of this local variable if we've created a
|
||||
// multi-result op.
|
||||
const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
|
||||
return formtValuePack("{0}.getOperation()->getResult({1})", results,
|
||||
numResults, 0);
|
||||
return formatValuePack("{0}.getOperation()->getResult({1})", results,
|
||||
numResults, /*offset=*/0);
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
|
||||
|
@ -799,8 +803,8 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
|||
|
||||
// We need to specify the types for all results.
|
||||
auto resultTypes =
|
||||
formtValuePack("op->getResult({1})->getType()", valuePackName,
|
||||
resultOp.getNumResults(), resultIndex);
|
||||
formatValuePack("op->getResult({1})->getType()", valuePackName,
|
||||
resultOp.getNumResults(), resultIndex);
|
||||
|
||||
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc, {2}",
|
||||
valuePackName, resultOp.getQualCppClassName(),
|
||||
|
|
Loading…
Reference in New Issue