[mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D115522
This commit is contained in:
Mogball 2021-12-10 15:04:46 +00:00
parent d658a4bb97
commit e40624ae60
3 changed files with 92 additions and 22 deletions

View File

@ -1136,7 +1136,9 @@ def ThreeResultOp : TEST_Op<"three_result"> {
let results = (outs I32:$result1, F32:$result2, F32:$result3);
}
def AnotherThreeResultOp : TEST_Op<"another_three_result", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
def AnotherThreeResultOp
: TEST_Op<"another_three_result",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins MultiResultOpEnum:$kind);
let results = (outs I32:$result1, F32:$result2, F32:$result3);
}
@ -2101,6 +2103,53 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
}];
}
// Base class for testing mixing allOperandTypes, allOperands, and
// inferResultTypes.
class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
: TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
let arguments = (ins Variadic<AnyType>:$args);
let results = (outs Variadic<AnyType>:$outs);
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
::mlir::TypeRange operandTypes = operands.getTypes();
inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
return ::mlir::success();
}
}];
}
// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
def FormatInferTypeAllOperandsAndTypesOp
: FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
}
// Test inferReturnTypes is called when allOperandTypes is true and there is one
// ODS operand.
def FormatInferTypeAllOperandsAndTypesOneOperandOp
: FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
}
// Test inferReturnTypes is called when allOperandTypes is true and there are
// more than one ODS operands.
def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
: FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
[SameVariadicOperandSize]> {
let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
}
// Test inferReturnTypes is called when allOperands is true and operand types
// are separately specified.
def FormatInferTypeAllTypesOp
: FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
}
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//

View File

@ -411,6 +411,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_infer_type
%ignored_res7 = test.format_infer_type
// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
// CHECK: test.format_infer_type_all_types_one_operand(%[[I64]], %[[I32]]) : i64, i32
%ignored_res9:2 = test.format_infer_type_all_types_one_operand(%i64, %i32) : i64, i32
// CHECK: test.format_infer_type_all_types_two_operands(%[[I64]], %[[I32]]) (%[[I64]], %[[I32]]) : i64, i32, i64, i32
%ignored_res10:4 = test.format_infer_type_all_types_two_operands(%i64, %i32) (%i64, %i32) : i64, i32, i64, i32
// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
//===----------------------------------------------------------------------===//
// Check DefaultValuedStrAttr
//===----------------------------------------------------------------------===//

View File

@ -424,14 +424,18 @@ struct OperationFormat {
/// Generate the parser code for a specific format element.
void genElementParser(Element *element, MethodBody &body,
FmtContext &attrTypeCtx);
/// Generate the c++ to resolve the types of operands and results during
/// Generate the C++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, MethodBody &body);
/// Generate the c++ to resolve regions during parsing.
/// Generate the C++ to resolve the types of the operands during parsing.
void genParserOperandTypeResolution(
Operator &op, MethodBody &body,
function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
/// Generate the C++ to resolve regions during parsing.
void genParserRegionResolution(Operator &op, MethodBody &body);
/// Generate the c++ to resolve successors during parsing.
/// Generate the C++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, MethodBody &body);
/// Generate the c++ to handling variadic segment size traits.
/// Generate the C++ to handling variadic segment size traits.
void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
/// Generate the operation printer from this format.
@ -1462,17 +1466,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
}
}
// Early exit if there are no operands.
if (op.getNumOperands() == 0) {
// Handle return type inference here if there are no operands
if (infersResultTypes)
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
return;
}
// Emit the operand type resolutions.
genParserOperandTypeResolution(op, body, emitTypeResolver);
// Handle the case where all operand types are in one group.
// Handle return type inference once all operands have been resolved
if (infersResultTypes)
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
}
void OperationFormat::genParserOperandTypeResolution(
Operator &op, MethodBody &body,
function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
// Early exit if there are no operands.
if (op.getNumOperands() == 0)
return;
// Handle the case where all operand types are grouped together with
// "types(operands)".
if (allOperandTypes) {
// If we have all operands together, use the full operand list directly.
// If `operands` was specified, use the full operand list directly.
if (allOperands) {
body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
"allOperandLoc, result.operands))\n"
@ -1496,7 +1508,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
<< " return ::mlir::failure();\n";
return;
}
// Handle the case where all of the operands were grouped together.
// Handle the case where all operands are grouped together with "operands".
if (allOperands) {
body << " if (parser.resolveOperands(allOperands, ";
@ -1551,10 +1564,6 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
body << ", " << operand.name << "OperandsLoc";
body << ", result.operands))\n return ::mlir::failure();\n";
}
// Handle return type inference once all operands have been resolved
if (infersResultTypes)
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
}
void OperationFormat::genParserRegionResolution(Operator &op,
@ -1833,7 +1842,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
// keyword.
llvm::BitVector nonKeywordCases(cases.size());
bool hasStrCase = false;
for (auto it : llvm::enumerate(cases)) {
for (auto &it : llvm::enumerate(cases)) {
hasStrCase = it.value().isStrCase();
if (!canFormatStringAsKeyword(it.value().getStr()))
nonKeywordCases.set(it.index());
@ -1860,7 +1869,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
// overlap with other cases. For simplicity sake, only allow cases with a
// single bit value.
if (enumAttr.isBitEnum()) {
for (auto it : llvm::enumerate(cases)) {
for (auto &it : llvm::enumerate(cases)) {
int64_t value = it.value().getValue();
if (value < 0 || !llvm::isPowerOf2_64(value))
nonKeywordCases.set(it.index());
@ -1873,7 +1882,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
body << " switch (caseValue) {\n";
StringRef cppNamespace = enumAttr.getCppNamespace();
StringRef enumName = enumAttr.getEnumClassName();
for (auto it : llvm::enumerate(cases)) {
for (auto &it : llvm::enumerate(cases)) {
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();