forked from OSchip/llvm-project
[mlir][DeclarativeParser] Emit an error if a `:` follows an attribute with a non-constant type.
Summary: The attribute grammar includes an optional trailing colon type, so for attributes without a constant buildable type this will generally lead to unexpected and undesired behavior. Given that, it's better to just error out on these cases. Differential Revision: https://reviews.llvm.org/D77293
This commit is contained in:
parent
ca47ac3d5f
commit
e3bb36370d
|
@ -20,6 +20,7 @@ def AffineMapAttr : Attr<
|
|||
CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
|
||||
let storageType = [{ AffineMapAttr }];
|
||||
let returnType = [{ AffineMap }];
|
||||
let valueType = Index;
|
||||
let constBuilderCall = "AffineMapAttr::get($0)";
|
||||
}
|
||||
|
||||
|
|
|
@ -319,7 +319,8 @@ class BuildableType<code builder> {
|
|||
def AnyType : Type<CPred<"true">, "any type">;
|
||||
|
||||
// None type
|
||||
def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
|
||||
def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">,
|
||||
BuildableType<"$_builder.getType<NoneType>()">;
|
||||
|
||||
// Any type from the given list
|
||||
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
|
||||
|
@ -835,6 +836,7 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
|
|||
def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
|
||||
let storageType = [{ BoolAttr }];
|
||||
let returnType = [{ bool }];
|
||||
let valueType = I1;
|
||||
let constBuilderCall = "$_builder.getBoolAttr($0)";
|
||||
}
|
||||
|
||||
|
@ -942,11 +944,18 @@ class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
|
|||
let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
|
||||
let storageType = [{ StringAttr }];
|
||||
let returnType = [{ StringRef }];
|
||||
let valueType = NoneType;
|
||||
}
|
||||
|
||||
def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
|
||||
"string attribute">;
|
||||
|
||||
// String attribute that has a specific value type.
|
||||
class TypedStrAttr<Type ty> : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
|
||||
"string attribute"> {
|
||||
let valueType = ty;
|
||||
}
|
||||
|
||||
// Base class for attributes containing types. Example:
|
||||
// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
|
||||
// defines a type attribute containing an integer type.
|
||||
|
@ -957,6 +966,7 @@ class TypeAttrBase<string retType, string description> :
|
|||
description> {
|
||||
let storageType = [{ TypeAttr }];
|
||||
let returnType = retType;
|
||||
let valueType = NoneType;
|
||||
let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
|
||||
}
|
||||
|
||||
|
@ -970,6 +980,7 @@ def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
|
|||
let constBuilderCall = "$_builder.getUnitAttr()";
|
||||
let convertFromStorage = "$_self != nullptr";
|
||||
let returnType = "bool";
|
||||
let valueType = NoneType;
|
||||
let isOptional = 1;
|
||||
}
|
||||
|
||||
|
@ -1166,6 +1177,7 @@ class DictionaryAttrBase : Attr<CPred<"$_self.isa<DictionaryAttr>()">,
|
|||
"dictionary of named attribute values"> {
|
||||
let storageType = [{ DictionaryAttr }];
|
||||
let returnType = [{ DictionaryAttr }];
|
||||
let valueType = NoneType;
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
|
@ -1285,6 +1297,7 @@ class ArrayAttrBase<Pred condition, string description> :
|
|||
Attr<condition, description> {
|
||||
let storageType = [{ ArrayAttr }];
|
||||
let returnType = [{ ArrayAttr }];
|
||||
let valueType = NoneType;
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
|
@ -1364,6 +1377,7 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
|
|||
"symbol reference attribute"> {
|
||||
let storageType = [{ SymbolRefAttr }];
|
||||
let returnType = [{ SymbolRefAttr }];
|
||||
let valueType = NoneType;
|
||||
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
@ -1371,6 +1385,7 @@ def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
|
|||
"flat symbol reference attribute"> {
|
||||
let storageType = [{ FlatSymbolRefAttr }];
|
||||
let returnType = [{ StringRef }];
|
||||
let valueType = NoneType;
|
||||
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
|
||||
let convertFromStorage = "$_self.getValue()";
|
||||
}
|
||||
|
|
|
@ -247,7 +247,7 @@ func @non_type_in_type_array_attr_fail() {
|
|||
// CHECK-LABEL: func @string_attr_custom_type
|
||||
func @string_attr_custom_type() {
|
||||
// CHECK: "string_data" : !foo.string
|
||||
test.string_attr_with_type "string_data"
|
||||
test.string_attr_with_type "string_data" : !foo.string
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -158,15 +158,8 @@ def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
|
|||
let arguments = (ins TypeArrayAttr:$attr);
|
||||
}
|
||||
def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
|
||||
let arguments = (ins StrAttr:$attr);
|
||||
let printer = [{ p << getAttr("attr"); }];
|
||||
let parser = [{
|
||||
Attribute attr;
|
||||
Type stringType = OpaqueType::get(Identifier::get("foo",
|
||||
result.getContext()), "string",
|
||||
result.getContext());
|
||||
return parser.parseAttribute(attr, stringType, "attr", result.attributes);
|
||||
}];
|
||||
let arguments = (ins TypedStrAttr<AnyType>:$attr);
|
||||
let assemblyFormat = "$attr attr-dict";
|
||||
}
|
||||
|
||||
def StrCaseA: StrEnumAttrCase<"A">;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s 2>&1 | FileCheck %s --dump-input-on-failure
|
||||
// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// This file contains tests for the specification of the declarative op format.
|
||||
|
||||
|
@ -275,6 +275,21 @@ def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
|
|||
}]> {
|
||||
let successors = (successor AnySuccessor:$successor);
|
||||
}
|
||||
// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
|
||||
def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
|
||||
$attr `:` attr-dict
|
||||
}]>, Arguments<(ins ElementsAttr:$attr)>;
|
||||
// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
|
||||
def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
|
||||
(`foo` $attr^)? `:` attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
|
||||
// CHECK-NOT: error:
|
||||
def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
|
||||
$attr `:` attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
|
||||
def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
|
||||
(`foo` $attr^)? `:` attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Coverage Checks
|
||||
|
|
|
@ -92,13 +92,23 @@ public:
|
|||
}
|
||||
const VarT *getVar() { return var; }
|
||||
|
||||
private:
|
||||
protected:
|
||||
const VarT *var;
|
||||
};
|
||||
|
||||
/// This class represents a variable that refers to an attribute argument.
|
||||
using AttributeVariable =
|
||||
VariableElement<NamedAttribute, Element::Kind::AttributeVariable>;
|
||||
struct AttributeVariable
|
||||
: public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
|
||||
using VariableElement<NamedAttribute,
|
||||
Element::Kind::AttributeVariable>::VariableElement;
|
||||
|
||||
/// Return the constant builder call for the type of this attribute, or None
|
||||
/// if it doesn't have one.
|
||||
Optional<StringRef> getTypeBuilder() const {
|
||||
Optional<Type> attrType = var->attr.getValueType();
|
||||
return attrType ? attrType->getBuilderCall() : llvm::None;
|
||||
}
|
||||
};
|
||||
|
||||
/// This class represents a variable that refers to an operand argument.
|
||||
using OperandVariable =
|
||||
|
@ -574,11 +584,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
|||
// If this attribute has a buildable type, use that when parsing the
|
||||
// attribute.
|
||||
std::string attrTypeStr;
|
||||
if (Optional<Type> attrType = var->attr.getValueType()) {
|
||||
if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
|
||||
llvm::raw_string_ostream os(attrTypeStr);
|
||||
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
|
||||
}
|
||||
if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
|
||||
llvm::raw_string_ostream os(attrTypeStr);
|
||||
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
|
||||
}
|
||||
|
||||
body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
|
||||
|
@ -932,8 +940,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
|||
}
|
||||
|
||||
// Elide the attribute type if it is buildable.
|
||||
Optional<Type> attrType = var->attr.getValueType();
|
||||
if (attrType && attrType->getBuilderCall())
|
||||
if (attr->getTypeBuilder())
|
||||
body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
|
||||
else
|
||||
body << " p.printAttribute(" << var->name << "Attr());\n";
|
||||
|
@ -1234,6 +1241,22 @@ private:
|
|||
Optional<StringRef> transformer;
|
||||
};
|
||||
|
||||
/// Verify the state of operation attributes within the format.
|
||||
LogicalResult verifyAttributes(llvm::SMLoc loc);
|
||||
|
||||
/// Verify the state of operation operands within the format.
|
||||
LogicalResult
|
||||
verifyOperands(llvm::SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
|
||||
/// Verify the state of operation results within the format.
|
||||
LogicalResult
|
||||
verifyResults(llvm::SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
|
||||
/// Verify the state of operation successors within the format.
|
||||
LogicalResult verifySuccessors(llvm::SMLoc loc);
|
||||
|
||||
/// Given the values of an `AllTypesMatch` trait, check for inferable type
|
||||
/// resolution.
|
||||
void handleAllTypesMatchConstraint(
|
||||
|
@ -1357,37 +1380,86 @@ LogicalResult FormatParser::parse() {
|
|||
}
|
||||
}
|
||||
|
||||
// Check that all of the result types can be inferred.
|
||||
auto &buildableTypes = fmt.buildableTypes;
|
||||
if (!fmt.allResultTypes) {
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
if (seenResultTypes.test(i))
|
||||
continue;
|
||||
// Verify the state of the various operation components.
|
||||
if (failed(verifyAttributes(loc)) ||
|
||||
failed(verifyResults(loc, variableTyResolver)) ||
|
||||
failed(verifyOperands(loc, variableTyResolver)) ||
|
||||
failed(verifySuccessors(loc)))
|
||||
return failure();
|
||||
|
||||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
|
||||
varResolverIt->second.transformer);
|
||||
continue;
|
||||
// Check to see if we are formatting all of the operands.
|
||||
fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
|
||||
return isa<OperandsDirective>(elt.get());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
|
||||
// Check that there are no `:` literals after an attribute without a constant
|
||||
// type. The attribute grammar contains an optional trailing colon type, which
|
||||
// can lead to unexpected and generally unintended behavior. Given that, it is
|
||||
// better to just error out here instead.
|
||||
using ElementsIterT = llvm::pointee_iterator<
|
||||
std::vector<std::unique_ptr<Element>>::const_iterator>;
|
||||
SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
|
||||
iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
|
||||
while (!iteratorStack.empty()) {
|
||||
auto &stackIt = iteratorStack.back();
|
||||
ElementsIterT &it = stackIt.first, e = stackIt.second;
|
||||
while (it != e) {
|
||||
Element *element = &*(it++);
|
||||
|
||||
// Traverse into optional groups.
|
||||
if (auto *optional = dyn_cast<OptionalElement>(element)) {
|
||||
auto elements = optional->getElements();
|
||||
iteratorStack.emplace_back(elements.begin(), elements.end());
|
||||
break;
|
||||
}
|
||||
|
||||
// If the result is not variadic, allow for the case where the type has a
|
||||
// builder that we can use.
|
||||
NamedTypeConstraint &result = op.getResult(i);
|
||||
Optional<StringRef> builder = result.constraint.getBuilderCall();
|
||||
if (!builder || result.constraint.isVariadic()) {
|
||||
return emitError(loc, "format missing instance of result #" + Twine(i) +
|
||||
"('" + result.name + "') type");
|
||||
// We are checking for an attribute element followed by a `:`, so there is
|
||||
// no need to check the end.
|
||||
if (it == e && iteratorStack.size() == 1)
|
||||
break;
|
||||
|
||||
// Check for an attribute with a constant type builder, followed by a `:`.
|
||||
auto *prevAttr = dyn_cast<AttributeVariable>(element);
|
||||
if (!prevAttr || prevAttr->getTypeBuilder())
|
||||
continue;
|
||||
|
||||
// Check the next iterator within the stack for literal elements.
|
||||
for (auto &nextItPair : iteratorStack) {
|
||||
ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
|
||||
for (; nextIt != nextE; ++nextIt) {
|
||||
// Skip any trailing optional groups or attribute dictionaries.
|
||||
if (isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
|
||||
continue;
|
||||
|
||||
// We are only interested in `:` literals.
|
||||
auto *literal = dyn_cast<LiteralElement>(&*nextIt);
|
||||
if (!literal || literal->getLiteral() != ":")
|
||||
break;
|
||||
|
||||
// TODO: Use the location of the literal element itself.
|
||||
return emitError(
|
||||
loc, llvm::formatv("format ambiguity caused by `:` literal found "
|
||||
"after attribute `{0}` which does not have "
|
||||
"a buildable type",
|
||||
prevAttr->getVar()->name));
|
||||
}
|
||||
}
|
||||
// Note in the format that this result uses the custom builder.
|
||||
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
|
||||
fmt.resultTypes[i].setBuilderIdx(it.first->second);
|
||||
}
|
||||
if (it == e)
|
||||
iteratorStack.pop_back();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::verifyOperands(
|
||||
llvm::SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
// Check that all of the operands are within the format, and their types can
|
||||
// be inferred.
|
||||
auto &buildableTypes = fmt.buildableTypes;
|
||||
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
NamedTypeConstraint &operand = op.getOperand(i);
|
||||
|
||||
|
@ -1419,22 +1491,57 @@ LogicalResult FormatParser::parse() {
|
|||
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
|
||||
fmt.operandTypes[i].setBuilderIdx(it.first->second);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::verifyResults(
|
||||
llvm::SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
// If we format all of the types together, there is nothing to check.
|
||||
if (fmt.allResultTypes)
|
||||
return success();
|
||||
|
||||
// Check that all of the result types can be inferred.
|
||||
auto &buildableTypes = fmt.buildableTypes;
|
||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
if (seenResultTypes.test(i))
|
||||
continue;
|
||||
|
||||
// Check to see if we can infer this type from another variable.
|
||||
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
|
||||
if (varResolverIt != variableTyResolver.end()) {
|
||||
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
|
||||
varResolverIt->second.transformer);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the result is not variadic, allow for the case where the type has a
|
||||
// builder that we can use.
|
||||
NamedTypeConstraint &result = op.getResult(i);
|
||||
Optional<StringRef> builder = result.constraint.getBuilderCall();
|
||||
if (!builder || result.constraint.isVariadic()) {
|
||||
return emitError(loc, "format missing instance of result #" + Twine(i) +
|
||||
"('" + result.name + "') type");
|
||||
}
|
||||
// Note in the format that this result uses the custom builder.
|
||||
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
|
||||
fmt.resultTypes[i].setBuilderIdx(it.first->second);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
|
||||
// Check that all of the successors are within the format.
|
||||
if (!hasAllSuccessors) {
|
||||
for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
|
||||
const NamedSuccessor &successor = op.getSuccessor(i);
|
||||
if (!seenSuccessors.count(&successor)) {
|
||||
return emitError(loc, "format missing instance of successor #" +
|
||||
Twine(i) + "('" + successor.name + "')");
|
||||
}
|
||||
if (hasAllSuccessors)
|
||||
return success();
|
||||
|
||||
for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
|
||||
const NamedSuccessor &successor = op.getSuccessor(i);
|
||||
if (!seenSuccessors.count(&successor)) {
|
||||
return emitError(loc, "format missing instance of successor #" +
|
||||
Twine(i) + "('" + successor.name + "')");
|
||||
}
|
||||
}
|
||||
|
||||
// Check to see if we are formatting all of the operands.
|
||||
fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
|
||||
return isa<OperandsDirective>(elt.get());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue