[mlir] resolve types from attributes in assemblyFormat

An operation can specify that an operation or result type matches the
type of another operation, result, or attribute via the `AllTypesMatch`
or `TypesMatchWith` constraints.

Use these constraints to also automatically resolve types in the
automatically generated assembly parser.
This way, only the attribute needs to be listed in `assemblyFormat`,
e.g. for constant operations.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D78434
This commit is contained in:
Martin Waitz 2020-07-07 04:40:01 +00:00 committed by Mehdi Amini
parent 3b5db7fc69
commit 72df59d590
4 changed files with 117 additions and 25 deletions

View File

@ -780,8 +780,8 @@ There are many operations that have known type equality constraints registered
as traits on the operation; for example the true, false, and result values of a
`select` operation often have the same type. The assembly format may inspect
these equal constraints to discern the types of missing variables. The currently
supported traits are: `AllTypesMatch`, `SameTypeOperands`, and
`SameOperandsAndResultType`.
supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`,
and `SameOperandsAndResultType`.
### `hasCanonicalizer`

View File

@ -1352,6 +1352,46 @@ def FormatInferVariadicTypeFromNonVariadic
let assemblyFormat = "$operands attr-dict `:` type($result)";
}
//===----------------------------------------------------------------------===//
// AllTypesMatch type inference
//===----------------------------------------------------------------------===//
def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
AllTypesMatch<["value1", "value2", "result"]>
]> {
let arguments = (ins AnyType:$value1, AnyType:$value2);
let results = (outs AnyType:$result);
let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)";
}
def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
AllTypesMatch<["value1", "value2", "result"]>
]> {
let arguments = (ins AnyAttr:$value1, AnyType:$value2);
let results = (outs AnyType:$result);
let assemblyFormat = "attr-dict $value1 `,` $value2";
}
//===----------------------------------------------------------------------===//
// TypesMatchWith type inference
//===----------------------------------------------------------------------===//
def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
TypesMatchWith<"result type matches operand", "value", "result", "$_self">
]> {
let arguments = (ins AnyType:$value);
let results = (outs AnyType:$result);
let assemblyFormat = "attr-dict $value `:` type($value)";
}
def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
TypesMatchWith<"result type matches constant", "value", "result", "$_self">
]> {
let arguments = (ins AnyAttr:$value);
let results = (outs AnyType:$result);
let assemblyFormat = "attr-dict $value";
}
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//

View File

@ -108,3 +108,23 @@ test.format_optional_operand_result_b_op : i64
// CHECK: test.format_infer_variadic_type_from_non_variadic %[[I64]], %[[I64]] : i64
test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
//===----------------------------------------------------------------------===//
// AllTypesMatch type inference
//===----------------------------------------------------------------------===//
// CHECK: test.format_all_types_match_var %[[I64]], %[[I64]] : i64
%ignored_res1 = test.format_all_types_match_var %i64, %i64 : i64
// CHECK: test.format_all_types_match_attr 1 : i64, %[[I64]]
%ignored_res2 = test.format_all_types_match_attr 1 : i64, %i64
//===----------------------------------------------------------------------===//
// TypesMatchWith type inference
//===----------------------------------------------------------------------===//
// CHECK: test.format_types_match_var %[[I64]] : i64
%ignored_res3 = test.format_types_match_var %i64 : i64
// CHECK: test.format_types_match_attr 1 : i64
%ignored_res4 = test.format_types_match_attr 1 : i64

View File

@ -270,6 +270,10 @@ private:
//===----------------------------------------------------------------------===//
namespace {
using ConstArgument =
llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
struct OperationFormat {
/// This class represents a specific resolver for an operand or result type.
class TypeResolution {
@ -280,15 +284,22 @@ struct OperationFormat {
Optional<int> getBuilderIdx() const { return builderIdx; }
void setBuilderIdx(int idx) { builderIdx = idx; }
/// Get the variable this type is resolved to, or None.
const NamedTypeConstraint *getVariable() const { return variable; }
/// Get the variable this type is resolved to, or nullptr.
const NamedTypeConstraint *getVariable() const {
return resolver.dyn_cast<const NamedTypeConstraint *>();
}
/// Get the attribute this type is resolved to, or nullptr.
const NamedAttribute *getAttribute() const {
return resolver.dyn_cast<const NamedAttribute *>();
}
/// Get the transformer for the type of the variable, or None.
Optional<StringRef> getVarTransformer() const {
return variableTransformer;
}
void setVariable(const NamedTypeConstraint *var,
Optional<StringRef> transformer) {
variable = var;
void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
resolver = arg;
variableTransformer = transformer;
assert(getVariable() || getAttribute());
}
private:
@ -296,8 +307,8 @@ struct OperationFormat {
/// 'buildableTypes' in the parent format.
Optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
/// the variable that this type is resolved to.
const NamedTypeConstraint *variable;
/// the variable or the attribute that this type is resolved to.
ConstArgument resolver;
/// If the type is resolved based upon another operand or result, this is
/// a transformer to apply to the variable when resolving.
Optional<StringRef> variableTransformer;
@ -729,7 +740,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
continue;
// Ensure that we don't verify the same variables twice.
const NamedTypeConstraint *variable = resolver.getVariable();
if (!verifiedVariables.insert(variable).second)
if (!variable || !verifiedVariables.insert(variable).second)
continue;
auto constraint = variable->constraint;
@ -764,6 +775,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
else
body << var->name << "Types";
} else if (const NamedAttribute *attr = resolver.getAttribute()) {
if (Optional<StringRef> tform = resolver.getVarTransformer())
body << tgfmt(*tform,
&FmtContext().withSelf(attr->name + "Attr.getType()"));
else
body << attr->name << "Attr.getType()";
} else {
body << curVar << "Types";
}
@ -1353,7 +1370,7 @@ private:
/// type as well as an optional transformer to apply to that type in order to
/// properly resolve the type of a variable.
struct TypeResolutionInstance {
const NamedTypeConstraint *type;
ConstArgument resolver;
Optional<StringRef> transformer;
};
@ -1392,10 +1409,15 @@ private:
void handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults);
/// Check for inferable type resolution based on another operand, result, or
/// attribute.
void handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
llvm::Record def);
/// Returns an argument with the given name that has been seen within the
/// format.
const NamedTypeConstraint *findSeenArg(StringRef name);
/// Returns an argument or attribute with the given name that has been seen
/// within the format.
ConstArgument findSeenArg(StringRef name);
/// Parse a specific element.
LogicalResult parseElement(std::unique_ptr<Element> &element,
@ -1504,9 +1526,7 @@ LogicalResult FormatParser::parse() {
} else if (def.getName() == "SameOperandsAndResultType") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
variableTyResolver[def.getValueAsString("rhs")] = {
lhsArg, def.getValueAsString("transformer")};
handleTypesMatchConstraint(variableTyResolver, def);
}
}
@ -1615,8 +1635,8 @@ LogicalResult FormatParser::verifyOperands(
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
if (varResolverIt != variableTyResolver.end()) {
fmt.operandTypes[i].setVariable(varResolverIt->second.type,
varResolverIt->second.transformer);
TypeResolutionInstance &resolver = varResolverIt->second;
fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
continue;
}
@ -1654,8 +1674,8 @@ LogicalResult FormatParser::verifyResults(
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
varResolverIt->second.transformer);
TypeResolutionInstance resolver = varResolverIt->second;
fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
continue;
}
@ -1702,7 +1722,7 @@ void FormatParser::handleAllTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type.
const NamedTypeConstraint *arg = findSeenArg(values[i]);
ConstArgument arg = findSeenArg(values[i]);
if (!arg)
continue;
@ -1739,11 +1759,23 @@ void FormatParser::handleSameTypesConstraint(
}
}
const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
if (auto *arg = findArg(op.getOperands(), name))
void FormatParser::handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
llvm::Record def) {
StringRef lhsName = def.getValueAsString("lhs");
StringRef rhsName = def.getValueAsString("rhs");
StringRef transformer = def.getValueAsString("transformer");
if (ConstArgument arg = findSeenArg(lhsName))
variableTyResolver[rhsName] = {arg, transformer};
}
ConstArgument FormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
if (auto *arg = findArg(op.getResults(), name))
if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
return nullptr;
}