forked from OSchip/llvm-project
[MLIR][TableGen] Fix ambiguous build methods when inferring result types.
- Fix ODS framework to suppress build methods that infer result types and are ambiguous with collective variants. This applies to operations with a single variadic inputs whose result types can be inferred. - Extended OpBuildGenTest to test these kinds of ops. Differential Revision: https://reviews.llvm.org/D85060
This commit is contained in:
parent
3162c6aa45
commit
13d05787d0
|
@ -151,6 +151,17 @@ public:
|
|||
// Returns the total number of arguments.
|
||||
int getNumArgs() const { return arguments.size(); }
|
||||
|
||||
// Returns true of the operation has a single variadic arg.
|
||||
bool hasSingleVariadicArg() const;
|
||||
|
||||
// Returns true if the operation has a single variadic result.
|
||||
bool hasSingleVariadicResult() const {
|
||||
return getNumResults() == 1 && getResult(0).isVariadic();
|
||||
}
|
||||
|
||||
// Returns true of the operation has no variadic regions.
|
||||
bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
|
||||
|
||||
using arg_iterator = const Argument *;
|
||||
using arg_range = llvm::iterator_range<arg_iterator>;
|
||||
|
||||
|
|
|
@ -134,6 +134,11 @@ unsigned tblgen::Operator::getNumVariableLengthOperands() const {
|
|||
});
|
||||
}
|
||||
|
||||
bool tblgen::Operator::hasSingleVariadicArg() const {
|
||||
return getNumArgs() == 1 && getArg(0).is<tblgen::NamedTypeConstraint *>() &&
|
||||
getOperand(0).isVariadic();
|
||||
}
|
||||
|
||||
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
|
||||
return arguments.begin();
|
||||
}
|
||||
|
|
|
@ -1526,4 +1526,31 @@ def TableGenBuildOp3 : TEST_Op<"tblgen_build_3", [SameVariadicResultSize]> {
|
|||
let results = (outs Variadic<AnyType>:$resultA, Variadic<AnyType>:$resultB);
|
||||
}
|
||||
|
||||
// Single variadic arg, non variadic results, with SameOperandsAndResultType.
|
||||
// Tests suppression of ambiguious build methods for operations with
|
||||
// SameOperandsAndResultType trait.
|
||||
def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
|
||||
let arguments = (ins Variadic<AnyType>:$inputs);
|
||||
let results = (outs AnyType:$result);
|
||||
}
|
||||
|
||||
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
|
||||
// Tests suppression of ambiguious build methods for operations with
|
||||
// SameOperandsAndResultType and InferTypeOpInterface.
|
||||
def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
|
||||
[SameOperandsAndResultType, InferTypeOpInterface]> {
|
||||
let arguments = (ins Variadic<AnyType>:$inputs);
|
||||
let results = (outs AnyType:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypes(MLIRContext *,
|
||||
Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
return success();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -110,8 +110,8 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
|
|||
let results = (outs AnyTensor:$result);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input)
|
||||
// CHECK: odsState.addTypes({input.front().getType()});
|
||||
// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes )
|
||||
// CHECK: odsState.addTypes({operands[0].getType()});
|
||||
|
||||
// Test with inferred shapes and interleaved with operands/attributes.
|
||||
//
|
||||
|
|
|
@ -232,6 +232,10 @@ private:
|
|||
// operand's type as all results' types.
|
||||
void genUseOperandAsResultTypeCollectiveParamBuilder();
|
||||
|
||||
// Returns true if the inferred collective param build method should be
|
||||
// generated.
|
||||
bool shouldGenerateInferredTypeCollectiveParamBuilder();
|
||||
|
||||
// Generates the build() method that takes aggregate operands/attributes
|
||||
// parameters. This build() method uses inferred types as result types.
|
||||
// Requires: The type needs to be inferable via InferTypeOpInterface.
|
||||
|
@ -984,40 +988,37 @@ void OpEmitter::genSeparateArgParamBuilder() {
|
|||
// result
|
||||
//
|
||||
// In that case, skip generating such ambiguous build methods here.
|
||||
bool hasSingleVariadicResult =
|
||||
op.getNumResults() == 1 && op.getResult(0).isVariadic();
|
||||
|
||||
bool hasSingleVariadicArg =
|
||||
op.getNumArgs() == 1 &&
|
||||
op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
|
||||
op.getOperand(0).isVariadic();
|
||||
bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
|
||||
|
||||
for (auto attrType : attrBuilderType) {
|
||||
// Case 3b above.
|
||||
if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
|
||||
hasSingleVariadicResult))
|
||||
if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() &&
|
||||
op.hasSingleVariadicResult()))
|
||||
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
|
||||
if (canInferType(op))
|
||||
emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
||||
if (canInferType(op)) {
|
||||
// When inferType = true, the generated build method does not have
|
||||
// result types. If the op has a single variadic arg, then this build
|
||||
// method will be ambiguious with the collective inferred build method
|
||||
// generated in `genInferredTypeCollectiveParamBuilder`. If we are going
|
||||
// to generate that collective inferred method, suppress generating the
|
||||
// ambiguious build method here.
|
||||
bool buildMethodAmbiguious =
|
||||
op.hasSingleVariadicArg() &&
|
||||
shouldGenerateInferredTypeCollectiveParamBuilder();
|
||||
if (!buildMethodAmbiguious)
|
||||
emit(attrType, TypeParamKind::None, /*inferType=*/true);
|
||||
}
|
||||
// The separate arg + collective param kind method will be:
|
||||
// (a) Same as the separate arg + separate param kind method if there is
|
||||
// only one variadic result.
|
||||
// (b) Ambiguous with the collective params method under conditions in (3a)
|
||||
// above.
|
||||
// In either case, skip generating such build method.
|
||||
if (!hasSingleVariadicResult &&
|
||||
!(hasNoVariadicRegions && hasSingleVariadicArg))
|
||||
if (!op.hasSingleVariadicResult() &&
|
||||
!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg()))
|
||||
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
||||
// If this op has a variadic result, we cannot generate this builder because
|
||||
// we don't know how many results to create.
|
||||
if (op.getNumVariableLengthResults() != 0)
|
||||
return;
|
||||
|
||||
int numResults = op.getNumResults();
|
||||
|
||||
// Signature
|
||||
|
@ -1055,6 +1056,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
|
|||
<< llvm::join(resultTypes, ", ") << "});\n\n";
|
||||
}
|
||||
|
||||
bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() {
|
||||
return canInferType(op) && op.getNumSuccessors() == 0;
|
||||
}
|
||||
|
||||
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
|
||||
// TODO: Expand to support regions.
|
||||
std::string params =
|
||||
|
@ -1209,8 +1214,21 @@ void OpEmitter::genBuilder() {
|
|||
// to facilitate different call patterns.
|
||||
if (op.getNumVariableLengthResults() == 0) {
|
||||
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
|
||||
genUseOperandAsResultTypeSeparateParamBuilder();
|
||||
genUseOperandAsResultTypeCollectiveParamBuilder();
|
||||
// If the operation has a single variadic input, then the build method
|
||||
// generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be
|
||||
// ambiguious with the one generated by
|
||||
// `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have
|
||||
// a single `ValueRange` argument for operands, and the collective one
|
||||
// will have a `ArrayRef<NamedAttribute>` argument initalized to empty).
|
||||
// Suppress such ambiguious build method.
|
||||
if (!op.hasSingleVariadicArg())
|
||||
genUseOperandAsResultTypeSeparateParamBuilder();
|
||||
|
||||
// The build method generated by the inferred type collective param
|
||||
// builder and one generated here have the same arguments and hence
|
||||
// generating both will be ambiguious. Enable just one of them.
|
||||
if (!shouldGenerateInferredTypeCollectiveParamBuilder())
|
||||
genUseOperandAsResultTypeCollectiveParamBuilder();
|
||||
}
|
||||
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
|
||||
genUseAttrAsResultTypeBuilder();
|
||||
|
@ -1269,7 +1287,7 @@ void OpEmitter::genCollectiveParamBuilder() {
|
|||
|
||||
// Generate builder that infers type too.
|
||||
// TODO: Expand to handle regions and successors.
|
||||
if (canInferType(op) && op.getNumSuccessors() == 0)
|
||||
if (shouldGenerateInferredTypeCollectiveParamBuilder())
|
||||
genInferredTypeCollectiveParamBuilder();
|
||||
}
|
||||
|
||||
|
|
|
@ -63,6 +63,28 @@ protected:
|
|||
concreteOp.erase();
|
||||
}
|
||||
|
||||
// Helper method to test ops with inferred result types and single variadic
|
||||
// input.
|
||||
template <typename OpTy>
|
||||
void testSingleVariadicInputInferredType() {
|
||||
// Test separate arg, separate param build method.
|
||||
auto op = builder.create<OpTy>(loc, i32Ty, ArrayRef<Value>{cstI32, cstI32});
|
||||
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
|
||||
|
||||
// Test collective params build method.
|
||||
op = builder.create<OpTy>(loc, ArrayRef<Type>{i32Ty},
|
||||
ArrayRef<Value>{cstI32, cstI32});
|
||||
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
|
||||
|
||||
// Test build method with no result types, default value of attributes.
|
||||
op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32});
|
||||
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
|
||||
|
||||
// Test build method with no result types and supplied attributes.
|
||||
op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32}, attrs);
|
||||
verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, attrs);
|
||||
}
|
||||
|
||||
protected:
|
||||
MLIRContext ctx;
|
||||
OpBuilder builder;
|
||||
|
@ -178,4 +200,19 @@ TEST_F(OpBuildGenTest,
|
|||
verifyOp(std::move(op), {i32Ty, f32Ty}, {cstI32}, attrs);
|
||||
}
|
||||
|
||||
// The next 2 tests test supression of ambiguious build methods for ops that
|
||||
// have a single variadic input, and single non-variadic result, and which
|
||||
// support the SameOperandsAndResultType trait and and optionally the
|
||||
// InferOpTypeInterface interface. For such ops, the ODS framework generates
|
||||
// build methods with no result types as they are inferred from the input types.
|
||||
TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
|
||||
testSingleVariadicInputInferredType<TableGenBuildOp4>();
|
||||
}
|
||||
|
||||
TEST_F(
|
||||
OpBuildGenTest,
|
||||
BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
|
||||
testSingleVariadicInputInferredType<TableGenBuildOp5>();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue