[mlir][ods] Fix OpFormatGen calling inferReturnTypes before region/segment resolution

The generated parser for ops with type inference calls `inferReturnTypes` before region resolution and segment attribute resolution, i.e. regions and the segment attributes are not passed to the `inferReturnTypes` even though it may need that information.

In particular, an op that has sized operand segments which queries those operands in its `inferReturnTypes` function will crash because the segment attributes hadn't been added yet.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D115782
This commit is contained in:
Mogball 2021-12-15 06:22:04 +00:00
parent 4c8dbe96d7
commit 1aa0b84fa4
3 changed files with 53 additions and 1 deletions

View File

@ -2150,6 +2150,48 @@ def FormatInferTypeAllTypesOp
let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)"; let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
} }
// Test inferReturnTypes coupled with regions.
def FormatInferTypeRegionsOp
: TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> {
let results = (outs Variadic<AnyType>:$outs);
let regions = (region AnyRegion:$region);
let assemblyFormat = "$region attr-dict";
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
if (regions.empty())
return ::mlir::failure();
auto types = regions.front()->getArgumentTypes();
inferredReturnTypes.assign(types.begin(), types.end());
return ::mlir::success();
}
}];
}
// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes).
def FormatInferTypeVariadicOperandsOp
: TEST_Op<"format_infer_type_variadic_operands",
[InferTypeOpInterface, AttrSizedOperandSegments]> {
let arguments = (ins Variadic<I32>:$a, Variadic<I64>:$b);
let results = (outs Variadic<AnyType>:$outs);
let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict";
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
FormatInferTypeVariadicOperandsOpAdaptor adaptor(operands, attributes);
auto aTypes = adaptor.getA().getTypes();
auto bTypes = adaptor.getB().getTypes();
inferredReturnTypes.append(aTypes.begin(), aTypes.end());
inferredReturnTypes.append(bTypes.begin(), bTypes.end());
return ::mlir::success();
}
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Test SideEffects // Test SideEffects
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -423,6 +423,16 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32 // CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32 %ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
// CHECK: test.format_infer_type_regions
// CHECK-NEXT: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}):
%ignored_res12:2 = test.format_infer_type_regions {
^bb0(%arg0: i32, %arg1: f32):
"test.terminator"() : () -> ()
}
// CHECK: test.format_infer_type_variadic_operands(%[[I32]], %[[I32]] : i32, i32) (%[[I64]], %[[I64]] : i64, i64)
%ignored_res13:4 = test.format_infer_type_variadic_operands(%i32, %i32 : i32, i32) (%i64, %i64 : i64, i64)
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Check DefaultValuedStrAttr // Check DefaultValuedStrAttr
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1185,10 +1185,10 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
// Generate the code to resolve the operand/result types and successors now // Generate the code to resolve the operand/result types and successors now
// that they have been parsed. // that they have been parsed.
genParserTypeResolution(op, body);
genParserRegionResolution(op, body); genParserRegionResolution(op, body);
genParserSuccessorResolution(op, body); genParserSuccessorResolution(op, body);
genParserVariadicSegmentResolution(op, body); genParserVariadicSegmentResolution(op, body);
genParserTypeResolution(op, body);
body << " return ::mlir::success();\n"; body << " return ::mlir::success();\n";
} }