forked from OSchip/llvm-project
[mlir] resolve types from attributes in assemblyFormat
An operation can specify that an operation or result type matches the type of another operation, result, or attribute via the `AllTypesMatch` or `TypesMatchWith` constraints. Use these constraints to also automatically resolve types in the automatically generated assembly parser. This way, only the attribute needs to be listed in `assemblyFormat`, e.g. for constant operations. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D78434
This commit is contained in:
parent
3b5db7fc69
commit
72df59d590
|
@ -780,8 +780,8 @@ There are many operations that have known type equality constraints registered
|
|||
as traits on the operation; for example the true, false, and result values of a
|
||||
`select` operation often have the same type. The assembly format may inspect
|
||||
these equal constraints to discern the types of missing variables. The currently
|
||||
supported traits are: `AllTypesMatch`, `SameTypeOperands`, and
|
||||
`SameOperandsAndResultType`.
|
||||
supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`,
|
||||
and `SameOperandsAndResultType`.
|
||||
|
||||
### `hasCanonicalizer`
|
||||
|
||||
|
|
|
@ -1352,6 +1352,46 @@ def FormatInferVariadicTypeFromNonVariadic
|
|||
let assemblyFormat = "$operands attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllTypesMatch type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
|
||||
AllTypesMatch<["value1", "value2", "result"]>
|
||||
]> {
|
||||
let arguments = (ins AnyType:$value1, AnyType:$value2);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)";
|
||||
}
|
||||
|
||||
def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
|
||||
AllTypesMatch<["value1", "value2", "result"]>
|
||||
]> {
|
||||
let arguments = (ins AnyAttr:$value1, AnyType:$value2);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "attr-dict $value1 `,` $value2";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypesMatchWith type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
|
||||
TypesMatchWith<"result type matches operand", "value", "result", "$_self">
|
||||
]> {
|
||||
let arguments = (ins AnyType:$value);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "attr-dict $value `:` type($value)";
|
||||
}
|
||||
|
||||
def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
|
||||
TypesMatchWith<"result type matches constant", "value", "result", "$_self">
|
||||
]> {
|
||||
let arguments = (ins AnyAttr:$value);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "attr-dict $value";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test SideEffects
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -108,3 +108,23 @@ test.format_optional_operand_result_b_op : i64
|
|||
|
||||
// CHECK: test.format_infer_variadic_type_from_non_variadic %[[I64]], %[[I64]] : i64
|
||||
test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllTypesMatch type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: test.format_all_types_match_var %[[I64]], %[[I64]] : i64
|
||||
%ignored_res1 = test.format_all_types_match_var %i64, %i64 : i64
|
||||
|
||||
// CHECK: test.format_all_types_match_attr 1 : i64, %[[I64]]
|
||||
%ignored_res2 = test.format_all_types_match_attr 1 : i64, %i64
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypesMatchWith type inference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: test.format_types_match_var %[[I64]] : i64
|
||||
%ignored_res3 = test.format_types_match_var %i64 : i64
|
||||
|
||||
// CHECK: test.format_types_match_attr 1 : i64
|
||||
%ignored_res4 = test.format_types_match_attr 1 : i64
|
||||
|
|
|
@ -270,6 +270,10 @@ private:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
using ConstArgument =
|
||||
llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
|
||||
|
||||
struct OperationFormat {
|
||||
/// This class represents a specific resolver for an operand or result type.
|
||||
class TypeResolution {
|
||||
|
@ -280,15 +284,22 @@ struct OperationFormat {
|
|||
Optional<int> getBuilderIdx() const { return builderIdx; }
|
||||
void setBuilderIdx(int idx) { builderIdx = idx; }
|
||||
|
||||
/// Get the variable this type is resolved to, or None.
|
||||
const NamedTypeConstraint *getVariable() const { return variable; }
|
||||
/// Get the variable this type is resolved to, or nullptr.
|
||||
const NamedTypeConstraint *getVariable() const {
|
||||
return resolver.dyn_cast<const NamedTypeConstraint *>();
|
||||
}
|
||||
/// Get the attribute this type is resolved to, or nullptr.
|
||||
const NamedAttribute *getAttribute() const {
|
||||
return resolver.dyn_cast<const NamedAttribute *>();
|
||||
}
|
||||
/// Get the transformer for the type of the variable, or None.
|
||||
Optional<StringRef> getVarTransformer() const {
|
||||
return variableTransformer;
|
||||
}
|
||||
void setVariable(const NamedTypeConstraint *var,
|
||||
Optional<StringRef> transformer) {
|
||||
variable = var;
|
||||
void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
|
||||
resolver = arg;
|
||||
variableTransformer = transformer;
|
||||
assert(getVariable() || getAttribute());
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -296,8 +307,8 @@ struct OperationFormat {
|
|||
/// 'buildableTypes' in the parent format.
|
||||
Optional<int> builderIdx;
|
||||
/// If the type is resolved based upon another operand or result, this is
|
||||
/// the variable that this type is resolved to.
|
||||
const NamedTypeConstraint *variable;
|
||||
/// the variable or the attribute that this type is resolved to.
|
||||
ConstArgument resolver;
|
||||
/// If the type is resolved based upon another operand or result, this is
|
||||
/// a transformer to apply to the variable when resolving.
|
||||
Optional<StringRef> variableTransformer;
|
||||
|
@ -729,7 +740,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
continue;
|
||||
// Ensure that we don't verify the same variables twice.
|
||||
const NamedTypeConstraint *variable = resolver.getVariable();
|
||||
if (!verifiedVariables.insert(variable).second)
|
||||
if (!variable || !verifiedVariables.insert(variable).second)
|
||||
continue;
|
||||
|
||||
auto constraint = variable->constraint;
|
||||
|
@ -764,6 +775,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
|||
body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
|
||||
else
|
||||
body << var->name << "Types";
|
||||
} else if (const NamedAttribute *attr = resolver.getAttribute()) {
|
||||
if (Optional<StringRef> tform = resolver.getVarTransformer())
|
||||
body << tgfmt(*tform,
|
||||
&FmtContext().withSelf(attr->name + "Attr.getType()"));
|
||||
else
|
||||
body << attr->name << "Attr.getType()";
|
||||
} else {
|
||||
body << curVar << "Types";
|
||||
}
|
||||
|
@ -1353,7 +1370,7 @@ private:
|
|||
/// type as well as an optional transformer to apply to that type in order to
|
||||
/// properly resolve the type of a variable.
|
||||
struct TypeResolutionInstance {
|
||||
const NamedTypeConstraint *type;
|
||||
ConstArgument resolver;
|
||||
Optional<StringRef> transformer;
|
||||
};
|
||||
|
||||
|
@ -1392,10 +1409,15 @@ private:
|
|||
void handleSameTypesConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
bool includeResults);
|
||||
/// Check for inferable type resolution based on another operand, result, or
|
||||
/// attribute.
|
||||
void handleTypesMatchConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
llvm::Record def);
|
||||
|
||||
/// Returns an argument with the given name that has been seen within the
|
||||
/// format.
|
||||
const NamedTypeConstraint *findSeenArg(StringRef name);
|
||||
/// Returns an argument or attribute with the given name that has been seen
|
||||
/// within the format.
|
||||
ConstArgument findSeenArg(StringRef name);
|
||||
|
||||
/// Parse a specific element.
|
||||
LogicalResult parseElement(std::unique_ptr<Element> &element,
|
||||
|
@ -1504,9 +1526,7 @@ LogicalResult FormatParser::parse() {
|
|||
} else if (def.getName() == "SameOperandsAndResultType") {
|
||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
|
||||
} else if (def.isSubClassOf("TypesMatchWith")) {
|
||||
if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
|
||||
variableTyResolver[def.getValueAsString("rhs")] = {
|
||||
lhsArg, def.getValueAsString("transformer")};
|
||||
handleTypesMatchConstraint(variableTyResolver, def);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1615,8 +1635,8 @@ LogicalResult FormatParser::verifyOperands(
|
|||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.operandTypes[i].setVariable(varResolverIt->second.type,
|
||||
varResolverIt->second.transformer);
|
||||
TypeResolutionInstance &resolver = varResolverIt->second;
|
||||
fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1654,8 +1674,8 @@ LogicalResult FormatParser::verifyResults(
|
|||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
|
||||
varResolverIt->second.transformer);
|
||||
TypeResolutionInstance resolver = varResolverIt->second;
|
||||
fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1702,7 +1722,7 @@ void FormatParser::handleAllTypesMatchConstraint(
|
|||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
// Check to see if this value matches a resolved operand or result type.
|
||||
const NamedTypeConstraint *arg = findSeenArg(values[i]);
|
||||
ConstArgument arg = findSeenArg(values[i]);
|
||||
if (!arg)
|
||||
continue;
|
||||
|
||||
|
@ -1739,11 +1759,23 @@ void FormatParser::handleSameTypesConstraint(
|
|||
}
|
||||
}
|
||||
|
||||
const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
|
||||
if (auto *arg = findArg(op.getOperands(), name))
|
||||
void FormatParser::handleTypesMatchConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
llvm::Record def) {
|
||||
StringRef lhsName = def.getValueAsString("lhs");
|
||||
StringRef rhsName = def.getValueAsString("rhs");
|
||||
StringRef transformer = def.getValueAsString("transformer");
|
||||
if (ConstArgument arg = findSeenArg(lhsName))
|
||||
variableTyResolver[rhsName] = {arg, transformer};
|
||||
}
|
||||
|
||||
ConstArgument FormatParser::findSeenArg(StringRef name) {
|
||||
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
|
||||
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
|
||||
if (auto *arg = findArg(op.getResults(), name))
|
||||
if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
|
||||
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
|
||||
if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
|
||||
return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue