[MLIR][TableGen] Fix ambiguous build methods when inferring result types.

- Fix ODS framework to suppress build methods that infer result types and are
  ambiguous with collective variants. This applies to operations with a single variadic
  inputs whose result types can be inferred.
- Extended OpBuildGenTest to test these kinds of ops.

Differential Revision: https://reviews.llvm.org/D85060
This commit is contained in:
Rahul Joshi 2020-08-07 14:02:19 -07:00
parent 3162c6aa45
commit 13d05787d0
6 changed files with 123 additions and 25 deletions

View File

@ -151,6 +151,17 @@ public:
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }
// Returns true of the operation has a single variadic arg.
bool hasSingleVariadicArg() const;
// Returns true if the operation has a single variadic result.
bool hasSingleVariadicResult() const {
return getNumResults() == 1 && getResult(0).isVariadic();
}
// Returns true of the operation has no variadic regions.
bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
using arg_iterator = const Argument *;
using arg_range = llvm::iterator_range<arg_iterator>;

View File

@ -134,6 +134,11 @@ unsigned tblgen::Operator::getNumVariableLengthOperands() const {
});
}
bool tblgen::Operator::hasSingleVariadicArg() const {
return getNumArgs() == 1 && getArg(0).is<tblgen::NamedTypeConstraint *>() &&
getOperand(0).isVariadic();
}
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
return arguments.begin();
}

View File

@ -1526,4 +1526,31 @@ def TableGenBuildOp3 : TEST_Op<"tblgen_build_3", [SameVariadicResultSize]> {
let results = (outs Variadic<AnyType>:$resultA, Variadic<AnyType>:$resultB);
}
// Single variadic arg, non variadic results, with SameOperandsAndResultType.
// Tests suppression of ambiguious build methods for operations with
// SameOperandsAndResultType trait.
def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs AnyType:$result);
}
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
// Tests suppression of ambiguious build methods for operations with
// SameOperandsAndResultType and InferTypeOpInterface.
def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
[SameOperandsAndResultType, InferTypeOpInterface]> {
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
static LogicalResult inferReturnTypes(MLIRContext *,
Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands[0].getType()});
return success();
}
}];
}
#endif // TEST_OPS

View File

@ -110,8 +110,8 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
let results = (outs AnyTensor:$result);
}
// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input)
// CHECK: odsState.addTypes({input.front().getType()});
// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes )
// CHECK: odsState.addTypes({operands[0].getType()});
// Test with inferred shapes and interleaved with operands/attributes.
//

View File

@ -232,6 +232,10 @@ private:
// operand's type as all results' types.
void genUseOperandAsResultTypeCollectiveParamBuilder();
// Returns true if the inferred collective param build method should be
// generated.
bool shouldGenerateInferredTypeCollectiveParamBuilder();
// Generates the build() method that takes aggregate operands/attributes
// parameters. This build() method uses inferred types as result types.
// Requires: The type needs to be inferable via InferTypeOpInterface.
@ -984,40 +988,37 @@ void OpEmitter::genSeparateArgParamBuilder() {
// result
//
// In that case, skip generating such ambiguous build methods here.
bool hasSingleVariadicResult =
op.getNumResults() == 1 && op.getResult(0).isVariadic();
bool hasSingleVariadicArg =
op.getNumArgs() == 1 &&
op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
op.getOperand(0).isVariadic();
bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
for (auto attrType : attrBuilderType) {
// Case 3b above.
if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
hasSingleVariadicResult))
if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() &&
op.hasSingleVariadicResult()))
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
if (canInferType(op))
emit(attrType, TypeParamKind::None, /*inferType=*/true);
if (canInferType(op)) {
// When inferType = true, the generated build method does not have
// result types. If the op has a single variadic arg, then this build
// method will be ambiguious with the collective inferred build method
// generated in `genInferredTypeCollectiveParamBuilder`. If we are going
// to generate that collective inferred method, suppress generating the
// ambiguious build method here.
bool buildMethodAmbiguious =
op.hasSingleVariadicArg() &&
shouldGenerateInferredTypeCollectiveParamBuilder();
if (!buildMethodAmbiguious)
emit(attrType, TypeParamKind::None, /*inferType=*/true);
}
// The separate arg + collective param kind method will be:
// (a) Same as the separate arg + separate param kind method if there is
// only one variadic result.
// (b) Ambiguous with the collective params method under conditions in (3a)
// above.
// In either case, skip generating such build method.
if (!hasSingleVariadicResult &&
!(hasNoVariadicRegions && hasSingleVariadicArg))
if (!op.hasSingleVariadicResult() &&
!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg()))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
// If this op has a variadic result, we cannot generate this builder because
// we don't know how many results to create.
if (op.getNumVariableLengthResults() != 0)
return;
int numResults = op.getNumResults();
// Signature
@ -1055,6 +1056,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() {
return canInferType(op) && op.getNumSuccessors() == 0;
}
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
std::string params =
@ -1209,8 +1214,21 @@ void OpEmitter::genBuilder() {
// to facilitate different call patterns.
if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
genUseOperandAsResultTypeSeparateParamBuilder();
genUseOperandAsResultTypeCollectiveParamBuilder();
// If the operation has a single variadic input, then the build method
// generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be
// ambiguious with the one generated by
// `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have
// a single `ValueRange` argument for operands, and the collective one
// will have a `ArrayRef<NamedAttribute>` argument initalized to empty).
// Suppress such ambiguious build method.
if (!op.hasSingleVariadicArg())
genUseOperandAsResultTypeSeparateParamBuilder();
// The build method generated by the inferred type collective param
// builder and one generated here have the same arguments and hence
// generating both will be ambiguious. Enable just one of them.
if (!shouldGenerateInferredTypeCollectiveParamBuilder())
genUseOperandAsResultTypeCollectiveParamBuilder();
}
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
@ -1269,7 +1287,7 @@ void OpEmitter::genCollectiveParamBuilder() {
// Generate builder that infers type too.
// TODO: Expand to handle regions and successors.
if (canInferType(op) && op.getNumSuccessors() == 0)
if (shouldGenerateInferredTypeCollectiveParamBuilder())
genInferredTypeCollectiveParamBuilder();
}

View File

@ -63,6 +63,28 @@ protected:
concreteOp.erase();
}
// Helper method to test ops with inferred result types and single variadic
// input.
template <typename OpTy>
void testSingleVariadicInputInferredType() {
// Test separate arg, separate param build method.
auto op = builder.create<OpTy>(loc, i32Ty, ArrayRef<Value>{cstI32, cstI32});
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
// Test collective params build method.
op = builder.create<OpTy>(loc, ArrayRef<Type>{i32Ty},
ArrayRef<Value>{cstI32, cstI32});
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
// Test build method with no result types, default value of attributes.
op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32});
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
// Test build method with no result types and supplied attributes.
op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32}, attrs);
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, attrs);
}
protected:
MLIRContext ctx;
OpBuilder builder;
@ -178,4 +200,19 @@ TEST_F(OpBuildGenTest,
verifyOp(std::move(op), {i32Ty, f32Ty}, {cstI32}, attrs);
}
// The next 2 tests test supression of ambiguious build methods for ops that
// have a single variadic input, and single non-variadic result, and which
// support the SameOperandsAndResultType trait and and optionally the
// InferOpTypeInterface interface. For such ops, the ODS framework generates
// build methods with no result types as they are inferred from the input types.
TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
testSingleVariadicInputInferredType<TableGenBuildOp4>();
}
TEST_F(
OpBuildGenTest,
BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
testSingleVariadicInputInferredType<TableGenBuildOp5>();
}
} // namespace mlir