[mlir][ODS] Support result type inference in custom assembly format

Operations that have the InferTypeOpInterface trait can now omit the return
types in their custom assembly formats.

Differential Revision: https://reviews.llvm.org/D111326
This commit is contained in:
Daniel Resnick 2021-10-06 18:50:38 -06:00
parent 25fabc434a
commit 1760d8b36b
5 changed files with 82 additions and 10 deletions

View File

@ -929,6 +929,11 @@ these equal constraints to discern the types of missing variables. The currently
supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, and
`SameOperandsAndResultType`.
* InferTypeOpInterface
Operations that implement `InferTypeOpInterface` can omit their result types in
their assembly format since the result types can be inferred from the operands.
### `hasCanonicalizer`
This boolean field indicate whether canonicalization patterns have been defined

View File

@ -2021,6 +2021,24 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
let assemblyFormat = "attr-dict $value `:` type($value)";
}
//===----------------------------------------------------------------------===//
// InferTypeOpInterface type inference in assembly format
def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
let results = (outs AnyType);
let assemblyFormat = "attr-dict";
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
return ::mlir::success();
}
}];
}
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//

View File

@ -3,6 +3,7 @@
// This file contains tests for the specification of the declarative op format.
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
def TestDialect : Dialect {
let name = "test";
@ -566,4 +567,6 @@ def ZCoverageValidH : TestFormat_Op<[{
operands type($result) attr-dict
}], [AllTypesMatch<["operand", "result"]>]>,
Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
def ZCoverageValidI : TestFormat_Op<[{
operands type(operands) attr-dict
}], [InferTypeOpInterface]>, Arguments<(ins Variadic<I64>:$inputs)>, Results<(outs I64:$result)>;

View File

@ -354,3 +354,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_types_match_context %[[I64]] : i64
%ignored_res6 = test.format_types_match_context %i64 : i64
//===----------------------------------------------------------------------===//
// InferTypeOpInterface type inference
//===----------------------------------------------------------------------===//
// CHECK: test.format_infer_type
%ignored_res7 = test.format_infer_type

View File

@ -441,7 +441,8 @@ struct OperationFormat {
};
OperationFormat(const Operator &op)
: allOperands(false), allOperandTypes(false), allResultTypes(false) {
: allOperands(false), allOperandTypes(false), allResultTypes(false),
infersResultTypes(false) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
@ -482,6 +483,9 @@ struct OperationFormat {
/// contains these, it can not contain individual type resolvers.
bool allOperands, allOperandTypes, allResultTypes;
/// A flag indicating if this operation infers its result types
bool infersResultTypes;
/// A flag indicating if this operation has the SingleBlockImplicitTerminator
/// trait.
bool hasImplicitTermTrait;
@ -682,6 +686,19 @@ const char *const functionalTypeParserCode = R"(
{1}Types = {0}__{1}_functionType.getResults();
)";
/// The code snippet used to generate a parser call to infer return types.
///
/// {0}: The operation class name
const char *const inferReturnTypesParserCode = R"(
::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
result.location, result.operands,
result.attributes.getDictionary(parser.getContext()),
result.regions, inferredReturnTypes)))
return ::mlir::failure();
result.addTypes(inferredReturnTypes);
)";
/// The code snippet used to generate a parser call for a region list.
///
/// {0}: The name for the region list.
@ -1437,19 +1454,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
};
// Resolve each of the result types.
if (allResultTypes) {
body << " result.addTypes(allResultTypes);\n";
} else {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
body << " result.addTypes(";
emitTypeResolver(resultTypes[i], op.getResultName(i));
body << ");\n";
if (!infersResultTypes) {
if (allResultTypes) {
body << " result.addTypes(allResultTypes);\n";
} else {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
body << " result.addTypes(";
emitTypeResolver(resultTypes[i], op.getResultName(i));
body << ");\n";
}
}
}
// Early exit if there are no operands.
if (op.getNumOperands() == 0)
if (op.getNumOperands() == 0) {
// Handle return type inference here if there are no operands
if (infersResultTypes)
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
return;
}
// Handle the case where all operand types are in one group.
if (allOperandTypes) {
@ -1532,6 +1555,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
body << ", " << operand.name << "OperandsLoc";
body << ", result.operands))\n return ::mlir::failure();\n";
}
// Handle return type inference once all operands have been resolved
if (infersResultTypes)
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
}
void OperationFormat::genParserRegionResolution(Operator &op,
@ -2478,6 +2505,7 @@ private:
// during parsing.
bool hasAttrDict = false;
bool hasAllRegions = false, hasAllSuccessors = false;
bool canInferResultTypes = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
@ -2515,6 +2543,9 @@ LogicalResult FormatParser::parse() {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
} else if (def.getName() == "InferTypeOpInterface" &&
!op.allResultTypesKnown()) {
canInferResultTypes = true;
}
}
@ -2684,6 +2715,14 @@ LogicalResult FormatParser::verifyResults(
if (fmt.allResultTypes)
return ::mlir::success();
// If no result types are specified and we can infer them, infer all result
// types
if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
canInferResultTypes) {
fmt.infersResultTypes = true;
return ::mlir::success();
}
// Check that all of the result types can be inferred.
auto &buildableTypes = fmt.buildableTypes;
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {