forked from OSchip/llvm-project
[mlir][ods] Fix OpFormatGen sometimes not calling inferReturnTypes
Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D115522
This commit is contained in:
parent
d658a4bb97
commit
e40624ae60
|
@ -1136,7 +1136,9 @@ def ThreeResultOp : TEST_Op<"three_result"> {
|
|||
let results = (outs I32:$result1, F32:$result2, F32:$result3);
|
||||
}
|
||||
|
||||
def AnotherThreeResultOp : TEST_Op<"another_three_result", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
def AnotherThreeResultOp
|
||||
: TEST_Op<"another_three_result",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let arguments = (ins MultiResultOpEnum:$kind);
|
||||
let results = (outs I32:$result1, F32:$result2, F32:$result3);
|
||||
}
|
||||
|
@ -2101,6 +2103,53 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
|
|||
}];
|
||||
}
|
||||
|
||||
// Base class for testing mixing allOperandTypes, allOperands, and
|
||||
// inferResultTypes.
|
||||
class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
|
||||
: TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
|
||||
let arguments = (ins Variadic<AnyType>:$args);
|
||||
let results = (outs Variadic<AnyType>:$outs);
|
||||
let extraClassDeclaration = [{
|
||||
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
|
||||
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
|
||||
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
::mlir::TypeRange operandTypes = operands.getTypes();
|
||||
inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
|
||||
return ::mlir::success();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// Test inferReturnTypes is called when allOperandTypes and allOperands is true.
|
||||
def FormatInferTypeAllOperandsAndTypesOp
|
||||
: FormatInferAllTypesBaseOp<"format_infer_type_all_operands_and_types"> {
|
||||
let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
// Test inferReturnTypes is called when allOperandTypes is true and there is one
|
||||
// ODS operand.
|
||||
def FormatInferTypeAllOperandsAndTypesOneOperandOp
|
||||
: FormatInferAllTypesBaseOp<"format_infer_type_all_types_one_operand"> {
|
||||
let assemblyFormat = "`(` $args `)` attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
// Test inferReturnTypes is called when allOperandTypes is true and there are
|
||||
// more than one ODS operands.
|
||||
def FormatInferTypeAllOperandsAndTypesTwoOperandsOp
|
||||
: FormatInferAllTypesBaseOp<"format_infer_type_all_types_two_operands",
|
||||
[SameVariadicOperandSize]> {
|
||||
let arguments = (ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1);
|
||||
let assemblyFormat = "`(` $args0 `)` `(` $args1 `)` attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
// Test inferReturnTypes is called when allOperands is true and operand types
|
||||
// are separately specified.
|
||||
def FormatInferTypeAllTypesOp
|
||||
: FormatInferAllTypesBaseOp<"format_infer_type_all_types"> {
|
||||
let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test SideEffects
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -411,6 +411,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
|
|||
// CHECK: test.format_infer_type
|
||||
%ignored_res7 = test.format_infer_type
|
||||
|
||||
// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
|
||||
%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
|
||||
|
||||
// CHECK: test.format_infer_type_all_types_one_operand(%[[I64]], %[[I32]]) : i64, i32
|
||||
%ignored_res9:2 = test.format_infer_type_all_types_one_operand(%i64, %i32) : i64, i32
|
||||
|
||||
// CHECK: test.format_infer_type_all_types_two_operands(%[[I64]], %[[I32]]) (%[[I64]], %[[I32]]) : i64, i32, i64, i32
|
||||
%ignored_res10:4 = test.format_infer_type_all_types_two_operands(%i64, %i32) (%i64, %i32) : i64, i32, i64, i32
|
||||
|
||||
// CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32
|
||||
%ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Check DefaultValuedStrAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -424,14 +424,18 @@ struct OperationFormat {
|
|||
/// Generate the parser code for a specific format element.
|
||||
void genElementParser(Element *element, MethodBody &body,
|
||||
FmtContext &attrTypeCtx);
|
||||
/// Generate the c++ to resolve the types of operands and results during
|
||||
/// Generate the C++ to resolve the types of operands and results during
|
||||
/// parsing.
|
||||
void genParserTypeResolution(Operator &op, MethodBody &body);
|
||||
/// Generate the c++ to resolve regions during parsing.
|
||||
/// Generate the C++ to resolve the types of the operands during parsing.
|
||||
void genParserOperandTypeResolution(
|
||||
Operator &op, MethodBody &body,
|
||||
function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
|
||||
/// Generate the C++ to resolve regions during parsing.
|
||||
void genParserRegionResolution(Operator &op, MethodBody &body);
|
||||
/// Generate the c++ to resolve successors during parsing.
|
||||
/// Generate the C++ to resolve successors during parsing.
|
||||
void genParserSuccessorResolution(Operator &op, MethodBody &body);
|
||||
/// Generate the c++ to handling variadic segment size traits.
|
||||
/// Generate the C++ to handling variadic segment size traits.
|
||||
void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
|
||||
|
||||
/// Generate the operation printer from this format.
|
||||
|
@ -1462,17 +1466,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
|
|||
}
|
||||
}
|
||||
|
||||
// Early exit if there are no operands.
|
||||
if (op.getNumOperands() == 0) {
|
||||
// Handle return type inference here if there are no operands
|
||||
if (infersResultTypes)
|
||||
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
|
||||
return;
|
||||
}
|
||||
// Emit the operand type resolutions.
|
||||
genParserOperandTypeResolution(op, body, emitTypeResolver);
|
||||
|
||||
// Handle the case where all operand types are in one group.
|
||||
// Handle return type inference once all operands have been resolved
|
||||
if (infersResultTypes)
|
||||
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
|
||||
}
|
||||
|
||||
void OperationFormat::genParserOperandTypeResolution(
|
||||
Operator &op, MethodBody &body,
|
||||
function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
|
||||
// Early exit if there are no operands.
|
||||
if (op.getNumOperands() == 0)
|
||||
return;
|
||||
|
||||
// Handle the case where all operand types are grouped together with
|
||||
// "types(operands)".
|
||||
if (allOperandTypes) {
|
||||
// If we have all operands together, use the full operand list directly.
|
||||
// If `operands` was specified, use the full operand list directly.
|
||||
if (allOperands) {
|
||||
body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
|
||||
"allOperandLoc, result.operands))\n"
|
||||
|
@ -1496,7 +1508,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
|
|||
<< " return ::mlir::failure();\n";
|
||||
return;
|
||||
}
|
||||
// Handle the case where all of the operands were grouped together.
|
||||
|
||||
// Handle the case where all operands are grouped together with "operands".
|
||||
if (allOperands) {
|
||||
body << " if (parser.resolveOperands(allOperands, ";
|
||||
|
||||
|
@ -1551,10 +1564,6 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
|
|||
body << ", " << operand.name << "OperandsLoc";
|
||||
body << ", result.operands))\n return ::mlir::failure();\n";
|
||||
}
|
||||
|
||||
// Handle return type inference once all operands have been resolved
|
||||
if (infersResultTypes)
|
||||
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
|
||||
}
|
||||
|
||||
void OperationFormat::genParserRegionResolution(Operator &op,
|
||||
|
@ -1833,7 +1842,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
|||
// keyword.
|
||||
llvm::BitVector nonKeywordCases(cases.size());
|
||||
bool hasStrCase = false;
|
||||
for (auto it : llvm::enumerate(cases)) {
|
||||
for (auto &it : llvm::enumerate(cases)) {
|
||||
hasStrCase = it.value().isStrCase();
|
||||
if (!canFormatStringAsKeyword(it.value().getStr()))
|
||||
nonKeywordCases.set(it.index());
|
||||
|
@ -1860,7 +1869,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
|||
// overlap with other cases. For simplicity sake, only allow cases with a
|
||||
// single bit value.
|
||||
if (enumAttr.isBitEnum()) {
|
||||
for (auto it : llvm::enumerate(cases)) {
|
||||
for (auto &it : llvm::enumerate(cases)) {
|
||||
int64_t value = it.value().getValue();
|
||||
if (value < 0 || !llvm::isPowerOf2_64(value))
|
||||
nonKeywordCases.set(it.index());
|
||||
|
@ -1873,7 +1882,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
|||
body << " switch (caseValue) {\n";
|
||||
StringRef cppNamespace = enumAttr.getCppNamespace();
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
for (auto it : llvm::enumerate(cases)) {
|
||||
for (auto &it : llvm::enumerate(cases)) {
|
||||
if (nonKeywordCases.test(it.index()))
|
||||
continue;
|
||||
StringRef symbol = it.value().getSymbol();
|
||||
|
|
Loading…
Reference in New Issue