Support variadic ops in declarative rewrite rules

This CL extends declarative rewrite rules to support matching and
generating ops with variadic operands/results. For this, the
generated `matchAndRewrite()` method for each pattern now are
changed to

* Use "range" types for the local variables used to store captured
  values (`operand_range` for operands, `ArrayRef<Value *>` for
  values, *Op for results). This allows us to have a unified way
  of handling both single values and value ranges.
* Create local variables for each operand for op creation. If the
  operand is variadic, then a `SmallVector<Value*>` will be created
  to collect all values for that operand; otherwise a `Value*` will
  be created.
* Use a collective result type builder. All result types are
  specified via a single parameter to the builder.

We can use one result pattern to replace multiple results of the
matched root op. When that happens, it will require specifying
types for multiple results. Add a new collective-type builder.

PiperOrigin-RevId: 264588559
This commit is contained in:
Lei Zhang 2019-08-21 05:35:07 -07:00 committed by A. Unique TensorFlower
parent 69cf811d5b
commit 31cfee6077
6 changed files with 487 additions and 132 deletions

View File

@ -262,8 +262,19 @@ public:
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values). `name` is the
// name of the C++ variable that this symbol bounds to. `index` should only
// be used for indexing results.
std::string getValueAndRangeUse(StringRef name, int index) const;
// be used for indexing results. `fmt` is used to format each value.
// `separator` is used to separate values if this is a value range.
std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value range regardless of how many static values this symbol
// represents. `name` is the name of the C++ variable that this symbol
// bounds to. `index` should only be used for indexing results. `fmt` is
// used to format each value. `separator` is used to separate values in the
// range.
std::string getAllRangeUse(StringRef name, int index, const char *fmt,
const char *separator) const;
const Operator *op; // The op where the bound entity belongs
Kind kind; // The kind of the bound entity
@ -309,8 +320,18 @@ public:
// Returns a string containing the C++ expression for referencing this
// symbol as a value (if this symbol represents one static value) or a value
// range (if this symbol represents multiple static values).
std::string getValueAndRangeUse(StringRef symbol) const;
// range (if this symbol represents multiple static values). `fmt` is used to
// format each value. `separator` is used to seperate values if `symbol`
// represents a value range.
std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
// Returns a string containing the C++ expression for referencing this
// symbol as a value range regardless of how many static values this symbol
// represents. `fmt` is used to format each value. `seperator` is used to
// separate values in the range.
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;
// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on sucess. Returns `symbol`

View File

@ -204,45 +204,99 @@ tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
return formatv("{0} {1};\n", type, name);
}
case Kind::Operand:
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
return formatv("Operation::operand_range {0}(op0->getOperands());\n", name);
}
case Kind::Value: {
return formatv("Value *{0};\n", name);
return formatv("ArrayRef<Value *> {0};\n", name);
}
case Kind::Result: {
// Use the op itself for the results.
// Use the op itself for captured results.
return formatv("{0} {1};\n", op->getQualCppClassName(), name);
}
}
llvm_unreachable("unknown kind");
}
std::string
tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
int index) const {
std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
switch (kind) {
case Kind::Attr: {
assert(index < 0);
return formatv(fmt, name);
}
case Kind::Operand: {
assert(index < 0);
auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
// If this operand is variadic, then return a range. Otherwise, return the
// value itself.
if (operand->isVariadic()) {
return formatv(fmt, name);
}
return formatv(fmt, formatv("(*{0}.begin())", name));
}
case Kind::Result: {
// If `index` is greater than zero, then we are referencing a specific
// result of a multi-result op. The result can still be variadic.
if (index >= 0) {
std::string v = formatv("{0}.getODSResults({1})", name, index);
if (!op->getResult(index).isVariadic())
v = formatv("(*{0}.begin())", v);
return formatv(fmt, v);
}
// We are referencing all results of the multi-result op. A specific result
// can either be a value or a range. Then join them with `separator`.
SmallVector<std::string, 4> values;
values.reserve(op->getNumResults());
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
std::string v = formatv("{0}.getODSResults({1})", name, i);
if (!op->getResult(i).isVariadic()) {
v = formatv("(*{0}.begin())", v);
}
values.push_back(formatv(fmt, v));
}
return llvm::join(values, separator);
}
case Kind::Value: {
assert(index < 0);
assert(op == nullptr);
return formatv(fmt, name);
}
}
}
std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
StringRef name, int index, const char *fmt, const char *separator) const {
switch (kind) {
case Kind::Attr:
case Kind::Operand: {
assert(index < 0 && "only allowed for symbol bound to result");
return name;
return formatv(fmt, name);
}
case Kind::Result: {
// TODO(b/133341698): The following is incorrect for variadic results. We
// should use getODSResults().
if (index >= 0) {
return formatv("{0}.getOperation()->getResult({1})", name, index);
return formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
}
// If referencing multiple results, compose a comma-separated list.
// We are referencing all results of the multi-result op. Each result should
// have a value range, and then join them with `separator`.
SmallVector<std::string, 4> values;
values.reserve(op->getNumResults());
for (int i = 0, e = op->getNumResults(); i < e; ++i) {
values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
values.push_back(
formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
}
return llvm::join(values, ", ");
return llvm::join(values, separator);
}
case Kind::Value: {
assert(index < 0 && "only allowed for symbol bound to result");
assert(op == nullptr);
return name;
return formatv(fmt, formatv("{{{0}}", name));
}
}
llvm_unreachable("unknown kind");
@ -294,7 +348,9 @@ int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return find(name)->getValue().getStaticValueCount();
}
std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
std::string
tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);
@ -304,7 +360,22 @@ std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
PrintFatalError(loc, error);
}
return it->getValue().getValueAndRangeUse(name, index);
return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
}
std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);
auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}
return it->getValue().getAllRangeUse(name, index, fmt, separator);
}
//===----------------------------------------------------------------------===//

