diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/RangeOp.h b/mlir/examples/Linalg/Linalg1/include/linalg1/RangeOp.h index 1973b462ad9e..09e725212f41 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/RangeOp.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/RangeOp.h @@ -41,7 +41,8 @@ public: static void build(mlir::Builder *b, mlir::OperationState *result, mlir::Value *min, mlir::Value *max, mlir::Value *step); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/SliceOp.h b/mlir/examples/Linalg/Linalg1/include/linalg1/SliceOp.h index 48f119f8a4a1..e9b5a8858493 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/SliceOp.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/SliceOp.h @@ -40,7 +40,8 @@ public: static void build(mlir::Builder *b, mlir::OperationState *result, mlir::Value *view, mlir::Value *indexing, unsigned dim); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h index 54af9a611f73..d24b709db55d 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h @@ -43,7 +43,8 @@ public: mlir::Value *memRef, llvm::ArrayRef indexings); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/examples/Linalg/Linalg1/lib/RangeOp.cpp b/mlir/examples/Linalg/Linalg1/lib/RangeOp.cpp index 6899ed698693..080a5e8400f0 100644 --- a/mlir/examples/Linalg/Linalg1/lib/RangeOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/RangeOp.cpp @@ -48,15 +48,17 @@ mlir::LogicalResult linalg::RangeOp::verify() { return mlir::success(); } -bool linalg::RangeOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult linalg::RangeOp::parse(OpAsmParser *parser, + OperationState *result) { SmallVector rangeInfo(3); RangeType type; auto indexTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || - parser->parseOperand(rangeInfo[1]) || parser->parseColon() || - parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || - parser->resolveOperands(rangeInfo, indexTy, result->operands) || - parser->addTypeToList(type, result->types); + return failure( + parser->parseOperand(rangeInfo[0]) || parser->parseColon() || + parser->parseOperand(rangeInfo[1]) || parser->parseColon() || + parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || + parser->resolveOperands(rangeInfo, indexTy, result->operands) || + parser->addTypeToList(type, result->types)); } // A RangeOp prints as: diff --git a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp index 56a4aede2f40..dc1ffce6d68c 100644 --- a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp @@ -74,7 +74,8 @@ mlir::LogicalResult linalg::SliceOp::verify() { return mlir::success(); } -bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult linalg::SliceOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType viewInfo; SmallVector indexingInfo; SmallVector types; @@ -83,7 +84,7 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) - return true; + return failure(); if (indexingInfo.size() != 1) return parser->emitError(parser->getNameLoc(), "expected 1 indexing type"); @@ -107,10 +108,10 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { ViewType resultViewType = ViewType::get(viewType.getContext(), viewType.getElementType(), rank); - return parser->resolveOperand(viewInfo, viewType, result->operands) || - parser->resolveOperands(indexingInfo[0], types.back(), - result->operands) || - parser->addTypeToList(resultViewType, result->types); + return failure(parser->resolveOperand(viewInfo, viewType, result->operands) || + parser->resolveOperands(indexingInfo[0], types.back(), + result->operands) || + parser->addTypeToList(resultViewType, result->types)); } // A SliceOp prints as: diff --git a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp index 6cd0f27c4a4c..1ce241a67722 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp @@ -89,7 +89,7 @@ LogicalResult linalg::ViewOp::verify() { return success(); } -bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memRefInfo; SmallVector indexingsInfo; SmallVector types; @@ -98,7 +98,7 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) - return true; + return failure(); if (types.size() != 2 + indexingsInfo.size()) return parser->emitError(parser->getNameLoc(), @@ -120,12 +120,13 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "expected " + Twine(memRefType.getRank()) + " indexing types"); - return parser->resolveOperand(memRefInfo, memRefType, result->operands) || - (!indexingsInfo.empty() && - parser->resolveOperands(indexingsInfo, indexingTypes, - indexingsInfo.front().location, - result->operands)) || - parser->addTypeToList(viewType, result->types); + return failure( + parser->resolveOperand(memRefInfo, memRefType, result->operands) || + (!indexingsInfo.empty() && + parser->resolveOperands(indexingsInfo, indexingTypes, + indexingsInfo.front().location, + result->operands)) || + parser->addTypeToList(viewType, result->types)); } // A ViewOp prints as: diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h index 940f8d7d312c..19dd6956f791 100644 --- a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h @@ -83,8 +83,9 @@ mlir::LogicalResult linalg::TensorContractionBase::verify() { } template -bool linalg::TensorContractionBase::parse( - mlir::OpAsmParser *parser, mlir::OperationState *result) { +mlir::ParseResult +linalg::TensorContractionBase::parse(mlir::OpAsmParser *parser, + mlir::OperationState *result) { llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); } diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h index 5eccbe68c24d..a4a341fac6d0 100644 --- a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h +++ b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h @@ -38,7 +38,8 @@ protected: ////////////////////////////////////////////////////////////////////////////// /// Generic implementation of hooks that should be called from `ConcreteType`s mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); public: @@ -118,7 +119,8 @@ public: return build(b, result, {A, B, C}); } mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// @@ -179,7 +181,8 @@ public: return build(b, result, {A, B, C}); } mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// @@ -240,7 +243,8 @@ public: return build(b, result, {A, B, C}); } mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp b/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp index 8a47e5d70eab..ce19d10681a6 100644 --- a/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp @@ -58,8 +58,8 @@ LogicalResult linalg::DotOp::verify() { } // Parsing of the linalg dialect is not supported in this tutorial. -bool linalg::DotOp::parse(mlir::OpAsmParser *parser, - mlir::OperationState *result) { +ParseResult linalg::DotOp::parse(mlir::OpAsmParser *parser, + mlir::OperationState *result) { return TensorContractionBaseType::parse(parser, result); } @@ -92,8 +92,8 @@ LogicalResult linalg::MatvecOp::verify() { } // Parsing of the linalg dialect is not supported in this tutorial. -bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser, - mlir::OperationState *result) { +ParseResult linalg::MatvecOp::parse(mlir::OpAsmParser *parser, + mlir::OperationState *result) { return TensorContractionBaseType::parse(parser, result); } @@ -123,8 +123,8 @@ LogicalResult linalg::MatmulOp::verify() { } // Parsing of the linalg dialect is not supported in this tutorial. -bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser, - mlir::OperationState *result) { +ParseResult linalg::MatmulOp::parse(mlir::OpAsmParser *parser, + mlir::OperationState *result) { return TensorContractionBaseType::parse(parser, result); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/LoadStoreOps.h b/mlir/examples/Linalg/Linalg3/include/linalg3/LoadStoreOps.h index a227613857a0..eb0e1ac360ed 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/LoadStoreOps.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/LoadStoreOps.h @@ -41,7 +41,8 @@ public: mlir::Value *view, mlir::ArrayRef indices = {}); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// @@ -71,7 +72,8 @@ public: mlir::Value *valueToStore, mlir::Value *view, mlir::ArrayRef indices = {}); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static mlir::ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); ////////////////////////////////////////////////////////////////////////////// diff --git a/mlir/examples/Linalg/Linalg3/lib/LoadStoreOps.cpp b/mlir/examples/Linalg/Linalg3/lib/LoadStoreOps.cpp index 340916f013b7..7a733339d61e 100644 --- a/mlir/examples/Linalg/Linalg3/lib/LoadStoreOps.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/LoadStoreOps.cpp @@ -49,9 +49,9 @@ void linalg::LoadOp::print(OpAsmPrinter *p) { *p << " : " << getViewType(); } -bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) { llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); - return false; + return success(); } LogicalResult linalg::LoadOp::verify() { @@ -101,9 +101,10 @@ void linalg::StoreOp::print(OpAsmPrinter *p) { *p << " : " << getViewType(); } -bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult linalg::StoreOp::parse(OpAsmParser *parser, + OperationState *result) { assert(false && "NYI"); - return false; + return success(); } LogicalResult linalg::StoreOp::verify() { diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index d46ecc55c20a..35244f93cc74 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -80,7 +80,7 @@ public: static StringRef getOperationName() { return "affine.apply"; } // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); Attribute constantFold(ArrayRef operands, MLIRContext *context); @@ -130,7 +130,7 @@ public: static void build(Builder *builder, OperationState *result, int64_t lb, int64_t ub, int64_t step = 1); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -326,7 +326,7 @@ public: Region &getElseBlocks(); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); }; diff --git a/mlir/include/mlir/GPU/GPUDialect.h b/mlir/include/mlir/GPU/GPUDialect.h index e555e00bfb0f..384af8e82bec 100644 --- a/mlir/include/mlir/GPU/GPUDialect.h +++ b/mlir/include/mlir/GPU/GPUDialect.h @@ -80,7 +80,7 @@ public: /// Custom syntax support. void print(OpAsmPrinter *p); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); static StringRef getOperationName() { return "gpu.launch"; } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index a0f0d549da73..986b9a40b4e0 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -54,6 +54,18 @@ template struct IsSingleResult { OpType *, OpTrait::OneResult *>::value; }; +/// This class represents success/failure for operation parsing. It is +/// essentially a simple wrapper class around LogicalResult that allows for +/// explicit conversion to bool. This allows for the parser to chain together +/// parse rules without the clutter of "failed/succeeded". +class ParseResult : public LogicalResult { +public: + ParseResult(LogicalResult result = success()) : LogicalResult(result) {} + + /// Failure is true in a boolean context. + explicit operator bool() const { return failed(*this); } +}; + /// This is the concrete base class that holds the operation pointer and has /// non-generic methods that only depend on State (to avoid having them /// instantiated on template types that don't affect them. @@ -132,10 +144,9 @@ protected: LogicalResult verify() { return success(); } /// Unless overridden, the custom assembly form of an op is always rejected. - /// Op implementations should implement this to return true on failure. - /// On success, they should return false and fill in result with the fields to - /// use. - static bool parse(OpAsmParser *parser, OperationState *result); + /// Op implementations should implement this to return failure. + /// On success, they should fill in result with the fields to use. + static ParseResult parse(OpAsmParser *parser, OperationState *result); // The fallback for the printer is to print it the generic assembly form. void print(OpAsmPrinter *p); @@ -768,9 +779,10 @@ public: /// This is the hook used by the AsmParser to parse the custom form of this /// op from an .mlir file. Op implementations should provide a parse method, - /// which returns boolean true on failure. On success, they should return - /// false and fill in result with the fields to use. - static bool parseAssembly(OpAsmParser *parser, OperationState *result) { + /// which returns failure. On success, they should return fill in result with + /// the fields to use. + static ParseResult parseAssembly(OpAsmParser *parser, + OperationState *result) { return ConcreteType::parse(parser, result); } @@ -854,7 +866,7 @@ private: namespace impl { void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, Value *rhs); -bool parseBinaryOp(OpAsmParser *parser, OperationState *result); +ParseResult parseBinaryOp(OpAsmParser *parser, OperationState *result); // Prints the given binary `op` in custom assembly form if both the two operands // and the result have the same time. Otherwise, prints the generic assembly // form. @@ -866,7 +878,7 @@ void printBinaryOp(Operation *op, OpAsmPrinter *p); namespace impl { void buildCastOp(Builder *builder, OperationState *result, Value *source, Type destType); -bool parseCastOp(OpAsmParser *parser, OperationState *result); +ParseResult parseCastOp(OpAsmParser *parser, OperationState *result); void printCastOp(Operation *op, OpAsmPrinter *p); Value *foldCastOp(Operation *op); } // namespace impl @@ -888,7 +900,7 @@ public: Type destType) { impl::buildCastOp(builder, result, source, destType); } - static bool parse(OpAsmParser *parser, OperationState *result) { + static ParseResult parse(OpAsmParser *parser, OperationState *result) { return impl::parseCastOp(parser, result); } void print(OpAsmPrinter *p) { diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index a818c01178e1..8d15921f41db 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -148,134 +148,138 @@ public: // High level parsing methods. //===--------------------------------------------------------------------===// - // These emit an error and return true on failure, or return false on success. + // These emit an error and return failure or success. // This allows these to be chained together into a linear sequence of || // expressions in many cases. /// Get the location of the next token and store it into the argument. This /// always succeeds. - virtual bool getCurrentLocation(llvm::SMLoc *loc) = 0; + virtual ParseResult getCurrentLocation(llvm::SMLoc *loc) = 0; /// This parses... a comma! - virtual bool parseComma() = 0; + virtual ParseResult parseComma() = 0; /// Parses a comma if present. - virtual bool parseOptionalComma() = 0; + virtual ParseResult parseOptionalComma() = 0; /// Parse a `:` token. - virtual bool parseColon() = 0; + virtual ParseResult parseColon() = 0; /// Parse a '(' token. - virtual bool parseLParen() = 0; + virtual ParseResult parseLParen() = 0; /// Parse a ')' token. - virtual bool parseRParen() = 0; + virtual ParseResult parseRParen() = 0; /// This parses an equal(=) token! - virtual bool parseEqual() = 0; + virtual ParseResult parseEqual() = 0; /// Parse a type. - virtual bool parseType(Type &result) = 0; + virtual ParseResult parseType(Type &result) = 0; /// Parse a colon followed by a type. - virtual bool parseColonType(Type &result) = 0; + virtual ParseResult parseColonType(Type &result) = 0; /// Parse a type of a specific kind, e.g. a FunctionType. - template bool parseColonType(TypeType &result) { + template ParseResult parseColonType(TypeType &result) { llvm::SMLoc loc; getCurrentLocation(&loc); // Parse any kind of type. Type type; if (parseColonType(type)) - return true; + return failure(); // Check for the right kind of attribute. result = type.dyn_cast(); if (!result) return emitError(loc, "invalid kind of type specified"); - return false; + return success(); } /// Parse a colon followed by a type list, which must have at least one type. - virtual bool parseColonTypeList(SmallVectorImpl &result) = 0; + virtual ParseResult parseColonTypeList(SmallVectorImpl &result) = 0; /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) { - return parseKeyword(keyword) || parseType(result); + ParseResult parseKeywordType(const char *keyword, Type &result) { + return failure(parseKeyword(keyword) || parseType(result)); } /// Parse a keyword. - bool parseKeyword(const char *keyword, const Twine &msg = "") { + ParseResult parseKeyword(const char *keyword, const Twine &msg = "") { if (parseOptionalKeyword(keyword)) return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg); - return false; + return success(); } /// If a keyword is present, then parse it. - virtual bool parseOptionalKeyword(const char *keyword) = 0; + virtual ParseResult parseOptionalKeyword(const char *keyword) = 0; /// Add the specified type to the end of the specified type list and return - /// false. This is a helper designed to allow parse methods to be simple and - /// chain through || operators. - bool addTypeToList(Type type, SmallVectorImpl &result) { + /// success. This is a helper designed to allow parse methods to be simple + /// and chain through || operators. + ParseResult addTypeToList(Type type, SmallVectorImpl &result) { result.push_back(type); - return false; + return success(); } /// Add the specified types to the end of the specified type list and return - /// false. This is a helper designed to allow parse methods to be simple and - /// chain through || operators. - bool addTypesToList(ArrayRef types, SmallVectorImpl &result) { + /// success. This is a helper designed to allow parse methods to be simple + /// and chain through || operators. + ParseResult addTypesToList(ArrayRef types, + SmallVectorImpl &result) { result.append(types.begin(), types.end()); - return false; + return success(); } /// Parse an arbitrary attribute and return it in result. This also adds the /// attribute to the specified attribute list with the specified name. - virtual bool parseAttribute(Attribute &result, StringRef attrName, - SmallVectorImpl &attrs) = 0; + virtual ParseResult + parseAttribute(Attribute &result, StringRef attrName, + SmallVectorImpl &attrs) = 0; /// Parse an arbitrary attribute of a given type and return it in result. This /// also adds the attribute to the specified attribute list with the specified /// name. - virtual bool parseAttribute(Attribute &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) = 0; + virtual ParseResult + parseAttribute(Attribute &result, Type type, StringRef attrName, + SmallVectorImpl &attrs) = 0; /// Parse an attribute of a specific kind and type. template - bool parseAttribute(AttrType &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) { + ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, + SmallVectorImpl &attrs) { llvm::SMLoc loc; getCurrentLocation(&loc); // Parse any kind of attribute. Attribute attr; if (parseAttribute(attr, type, attrName, attrs)) - return true; + return failure(); // Check for the right kind of attribute. result = attr.dyn_cast(); if (!result) return emitError(loc, "invalid kind of constant specified"); - return false; + return success(); } /// If a named attribute dictionary is present, parse it into result. - virtual bool + virtual ParseResult parseOptionalAttributeDict(SmallVectorImpl &result) = 0; /// Parse a function name like '@foo' and return the name in a form that can /// be passed to resolveFunctionName when a function type is available. - virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) = 0; + virtual ParseResult parseFunctionName(StringRef &result, + llvm::SMLoc &loc) = 0; /// Parse a function name like '@foo` if present and return the name without /// the sigil in `result`. Return true if the next token is not a function /// name and keep `result` unchanged. - virtual bool parseOptionalFunctionName(StringRef &result, - llvm::SMLoc &loc) = 0; + virtual ParseResult parseOptionalFunctionName(StringRef &result, + llvm::SMLoc &loc) = 0; /// This is the representation of an operand reference. struct OperandType { @@ -285,11 +289,12 @@ public: }; /// Parse a single operand. - virtual bool parseOperand(OperandType &result) = 0; + virtual ParseResult parseOperand(OperandType &result) = 0; /// Parse a single operation successor and it's operand list. - virtual bool parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) = 0; + virtual ParseResult + parseSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands) = 0; /// These are the supported delimiters around operand lists, used by /// parseOperandList. @@ -308,14 +313,15 @@ public: /// Parse zero or more SSA comma-separated operand references with a specified /// surrounding delimiter, and an optional required operand count. - virtual bool parseOperandList(SmallVectorImpl &result, - int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) = 0; + virtual ParseResult + parseOperandList(SmallVectorImpl &result, + int requiredOperandCount = -1, + Delimiter delimiter = Delimiter::None) = 0; /// Parse zero or more trailing SSA comma-separated trailing operand /// references with a specified surrounding delimiter, and an optional /// required operand count. A leading comma is expected before the operands. - virtual bool + virtual ParseResult parseTrailingOperandList(SmallVectorImpl &result, int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; @@ -323,12 +329,13 @@ public: /// Parses a region. Any parsed blocks are appended to "region" and must be /// moved to the op regions after the op is created. The first block of the /// region takes "arguments" of types "argTypes". - virtual bool parseRegion(Region ®ion, ArrayRef arguments, - ArrayRef argTypes) = 0; + virtual ParseResult parseRegion(Region ®ion, + ArrayRef arguments, + ArrayRef argTypes) = 0; /// Parse a region argument. Region arguments define new values, so this also /// checks if the values with the same name has not been defined yet. - virtual bool parseRegionArgument(OperandType &argument) = 0; + virtual ParseResult parseRegionArgument(OperandType &argument) = 0; //===--------------------------------------------------------------------===// // Methods for interacting with the parser @@ -341,46 +348,45 @@ public: /// Return the location of the original name token. virtual llvm::SMLoc getNameLoc() const = 0; - /// Resolve an operand to an SSA value, emitting an error and returning true - /// on failure. - virtual bool resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) = 0; + /// Resolve an operand to an SSA value, emitting an error on failure. + virtual ParseResult resolveOperand(const OperandType &operand, Type type, + SmallVectorImpl &result) = 0; - /// Resolve a list of operands to SSA values, emitting an error and returning - /// true on failure, or appending the results to the list on success. - /// This method should be used when all operands have the same type. - virtual bool resolveOperands(ArrayRef operands, Type type, - SmallVectorImpl &result) { + /// Resolve a list of operands to SSA values, emitting an error on failure, or + /// appending the results to the list on success. This method should be used + /// when all operands have the same type. + virtual ParseResult resolveOperands(ArrayRef operands, Type type, + SmallVectorImpl &result) { for (auto elt : operands) if (resolveOperand(elt, type, result)) - return true; - return false; + return failure(); + return success(); } /// Resolve a list of operands and a list of operand types to SSA values, - /// emitting an error and returning true on failure, or appending the results + /// emitting an error and returning failure, or appending the results /// to the list on success. - virtual bool resolveOperands(ArrayRef operands, - ArrayRef types, llvm::SMLoc loc, - SmallVectorImpl &result) { + virtual ParseResult resolveOperands(ArrayRef operands, + ArrayRef types, llvm::SMLoc loc, + SmallVectorImpl &result) { if (operands.size() != types.size()) return emitError(loc, Twine(operands.size()) + " operands present, but expected " + Twine(types.size())); - for (unsigned i = 0, e = operands.size(); i != e; ++i) { + for (unsigned i = 0, e = operands.size(); i != e; ++i) if (resolveOperand(operands[i], types[i], result)) - return true; - } - return false; + return failure(); + return success(); } /// Resolve a parse function name and a type into a function reference. - virtual bool resolveFunctionName(StringRef name, FunctionType type, - llvm::SMLoc loc, Function *&result) = 0; + virtual ParseResult resolveFunctionName(StringRef name, FunctionType type, + llvm::SMLoc loc, + Function *&result) = 0; - /// Emit a diagnostic at the specified location and return true. - virtual bool emitError(llvm::SMLoc loc, const Twine &message) = 0; + /// Emit a diagnostic at the specified location and return failure. + virtual ParseResult emitError(llvm::SMLoc loc, const Twine &message) = 0; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index ecab75aabc75..ad2825b3a1c1 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -42,6 +42,7 @@ struct OperationState; class OpAsmParser; class OpAsmParserResult; class OpAsmPrinter; +class ParseResult; class Pattern; class Region; class RewritePattern; @@ -85,7 +86,7 @@ public: bool (&isClassFor)(Operation *op); /// Use the specified object to parse this ops custom assembly format. - bool (&parseAssembly)(OpAsmParser *parser, OperationState *result); + ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result); /// This hook implements the AsmPrinter for this operation. void (&printAssembly)(Operation *op, OpAsmPrinter *p); @@ -150,7 +151,7 @@ private: AbstractOperation( StringRef name, Dialect &dialect, OperationProperties opProperties, bool (&isClassFor)(Operation *op), - bool (&parseAssembly)(OpAsmParser *parser, OperationState *result), + ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result), void (&printAssembly)(Operation *op, OpAsmPrinter *p), LogicalResult (&verifyInvariants)(Operation *op), LogicalResult (&constantFoldHook)(Operation *op, diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index 58ff5a95d8f3..9472c71d1eab 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -43,7 +43,7 @@ public: static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; } static void build(Builder *b, OperationState *result, Type type, Value *size); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); // Op-specific functionality. @@ -67,7 +67,7 @@ public: static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; } static void build(Builder *b, OperationState *result, Value *buffer); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); // Op-specific functionality. @@ -94,7 +94,7 @@ public: static void build(Builder *b, OperationState *result, Value *min, Value *max, Value *step); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); // Op-specific functionality. @@ -156,7 +156,8 @@ public: static void build(mlir::Builder *b, mlir::OperationState *result, mlir::Value *base, llvm::ArrayRef indexings); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); // Op-specific functionality. @@ -207,7 +208,8 @@ public: mlir::Value *buffer, llvm::ArrayRef indexings); mlir::LogicalResult verify(); - static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + static ParseResult parse(mlir::OpAsmParser *parser, + mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); // Op-specific functionality. diff --git a/mlir/include/mlir/StandardOps/Ops.h b/mlir/include/mlir/StandardOps/Ops.h index 408810889676..8d2ee87776df 100644 --- a/mlir/include/mlir/StandardOps/Ops.h +++ b/mlir/include/mlir/StandardOps/Ops.h @@ -80,7 +80,7 @@ public: static void build(Builder *builder, OperationState *result, MemRefType memrefType, ArrayRef operands = {}); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); @@ -108,7 +108,7 @@ public: ArrayRef operands = {}); // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); /// Return the block this branch jumps to. @@ -149,7 +149,7 @@ public: operand_iterator arg_operand_end() { return operand_end(); } // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); }; @@ -183,7 +183,7 @@ public: operand_iterator arg_operand_end() { return operand_end(); } // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); static void getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -249,7 +249,7 @@ public: static void build(Builder *builder, OperationState *result, CmpIPredicate, Value *lhs, Value *rhs); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); Attribute constantFold(ArrayRef operands, MLIRContext *context); @@ -324,7 +324,7 @@ public: static void build(Builder *builder, OperationState *result, CmpFPredicate, Value *lhs, Value *rhs); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); Attribute constantFold(ArrayRef operands, MLIRContext *context); @@ -362,7 +362,7 @@ public: Block *falseDest, ArrayRef falseOperands); // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); @@ -521,7 +521,7 @@ public: // Hooks to customize behavior of this op. static void build(Builder *builder, OperationState *result, Value *memref); LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); @@ -553,7 +553,7 @@ public: // Hooks to customize behavior of this op. LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); }; @@ -682,7 +682,7 @@ public: } static StringRef getOperationName() { return "std.dma_start"; } - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); @@ -748,7 +748,7 @@ public: // Returns the number of elements transferred in the associated DMA operation. Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); } - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); @@ -785,7 +785,7 @@ public: // Hooks to customize behavior of this op. LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); Attribute constantFold(ArrayRef operands, MLIRContext *context); }; @@ -821,7 +821,7 @@ public: static StringRef getOperationName() { return "std.load"; } LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); @@ -881,7 +881,7 @@ public: ArrayRef results = {}); // Hooks to customize behavior of this op. - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); }; @@ -906,7 +906,7 @@ public: static StringRef getOperationName() { return "std.select"; } static void build(Builder *builder, OperationState *result, Value *condition, Value *trueValue, Value *falseValue); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); @@ -953,7 +953,7 @@ public: static StringRef getOperationName() { return "std.store"; } LogicalResult verify(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); static void getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -994,9 +994,9 @@ void printDimAndSymbolList(Operation::operand_iterator begin, OpAsmPrinter *p); /// Parses dimension and symbol list and returns true if parsing failed. -bool parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, - unsigned &numDims); +ParseResult parseDimAndSymbolList(OpAsmParser *parser, + SmallVector &operands, + unsigned &numDims); } // end namespace mlir diff --git a/mlir/include/mlir/VectorOps/VectorOps.h b/mlir/include/mlir/VectorOps/VectorOps.h index 30e5161b5767..f529f98c29dd 100644 --- a/mlir/include/mlir/VectorOps/VectorOps.h +++ b/mlir/include/mlir/VectorOps/VectorOps.h @@ -115,7 +115,7 @@ public: Optional getPaddingValue(); AffineMap getPermutationMap(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); }; @@ -177,7 +177,7 @@ public: operand_range getIndices(); AffineMap getPermutationMap(); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); }; @@ -200,7 +200,7 @@ public: static StringRef getOperationName() { return "vector.type_cast"; } static void build(Builder *builder, OperationState *result, Value *srcVector, Type dstType); - static bool parse(OpAsmParser *parser, OperationState *result); + static ParseResult parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p); LogicalResult verify(); }; diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 4135c612e18b..51209da7385d 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -131,7 +131,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result, result->addAttribute("map", builder->getAffineMapAttr(map)); } -bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); auto affineIntTy = builder.getIndexType(); @@ -140,7 +140,7 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->parseAttribute(mapAttr, "map", result->attributes) || parseDimAndSymbolList(parser, result->operands, numDims) || parser->parseOptionalAttributeDict(result->attributes)) - return true; + return failure(); auto map = mapAttr.getValue(); if (map.getNumDims() != numDims || @@ -150,7 +150,7 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { } result->types.append(map.getNumResults(), affineIntTy); - return false; + return success(); } void AffineApplyOp::print(OpAsmPrinter *p) { @@ -801,10 +801,12 @@ LogicalResult AffineForOp::verify() { } /// Parse a for operation loop bounds. -static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { +static ParseResult parseBound(bool isLower, OperationState *result, + OpAsmParser *p) { // 'min' / 'max' prefixes are generally syntactic sugar, but are required if // the map has multiple results. - bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min"); + bool failedToParsedMinMax = + failed(p->parseOptionalKeyword(isLower ? "max" : "min")); auto &builder = p->getBuilder(); auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() @@ -813,7 +815,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { // Parse ssa-id as identity map. SmallVector boundOpInfos; if (p->parseOperandList(boundOpInfos)) - return true; + return failure(); if (!boundOpInfos.empty()) { // Check that only one operand was parsed. @@ -825,14 +827,14 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { // Currently it is 'use of value ... expects different type than prior uses' if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), result->operands)) - return true; + return failure(); // Create an identity map using symbol id. This representation is optimized // for storage. Analysis passes may expand it into a multi-dimensional map // if desired. AffineMap map = builder.getSymbolIdentityMap(); result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); - return false; + return success(); } // Get the attribute location. @@ -842,14 +844,14 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { Attribute boundAttr; if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, result->attributes)) - return true; + return failure(); // Parse full form - affine map followed by dim and symbol list. if (auto affineMapAttr = boundAttr.dyn_cast()) { unsigned currentNumOperands = result->operands.size(); unsigned numDims; if (parseDimAndSymbolList(p, result->operands, numDims)) - return true; + return failure(); auto map = affineMapAttr.getValue(); if (map.getNumDims() != numDims) @@ -874,7 +876,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { return p->emitError(attrLoc, "upper loop bound affine map with multiple " "results requires 'min' prefix"); } - return false; + return success(); } // Parse custom assembly form. @@ -883,7 +885,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { result->addAttribute( boundAttrName, builder.getAffineMapAttr( builder.getConstantAffineMap(integerAttr.getInt()))); - return false; + return success(); } return p->emitError( @@ -891,18 +893,18 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { "expected valid affine map representation for loop bounds"); } -bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult AffineForOp::parse(OpAsmParser *parser, OperationState *result) { auto &builder = parser->getBuilder(); OpAsmParser::OperandType inductionVariable; // Parse the induction variable followed by '='. if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) - return true; + return failure(); // Parse loop bounds. if (parseBound(/*isLower=*/true, result, parser) || parser->parseKeyword("to", " between bounds") || parseBound(/*isLower=*/false, result, parser)) - return true; + return failure(); // Parse the optional loop step, we default to 1 if one is not present. if (parser->parseOptionalKeyword("step")) { @@ -915,7 +917,7 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->getCurrentLocation(&stepLoc) || parser->parseAttribute(stepAttr, builder.getIndexType(), getStepAttrName().data(), result->attributes)) - return true; + return failure(); if (stepAttr.getValue().getSExtValue() < 0) return parser->emitError( @@ -926,17 +928,17 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { // Parse the body region. Region *body = result->addRegion(); if (parser->parseRegion(*body, inductionVariable, builder.getIndexType())) - return true; + return failure(); ensureAffineTerminator(*body, builder, result->location); // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) - return true; + return failure(); // Set the operands list as resizable so that we can freely modify the bounds. result->setOperandListToResizable(); - return false; + return success(); } static void printBound(AffineMapAttr boundMap, @@ -1253,14 +1255,14 @@ LogicalResult AffineIfOp::verify() { return success(); } -bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { // Parse the condition attribute set. IntegerSetAttr conditionAttr; unsigned numDims; if (parser->parseAttribute(conditionAttr, getConditionAttrName(), result->attributes) || parseDimAndSymbolList(parser, result->operands, numDims)) - return true; + return failure(); // Verify the condition operands. auto set = conditionAttr.getValue(); @@ -1281,21 +1283,21 @@ bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { // Parse the 'then' region. if (parser->parseRegion(*thenRegion, {}, {})) - return true; + return failure(); ensureAffineTerminator(*thenRegion, parser->getBuilder(), result->location); // If we find an 'else' keyword then parse the 'else' region. if (!parser->parseOptionalKeyword("else")) { if (parser->parseRegion(*elseRegion, {}, {})) - return true; + return failure(); ensureAffineTerminator(*elseRegion, parser->getBuilder(), result->location); } // Parse the optional attribute list. if (parser->parseOptionalAttributeDict(result->attributes)) - return true; + return failure(); - return false; + return success(); } void AffineIfOp::print(OpAsmPrinter *p) { diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index a55e92576f1f..69d2b7291c29 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -159,7 +159,7 @@ void LaunchOp::print(OpAsmPrinter *p) { // where %region_arg are percent-identifiers for the region arguments to be // introduced futher (SSA defs), and %operand are percent-identifiers for the // SSA value uses. -static bool +static ParseResult parseSizeAssignment(OpAsmParser *parser, MutableArrayRef sizes, MutableArrayRef regionSizes, @@ -169,14 +169,14 @@ parseSizeAssignment(OpAsmParser *parser, parser->parseComma() || parser->parseRegionArgument(indices[2]) || parser->parseRParen() || parser->parseKeyword("in") || parser->parseLParen()) - return true; + return failure(); for (int i = 0; i < 3; ++i) { if (i != 0 && parser->parseComma()) - return true; + return failure(); if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() || parser->parseOperand(sizes[i])) - return true; + return failure(); } return parser->parseRParen(); @@ -188,7 +188,7 @@ parseSizeAssignment(OpAsmParser *parser, // (`args` ssa-reassignment `:` type-list)? // region attr-dict? // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` -bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { // Sizes of the grid and block. SmallVector sizes( kNumConfigOperands); @@ -217,7 +217,7 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) { regionArgsRef.slice(3, 3)) || parser->resolveOperands(sizes, parser->getBuilder().getIndexType(), result->operands)) - return true; + return failure(); // If kernel argument renaming segment is present, parse it. When present, // the segment should have at least one element. If this segment is present, @@ -232,20 +232,20 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->getCurrentLocation(&argsLoc) || parser->parseLParen() || parser->parseRegionArgument(regionArgs.back()) || parser->parseEqual() || parser->parseOperand(dataOperands.back())) - return true; + return failure(); while (!parser->parseOptionalComma()) { regionArgs.push_back({}); dataOperands.push_back({}); if (parser->parseRegionArgument(regionArgs.back()) || parser->parseEqual() || parser->parseOperand(dataOperands.back())) - return true; + return failure(); } if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) || parser->resolveOperands(dataOperands, dataTypes, argsLoc, result->operands)) - return true; + return failure(); } // Introduce the body region and parse it. The region has @@ -255,11 +255,10 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) { Type index = parser->getBuilder().getIndexType(); dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); Region *body = result->addRegion(); - return parser->parseRegion(*body, regionArgs, dataTypes) || - parser->parseOptionalAttributeDict(result->attributes); + return failure(parser->parseRegion(*body, regionArgs, dataTypes) || + parser->parseOptionalAttributeDict(result->attributes)); } - //===----------------------------------------------------------------------===// // LaunchFuncOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 6074fd592673..992c66ea1d1c 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -640,7 +640,7 @@ Operation *Operation::clone(MLIRContext *context) { //===----------------------------------------------------------------------===// // The fallback for the parser is to reject the custom assembly form. -bool OpState::parse(OpAsmParser *parser, OperationState *result) { +ParseResult OpState::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "has no custom assembly form"); } @@ -948,14 +948,14 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, result->types.push_back(lhs->getType()); } -bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { +ParseResult impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { SmallVector ops; Type type; - return parser->parseOperandList(ops, 2) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperands(ops, type, result->operands) || - parser->addTypeToList(type, result->types); + return failure(parser->parseOperandList(ops, 2) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperands(ops, type, result->operands) || + parser->addTypeToList(type, result->types)); } void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) { @@ -988,13 +988,14 @@ void impl::buildCastOp(Builder *builder, OperationState *result, Value *source, result->addTypes(destType); } -bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { +ParseResult impl::parseCastOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType srcInfo; Type srcType, dstType; - return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || - parser->resolveOperand(srcInfo, srcType, result->operands) || - parser->parseKeywordType("to", dstType) || - parser->addTypeToList(dstType, result->types); + return failure(parser->parseOperand(srcInfo) || + parser->parseColonType(srcType) || + parser->resolveOperand(srcInfo, srcType, result->operands) || + parser->parseKeywordType("to", dstType) || + parser->addTypeToList(dstType, result->types)); } void impl::printCastOp(Operation *op, OpAsmPrinter *p) { diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 20d5463a6b73..48cc47620523 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -126,7 +126,7 @@ static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) { // ::= `llvm.icmp` string-literal ssa-use `,` ssa-use // attribute-dict? `:` type -static bool parseICmpOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) { Builder &builder = parser->getBuilder(); Attribute predicate; @@ -142,7 +142,7 @@ static bool parseICmpOp(OpAsmParser *parser, OperationState *result) { parser->parseType(type) || parser->resolveOperand(lhs, type, result->operands) || parser->resolveOperand(rhs, type, result->operands)) - return true; + return failure(); // Replace the string attribute `predicate` with an integer attribute. auto predicateStr = predicate.dyn_cast(); @@ -173,7 +173,7 @@ static bool parseICmpOp(OpAsmParser *parser, OperationState *result) { result->attributes = attrs; result->addTypes({resultType}); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -195,7 +195,7 @@ static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) { // ::= `llvm.alloca` ssa-use `x` type attribute-dict? // `:` type `,` type -static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) { SmallVector attrs; OpAsmParser::OperandType arraySize; Type type, elemType; @@ -204,7 +204,7 @@ static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) { parser->parseType(elemType) || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return true; + return failure(); // Extract the result type from the trailing function type. auto funcType = type.dyn_cast(); @@ -215,11 +215,11 @@ static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) { "expected trailing function type with one argument and one result"); if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands)) - return true; + return failure(); result->attributes = attrs; result->addTypes({funcType.getResult(0)}); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -242,7 +242,7 @@ static void printGEPOp(OpAsmPrinter *p, GEPOp &op) { // ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]` // attribute-dict? `:` type -static bool parseGEPOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) { SmallVector attrs; OpAsmParser::OperandType base; SmallVector indices; @@ -253,7 +253,7 @@ static bool parseGEPOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return true; + return failure(); // Deconstruct the trailing function type to extract the types of the base // pointer and result (same type) and the types of the indices. @@ -267,11 +267,11 @@ static bool parseGEPOp(OpAsmParser *parser, OperationState *result) { if (parser->resolveOperand(base, funcType.getInput(0), result->operands) || parser->resolveOperands(indices, funcType.getInputs().drop_front(), parser->getNameLoc(), result->operands)) - return true; + return failure(); result->attributes = attrs; result->addTypes(funcType.getResults()); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -302,7 +302,7 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type, } // ::= `llvm.load` ssa-use attribute-dict? `:` type -static bool parseLoadOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { SmallVector attrs; OpAsmParser::OperandType addr; Type type; @@ -312,13 +312,13 @@ static bool parseLoadOp(OpAsmParser *parser, OperationState *result) { parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type) || parser->resolveOperand(addr, type, result->operands)) - return true; + return failure(); Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); result->attributes = attrs; result->addTypes(elemTy); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -332,7 +332,7 @@ static void printStoreOp(OpAsmPrinter *p, StoreOp &op) { } // ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type -static bool parseStoreOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { SmallVector attrs; OpAsmParser::OperandType addr, value; Type type; @@ -342,18 +342,18 @@ static bool parseStoreOp(OpAsmParser *parser, OperationState *result) { parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return true; + return failure(); Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); if (!elemTy) - return true; + return failure(); if (parser->resolveOperand(value, elemTy, result->operands) || parser->resolveOperand(addr, type, result->operands)) - return true; + return failure(); result->attributes = attrs; - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -367,7 +367,7 @@ static void printBitcastOp(OpAsmPrinter *p, BitcastOp &op) { } // ::= `llvm.bitcast` ssa-use attribute-dict? `:` type `to` type -static bool parseBitcastOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseBitcastOp(OpAsmParser *parser, OperationState *result) { SmallVector attrs; OpAsmParser::OperandType arg; Type sourceType, type; @@ -376,11 +376,11 @@ static bool parseBitcastOp(OpAsmParser *parser, OperationState *result) { parser->parseColonType(sourceType) || parser->parseKeyword("to") || parser->parseType(type) || parser->resolveOperand(arg, sourceType, result->operands)) - return true; + return failure(); result->attributes = attrs; result->addTypes(type); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -438,7 +438,7 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) { // ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` // attribute-dict? `:` function-type -static bool parseCallOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { SmallVector attrs; SmallVector operands; Type type; @@ -450,19 +450,19 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) { // direct call, there will be no operands and the parser will stop at the // function identifier without complaining. if (parser->parseOperandList(operands)) - return true; + return failure(); bool isDirect = operands.empty(); // Optionally parse a function identifier. if (isDirect) if (parser->parseFunctionName(calleeName, calleeLoc)) - return true; + return failure(); if (parser->parseOperandList(operands, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren) || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) - return true; + return failure(); auto funcType = type.dyn_cast(); if (!funcType) @@ -471,14 +471,14 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) { // Add the direct callee as an Op attribute. Function *func; if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func)) - return true; + return failure(); auto funcAttr = parser->getBuilder().getFunctionAttr(func); attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr)); // Make sure types match. if (parser->resolveOperands(operands, funcType.getInputs(), parser->getNameLoc(), result->operands)) - return true; + return failure(); result->addTypes(funcType.getResults()); } else { // Construct the LLVM IR Dialect function type that the first operand @@ -528,13 +528,13 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) { result->operands) || parser->resolveOperands(funcArguments, funcType.getInputs(), parser->getNameLoc(), result->operands)) - return true; + return failure(); result->addTypes(wrappedResultType); } result->attributes = attrs; - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -601,7 +601,8 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, // ::= `llvm.extractvalue` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type -static bool parseExtractValueOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseExtractValueOp(OpAsmParser *parser, + OperationState *result) { SmallVector attrs; OpAsmParser::OperandType container; Type containerType; @@ -615,16 +616,16 @@ static bool parseExtractValueOp(OpAsmParser *parser, OperationState *result) { parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(containerType) || parser->resolveOperand(container, containerType, result->operands)) - return true; + return failure(); auto elementType = getInsertExtractValueElementType( parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); if (!elementType) - return true; + return failure(); result->attributes = attrs; result->addTypes(elementType); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -641,7 +642,8 @@ static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) { // ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use // `[` integer-literal (`,` integer-literal)* `]` // attribute-dict? `:` type -static bool parseInsertValueOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseInsertValueOp(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType container, value; Type containerType; Attribute positionAttr; @@ -654,19 +656,19 @@ static bool parseInsertValueOp(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(containerType)) - return true; + return failure(); auto valueType = getInsertExtractValueElementType( parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); if (!valueType) - return true; + return failure(); if (parser->resolveOperand(container, containerType, result->operands) || parser->resolveOperand(value, valueType, result->operands)) - return true; + return failure(); result->addTypes(containerType); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -682,7 +684,7 @@ static void printSelectOp(OpAsmPrinter *p, SelectOp &op) { // ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use // attribute-dict? `:` type, type -static bool parseSelectOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType condition, trueValue, falseValue; Type conditionType, argType; @@ -692,15 +694,15 @@ static bool parseSelectOp(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(conditionType) || parser->parseComma() || parser->parseType(argType)) - return true; + return failure(); if (parser->resolveOperand(condition, conditionType, result->operands) || parser->resolveOperand(trueValue, argType, result->operands) || parser->resolveOperand(falseValue, argType, result->operands)) - return true; + return failure(); result->addTypes(argType); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -715,15 +717,15 @@ static void printBrOp(OpAsmPrinter *p, BrOp &op) { // ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? // attribute-dict? -static bool parseBrOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) { Block *dest; SmallVector operands; if (parser->parseSuccessorAndUseList(dest, operands) || parser->parseOptionalAttributeDict(result->attributes)) - return true; + return failure(); result->addSuccessor(dest, operands); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -741,7 +743,7 @@ static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) { // ::= `llvm.cond_br` ssa-use `,` // bb-id (`[` ssa-use-and-type-list `]`)? `,` // bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? -static bool parseCondBrOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) { Block *trueDest; Block *falseDest; SmallVector trueOperands; @@ -760,11 +762,11 @@ static bool parseCondBrOp(OpAsmParser *parser, OperationState *result) { parser->parseSuccessorAndUseList(falseDest, falseOperands) || parser->parseOptionalAttributeDict(result->attributes) || parser->resolveOperand(condition, i1Type, result->operands)) - return true; + return failure(); result->addSuccessor(trueDest, trueOperands); result->addSuccessor(falseDest, falseOperands); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -784,20 +786,20 @@ static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) { // ::= `llvm.return` ssa-use-list attribute-dict? `:` // type-list-no-parens -static bool parseReturnOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { SmallVector operands; Type type; if (parser->parseOperandList(operands) || parser->parseOptionalAttributeDict(result->attributes)) - return true; + return failure(); if (operands.empty()) - return false; + return success(); if (parser->parseColonType(type) || parser->resolveOperand(operands[0], type, result->operands)) - return true; - return false; + return failure(); + return success(); } //===----------------------------------------------------------------------===// @@ -811,15 +813,15 @@ static void printUndefOp(OpAsmPrinter *p, UndefOp &op) { } // ::= `llvm.undef` attribute-dict? : type -static bool parseUndefOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) { Type type; if (parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type)) - return true; + return failure(); result->addTypes(type); - return false; + return success(); } //===----------------------------------------------------------------------===// @@ -845,7 +847,8 @@ static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) { } // ::= `llvm.constant` `(` attribute `)` attribute-list? : type -static bool parseConstantOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseConstantOp(OpAsmParser *parser, + OperationState *result) { Attribute valueAttr; Type type; @@ -854,10 +857,10 @@ static bool parseConstantOp(OpAsmParser *parser, OperationState *result) { parser->parseRParen() || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type)) - return true; + return failure(); result->addTypes(type); - return false; + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/LLVMIR/IR/NVVMDialect.cpp index c86bcf7776ff..f586f0e5c7ce 100644 --- a/mlir/lib/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/NVVMDialect.cpp @@ -53,15 +53,15 @@ static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) { } // ::= `llvm.nvvm.XYZ` : type -static bool parseNVVMSpecialRegisterOp(OpAsmParser *parser, - OperationState *result) { +static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser, + OperationState *result) { Type type; if (parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type)) - return true; + return failure(); result->addTypes(type); - return false; + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 958218d1a877..daa2cd31b438 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -58,18 +58,19 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *size() << " : " << getType(); } -bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType sizeInfo; BufferType bufferType; auto indexTy = parser->getBuilder().getIndexType(); if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType)) - return true; + return failure(); if (bufferType.getElementType() != parser->getBuilder().getF32Type()) return parser->emitError( parser->getNameLoc(), "Only buffer supported until mlir::Parser pieces are exposed"); - return parser->resolveOperands(sizeInfo, indexTy, result->operands) || - parser->addTypeToList(bufferType, result->types); + return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) || + parser->addTypeToList(bufferType, result->types)); } ////////////////////////////////////////////////////////////////////////////// @@ -95,11 +96,13 @@ void mlir::BufferDeallocOp::print(OpAsmPrinter *p) { *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType(); } -bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::BufferDeallocOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType sizeInfo; BufferType bufferType; - return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || - parser->resolveOperands(sizeInfo, bufferType, result->operands); + return failure( + parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || + parser->resolveOperands(sizeInfo, bufferType, result->operands)); } ////////////////////////////////////////////////////////////////////////////// // RangeOp @@ -131,15 +134,16 @@ void mlir::RangeOp::print(OpAsmPrinter *p) { << " : " << getType(); } -bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector rangeInfo(3); RangeType type; auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || - parser->parseOperand(rangeInfo[1]) || parser->parseColon() || - parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || - parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || - parser->addTypeToList(type, result->types); + return failure( + parser->parseOperand(rangeInfo[0]) || parser->parseColon() || + parser->parseOperand(rangeInfo[1]) || parser->parseColon() || + parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || + parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || + parser->addTypeToList(type, result->types)); } ////////////////////////////////////////////////////////////////////////////// @@ -189,7 +193,7 @@ LogicalResult mlir::SliceOp::verify() { return success(); } -bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType baseInfo; SmallVector indexingsInfo; SmallVector types; @@ -198,7 +202,7 @@ bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) - return true; + return failure(); if (types.size() != 2 + indexingsInfo.size()) return parser->emitError(parser->getNameLoc(), @@ -221,12 +225,13 @@ bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "expected " + Twine(baseViewType.getRank()) + " indexing types"); - return parser->resolveOperand(baseInfo, baseViewType, result->operands) || - (!indexingsInfo.empty() && - parser->resolveOperands(indexingsInfo, indexingTypes, - indexingsInfo.front().location, - result->operands)) || - parser->addTypeToList(viewType, result->types); + return failure( + parser->resolveOperand(baseInfo, baseViewType, result->operands) || + (!indexingsInfo.empty() && + parser->resolveOperands(indexingsInfo, indexingTypes, + indexingsInfo.front().location, + result->operands)) || + parser->addTypeToList(viewType, result->types)); } // A SliceOp prints as: @@ -306,7 +311,7 @@ LogicalResult mlir::ViewOp::verify() { return success(); } -bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType bufferInfo; SmallVector indexingsInfo; Type type; @@ -315,7 +320,7 @@ bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type)) - return true; + return failure(); ViewType viewType = type.dyn_cast(); if (!viewType) @@ -324,15 +329,15 @@ bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "expected" + Twine(viewType.getRank()) + " range indexings"); - return parser->resolveOperand( - bufferInfo, - BufferType::get(type.getContext(), viewType.getElementType()), - result->operands) || - (!indexingsInfo.empty() && - parser->resolveOperands(indexingsInfo, - RangeType::get(type.getContext()), - result->operands)) || - parser->addTypeToList(viewType, result->types); + return failure( + parser->resolveOperand( + bufferInfo, + BufferType::get(type.getContext(), viewType.getElementType()), + result->operands) || + (!indexingsInfo.empty() && + parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()), + result->operands)) || + parser->addTypeToList(viewType, result->types)); } // A ViewOp prints as: @@ -354,9 +359,9 @@ void mlir::ViewOp::print(OpAsmPrinter *p) { namespace mlir { namespace impl { void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op); -bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result); +ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result); void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op); -bool parseBufferSizeOp(OpAsmParser *parser, OperationState *result); +ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result); } // namespace impl /// Buffer size prints as: @@ -372,16 +377,16 @@ void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) { *p << " : " << op->getOperand(0)->getType(); } -bool mlir::impl::parseBufferSizeOp(OpAsmParser *parser, - OperationState *result) { +ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType op; Type type; - return parser->parseOperand(op) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(op, type, result->operands) || - parser->addTypeToList(parser->getBuilder().getIndexType(), - result->types); + return failure(parser->parseOperand(op) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(op, type, result->operands) || + parser->addTypeToList(parser->getBuilder().getIndexType(), + result->types)); } #define GET_OP_CLASSES @@ -415,15 +420,16 @@ void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) { [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); } -bool mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser, - OperationState *result) { +ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser, + OperationState *result) { SmallVector ops; SmallVector types; - return parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonTypeList(types) || - parser->resolveOperands(ops, types, parser->getNameLoc(), - result->operands); + return failure( + parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonTypeList(types) || + parser->resolveOperands(ops, types, parser->getNameLoc(), + result->operands)); } // Ideally this should all be Tablegen'd but there is no good story for diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 8a4c5ad05fd5..d637994b4c73 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -45,17 +45,6 @@ using llvm::MemoryBuffer; using llvm::SMLoc; using llvm::SourceMgr; -/// Simple wrapper class around LogicalResult that allows for explicit -/// conversion to bool. This allows for the parser to chain together parse rules -/// without the clutter of "failed/succeeded". -class ParseResult : public LogicalResult { -public: - ParseResult(LogicalResult result = success()) : LogicalResult(result) {} - - /// Failure is true in a boolean context. - explicit operator bool() const { return failed(*this); } -}; - namespace { class Parser; @@ -2266,8 +2255,8 @@ public: ParseResult parseFunctionBody(bool hadNamedArguments); /// Parse a single operation successor and it's operand list. - bool parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands); + ParseResult parseSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands); /// Parse a comma-separated list of operation successors in brackets. ParseResult @@ -2809,11 +2798,12 @@ Block *FunctionParser::defineBlockNamed(StringRef name, SMLoc loc, /// successor ::= block-id branch-use-list? /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// -bool FunctionParser::parseSuccessorAndUseList( - Block *&dest, SmallVectorImpl &operands) { +ParseResult +FunctionParser::parseSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands) { // Verify branch is identifier and get the matching block. if (!getToken().is(Token::caret_identifier)) - return emitError("expected block name"), true; + return emitError("expected block name"); dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); consumeToken(); @@ -2821,10 +2811,10 @@ bool FunctionParser::parseSuccessorAndUseList( if (consumeIf(Token::l_paren) && (parseOptionalSSAUseAndTypeList(operands) || parseToken(Token::r_paren, "expected ')' to close argument list"))) { - return true; + return failure(); } - return false; + return success(); } /// Parse a comma-separated list of operation successors in brackets. @@ -2840,10 +2830,10 @@ ParseResult FunctionParser::parseSuccessors( auto parseElt = [this, &destinations, &operands]() { Block *dest; SmallVector destOperands; - bool r = parseSuccessorAndUseList(dest, destOperands); + auto res = parseSuccessorAndUseList(dest, destOperands); destinations.push_back(dest); operands.push_back(destOperands); - return r ? failure() : success(); + return res; }; return parseCommaSeparatedListUntil(Token::r_square, parseElt, /*allowEmptyList=*/false); @@ -3105,10 +3095,10 @@ public: CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser) : nameLoc(nameLoc), opName(opName), parser(parser) {} - bool parseOperation(const AbstractOperation *opDefinition, - OperationState *opState) { + ParseResult parseOperation(const AbstractOperation *opDefinition, + OperationState *opState) { if (opDefinition->parseAssembly(this, opState)) - return true; + return failure(); // Check that none of the operands of the current operation reference an // entry block argument for any of the region. @@ -3116,53 +3106,53 @@ public: if (llvm::is_contained(opState->operands, entryArg)) return emitError(nameLoc, "operand use before it's defined"); - return false; + return success(); } //===--------------------------------------------------------------------===// // High level parsing methods. //===--------------------------------------------------------------------===// - bool getCurrentLocation(llvm::SMLoc *loc) override { + ParseResult getCurrentLocation(llvm::SMLoc *loc) override { *loc = parser.getToken().getLoc(); - return false; + return success(); } - bool parseComma() override { - return failed(parser.parseToken(Token::comma, "expected ','")); + ParseResult parseComma() override { + return parser.parseToken(Token::comma, "expected ','"); } - bool parseColon() override { - return failed(parser.parseToken(Token::colon, "expected ':'")); + ParseResult parseColon() override { + return parser.parseToken(Token::colon, "expected ':'"); } - bool parseEqual() override { - return failed(parser.parseToken(Token::equal, "expected '='")); + ParseResult parseEqual() override { + return parser.parseToken(Token::equal, "expected '='"); } - bool parseType(Type &result) override { - return !(result = parser.parseType()); + ParseResult parseType(Type &result) override { + return failure(!(result = parser.parseType())); } - bool parseColonType(Type &result) override { - return parser.parseToken(Token::colon, "expected ':'") || - !(result = parser.parseType()); + ParseResult parseColonType(Type &result) override { + return failure(parser.parseToken(Token::colon, "expected ':'") || + !(result = parser.parseType())); } - bool parseColonTypeList(SmallVectorImpl &result) override { + ParseResult parseColonTypeList(SmallVectorImpl &result) override { if (parser.parseToken(Token::colon, "expected ':'")) - return true; + return failure(); do { if (auto type = parser.parseType()) result.push_back(type); else - return true; + return failure(); } while (parser.consumeIf(Token::comma)); - return false; + return success(); } - bool parseTrailingOperandList(SmallVectorImpl &result, - int requiredOperandCount, - Delimiter delimiter) override { + ParseResult parseTrailingOperandList(SmallVectorImpl &result, + int requiredOperandCount, + Delimiter delimiter) override { if (parser.getToken().is(Token::comma)) { parseComma(); return parseOperandList(result, requiredOperandCount, delimiter); @@ -3170,101 +3160,105 @@ public: if (requiredOperandCount != -1) return emitError(parser.getToken().getLoc(), "expected " + Twine(requiredOperandCount) + " operands"); - return false; + return success(); } - bool parseOptionalComma() override { return !parser.consumeIf(Token::comma); } + ParseResult parseOptionalComma() override { + return success(parser.consumeIf(Token::comma)); + } /// Parse an optional keyword. - bool parseOptionalKeyword(const char *keyword) override { + ParseResult parseOptionalKeyword(const char *keyword) override { // Check that the current token is a bare identifier or keyword. if (parser.getToken().isNot(Token::bare_identifier) && !parser.getToken().isKeyword()) - return true; + return failure(); if (parser.getTokenSpelling() == keyword) { parser.consumeToken(); - return false; + return success(); } - return true; + return failure(); } /// Parse an arbitrary attribute of a given type and return it in result. This /// also adds the attribute to the specified attribute list with the specified /// name. - bool parseAttribute(Attribute &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) override { + ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, + SmallVectorImpl &attrs) override { result = parser.parseAttribute(type); if (!result) - return true; + return failure(); attrs.push_back(parser.builder.getNamedAttr(attrName, result)); - return false; + return success(); } /// Parse an arbitrary attribute and return it in result. This also adds /// the attribute to the specified attribute list with the specified name. - bool parseAttribute(Attribute &result, StringRef attrName, - SmallVectorImpl &attrs) override { + ParseResult parseAttribute(Attribute &result, StringRef attrName, + SmallVectorImpl &attrs) override { return parseAttribute(result, Type(), attrName, attrs); } /// If a named attribute list is present, parse is into result. - bool + ParseResult parseOptionalAttributeDict(SmallVectorImpl &result) override { if (parser.getToken().isNot(Token::l_brace)) - return false; - return failed(parser.parseAttributeDict(result)); + return success(); + return parser.parseAttributeDict(result); } /// Parse a function name like '@foo' and return the name in a form that can /// be passed to resolveFunctionName when a function type is available. - virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) { + virtual ParseResult parseFunctionName(StringRef &result, llvm::SMLoc &loc) { if (parseOptionalFunctionName(result, loc)) return emitError(loc, "expected function name"); - return false; + return success(); } /// Parse a function name like '@foo` if present and return the name without /// the sigil in `result`. Return true if the next token is not a function /// name and keep `result` unchanged. - bool parseOptionalFunctionName(StringRef &result, llvm::SMLoc &loc) override { + ParseResult parseOptionalFunctionName(StringRef &result, + llvm::SMLoc &loc) override { loc = parser.getToken().getLoc(); if (parser.getToken().isNot(Token::at_identifier)) - return true; + return failure(); result = parser.getTokenSpelling(); parser.consumeToken(Token::at_identifier); - return false; + return success(); } - bool parseOperand(OperandType &result) override { + ParseResult parseOperand(OperandType &result) override { FunctionParser::SSAUseInfo useInfo; if (parser.parseSSAUse(useInfo)) - return true; + return failure(); result = {useInfo.loc, useInfo.name, useInfo.number}; - return false; + return success(); } - bool parseSuccessorAndUseList(Block *&dest, - SmallVectorImpl &operands) override { + ParseResult + parseSuccessorAndUseList(Block *&dest, + SmallVectorImpl &operands) override { // Defer successor parsing to the function parsers. return parser.parseSuccessorAndUseList(dest, operands); } - bool parseLParen() override { - return failed(parser.parseToken(Token::l_paren, "expected '('")); + ParseResult parseLParen() override { + return parser.parseToken(Token::l_paren, "expected '('"); } - bool parseRParen() override { - return failed(parser.parseToken(Token::r_paren, "expected ')'")); + ParseResult parseRParen() override { + return parser.parseToken(Token::r_paren, "expected ')'"); } - bool parseOperandList(SmallVectorImpl &result, - int requiredOperandCount = -1, - Delimiter delimiter = Delimiter::None) override { + ParseResult parseOperandList(SmallVectorImpl &result, + int requiredOperandCount = -1, + Delimiter delimiter = Delimiter::None) override { auto startLoc = parser.getToken().getLoc(); // Handle delimiters. @@ -3284,19 +3278,19 @@ public: return emitError(startLoc, "invalid operand"); case Delimiter::OptionalParen: if (parser.getToken().isNot(Token::l_paren)) - return false; + return success(); LLVM_FALLTHROUGH; case Delimiter::Paren: if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) - return true; + return failure(); break; case Delimiter::OptionalSquare: if (parser.getToken().isNot(Token::l_square)) - return false; + return success(); LLVM_FALLTHROUGH; case Delimiter::Square: if (parser.parseToken(Token::l_square, "expected '[' in operand list")) - return true; + return failure(); break; } @@ -3305,7 +3299,7 @@ public: do { OperandType operand; if (parseOperand(operand)) - return true; + return failure(); result.push_back(operand); } while (parser.consumeIf(Token::comma)); } @@ -3318,32 +3312,32 @@ public: case Delimiter::OptionalParen: case Delimiter::Paren: if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) - return true; + return failure(); break; case Delimiter::OptionalSquare: case Delimiter::Square: if (parser.parseToken(Token::r_square, "expected ']' in operand list")) - return true; + return failure(); break; } if (requiredOperandCount != -1 && result.size() != requiredOperandCount) return emitError(startLoc, "expected " + Twine(requiredOperandCount) + " operands"); - return false; + return success(); } /// Resolve a parse function name and a type into a function reference. - virtual bool resolveFunctionName(StringRef name, FunctionType type, - llvm::SMLoc loc, Function *&result) { + virtual ParseResult resolveFunctionName(StringRef name, FunctionType type, + llvm::SMLoc loc, Function *&result) { result = parser.resolveFunctionReference(name, loc, type); - return result == nullptr; + return failure(result == nullptr); } /// Parse a region that takes `arguments` of `argTypes` types. This /// effectively defines the SSA values of `arguments` and assignes their type. - bool parseRegion(Region ®ion, ArrayRef arguments, - ArrayRef argTypes) override { + ParseResult parseRegion(Region ®ion, ArrayRef arguments, + ArrayRef argTypes) override { assert(arguments.size() == argTypes.size() && "mismatching number of arguments and types"); @@ -3359,26 +3353,26 @@ public: // references to region arguments. Value *value = parser.resolveSSAUse(operandInfo, type); if (!value) - return true; + return failure(); parsedRegionEntryArgumentPlaceholders.emplace_back(value); } - return failed(parser.parseOperationRegion(region, regionArguments)); + return parser.parseOperationRegion(region, regionArguments); } /// Parse a region argument. Region arguments define new values, so this also /// checks if the values with the same name has not been defined yet. The /// type of the argument will be resolved later by a call to `parseRegion`. - bool parseRegionArgument(OperandType &argument) { + ParseResult parseRegionArgument(OperandType &argument) { // Use parseOperand to fill in the OperandType structure. if (parseOperand(argument)) - return true; + return failure(); if (auto defLoc = parser.getDefinitionLoc(argument.name, argument.number)) { parser.emitError(argument.location, "redefinition of SSA value '" + argument.name + "'"); - return parser.emitError(*defLoc, "previously defined here"), true; + return parser.emitError(*defLoc, "previously defined here"); } - return false; + return success(); } //===--------------------------------------------------------------------===// @@ -3389,22 +3383,22 @@ public: llvm::SMLoc getNameLoc() const override { return nameLoc; } - bool resolveOperand(const OperandType &operand, Type type, - SmallVectorImpl &result) override { + ParseResult resolveOperand(const OperandType &operand, Type type, + SmallVectorImpl &result) override { FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, operand.location}; if (auto *value = parser.resolveSSAUse(operandInfo, type)) { result.push_back(value); - return false; + return success(); } - return true; + return failure(); } - /// Emit a diagnostic at the specified location and return true. - bool emitError(llvm::SMLoc loc, const Twine &message) override { - parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); + /// Emit a diagnostic at the specified location and return failure. + ParseResult emitError(llvm::SMLoc loc, const Twine &message) override { emittedError = true; - return true; + return parser.emitError(loc, + "custom op '" + Twine(opName) + "' " + message); } bool didEmitError() const { return emittedError; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 969875ff755a..9de36a769c9e 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -87,12 +87,12 @@ void mlir::printDimAndSymbolList(Operation::operand_iterator begin, // Parses dimension and symbol list, and sets 'numDims' to the number of // dimension operands parsed. // Returns 'false' on success and 'true' on error. -bool mlir::parseDimAndSymbolList(OpAsmParser *parser, - SmallVector &operands, - unsigned &numDims) { +ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser, + SmallVector &operands, + unsigned &numDims) { SmallVector opInfos; if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) - return true; + return failure(); // Store number of dimensions for validation by caller. numDims = opInfos.size(); @@ -101,8 +101,8 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser, if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::OptionalSquare) || parser->resolveOperands(opInfos, affineIntTy, operands)) - return true; - return false; + return failure(); + return success(); } /// Matches a ConstantIndexOp. @@ -223,7 +223,7 @@ void AllocOp::print(OpAsmPrinter *p) { *p << " : " << type; } -bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult AllocOp::parse(OpAsmParser *parser, OperationState *result) { MemRefType type; // Parse the dimension operands and optional symbol operands, followed by a @@ -232,7 +232,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type)) - return true; + return failure(); // Check numDynamicDims against number of question marks in memref type. // Note: this check remains here (instead of in verify()), because the @@ -246,7 +246,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { "dynamic dimension count"); } result->types.push_back(type); - return false; + return success(); } LogicalResult AllocOp::verify() { @@ -385,13 +385,13 @@ void BranchOp::build(Builder *builder, OperationState *result, Block *dest, result->addSuccessor(dest, operands); } -bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) { Block *dest; SmallVector destOperands; if (parser->parseSuccessorAndUseList(dest, destOperands)) - return true; + return failure(); result->addSuccessor(dest, destOperands); - return false; + return success(); } void BranchOp::print(OpAsmPrinter *p) { @@ -420,7 +420,7 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee, result->addTypes(callee->getType().getResults()); } -bool CallOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) { StringRef calleeName; llvm::SMLoc calleeLoc; FunctionType calleeType; @@ -435,10 +435,10 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) { parser->addTypesToList(calleeType.getResults(), result->types) || parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, result->operands)) - return true; + return failure(); result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee)); - return false; + return success(); } void CallOp::print(OpAsmPrinter *p) { @@ -517,21 +517,22 @@ void CallIndirectOp::build(Builder *builder, OperationState *result, result->addTypes(fnType.getResults()); } -bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { FunctionType calleeType; OpAsmParser::OperandType callee; llvm::SMLoc operandsLoc; SmallVector operands; - return parser->parseOperand(callee) || - parser->getCurrentLocation(&operandsLoc) || - parser->parseOperandList(operands, /*requiredOperandCount=*/-1, - OpAsmParser::Delimiter::Paren) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(calleeType) || - parser->resolveOperand(callee, calleeType, result->operands) || - parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result->operands) || - parser->addTypesToList(calleeType.getResults(), result->types); + return failure( + parser->parseOperand(callee) || + parser->getCurrentLocation(&operandsLoc) || + parser->parseOperandList(operands, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(calleeType) || + parser->resolveOperand(callee, calleeType, result->operands) || + parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result->operands) || + parser->addTypesToList(calleeType.getResults(), result->types)); } void CallIndirectOp::print(OpAsmPrinter *p) { @@ -678,7 +679,7 @@ void CmpIOp::build(Builder *build, OperationState *result, build->getI64IntegerAttr(static_cast(predicate))); } -bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector ops; SmallVector attrs; Attribute predicateNameAttr; @@ -689,7 +690,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type) || parser->resolveOperands(ops, type, result->operands)) - return true; + return failure(); if (!predicateNameAttr.isa()) return parser->emitError(parser->getNameLoc(), @@ -713,7 +714,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { result->attributes = attrs; result->addTypes({i1Type}); - return false; + return success(); } void CmpIOp::print(OpAsmPrinter *p) { @@ -856,7 +857,7 @@ void CmpFOp::build(Builder *build, OperationState *result, build->getI64IntegerAttr(static_cast(predicate))); } -bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector ops; SmallVector attrs; Attribute predicateNameAttr; @@ -867,7 +868,7 @@ bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(attrs) || parser->parseColonType(type) || parser->resolveOperands(ops, type, result->operands)) - return true; + return failure(); if (!predicateNameAttr.isa()) return parser->emitError(parser->getNameLoc(), @@ -891,7 +892,7 @@ bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) { result->attributes = attrs; result->addTypes({i1Type}); - return false; + return success(); } void CmpFOp::print(OpAsmPrinter *p) { @@ -1044,7 +1045,7 @@ void CondBranchOp::build(Builder *builder, OperationState *result, result->addSuccessor(falseDest, falseOperands); } -bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector destOperands; Block *dest; OpAsmParser::OperandType condInfo; @@ -1059,18 +1060,17 @@ bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { // Parse the true successor. if (parser->parseSuccessorAndUseList(dest, destOperands)) - return true; + return failure(); result->addSuccessor(dest, destOperands); // Parse the false successor. destOperands.clear(); if (parser->parseComma() || parser->parseSuccessorAndUseList(dest, destOperands)) - return true; + return failure(); result->addSuccessor(dest, destOperands); - // Return false on success. - return false; + return success(); } void CondBranchOp::print(OpAsmPrinter *p) { @@ -1132,13 +1132,14 @@ static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) { *p << " : " << op.getType(); } -static bool parseConstantOp(OpAsmParser *parser, OperationState *result) { +static ParseResult parseConstantOp(OpAsmParser *parser, + OperationState *result) { Attribute valueAttr; Type type; if (parser->parseOptionalAttributeDict(result->attributes) || parser->parseAttribute(valueAttr, "value", result->attributes)) - return true; + return failure(); // 'constant' taking a function reference doesn't get a redundant type // specifier. The attribute itself carries it. @@ -1150,7 +1151,7 @@ static bool parseConstantOp(OpAsmParser *parser, OperationState *result) { } else if (auto fpAttr = valueAttr.dyn_cast()) { type = fpAttr.getType(); } else if (parser->parseColonType(type)) { - return true; + return failure(); } return parser->addTypeToList(type, result->types); } @@ -1298,12 +1299,13 @@ void DeallocOp::print(OpAsmPrinter *p) { *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); } -bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult DeallocOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; MemRefType type; - return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands); + return failure(parser->parseOperand(memrefInfo) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands)); } LogicalResult DeallocOp::verify() { @@ -1338,19 +1340,19 @@ void DimOp::print(OpAsmPrinter *p) { *p << " : " << getOperand()->getType(); } -bool DimOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType operandInfo; IntegerAttr indexAttr; Type type; Type indexType = parser->getBuilder().getIndexType(); - return parser->parseOperand(operandInfo) || parser->parseComma() || - parser->parseAttribute(indexAttr, indexType, "index", - result->attributes) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(operandInfo, type, result->operands) || - parser->addTypeToList(indexType, result->types); + return failure(parser->parseOperand(operandInfo) || parser->parseComma() || + parser->parseAttribute(indexAttr, indexType, "index", + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(operandInfo, type, result->operands) || + parser->addTypeToList(indexType, result->types)); } LogicalResult DimOp::verify() { @@ -1491,7 +1493,7 @@ void DmaStartOp::print(OpAsmPrinter *p) { // memref<1024 x f32, 2>, // memref<1 x i32> // -bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType srcMemRefInfo; SmallVector srcIndexInfos; OpAsmParser::OperandType dstMemRefInfo; @@ -1518,11 +1520,11 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseComma() || parser->parseOperand(tagMemrefInfo) || parser->parseOperandList(tagIndexInfos, -1, OpAsmParser::Delimiter::Square)) - return true; + return failure(); // Parse optional stride and elements per stride. if (parser->parseTrailingOperandList(strideInfo)) { - return true; + return failure(); } if (!strideInfo.empty() && strideInfo.size() != 2) { return parser->emitError(parser->getNameLoc(), @@ -1531,7 +1533,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { bool isStrided = strideInfo.size() == 2; if (parser->parseColonTypeList(types)) - return true; + return failure(); if (types.size() != 3) return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); @@ -1545,7 +1547,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || // tag indices should be index. parser->resolveOperands(tagIndexInfos, indexType, result->operands)) - return true; + return failure(); if (!types[0].isa()) return parser->emitError(parser->getNameLoc(), @@ -1562,7 +1564,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { if (isStrided) { if (parser->resolveOperand(strideInfo[0], indexType, result->operands) || parser->resolveOperand(strideInfo[1], indexType, result->operands)) - return true; + return failure(); } // Check that source/destination index list size matches associated rank. @@ -1575,7 +1577,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); - return false; + return success(); } LogicalResult DmaStartOp::verify() { @@ -1628,7 +1630,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) { // Eg: // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> // -bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector tagIndexInfos; Type type; @@ -1644,7 +1646,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { parser->resolveOperand(tagMemrefInfo, type, result->operands) || parser->resolveOperands(tagIndexInfos, indexType, result->operands) || parser->resolveOperand(numElementsInfo, indexType, result->operands)) - return true; + return failure(); if (!type.isa()) return parser->emitError(parser->getNameLoc(), @@ -1654,7 +1656,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc(), "tag memref rank not equal to indices count"); - return false; + return success(); } void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, @@ -1684,20 +1686,21 @@ void ExtractElementOp::print(OpAsmPrinter *p) { *p << " : " << getAggregate()->getType(); } -bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult ExtractElementOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType aggregateInfo; SmallVector indexInfo; VectorOrTensorType type; auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(aggregateInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(aggregateInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types); + return failure( + parser->parseOperand(aggregateInfo) || + parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(aggregateInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); } LogicalResult ExtractElementOp::verify() { @@ -1771,20 +1774,20 @@ void LoadOp::print(OpAsmPrinter *p) { *p << " : " << getMemRefType(); } -bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; MemRefType type; auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(type) || - parser->resolveOperand(memrefInfo, type, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands) || - parser->addTypeToList(type.getElementType(), result->types); + return failure( + parser->parseOperand(memrefInfo) || + parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); } LogicalResult LoadOp::verify() { @@ -1963,13 +1966,14 @@ void ReturnOp::build(Builder *builder, OperationState *result, result->addOperands(results); } -bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult ReturnOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector opInfo; SmallVector types; llvm::SMLoc loc; - return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || - (!opInfo.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(opInfo, types, loc, result->operands); + return failure(parser->getCurrentLocation(&loc) || + parser->parseOperandList(opInfo) || + (!opInfo.empty() && parser->parseColonTypeList(types)) || + parser->resolveOperands(opInfo, types, loc, result->operands)); } void ReturnOp::print(OpAsmPrinter *p) { @@ -2012,7 +2016,7 @@ void SelectOp::build(Builder *builder, OperationState *result, Value *condition, result->addTypes(trueValue->getType()); } -bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector ops; SmallVector attrs; Type type; @@ -2020,7 +2024,7 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { if (parser->parseOperandList(ops, 3) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonType(type)) - return true; + return failure(); auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); if (!i1Type) @@ -2028,9 +2032,9 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { "expected type with valid i1 shape"); SmallVector types = {i1Type, type, type}; - return parser->resolveOperands(ops, types, parser->getNameLoc(), - result->operands) || - parser->addTypeToList(type, result->types); + return failure(parser->resolveOperands(ops, types, parser->getNameLoc(), + result->operands) || + parser->addTypeToList(type, result->types)); } void SelectOp::print(OpAsmPrinter *p) { @@ -2090,23 +2094,23 @@ void StoreOp::print(OpAsmPrinter *p) { *p << " : " << getMemRefType(); } -bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; MemRefType memrefType; auto affineIntTy = parser->getBuilder().getIndexType(); - return parser->parseOperand(storeValueInfo) || parser->parseComma() || - parser->parseOperand(memrefInfo) || - parser->parseOperandList(indexInfo, -1, - OpAsmParser::Delimiter::Square) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(memrefType) || - parser->resolveOperand(storeValueInfo, memrefType.getElementType(), - result->operands) || - parser->resolveOperand(memrefInfo, memrefType, result->operands) || - parser->resolveOperands(indexInfo, affineIntTy, result->operands); + return failure( + parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(memrefInfo) || + parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(memrefType) || + parser->resolveOperand(storeValueInfo, memrefType.getElementType(), + result->operands) || + parser->resolveOperand(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, affineIntTy, result->operands)); } LogicalResult StoreOp::verify() { diff --git a/mlir/lib/VectorOps/VectorOps.cpp b/mlir/lib/VectorOps/VectorOps.cpp index 6f416930b1f4..05af0293989d 100644 --- a/mlir/lib/VectorOps/VectorOps.cpp +++ b/mlir/lib/VectorOps/VectorOps.cpp @@ -120,7 +120,8 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) { *p << ", " << getResultType(); } -bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult VectorTransferReadOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; SmallVector paddingInfo; @@ -133,7 +134,7 @@ bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) { OpAsmParser::Delimiter::Paren) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) - return true; + return failure(); // Resolution. if (types.size() != 2) @@ -160,12 +161,12 @@ bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) { paddingType = vectorType.getElementType(); } auto indexType = parser->getBuilder().getIndexType(); - return parser->resolveOperand(memrefInfo, memrefType, result->operands) || - parser->resolveOperands(indexInfo, indexType, result->operands) || - (hasOptionalPaddingValue && - parser->resolveOperand(paddingInfo[0], paddingType, - result->operands)) || - parser->addTypeToList(vectorType, result->types); + return failure( + parser->resolveOperand(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, indexType, result->operands) || + (hasOptionalPaddingValue && + parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) || + parser->addTypeToList(vectorType, result->types)); } LogicalResult VectorTransferReadOp::verify() { @@ -286,7 +287,8 @@ void VectorTransferWriteOp::print(OpAsmPrinter *p) { p->printType(getMemRefType()); } -bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; @@ -297,7 +299,7 @@ bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) { parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || parser->parseOptionalAttributeDict(result->attributes) || parser->parseColonTypeList(types)) - return true; + return failure(); if (types.size() != 2) return parser->emitError(parser->getNameLoc(), "expected 2 types"); @@ -308,10 +310,10 @@ bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) { if (!memrefType) return parser->emitError(parser->getNameLoc(), "memRef type expected"); - return parser->resolveOperands(storeValueInfo, vectorType, - result->operands) || - parser->resolveOperands(memrefInfo, memrefType, result->operands) || - parser->resolveOperands(indexInfo, indexType, result->operands); + return failure( + parser->resolveOperands(storeValueInfo, vectorType, result->operands) || + parser->resolveOperands(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, indexType, result->operands)); } LogicalResult VectorTransferWriteOp::verify() { @@ -390,15 +392,16 @@ void VectorTypeCastOp::build(Builder *builder, OperationState *result, result->addTypes(dstType); } -bool VectorTypeCastOp::parse(OpAsmParser *parser, OperationState *result) { +ParseResult VectorTypeCastOp::parse(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType operand; Type srcType, dstType; - return parser->parseOperand(operand) || - parser->parseOptionalAttributeDict(result->attributes) || - parser->parseColonType(srcType) || parser->parseComma() || - parser->parseType(dstType) || - parser->addTypeToList(dstType, result->types) || - parser->resolveOperand(operand, srcType, result->operands); + return failure(parser->parseOperand(operand) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(srcType) || parser->parseComma() || + parser->parseType(dstType) || + parser->addTypeToList(dstType, result->types) || + parser->resolveOperand(operand, srcType, result->operands)); } void VectorTypeCastOp::print(OpAsmPrinter *p) { diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 88cbf0fc4eee..528caf6d745e 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -51,7 +51,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> { // CHECK: static void build(Value *val); // CHECK: static void build(Builder *, OperationState *tblgen_state, Type r, ArrayRef s, Value *a, ArrayRef b, IntegerAttr attr1, /*optional*/FloatAttr attr2); // CHECK: static void build(Builder *, OperationState *tblgen_state, ArrayRef resultTypes, ArrayRef operands, ArrayRef attributes); -// CHECK: static bool parse(OpAsmParser *parser, OperationState *result); +// CHECK: static ParseResult parse(OpAsmParser *parser, OperationState *result); // CHECK: void print(OpAsmPrinter *p); // CHECK: LogicalResult verify(); // CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 3944145b0e0c..36464f7f0429 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -863,7 +863,7 @@ void OpEmitter::genParser() { return; auto &method = opClass.newMethod( - "bool", "parse", "OpAsmParser *parser, OperationState *result", + "ParseResult", "parse", "OpAsmParser *parser, OperationState *result", OpMethod::MP_Static); auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); method.body() << " " << parser;