[mlir][ods] Fix builder gen for VariadicRegion with inferred types

Builders generated for ops with variadic regions and inferred return types were not being correctly generated (missing parameter).
This commit is contained in:
Mogball 2022-04-07 18:22:14 +00:00
parent ee2d9b8723
commit 2f78b43f4b
2 changed files with 21 additions and 2 deletions

View File

@ -347,6 +347,22 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
}
def VariadicRegionInferredTypesOp : TEST_Op<"variadic_region_inferred",
[InferTypeOpInterface]> {
let regions = (region VariadicRegion<AnyRegion>:$bodies);
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
static mlir::LogicalResult inferReturnTypes(mlir::MLIRContext *context,
llvm::Optional<::mlir::Location> location, mlir::ValueRange operands,
mlir::DictionaryAttr attributes, mlir::RegionRange regions,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({mlir::IntegerType::get(context, 16)});
return mlir::success();
}
}];
}
//===----------------------------------------------------------------------===//
// NoTerminator Operation
//===----------------------------------------------------------------------===//

View File

@ -1373,13 +1373,16 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
}
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OperationState &", builderOpState);
paramList.emplace_back("::mlir::ValueRange", "operands");
StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
"attributes", "{}");
"attributes", attributesDefaultValue);
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
// If the builder is redundant, skip generating the method
if (!m)