View File

@ -505,6 +505,116 @@ def : Pattern<
(AnotherTwoResultOp MultiResultOpKind6)
]>;
//===----------------------------------------------------------------------===//
// Test Patterns (Variadic Ops)
def OneVResOneVOperandOp1 : TEST_Op<"one_variadic_out_one_variadic_in1"> {
let arguments = (ins Variadic<I32>:$inputs);
let results = (outs Variadic<I32>:$outputs);
}
def OneVResOneVOperandOp2 : TEST_Op<"one_variadic_out_one_variadic_in2"> {
let arguments = (ins Variadic<I32>:$inputs);
let results = (outs Variadic<I32>:$outputs);
}
// Rewrite an op with one variadic operand and one variadic result to
// another similiar op.
def : Pat<(OneVResOneVOperandOp1 $inputs), (OneVResOneVOperandOp2 $inputs)>;
def MixedVOperandOp1 : TEST_Op<"mixed_variadic_in1",
[SameVariadicOperandSize]> {
let arguments = (ins
Variadic<I32>:$input1,
F32:$input2,
Variadic<I32>:$input3
);
}
def MixedVOperandOp2 : TEST_Op<"mixed_variadic_in2",
[SameVariadicOperandSize]> {
let arguments = (ins
Variadic<I32>:$input1,
F32:$input2,
Variadic<I32>:$input3
);
}
// Rewrite an op with both variadic operands and normal operands.
def : Pat<(MixedVOperandOp1 $input1, $input2, $input3),
(MixedVOperandOp2 $input1, $input2, $input3)>;
def MixedVResultOp1 : TEST_Op<"mixed_variadic_out1", [SameVariadicResultSize]> {
let results = (outs
Variadic<I32>:$output1,
F32:$output2,
Variadic<I32>:$output3
);
}
def MixedVResultOp2 : TEST_Op<"mixed_variadic_out2", [SameVariadicResultSize]> {
let results = (outs
Variadic<I32>:$output1,
F32:$output2,
Variadic<I32>:$output3
);
}
// Rewrite an op with both variadic results and normal results.
// Note that because we are generating the op with a top-level result pattern,
// we are able to deduce the correct result types for the generated op using
// the information from the matched root op.
def : Pat<(MixedVResultOp1), (MixedVResultOp2)>;
def OneI32ResultOp : TEST_Op<"one_i32_out"> {
let results = (outs I32:$output);
}
def MixedVOperandOp3 : TEST_Op<"mixed_variadic_in3",
[SameVariadicOperandSize]> {
let arguments = (ins
I32:$input1,
Variadic<I32>:$input2,
Variadic<I32>:$input3,
I32Attr:$count
);
let results = (outs I32:$output);
}
def MixedVResultOp3 : TEST_Op<"mixed_variadic_out3",
[SameVariadicResultSize]> {
let arguments = (ins I32Attr:$count);
let results = (outs
I32:$output1,
Variadic<I32>:$output2,
Variadic<I32>:$output3
);
// We will use this op in a nested result pattern, where we cannot deduce the
// result type. So need to provide a builder not requiring result types.
let builders = [
OpBuilder<
"Builder *builder, OperationState *state, IntegerAttr count",
[{
auto i32Type = builder->getIntegerType(32);
state->addTypes(i32Type); // $ouput1
SmallVector<Type, 4> types(count.getInt(), i32Type);
state->addTypes(types); // $ouput2
state->addTypes(types); // $ouput3
state->addAttribute("count", count);
}]>
];
}
// Generates an op with variadic results using nested pattern.
def : Pat<(OneI32ResultOp),
(MixedVOperandOp3
(MixedVResultOp3:$results__0 ConstantAttr<I32Attr, "2">),
(replaceWithValue $results__1),
(replaceWithValue $results__2),
ConstantAttr<I32Attr, "2">)>;
//===----------------------------------------------------------------------===//
// Test Legalization
//===----------------------------------------------------------------------===//

