forked from OSchip/llvm-project
[mlir][ods] Fix OpDefinitionsGen infer return types builder with regions
Despite handling regions and inferred return types, the builder was never generated for ops with both InferReturnTypeOpInterface and regions. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D115525
This commit is contained in:
parent
a47af1ac34
commit
843534db3c
|
@ -2315,11 +2315,11 @@ def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
|
||||||
let results = (outs AnyType:$result);
|
let results = (outs AnyType:$result);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
|
// Base class for testing `build` methods for ops with
|
||||||
// Tests suppression of ambiguous build methods for operations with
|
// InferReturnTypeOpInterface.
|
||||||
// SameOperandsAndResultType and InferTypeOpInterface.
|
class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
|
||||||
def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
|
list<OpTrait> traits = []>
|
||||||
[SameOperandsAndResultType, InferTypeOpInterface]> {
|
: TEST_Op<mnemonic, [InferTypeOpInterface] # traits> {
|
||||||
let arguments = (ins Variadic<AnyType>:$inputs);
|
let arguments = (ins Variadic<AnyType>:$inputs);
|
||||||
let results = (outs AnyType:$result);
|
let results = (outs AnyType:$result);
|
||||||
|
|
||||||
|
@ -2334,6 +2334,18 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
|
||||||
|
// Tests suppression of ambiguous build methods for operations with
|
||||||
|
// SameOperandsAndResultType and InferTypeOpInterface.
|
||||||
|
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
|
||||||
|
"tblgen_build_5", [SameOperandsAndResultType]>;
|
||||||
|
|
||||||
|
// Op with InferTypeOpInterface and regions.
|
||||||
|
def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
|
||||||
|
"tblgen_build_6", [InferTypeOpInterface]> {
|
||||||
|
let regions = (region AnyRegion:$body);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Test BufferPlacement
|
// Test BufferPlacement
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1220,8 +1220,7 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool canInferType(Operator &op) {
|
static bool canInferType(Operator &op) {
|
||||||
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
|
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
|
||||||
op.getNumRegions() == 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void OpEmitter::genSeparateArgParamBuilder() {
|
void OpEmitter::genSeparateArgParamBuilder() {
|
||||||
|
@ -1304,7 +1303,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
||||||
// ambiguous function detection will elide those ones.
|
// ambiguous function detection will elide those ones.
|
||||||
for (auto attrType : attrBuilderType) {
|
for (auto attrType : attrBuilderType) {
|
||||||
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
|
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
|
||||||
if (canInferType(op))
|
if (canInferType(op) && op.getNumRegions() == 0)
|
||||||
emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
||||||
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
|
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
|
||||||
}
|
}
|
||||||
|
@ -1396,17 +1395,18 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
|
||||||
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
|
if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
|
||||||
{1}.location, operands,
|
{1}.location, operands,
|
||||||
{1}.attributes.getDictionary({1}.getContext()),
|
{1}.attributes.getDictionary({1}.getContext()),
|
||||||
/*regions=*/{{}, inferredReturnTypes))) {{)",
|
{1}.regions, inferredReturnTypes))) {{)",
|
||||||
opClass.getClassName(), builderOpState);
|
opClass.getClassName(), builderOpState);
|
||||||
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
||||||
body << " assert(inferredReturnTypes.size()"
|
body << "\n assert(inferredReturnTypes.size()"
|
||||||
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
||||||
<< "u && \"mismatched number of return types\");\n";
|
<< "u && \"mismatched number of return types\");";
|
||||||
body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
|
body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
|
||||||
|
|
||||||
body << formatv(R"(
|
body << formatv(R"(
|
||||||
} else
|
} else {{
|
||||||
::llvm::report_fatal_error("Failed to infer result type(s).");)",
|
::llvm::report_fatal_error("Failed to infer result type(s).");
|
||||||
|
})",
|
||||||
opClass.getClassName(), builderOpState);
|
opClass.getClassName(), builderOpState);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1606,7 +1606,7 @@ void OpEmitter::genCollectiveParamBuilder() {
|
||||||
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
body << " " << builderOpState << ".addTypes(resultTypes);\n";
|
||||||
|
|
||||||
// Generate builder that infers type too.
|
// Generate builder that infers type too.
|
||||||
// TODO: Expand to handle regions and successors.
|
// TODO: Expand to handle successors.
|
||||||
if (canInferType(op) && op.getNumSuccessors() == 0)
|
if (canInferType(op) && op.getNumSuccessors() == 0)
|
||||||
genInferredTypeCollectiveParamBuilder();
|
genInferredTypeCollectiveParamBuilder();
|
||||||
}
|
}
|
||||||
|
|
|
@ -219,4 +219,11 @@ TEST_F(
|
||||||
testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
|
testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
|
||||||
|
auto op = builder.create<test::TableGenBuildOp6>(
|
||||||
|
loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
|
||||||
|
ASSERT_EQ(op->getNumRegions(), 1u);
|
||||||
|
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstF32}, noAttrs);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
Loading…
Reference in New Issue