forked from OSchip/llvm-project
[mlir][ODS] Support result type inference in custom assembly format
Operations that have the InferTypeOpInterface trait can now omit the return types in their custom assembly formats. Differential Revision: https://reviews.llvm.org/D111326
This commit is contained in:
parent
25fabc434a
commit
1760d8b36b
|
@ -929,6 +929,11 @@ these equal constraints to discern the types of missing variables. The currently
|
|||
supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`, and
|
||||
`SameOperandsAndResultType`.
|
||||
|
||||
* InferTypeOpInterface
|
||||
|
||||
Operations that implement `InferTypeOpInterface` can omit their result types in
|
||||
their assembly format since the result types can be inferred from the operands.
|
||||
|
||||
### `hasCanonicalizer`
|
||||
|
||||
This boolean field indicate whether canonicalization patterns have been defined
|
||||
|
|
|
@ -2021,6 +2021,24 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
|
|||
let assemblyFormat = "attr-dict $value `:` type($value)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InferTypeOpInterface type inference in assembly format
|
||||
|
||||
def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
|
||||
let results = (outs AnyType);
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
|
||||
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
|
||||
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
|
||||
return ::mlir::success();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test SideEffects
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// This file contains tests for the specification of the declarative op format.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
def TestDialect : Dialect {
|
||||
let name = "test";
|
||||
|
@ -566,4 +567,6 @@ def ZCoverageValidH : TestFormat_Op<[{
|
|||
operands type($result) attr-dict
|
||||
}], [AllTypesMatch<["operand", "result"]>]>,
|
||||
Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>;
|
||||
|
||||
def ZCoverageValidI : TestFormat_Op<[{
|
||||
operands type(operands) attr-dict
|
||||
}], [InferTypeOpInterface]>, Arguments<(ins Variadic<I64>:$inputs)>, Results<(outs I64:$result)>;
|
||||
|
|
|
@ -354,3 +354,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
|
|||
|
||||
// CHECK: test.format_types_match_context %[[I64]] : i64
|
||||
%ignored_res6 = test.format_types_match_context %i64 : i64
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InferTypeOpInterface type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: test.format_infer_type
|
||||
%ignored_res7 = test.format_infer_type
|
||||
|
|
|
@ -441,7 +441,8 @@ struct OperationFormat {
|
|||
};
|
||||
|
||||
OperationFormat(const Operator &op)
|
||||
: allOperands(false), allOperandTypes(false), allResultTypes(false) {
|
||||
: allOperands(false), allOperandTypes(false), allResultTypes(false),
|
||||
infersResultTypes(false) {
|
||||
operandTypes.resize(op.getNumOperands(), TypeResolution());
|
||||
resultTypes.resize(op.getNumResults(), TypeResolution());
|
||||
|
||||
|
@ -482,6 +483,9 @@ struct OperationFormat {
|
|||
/// contains these, it can not contain individual type resolvers.
|
||||
bool allOperands, allOperandTypes, allResultTypes;
|
||||
|
||||
/// A flag indicating if this operation infers its result types
|
||||
bool infersResultTypes;
|
||||
|
||||
/// A flag indicating if this operation has the SingleBlockImplicitTerminator
|
||||
/// trait.
|
||||
bool hasImplicitTermTrait;
|
||||
|
@ -682,6 +686,19 @@ const char *const functionalTypeParserCode = R"(
|
|||
{1}Types = {0}__{1}_functionType.getResults();
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call to infer return types.
|
||||
///
|
||||
/// {0}: The operation class name
|
||||
const char *const inferReturnTypesParserCode = R"(
|
||||
::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
|
||||
if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
|
||||
result.location, result.operands,
|
||||
result.attributes.getDictionary(parser.getContext()),
|
||||
result.regions, inferredReturnTypes)))
|
||||
return ::mlir::failure();
|
||||
result.addTypes(inferredReturnTypes);
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for a region list.
|
||||
///
|
||||
/// {0}: The name for the region list.
|
||||
|
@ -1437,19 +1454,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
};
|
||||
|
||||
// Resolve each of the result types.
|
||||
if (allResultTypes) {
|
||||
body << " result.addTypes(allResultTypes);\n";
|
||||
} else {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
body << " result.addTypes(";
|
||||
emitTypeResolver(resultTypes[i], op.getResultName(i));
|
||||
body << ");\n";
|
||||
if (!infersResultTypes) {
|
||||
if (allResultTypes) {
|
||||
body << " result.addTypes(allResultTypes);\n";
|
||||
} else {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
body << " result.addTypes(";
|
||||
emitTypeResolver(resultTypes[i], op.getResultName(i));
|
||||
body << ");\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit if there are no operands.
|
||||
if (op.getNumOperands() == 0)
|
||||
if (op.getNumOperands() == 0) {
|
||||
// Handle return type inference here if there are no operands
|
||||
if (infersResultTypes)
|
||||
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle the case where all operand types are in one group.
|
||||
if (allOperandTypes) {
|
||||
|
@ -1532,6 +1555,10 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
body << ", " << operand.name << "OperandsLoc";
|
||||
body << ", result.operands))\n return ::mlir::failure();\n";
|
||||
}
|
||||
|
||||
// Handle return type inference once all operands have been resolved
|
||||
if (infersResultTypes)
|
||||
body << formatv(inferReturnTypesParserCode, op.getCppClassName());
|
||||
}
|
||||
|
||||
void OperationFormat::genParserRegionResolution(Operator &op,
|
||||
|
@ -2478,6 +2505,7 @@ private:
|
|||
// during parsing.
|
||||
bool hasAttrDict = false;
|
||||
bool hasAllRegions = false, hasAllSuccessors = false;
|
||||
bool canInferResultTypes = false;
|
||||
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
|
||||
llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
|
||||
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
|
||||
|
@ -2515,6 +2543,9 @@ LogicalResult FormatParser::parse() {
|
|||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
|
||||
} else if (def.isSubClassOf("TypesMatchWith")) {
|
||||
handleTypesMatchConstraint(variableTyResolver, def);
|
||||
} else if (def.getName() == "InferTypeOpInterface" &&
|
||||
!op.allResultTypesKnown()) {
|
||||
canInferResultTypes = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2684,6 +2715,14 @@ LogicalResult FormatParser::verifyResults(
|
|||
if (fmt.allResultTypes)
|
||||
return ::mlir::success();
|
||||
|
||||
// If no result types are specified and we can infer them, infer all result
|
||||
// types
|
||||
if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
|
||||
canInferResultTypes) {
|
||||
fmt.infersResultTypes = true;
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
// Check that all of the result types can be inferred.
|
||||
auto &buildableTypes = fmt.buildableTypes;
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
|
|
Loading…
Reference in New Issue