View File

@ -215,3 +215,55 @@ func @useAuxiliaryOpToReplaceMultiResultOp() -> (i32, f32, f32) {
%0:3 = "test.three_result"() {kind = 6} : () -> (i32, f32, f32)
return %0#0, %0#1, %0#2 : i32, f32, f32
}
//===----------------------------------------------------------------------===//
// Test Multi-result Ops
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @replaceOneVariadicOutOneVariadicInOp
func @replaceOneVariadicOutOneVariadicInOp(%arg0: i32, %arg1: i32, %arg2: i32) -> (i32, i32, i32, i32, i32, i32) {
// CHECK: %[[cnt1:.*]] = "test.one_variadic_out_one_variadic_in2"(%arg0)
// CHECK: %[[cnt2:.*]]:2 = "test.one_variadic_out_one_variadic_in2"(%arg0, %arg1)
// CHECK: %[[cnt3:.*]]:3 = "test.one_variadic_out_one_variadic_in2"(%arg0, %arg1, %arg2)
// CHECK: return %[[cnt1]], %[[cnt2]]#0, %[[cnt2]]#1, %[[cnt3]]#0, %[[cnt3]]#1, %[[cnt3]]#2
%0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> (i32)
%1:2 = "test.one_variadic_out_one_variadic_in1"(%arg0, %arg1) : (i32, i32) -> (i32, i32)
%2:3 = "test.one_variadic_out_one_variadic_in1"(%arg0, %arg1, %arg2) : (i32, i32, i32) -> (i32, i32, i32)
return %0, %1#0, %1#1, %2#0, %2#1, %2#2 : i32, i32, i32, i32, i32, i32
}
// CHECK-LABEL: @replaceMixedVariadicInputOp
func @replaceMixedVariadicInputOp(%arg0: i32, %arg1: f32, %arg2: i32) -> () {
// CHECK: "test.mixed_variadic_in2"(%arg1)
// CHECK: "test.mixed_variadic_in2"(%arg0, %arg1, %arg2)
// CHECK: "test.mixed_variadic_in2"(%arg0, %arg0, %arg1, %arg2, %arg2)
"test.mixed_variadic_in1"(%arg1) : (f32) -> ()
"test.mixed_variadic_in1"(%arg0, %arg1, %arg2) : (i32, f32, i32) -> ()
"test.mixed_variadic_in1"(%arg0, %arg0, %arg1, %arg2, %arg2) : (i32, i32, f32, i32, i32) -> ()
return
}
// CHECK-LABEL: @replaceMixedVariadicOutputOp
func @replaceMixedVariadicOutputOp() -> (f32, i32, f32, i32, i32, i32, f32, i32, i32) {
// CHECK: %[[cnt1:.*]] = "test.mixed_variadic_out2"()
// CHECK: %[[cnt3:.*]]:3 = "test.mixed_variadic_out2"()
// CHECK: %[[cnt5:.*]]:5 = "test.mixed_variadic_out2"()
// CHECK: return %[[cnt1]], %[[cnt3]]#0, %[[cnt3]]#1, %[[cnt3]]#2, %[[cnt5]]#0, %[[cnt5]]#1, %[[cnt5]]#2, %[[cnt5]]#3, %[[cnt5]]#4
%0 = "test.mixed_variadic_out1"() : () -> (f32)
%1:3 = "test.mixed_variadic_out1"() : () -> (i32, f32, i32)
%2:5 = "test.mixed_variadic_out1"() : () -> (i32, i32, f32, i32, i32)
return %0, %1#0, %1#1, %1#2, %2#0, %2#1, %2#2, %2#3, %2#4 : f32, i32, f32, i32, i32, i32, f32, i32, i32
}
// CHECK-LABEL: @generateVaridicOutputOpInNestedPattern
func @generateVaridicOutputOpInNestedPattern() -> (i32) {
// CHECK: %[[cnt5:.*]]:5 = "test.mixed_variadic_out3"()
// CHECK: %[[res:.*]] = "test.mixed_variadic_in3"(%[[cnt5]]#0, %[[cnt5]]#1, %[[cnt5]]#2, %[[cnt5]]#3, %[[cnt5]]#4)
// CHECK: return %[[res]]
%0 = "test.one_i32_out"() : () -> (i32)
return %0 : i32
}

