[mlir][DeclarativeParser] Emit an error if a `:` follows an attribute with a non-constant type.

Summary: The attribute grammar includes an optional trailing colon type, so for attributes without a constant buildable type this will generally lead to unexpected and undesired behavior. Given that, it's better to just error out on these cases.

Differential Revision: https://reviews.llvm.org/D77293
This commit is contained in:
River Riddle 2020-04-03 19:20:33 -07:00
parent ca47ac3d5f
commit e3bb36370d
6 changed files with 187 additions and 56 deletions

View File

@ -20,6 +20,7 @@ def AffineMapAttr : Attr<
CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
let storageType = [{ AffineMapAttr }];
let returnType = [{ AffineMap }];
let valueType = Index;
let constBuilderCall = "AffineMapAttr::get($0)";
}

View File

@ -319,7 +319,8 @@ class BuildableType<code builder> {
def AnyType : Type<CPred<"true">, "any type">;
// None type
def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">;
def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">,
BuildableType<"$_builder.getType<NoneType>()">;
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type<
@ -835,6 +836,7 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
let storageType = [{ BoolAttr }];
let returnType = [{ bool }];
let valueType = I1;
let constBuilderCall = "$_builder.getBoolAttr($0)";
}
@ -942,11 +944,18 @@ class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
let storageType = [{ StringAttr }];
let returnType = [{ StringRef }];
let valueType = NoneType;
}
def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
"string attribute">;
// String attribute that has a specific value type.
class TypedStrAttr<Type ty> : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
"string attribute"> {
let valueType = ty;
}
// Base class for attributes containing types. Example:
// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute">
// defines a type attribute containing an integer type.
@ -957,6 +966,7 @@ class TypeAttrBase<string retType, string description> :
description> {
let storageType = [{ TypeAttr }];
let returnType = retType;
let valueType = NoneType;
let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
}
@ -970,6 +980,7 @@ def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
let constBuilderCall = "$_builder.getUnitAttr()";
let convertFromStorage = "$_self != nullptr";
let returnType = "bool";
let valueType = NoneType;
let isOptional = 1;
}
@ -1166,6 +1177,7 @@ class DictionaryAttrBase : Attr<CPred<"$_self.isa<DictionaryAttr>()">,
"dictionary of named attribute values"> {
let storageType = [{ DictionaryAttr }];
let returnType = [{ DictionaryAttr }];
let valueType = NoneType;
let convertFromStorage = "$_self";
}
@ -1285,6 +1297,7 @@ class ArrayAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayAttr }];
let valueType = NoneType;
let convertFromStorage = "$_self";
}
@ -1364,6 +1377,7 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
"symbol reference attribute"> {
let storageType = [{ SymbolRefAttr }];
let returnType = [{ SymbolRefAttr }];
let valueType = NoneType;
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
let convertFromStorage = "$_self";
}
@ -1371,6 +1385,7 @@ def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
"flat symbol reference attribute"> {
let storageType = [{ FlatSymbolRefAttr }];
let returnType = [{ StringRef }];
let valueType = NoneType;
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
let convertFromStorage = "$_self.getValue()";
}

View File

@ -247,7 +247,7 @@ func @non_type_in_type_array_attr_fail() {
// CHECK-LABEL: func @string_attr_custom_type
func @string_attr_custom_type() {
// CHECK: "string_data" : !foo.string
test.string_attr_with_type "string_data"
test.string_attr_with_type "string_data" : !foo.string
return
}

View File

@ -158,15 +158,8 @@ def TypeArrayAttrOp : TEST_Op<"type_array_attr"> {
let arguments = (ins TypeArrayAttr:$attr);
}
def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
let arguments = (ins StrAttr:$attr);
let printer = [{ p << getAttr("attr"); }];
let parser = [{
Attribute attr;
Type stringType = OpaqueType::get(Identifier::get("foo",
result.getContext()), "string",
result.getContext());
return parser.parseAttribute(attr, stringType, "attr", result.attributes);
}];
let arguments = (ins TypedStrAttr<AnyType>:$attr);
let assemblyFormat = "$attr attr-dict";
}
def StrCaseA: StrEnumAttrCase<"A">;

View File

@ -1,4 +1,4 @@
// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s 2>&1 | FileCheck %s --dump-input-on-failure
// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s --dump-input-on-failure
// This file contains tests for the specification of the declarative op format.
@ -275,6 +275,21 @@ def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{
}]> {
let successors = (successor AnySuccessor:$successor);
}
// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
$attr `:` attr-dict
}]>, Arguments<(ins ElementsAttr:$attr)>;
// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type
def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
(`foo` $attr^)? `:` attr-dict
}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
// CHECK-NOT: error:
def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
$attr `:` attr-dict
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
(`foo` $attr^)? `:` attr-dict
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
//===----------------------------------------------------------------------===//
// Coverage Checks

View File

@ -92,13 +92,23 @@ public:
}
const VarT *getVar() { return var; }
private:
protected:
const VarT *var;
};
/// This class represents a variable that refers to an attribute argument.
using AttributeVariable =
VariableElement<NamedAttribute, Element::Kind::AttributeVariable>;
struct AttributeVariable
: public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
using VariableElement<NamedAttribute,
Element::Kind::AttributeVariable>::VariableElement;
/// Return the constant builder call for the type of this attribute, or None
/// if it doesn't have one.
Optional<StringRef> getTypeBuilder() const {
Optional<Type> attrType = var->attr.getValueType();
return attrType ? attrType->getBuilderCall() : llvm::None;
}
};
/// This class represents a variable that refers to an operand argument.
using OperandVariable =
@ -574,11 +584,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
// If this attribute has a buildable type, use that when parsing the
// attribute.
std::string attrTypeStr;
if (Optional<Type> attrType = var->attr.getValueType()) {
if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
llvm::raw_string_ostream os(attrTypeStr);
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
}
if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
llvm::raw_string_ostream os(attrTypeStr);
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
}
body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
@ -932,8 +940,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
}
// Elide the attribute type if it is buildable.
Optional<Type> attrType = var->attr.getValueType();
if (attrType && attrType->getBuilderCall())
if (attr->getTypeBuilder())
body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
else
body << " p.printAttribute(" << var->name << "Attr());\n";
@ -1234,6 +1241,22 @@ private:
Optional<StringRef> transformer;
};
/// Verify the state of operation attributes within the format.
LogicalResult verifyAttributes(llvm::SMLoc loc);
/// Verify the state of operation operands within the format.
LogicalResult
verifyOperands(llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation results within the format.
LogicalResult
verifyResults(llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation successors within the format.
LogicalResult verifySuccessors(llvm::SMLoc loc);
/// Given the values of an `AllTypesMatch` trait, check for inferable type
/// resolution.
void handleAllTypesMatchConstraint(
@ -1357,37 +1380,86 @@ LogicalResult FormatParser::parse() {
}
}
// Check that all of the result types can be inferred.
auto &buildableTypes = fmt.buildableTypes;
if (!fmt.allResultTypes) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
if (seenResultTypes.test(i))
continue;
// Verify the state of the various operation components.
if (failed(verifyAttributes(loc)) ||
failed(verifyResults(loc, variableTyResolver)) ||
failed(verifyOperands(loc, variableTyResolver)) ||
failed(verifySuccessors(loc)))
return failure();
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
varResolverIt->second.transformer);
continue;
// Check to see if we are formatting all of the operands.
fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
return isa<OperandsDirective>(elt.get());
});
return success();
}
LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
// Check that there are no `:` literals after an attribute without a constant
// type. The attribute grammar contains an optional trailing colon type, which
// can lead to unexpected and generally unintended behavior. Given that, it is
// better to just error out here instead.
using ElementsIterT = llvm::pointee_iterator<
std::vector<std::unique_ptr<Element>>::const_iterator>;
SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
while (!iteratorStack.empty()) {
auto &stackIt = iteratorStack.back();
ElementsIterT &it = stackIt.first, e = stackIt.second;
while (it != e) {
Element *element = &*(it++);
// Traverse into optional groups.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getElements();
iteratorStack.emplace_back(elements.begin(), elements.end());
break;
}
// If the result is not variadic, allow for the case where the type has a
// builder that we can use.
NamedTypeConstraint &result = op.getResult(i);
Optional<StringRef> builder = result.constraint.getBuilderCall();
if (!builder || result.constraint.isVariadic()) {
return emitError(loc, "format missing instance of result #" + Twine(i) +
"('" + result.name + "') type");
// We are checking for an attribute element followed by a `:`, so there is
// no need to check the end.
if (it == e && iteratorStack.size() == 1)
break;
// Check for an attribute with a constant type builder, followed by a `:`.
auto *prevAttr = dyn_cast<AttributeVariable>(element);
if (!prevAttr || prevAttr->getTypeBuilder())
continue;
// Check the next iterator within the stack for literal elements.
for (auto &nextItPair : iteratorStack) {
ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
for (; nextIt != nextE; ++nextIt) {
// Skip any trailing optional groups or attribute dictionaries.
if (isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
continue;
// We are only interested in `:` literals.
auto *literal = dyn_cast<LiteralElement>(&*nextIt);
if (!literal || literal->getLiteral() != ":")
break;
// TODO: Use the location of the literal element itself.
return emitError(
loc, llvm::formatv("format ambiguity caused by `:` literal found "
"after attribute `{0}` which does not have "
"a buildable type",
prevAttr->getVar()->name));
}
}
// Note in the format that this result uses the custom builder.
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.resultTypes[i].setBuilderIdx(it.first->second);
}
if (it == e)
iteratorStack.pop_back();
}
return success();
}
LogicalResult FormatParser::verifyOperands(
llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
// Check that all of the operands are within the format, and their types can
// be inferred.
auto &buildableTypes = fmt.buildableTypes;
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
NamedTypeConstraint &operand = op.getOperand(i);
@ -1419,22 +1491,57 @@ LogicalResult FormatParser::parse() {
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.operandTypes[i].setBuilderIdx(it.first->second);
}
return success();
}
LogicalResult FormatParser::verifyResults(
llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
// If we format all of the types together, there is nothing to check.
if (fmt.allResultTypes)
return success();
// Check that all of the result types can be inferred.
auto &buildableTypes = fmt.buildableTypes;
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
if (seenResultTypes.test(i))
continue;
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
fmt.resultTypes[i].setVariable(varResolverIt->second.type,
varResolverIt->second.transformer);
continue;
}
// If the result is not variadic, allow for the case where the type has a
// builder that we can use.
NamedTypeConstraint &result = op.getResult(i);
Optional<StringRef> builder = result.constraint.getBuilderCall();
if (!builder || result.constraint.isVariadic()) {
return emitError(loc, "format missing instance of result #" + Twine(i) +
"('" + result.name + "') type");
}
// Note in the format that this result uses the custom builder.
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.resultTypes[i].setBuilderIdx(it.first->second);
}
return success();
}
LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
// Check that all of the successors are within the format.
if (!hasAllSuccessors) {
for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
const NamedSuccessor &successor = op.getSuccessor(i);
if (!seenSuccessors.count(&successor)) {
return emitError(loc, "format missing instance of successor #" +
Twine(i) + "('" + successor.name + "')");
}
if (hasAllSuccessors)
return success();
for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
const NamedSuccessor &successor = op.getSuccessor(i);
if (!seenSuccessors.count(&successor)) {
return emitError(loc, "format missing instance of successor #" +
Twine(i) + "('" + successor.name + "')");
}
}
// Check to see if we are formatting all of the operands.
fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
return isa<OperandsDirective>(elt.get());
});
return success();
}