View File

@ -481,6 +481,10 @@ private:
// result types for all results.
void genSeparateParamBuilder();
// Generates the build() method that takes a single parameter for all the
// result types and a separate parameter for each operand/attribute.
void genCollectiveTypeParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. This build() method uses first operand's type
// as all result's types.
@ -495,6 +499,8 @@ private:
// one parameter. Similarly for operands and attributes.
void genCollectiveParamBuilder();
enum class TypeParamKind { None, Separate, Collective };
// Builds the parameter list for build() method of this op. This method writes
// to `paramList` the comma-separated parameter list. If `includeResultTypes`
// is true then `paramList` will also contain the parameters for all results
@ -502,7 +508,7 @@ private:
// result type.
void buildParamList(std::string &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
bool includeResultTypes);
TypeParamKind kind);
// Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body);
@ -765,7 +771,7 @@ void OpEmitter::genNamedRegionGetters() {
void OpEmitter::genSeparateParamBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, /*includeResultTypes=*/true);
buildParamList(paramList, resultNames, TypeParamKind::Separate);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
@ -777,10 +783,35 @@ void OpEmitter::genSeparateParamBuilder() {
}
}
void OpEmitter::genCollectiveTypeParamBuilder() {
auto numResults = op.getNumResults();
// If this op has no results, then just skip generating this builder.
// Otherwise we are generating the same signature as the separate-parameter
// builder.
if (numResults == 0)
return;
// Similarly for ops with one single variadic result, which will also have one
// `ArrayRef<Type>` parameter for the result type.
if (numResults == 1 && op.getResult(0).isVariadic())
return;
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::Collective);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
// Push all result types to the operation state
m.body() << formatv(" {0}->addTypes(resultTypes);\n", builderOpState);
}
void OpEmitter::genUseOperandAsResultTypeBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, /*includeResultTypes=*/false);
buildParamList(paramList, resultNames, TypeParamKind::None);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
@ -802,7 +833,7 @@ void OpEmitter::genUseOperandAsResultTypeBuilder() {
void OpEmitter::genUseAttrAsResultTypeBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, /*includeResultTypes=*/false);
buildParamList(paramList, resultNames, TypeParamKind::None);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
@ -861,10 +892,13 @@ void OpEmitter::genBuilder() {
// 1. one having a stand-alone parameter for each result type / operand /
// attribute, and
genSeparateParamBuilder();
// 2. one having an aggregated parameter for all result types / operands /
// 2. one having a stand-alone parameter for each operand / attribute and
// an aggregrated parameter for all result types, and
genCollectiveTypeParamBuilder();
// 3. one having an aggregated parameter for all result types / operands /
// attributes, and
genCollectiveParamBuilder();
// 3. one having a stand-alone prameter for each operand and attribute,
// 4. one having a stand-alone prameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariadicResults() == 0) {
@ -920,16 +954,18 @@ void OpEmitter::genCollectiveParamBuilder() {
void OpEmitter::buildParamList(std::string &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
bool includeResultTypes) {
TypeParamKind kind) {
resultTypeNames.clear();
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
paramList = "Builder *, OperationState *";
paramList.append(builderOpState);
if (includeResultTypes) {
resultTypeNames.clear();
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
switch (kind) {
case TypeParamKind::None:
break;
case TypeParamKind::Separate: {
// Add parameters for all return types
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
@ -942,6 +978,11 @@ void OpEmitter::buildParamList(std::string &paramList,
resultTypeNames.emplace_back(std::move(resultName));
}
} break;
case TypeParamKind::Collective: {
paramList.append(", ArrayRef<Type> resultTypes");
resultTypeNames.push_back("resultTypes");
} break;
}
int numOperands = 0;
@ -1226,8 +1267,8 @@ void OpEmitter::genTraits() {
if (numResults == numVariadicResults)
opClass.addTrait("OpTrait::VariadicResults");
else
opClass.addTrait("OpTrait::AtLeastNResults<" + Twine(numResults - 1) +
">::Impl");
opClass.addTrait("OpTrait::AtLeastNResults<" +
Twine(numResults - numVariadicResults) + ">::Impl");
} else {
switch (numResults) {
case 0:
@ -1256,8 +1297,8 @@ void OpEmitter::genTraits() {
if (numOperands == numVariadicOperands)
opClass.addTrait("OpTrait::VariadicOperands");
else
opClass.addTrait("OpTrait::AtLeastNOperands<" + Twine(numOperands - 1) +
">::Impl");
opClass.addTrait("OpTrait::AtLeastNOperands<" +
Twine(numOperands - numVariadicOperands) + ">::Impl");
} else {
switch (numOperands) {
case 0:

View File

@ -90,9 +90,16 @@ private:
// Rewrite utilities
//===--------------------------------------------------------------------===//
// Entry point for handling a result pattern rooted at `resultTree` and
// dispatches to concrete handlers. The given tree is the `resultIndex`-th
// argument of the enclosing DAG.
// The entry point for handling a result pattern rooted at `resultTree`. This
// method dispatches to concrete handlers according to `resultTree`'s kind and
// returns a symbol representing the whole value pack. Callers are expected to
// further resolve the symbol according to the specific use case.
//
// `depth` is the nesting level of `resultTree`; 0 means top-level result
// pattern. For top-level result pattern, `resultIndex` indicates which result
// of the matched root op this pattern is intended to replace, which can be
// used to deduce the result type of the op generated from this result
// pattern.
std::string handleResultPattern(DagNode resultTree, int resultIndex,
int depth);
@ -133,9 +140,6 @@ private:
// Symbol utilities
//===--------------------------------------------------------------------===//
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
std::string resolveSymbol(StringRef symbol);
// Returns how many static values the given DAG `node` correspond to.
int getNodeValueCount(DagNode node);
@ -187,11 +191,6 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
// Helper function to match patterns.
void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
Operator &op = tree.getDialectOp(opMap);
if (op.isVariadic()) {
PrintFatalError(loc, formatv("matching op '{0}' with variadic "
"operands/results is unsupported right now",
op.getOperationName()));
}
int indent = 4 + 2 * depth;
os.indent(indent) << formatv(
@ -220,10 +219,20 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
if (operand->isVariadic()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
op.getOperationName(), i);
PrintFatalError(loc, error);
}
}
os.indent(indent) << "{\n";
os.indent(indent + 2)
<< formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
depth + 1, depth, i);
os.indent(indent + 2) << formatv(
"auto *op{0} = "
"(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent + 2)
<< formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
@ -260,7 +269,15 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
// Only need to verify if the matcher's type is different from the one
// of op definition.
if (operand->constraint != matcher.getAsConstraint()) {
auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
if (operand->isVariadic()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), index);
PrintFatalError(loc, error);
}
auto self =
formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
depth, index);
os.indent(indent) << "if (!("
<< tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf(self))
@ -271,8 +288,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
// Capture the value
auto name = tree.getArgName(index);
if (!name.empty()) {
os.indent(indent) << formatv("{0} = op{1}->getOperand({2});\n", name, depth,
index);
os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
name, depth, index);
}
}
@ -339,7 +356,8 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
auto cmd = "if (!({0})) return matchFailure();\n";
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("({0}->getType())", resolveSymbol(entities.front()));
auto self = formatv("({0}->getType())",
symbolInfoMap.getValueAndRangeUse(entities.front()));
os.indent(4) << formatv(cmd,
tgfmt(condition, &fmtCtx.withSelf(self.str())));
} else if (isa<AttrConstraint>(constraint)) {
@ -354,10 +372,10 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
SmallVector<std::string, 4> names;
int i = 0;
for (int e = entities.size(); i < e; ++i)
names.push_back(resolveSymbol(entities[i]));
names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
std::string self = appliedConstraint.self;
if (!self.empty())
self = resolveSymbol(self);
self = symbolInfoMap.getValueAndRangeUse(self);
for (; i < 4; ++i)
names.push_back("<unused>");
os.indent(4) << formatv(cmd,
@ -476,25 +494,31 @@ void PatternEmitter::emitRewriteLogic() {
PrintFatalError(loc, error);
}
os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
os.indent(4) << "auto loc = rewriter.getFusedLoc({";
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
}
os << "}); (void)loc;\n";
// Collect the replacement value for each result
// Process each result pattern and record the result symbol.
llvm::SmallVector<std::string, 2> resultValues;
for (int i = 0; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
}
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op0, {";
interleaveComma(
ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os,
[&](const std::string &symbol) { os << resolveSymbol(symbol); });
os << "});\n";
os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
// Only use the last portion for replacing the matched root op's results.
auto range = llvm::makeArrayRef(resultValues).drop_front(replStartIndex);
for (const auto &val : range) {
os.indent(4) << "\n";
// Resolve each symbol for all range use so that we can loop over them.
os << symbolInfoMap.getAllRangeUse(
val, " for (auto *v : {0}) tblgen_values.push_back(v);", "\n");
}
os.indent(4) << "\n";
os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
}
std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
@ -535,10 +559,11 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
}
return resolveSymbol(tree.getArgName(0));
return tree.getArgName(0);
}
std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
StringRef patArgName) {
if (leaf.isConstantAttr()) {
auto constAttr = leaf.getAsConstantAttr();
return handleConstantAttr(constAttr.getAttribute(),
@ -553,6 +578,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
std::string val = std::to_string(enumCase.getValue());
return handleConstantAttr(enumCase, val);
}
auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return argName;
}
@ -577,14 +604,6 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
attrs[5], attrs[6], attrs[7]);
}
std::string PatternEmitter::resolveSymbol(StringRef symbol) {
auto subst = symbolInfoMap.getValueAndRangeUse(symbol);
if (subst.empty()) {
PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
}
return subst;
}
int PatternEmitter::getNodeValueCount(DagNode node) {
if (node.isOperation()) {
// If the op is bound to a symbol in the rewrite rule, query its result
@ -606,12 +625,6 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs();
if (resultOp.isVariadic()) {
PrintFatalError(loc, formatv("generating op '{0}' with variadic "
"operands/results is unsupported now",
resultOp.getOperationName()));
}
if (numOpArgs != tree.getNumArgs()) {
PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
@ -620,30 +633,88 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
// A map to collect all nested DAG child nodes' names, with operand index as
// the key. This includes both bound and unbound child nodes. Bound child
// nodes will additionally be tracked in `symbolResolver` so they can be
// referenced by other patterns. Unbound child nodes will only be used once
// to build this op.
// the key. This includes both bound and unbound child nodes.
llvm::DenseMap<unsigned, std::string> childNodeNames;
// First go through all the child nodes who are nested DAG constructs to
// create ops for them, so that we can use the results in the current node.
// This happens in a recursive manner.
// create ops for them and remember the symbol names for them, so that we can
// use the results in the current node. This happens in a recursive manner.
for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
if (auto child = tree.getArgAsNestedDag(i)) {
childNodeNames[i] = handleResultPattern(child, i, depth + 1);
}
}
// Use the specified name for this op if available. Generate one otherwise.
std::string resultValue = tree.getSymbol();
if (resultValue.empty())
resultValue = getUniqueSymbol(&resultOp);
// Strip the index to get the name for the value pack. This will be used to
// name the local variable for the op.
StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue);
// The name of the local variable holding this op.
std::string valuePackName;
// The symbol for holding the result of this pattern. Note that the result of
// this pattern is not necessarily the same as the variable created by this
// pattern because we can use `__N` suffix to refer only a specific result if
// the generated op is a multi-result op.
std::string resultValue;
if (tree.getSymbol().empty()) {
// No symbol is explicitly bound to this op in the pattern. Generate a
// unique name.
valuePackName = resultValue = getUniqueSymbol(&resultOp);
} else {
resultValue = tree.getSymbol();
// Strip the index to get the name for the value pack and use it to name the
// local variable for the op.
valuePackName = SymbolInfoMap::getValuePackName(resultValue);
}
// Then we build the new op corresponding to this DAG node.
// Create the local variable for this op.
os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
valuePackName);
os.indent(4) << "{\n";
// Now prepare operands used for building this op:
// * If the operand is non-variadic, we create a `Value*` local variable.
// * If the operand is variadic, we create a `SmallVector<Value*>` local
// variable.
int argIndex = 0; // The current index to this op's ODS argument
int valueIndex = 0; // An index for uniquing local variable names.
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
const auto &operand = resultOp.getOperand(argIndex);
std::string varName;
if (operand.isVariadic()) {
varName = formatv("tblgen_values_{0}", valueIndex++);
os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
std::string range;
if (tree.isNestedDagArg(argIndex)) {
range = childNodeNames[argIndex];
} else {
range = tree.getArgName(argIndex);
}
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
varName);
} else {
varName = formatv("tblgen_value_{0}", valueIndex++);
os.indent(6) << formatv("Value *{0} = ", varName);
if (tree.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
} else {
DagLeaf leaf = tree.getArgAsLeaf(argIndex);
auto symbol =
symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
if (leaf.isNativeCodeCall()) {
os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
} else {
os << symbol;
}
}
os << ";\n";
}
// Update to use the newly created local variable for building the op later.
childNodeNames[argIndex] = varName;
}
// Then we create the builder call.
// Right now we don't have general type inference in MLIR. Except a few
// special cases listed below, we need to supply types for all results
@ -657,8 +728,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
usePartialResults || depth > 0 || resultIndex < 0) {
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
valuePackName, resultOp.getQualCppClassName());
os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
resultOp.getQualCppClassName());
} else {
// If depth == 0 and resultIndex >= 0, it means we are replacing the values
// generated from the source pattern root op. Then we can use the source
@ -666,50 +737,38 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// here.
// We need to specify the types for all results.
SmallVector<std::string, 4> resultTypes;
int numResults = resultOp.getNumResults();
resultTypes.reserve(numResults);
for (int i = 0; i < numResults; ++i) {
resultTypes.push_back(
formatv("op0->getResult({0})->getType()", resultIndex + i));
}
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
valuePackName, resultOp.getQualCppClassName())
<< (resultTypes.empty() ? "" : ", ")
<< llvm::join(resultTypes, ", ");
}
// Create the builder call for the result.
// Add operands.
int argIndex = 0;
for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
const auto &operand = resultOp.getOperand(argIndex);
// Start each operand on its own line.
(os << ",\n").indent(6);
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
if (tree.isNestedDagArg(argIndex)) {
os << childNodeNames[argIndex];
} else {
DagLeaf leaf = tree.getArgAsLeaf(argIndex);
auto symbol = resolveSymbol(tree.getArgName(argIndex));
if (leaf.isNativeCodeCall()) {
os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
} else {
os << symbol;
if (numResults != 0) {
os.indent(6) << "tblgen_types.clear();\n";
for (int i = 0; i < numResults; ++i) {
os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
"tblgen_types.push_back(v->getType());\n",
resultIndex + i);
}
}
os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
resultOp.getQualCppClassName());
if (numResults != 0)
os.indent(6) << ", tblgen_types";
}
// Add operands for the builder all.
for (int i = 0; i < argIndex; ++i) {
const auto &operand = resultOp.getOperand(i);
// Start each operand on its own line.
(os << ",\n").indent(8);
if (!operand.name.empty()) {
os << "/*" << operand.name << "=*/";
}
os << childNodeNames[i];
// TODO(jpienaar): verify types
}
// Add attributes.
// Add attributes for the builder call.
for (; argIndex != numOpArgs; ++argIndex) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
(os << ",\n").indent(8);
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
@ -735,7 +794,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
os << handleOpArgument(leaf, patArgName);
}
}
os << "\n );\n";
os << "\n );\n";
os.indent(4) << "}\n";
return resultValue;
}