NFC: Make ParseResult public and update the OpAsmParser(and thus all of the custom operation parsers) to use it instead of bool.

--

PiperOrigin-RevId: 246955523
This commit is contained in:
River Riddle 2019-05-06 22:01:31 -07:00 committed by Mehdi Amini
parent 6ccf90147c
commit b7dc252683
30 changed files with 591 additions and 543 deletions

View File

@ -41,7 +41,8 @@ public:
static void build(mlir::Builder *b, mlir::OperationState *result, static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *min, mlir::Value *max, mlir::Value *step); mlir::Value *min, mlir::Value *max, mlir::Value *step);
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -40,7 +40,8 @@ public:
static void build(mlir::Builder *b, mlir::OperationState *result, static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *view, mlir::Value *indexing, unsigned dim); mlir::Value *view, mlir::Value *indexing, unsigned dim);
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -43,7 +43,8 @@ public:
mlir::Value *memRef, mlir::Value *memRef,
llvm::ArrayRef<mlir::Value *> indexings); llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -48,15 +48,17 @@ mlir::LogicalResult linalg::RangeOp::verify() {
return mlir::success(); return mlir::success();
} }
bool linalg::RangeOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult linalg::RangeOp::parse(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3); SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type; RangeType type;
auto indexTy = parser->getBuilder().getIndexType(); auto indexTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || return failure(
parser->parseOperand(rangeInfo[1]) || parser->parseColon() || parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->resolveOperands(rangeInfo, indexTy, result->operands) || parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->addTypeToList(type, result->types); parser->resolveOperands(rangeInfo, indexTy, result->operands) ||
parser->addTypeToList(type, result->types));
} }
// A RangeOp prints as: // A RangeOp prints as:

View File

@ -74,7 +74,8 @@ mlir::LogicalResult linalg::SliceOp::verify() {
return mlir::success(); return mlir::success();
} }
bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult linalg::SliceOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType viewInfo; OpAsmParser::OperandType viewInfo;
SmallVector<OpAsmParser::OperandType, 1> indexingInfo; SmallVector<OpAsmParser::OperandType, 1> indexingInfo;
SmallVector<Type, 8> types; SmallVector<Type, 8> types;
@ -83,7 +84,7 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) || OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types)) parser->parseColonTypeList(types))
return true; return failure();
if (indexingInfo.size() != 1) if (indexingInfo.size() != 1)
return parser->emitError(parser->getNameLoc(), "expected 1 indexing type"); return parser->emitError(parser->getNameLoc(), "expected 1 indexing type");
@ -107,10 +108,10 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
ViewType resultViewType = ViewType resultViewType =
ViewType::get(viewType.getContext(), viewType.getElementType(), rank); ViewType::get(viewType.getContext(), viewType.getElementType(), rank);
return parser->resolveOperand(viewInfo, viewType, result->operands) || return failure(parser->resolveOperand(viewInfo, viewType, result->operands) ||
parser->resolveOperands(indexingInfo[0], types.back(), parser->resolveOperands(indexingInfo[0], types.back(),
result->operands) || result->operands) ||
parser->addTypeToList(resultViewType, result->types); parser->addTypeToList(resultViewType, result->types));
} }
// A SliceOp prints as: // A SliceOp prints as:

View File

@ -89,7 +89,7 @@ LogicalResult linalg::ViewOp::verify() {
return success(); return success();
} }
bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memRefInfo; OpAsmParser::OperandType memRefInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo; SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types; SmallVector<Type, 8> types;
@ -98,7 +98,7 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) || OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types)) parser->parseColonTypeList(types))
return true; return failure();
if (types.size() != 2 + indexingsInfo.size()) if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -120,12 +120,13 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memRefType.getRank()) + "expected " + Twine(memRefType.getRank()) +
" indexing types"); " indexing types");
return parser->resolveOperand(memRefInfo, memRefType, result->operands) || return failure(
(!indexingsInfo.empty() && parser->resolveOperand(memRefInfo, memRefType, result->operands) ||
parser->resolveOperands(indexingsInfo, indexingTypes, (!indexingsInfo.empty() &&
indexingsInfo.front().location, parser->resolveOperands(indexingsInfo, indexingTypes,
result->operands)) || indexingsInfo.front().location,
parser->addTypeToList(viewType, result->types); result->operands)) ||
parser->addTypeToList(viewType, result->types));
} }
// A ViewOp prints as: // A ViewOp prints as:

View File

@ -83,8 +83,9 @@ mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
} }
template <class ConcreteOp> template <class ConcreteOp>
bool linalg::TensorContractionBase<ConcreteOp>::parse( mlir::ParseResult
mlir::OpAsmParser *parser, mlir::OperationState *result) { linalg::TensorContractionBase<ConcreteOp>::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
} }

View File

@ -38,7 +38,8 @@ protected:
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// Generic implementation of hooks that should be called from `ConcreteType`s /// Generic implementation of hooks that should be called from `ConcreteType`s
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
public: public:
@ -118,7 +119,8 @@ public:
return build(b, result, {A, B, C}); return build(b, result, {A, B, C});
} }
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -179,7 +181,8 @@ public:
return build(b, result, {A, B, C}); return build(b, result, {A, B, C});
} }
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -240,7 +243,8 @@ public:
return build(b, result, {A, B, C}); return build(b, result, {A, B, C});
} }
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -58,8 +58,8 @@ LogicalResult linalg::DotOp::verify() {
} }
// Parsing of the linalg dialect is not supported in this tutorial. // Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::DotOp::parse(mlir::OpAsmParser *parser, ParseResult linalg::DotOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) { mlir::OperationState *result) {
return TensorContractionBaseType::parse(parser, result); return TensorContractionBaseType::parse(parser, result);
} }
@ -92,8 +92,8 @@ LogicalResult linalg::MatvecOp::verify() {
} }
// Parsing of the linalg dialect is not supported in this tutorial. // Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser, ParseResult linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) { mlir::OperationState *result) {
return TensorContractionBaseType::parse(parser, result); return TensorContractionBaseType::parse(parser, result);
} }
@ -123,8 +123,8 @@ LogicalResult linalg::MatmulOp::verify() {
} }
// Parsing of the linalg dialect is not supported in this tutorial. // Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser, ParseResult linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) { mlir::OperationState *result) {
return TensorContractionBaseType::parse(parser, result); return TensorContractionBaseType::parse(parser, result);
} }

View File

@ -41,7 +41,8 @@ public:
mlir::Value *view, mlir::Value *view,
mlir::ArrayRef<mlir::Value *> indices = {}); mlir::ArrayRef<mlir::Value *> indices = {});
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -71,7 +72,8 @@ public:
mlir::Value *valueToStore, mlir::Value *view, mlir::Value *valueToStore, mlir::Value *view,
mlir::ArrayRef<mlir::Value *> indices = {}); mlir::ArrayRef<mlir::Value *> indices = {});
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -49,9 +49,9 @@ void linalg::LoadOp::print(OpAsmPrinter *p) {
*p << " : " << getViewType(); *p << " : " << getViewType();
} }
bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial"); llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
return false; return success();
} }
LogicalResult linalg::LoadOp::verify() { LogicalResult linalg::LoadOp::verify() {
@ -101,9 +101,10 @@ void linalg::StoreOp::print(OpAsmPrinter *p) {
*p << " : " << getViewType(); *p << " : " << getViewType();
} }
bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult linalg::StoreOp::parse(OpAsmParser *parser,
OperationState *result) {
assert(false && "NYI"); assert(false && "NYI");
return false; return success();
} }
LogicalResult linalg::StoreOp::verify() { LogicalResult linalg::StoreOp::verify() {

View File

@ -80,7 +80,7 @@ public:
static StringRef getOperationName() { return "affine.apply"; } static StringRef getOperationName() { return "affine.apply"; }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context); Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
@ -130,7 +130,7 @@ public:
static void build(Builder *builder, OperationState *result, int64_t lb, static void build(Builder *builder, OperationState *result, int64_t lb,
int64_t ub, int64_t step = 1); int64_t ub, int64_t step = 1);
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -326,7 +326,7 @@ public:
Region &getElseBlocks(); Region &getElseBlocks();
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
}; };

View File

@ -80,7 +80,7 @@ public:
/// Custom syntax support. /// Custom syntax support.
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
static StringRef getOperationName() { return "gpu.launch"; } static StringRef getOperationName() { return "gpu.launch"; }

View File

@ -54,6 +54,18 @@ template <typename OpType> struct IsSingleResult {
OpType *, OpTrait::OneResult<typename OpType::ConcreteOpType> *>::value; OpType *, OpTrait::OneResult<typename OpType::ConcreteOpType> *>::value;
}; };
/// This class represents success/failure for operation parsing. It is
/// essentially a simple wrapper class around LogicalResult that allows for
/// explicit conversion to bool. This allows for the parser to chain together
/// parse rules without the clutter of "failed/succeeded".
class ParseResult : public LogicalResult {
public:
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
/// Failure is true in a boolean context.
explicit operator bool() const { return failed(*this); }
};
/// This is the concrete base class that holds the operation pointer and has /// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them /// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them. /// instantiated on template types that don't affect them.
@ -132,10 +144,9 @@ protected:
LogicalResult verify() { return success(); } LogicalResult verify() { return success(); }
/// Unless overridden, the custom assembly form of an op is always rejected. /// Unless overridden, the custom assembly form of an op is always rejected.
/// Op implementations should implement this to return true on failure. /// Op implementations should implement this to return failure.
/// On success, they should return false and fill in result with the fields to /// On success, they should fill in result with the fields to use.
/// use. static ParseResult parse(OpAsmParser *parser, OperationState *result);
static bool parse(OpAsmParser *parser, OperationState *result);
// The fallback for the printer is to print it the generic assembly form. // The fallback for the printer is to print it the generic assembly form.
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
@ -768,9 +779,10 @@ public:
/// This is the hook used by the AsmParser to parse the custom form of this /// This is the hook used by the AsmParser to parse the custom form of this
/// op from an .mlir file. Op implementations should provide a parse method, /// op from an .mlir file. Op implementations should provide a parse method,
/// which returns boolean true on failure. On success, they should return /// which returns failure. On success, they should return fill in result with
/// false and fill in result with the fields to use. /// the fields to use.
static bool parseAssembly(OpAsmParser *parser, OperationState *result) { static ParseResult parseAssembly(OpAsmParser *parser,
OperationState *result) {
return ConcreteType::parse(parser, result); return ConcreteType::parse(parser, result);
} }
@ -854,7 +866,7 @@ private:
namespace impl { namespace impl {
void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs); Value *rhs);
bool parseBinaryOp(OpAsmParser *parser, OperationState *result); ParseResult parseBinaryOp(OpAsmParser *parser, OperationState *result);
// Prints the given binary `op` in custom assembly form if both the two operands // Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly // and the result have the same time. Otherwise, prints the generic assembly
// form. // form.
@ -866,7 +878,7 @@ void printBinaryOp(Operation *op, OpAsmPrinter *p);
namespace impl { namespace impl {
void buildCastOp(Builder *builder, OperationState *result, Value *source, void buildCastOp(Builder *builder, OperationState *result, Value *source,
Type destType); Type destType);
bool parseCastOp(OpAsmParser *parser, OperationState *result); ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(Operation *op, OpAsmPrinter *p); void printCastOp(Operation *op, OpAsmPrinter *p);
Value *foldCastOp(Operation *op); Value *foldCastOp(Operation *op);
} // namespace impl } // namespace impl
@ -888,7 +900,7 @@ public:
Type destType) { Type destType) {
impl::buildCastOp(builder, result, source, destType); impl::buildCastOp(builder, result, source, destType);
} }
static bool parse(OpAsmParser *parser, OperationState *result) { static ParseResult parse(OpAsmParser *parser, OperationState *result) {
return impl::parseCastOp(parser, result); return impl::parseCastOp(parser, result);
} }
void print(OpAsmPrinter *p) { void print(OpAsmPrinter *p) {

View File

@ -148,134 +148,138 @@ public:
// High level parsing methods. // High level parsing methods.
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// These emit an error and return true on failure, or return false on success. // These emit an error and return failure or success.
// This allows these to be chained together into a linear sequence of || // This allows these to be chained together into a linear sequence of ||
// expressions in many cases. // expressions in many cases.
/// Get the location of the next token and store it into the argument. This /// Get the location of the next token and store it into the argument. This
/// always succeeds. /// always succeeds.
virtual bool getCurrentLocation(llvm::SMLoc *loc) = 0; virtual ParseResult getCurrentLocation(llvm::SMLoc *loc) = 0;
/// This parses... a comma! /// This parses... a comma!
virtual bool parseComma() = 0; virtual ParseResult parseComma() = 0;
/// Parses a comma if present. /// Parses a comma if present.
virtual bool parseOptionalComma() = 0; virtual ParseResult parseOptionalComma() = 0;
/// Parse a `:` token. /// Parse a `:` token.
virtual bool parseColon() = 0; virtual ParseResult parseColon() = 0;
/// Parse a '(' token. /// Parse a '(' token.
virtual bool parseLParen() = 0; virtual ParseResult parseLParen() = 0;
/// Parse a ')' token. /// Parse a ')' token.
virtual bool parseRParen() = 0; virtual ParseResult parseRParen() = 0;
/// This parses an equal(=) token! /// This parses an equal(=) token!
virtual bool parseEqual() = 0; virtual ParseResult parseEqual() = 0;
/// Parse a type. /// Parse a type.
virtual bool parseType(Type &result) = 0; virtual ParseResult parseType(Type &result) = 0;
/// Parse a colon followed by a type. /// Parse a colon followed by a type.
virtual bool parseColonType(Type &result) = 0; virtual ParseResult parseColonType(Type &result) = 0;
/// Parse a type of a specific kind, e.g. a FunctionType. /// Parse a type of a specific kind, e.g. a FunctionType.
template <typename TypeType> bool parseColonType(TypeType &result) { template <typename TypeType> ParseResult parseColonType(TypeType &result) {
llvm::SMLoc loc; llvm::SMLoc loc;
getCurrentLocation(&loc); getCurrentLocation(&loc);
// Parse any kind of type. // Parse any kind of type.
Type type; Type type;
if (parseColonType(type)) if (parseColonType(type))
return true; return failure();
// Check for the right kind of attribute. // Check for the right kind of attribute.
result = type.dyn_cast<TypeType>(); result = type.dyn_cast<TypeType>();
if (!result) if (!result)
return emitError(loc, "invalid kind of type specified"); return emitError(loc, "invalid kind of type specified");
return false; return success();
} }
/// Parse a colon followed by a type list, which must have at least one type. /// Parse a colon followed by a type list, which must have at least one type.
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0; virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type. /// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type &result) { ParseResult parseKeywordType(const char *keyword, Type &result) {
return parseKeyword(keyword) || parseType(result); return failure(parseKeyword(keyword) || parseType(result));
} }
/// Parse a keyword. /// Parse a keyword.
bool parseKeyword(const char *keyword, const Twine &msg = "") { ParseResult parseKeyword(const char *keyword, const Twine &msg = "") {
if (parseOptionalKeyword(keyword)) if (parseOptionalKeyword(keyword))
return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg); return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg);
return false; return success();
} }
/// If a keyword is present, then parse it. /// If a keyword is present, then parse it.
virtual bool parseOptionalKeyword(const char *keyword) = 0; virtual ParseResult parseOptionalKeyword(const char *keyword) = 0;
/// Add the specified type to the end of the specified type list and return /// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and /// success. This is a helper designed to allow parse methods to be simple
/// chain through || operators. /// and chain through || operators.
bool addTypeToList(Type type, SmallVectorImpl<Type> &result) { ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type); result.push_back(type);
return false; return success();
} }
/// Add the specified types to the end of the specified type list and return /// Add the specified types to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and /// success. This is a helper designed to allow parse methods to be simple
/// chain through || operators. /// and chain through || operators.
bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) { ParseResult addTypesToList(ArrayRef<Type> types,
SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end()); result.append(types.begin(), types.end());
return false; return success();
} }
/// Parse an arbitrary attribute and return it in result. This also adds the /// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. /// attribute to the specified attribute list with the specified name.
virtual bool parseAttribute(Attribute &result, StringRef attrName, virtual ParseResult
SmallVectorImpl<NamedAttribute> &attrs) = 0; parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an arbitrary attribute of a given type and return it in result. This /// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified /// also adds the attribute to the specified attribute list with the specified
/// name. /// name.
virtual bool parseAttribute(Attribute &result, Type type, StringRef attrName, virtual ParseResult
SmallVectorImpl<NamedAttribute> &attrs) = 0; parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind and type. /// Parse an attribute of a specific kind and type.
template <typename AttrType> template <typename AttrType>
bool parseAttribute(AttrType &result, Type type, StringRef attrName, ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) { SmallVectorImpl<NamedAttribute> &attrs) {
llvm::SMLoc loc; llvm::SMLoc loc;
getCurrentLocation(&loc); getCurrentLocation(&loc);
// Parse any kind of attribute. // Parse any kind of attribute.
Attribute attr; Attribute attr;
if (parseAttribute(attr, type, attrName, attrs)) if (parseAttribute(attr, type, attrName, attrs))
return true; return failure();
// Check for the right kind of attribute. // Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>(); result = attr.dyn_cast<AttrType>();
if (!result) if (!result)
return emitError(loc, "invalid kind of constant specified"); return emitError(loc, "invalid kind of constant specified");
return false; return success();
} }
/// If a named attribute dictionary is present, parse it into result. /// If a named attribute dictionary is present, parse it into result.
virtual bool virtual ParseResult
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) = 0; parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) = 0;
/// Parse a function name like '@foo' and return the name in a form that can /// Parse a function name like '@foo' and return the name in a form that can
/// be passed to resolveFunctionName when a function type is available. /// be passed to resolveFunctionName when a function type is available.
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) = 0; virtual ParseResult parseFunctionName(StringRef &result,
llvm::SMLoc &loc) = 0;
/// Parse a function name like '@foo` if present and return the name without /// Parse a function name like '@foo` if present and return the name without
/// the sigil in `result`. Return true if the next token is not a function /// the sigil in `result`. Return true if the next token is not a function
/// name and keep `result` unchanged. /// name and keep `result` unchanged.
virtual bool parseOptionalFunctionName(StringRef &result, virtual ParseResult parseOptionalFunctionName(StringRef &result,
llvm::SMLoc &loc) = 0; llvm::SMLoc &loc) = 0;
/// This is the representation of an operand reference. /// This is the representation of an operand reference.
struct OperandType { struct OperandType {
@ -285,11 +289,12 @@ public:
}; };
/// Parse a single operand. /// Parse a single operand.
virtual bool parseOperand(OperandType &result) = 0; virtual ParseResult parseOperand(OperandType &result) = 0;
/// Parse a single operation successor and it's operand list. /// Parse a single operation successor and it's operand list.
virtual bool parseSuccessorAndUseList(Block *&dest, virtual ParseResult
SmallVectorImpl<Value *> &operands) = 0; parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) = 0;
/// These are the supported delimiters around operand lists, used by /// These are the supported delimiters around operand lists, used by
/// parseOperandList. /// parseOperandList.
@ -308,14 +313,15 @@ public:
/// Parse zero or more SSA comma-separated operand references with a specified /// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimiter, and an optional required operand count. /// surrounding delimiter, and an optional required operand count.
virtual bool parseOperandList(SmallVectorImpl<OperandType> &result, virtual ParseResult
int requiredOperandCount = -1, parseOperandList(SmallVectorImpl<OperandType> &result,
Delimiter delimiter = Delimiter::None) = 0; int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
/// Parse zero or more trailing SSA comma-separated trailing operand /// Parse zero or more trailing SSA comma-separated trailing operand
/// references with a specified surrounding delimiter, and an optional /// references with a specified surrounding delimiter, and an optional
/// required operand count. A leading comma is expected before the operands. /// required operand count. A leading comma is expected before the operands.
virtual bool virtual ParseResult
parseTrailingOperandList(SmallVectorImpl<OperandType> &result, parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1, int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0; Delimiter delimiter = Delimiter::None) = 0;
@ -323,12 +329,13 @@ public:
/// Parses a region. Any parsed blocks are appended to "region" and must be /// Parses a region. Any parsed blocks are appended to "region" and must be
/// moved to the op regions after the op is created. The first block of the /// moved to the op regions after the op is created. The first block of the
/// region takes "arguments" of types "argTypes". /// region takes "arguments" of types "argTypes".
virtual bool parseRegion(Region &region, ArrayRef<OperandType> arguments, virtual ParseResult parseRegion(Region &region,
ArrayRef<Type> argTypes) = 0; ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) = 0;
/// Parse a region argument. Region arguments define new values, so this also /// Parse a region argument. Region arguments define new values, so this also
/// checks if the values with the same name has not been defined yet. /// checks if the values with the same name has not been defined yet.
virtual bool parseRegionArgument(OperandType &argument) = 0; virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Methods for interacting with the parser // Methods for interacting with the parser
@ -341,46 +348,45 @@ public:
/// Return the location of the original name token. /// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0; virtual llvm::SMLoc getNameLoc() const = 0;
/// Resolve an operand to an SSA value, emitting an error and returning true /// Resolve an operand to an SSA value, emitting an error on failure.
/// on failure. virtual ParseResult resolveOperand(const OperandType &operand, Type type,
virtual bool resolveOperand(const OperandType &operand, Type type, SmallVectorImpl<Value *> &result) = 0;
SmallVectorImpl<Value *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning /// Resolve a list of operands to SSA values, emitting an error on failure, or
/// true on failure, or appending the results to the list on success. /// appending the results to the list on success. This method should be used
/// This method should be used when all operands have the same type. /// when all operands have the same type.
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type, virtual ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<Value *> &result) { SmallVectorImpl<Value *> &result) {
for (auto elt : operands) for (auto elt : operands)
if (resolveOperand(elt, type, result)) if (resolveOperand(elt, type, result))
return true; return failure();
return false; return success();
} }
/// Resolve a list of operands and a list of operand types to SSA values, /// Resolve a list of operands and a list of operand types to SSA values,
/// emitting an error and returning true on failure, or appending the results /// emitting an error and returning failure, or appending the results
/// to the list on success. /// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands, virtual ParseResult resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc, ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<Value *> &result) { SmallVectorImpl<Value *> &result) {
if (operands.size() != types.size()) if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) + return emitError(loc, Twine(operands.size()) +
" operands present, but expected " + " operands present, but expected " +
Twine(types.size())); Twine(types.size()));
for (unsigned i = 0, e = operands.size(); i != e; ++i) { for (unsigned i = 0, e = operands.size(); i != e; ++i)
if (resolveOperand(operands[i], types[i], result)) if (resolveOperand(operands[i], types[i], result))
return true; return failure();
} return success();
return false;
} }
/// Resolve a parse function name and a type into a function reference. /// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType type, virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) = 0; llvm::SMLoc loc,
Function *&result) = 0;
/// Emit a diagnostic at the specified location and return true. /// Emit a diagnostic at the specified location and return failure.
virtual bool emitError(llvm::SMLoc loc, const Twine &message) = 0; virtual ParseResult emitError(llvm::SMLoc loc, const Twine &message) = 0;
}; };
} // end namespace mlir } // end namespace mlir

View File

@ -42,6 +42,7 @@ struct OperationState;
class OpAsmParser; class OpAsmParser;
class OpAsmParserResult; class OpAsmParserResult;
class OpAsmPrinter; class OpAsmPrinter;
class ParseResult;
class Pattern; class Pattern;
class Region; class Region;
class RewritePattern; class RewritePattern;
@ -85,7 +86,7 @@ public:
bool (&isClassFor)(Operation *op); bool (&isClassFor)(Operation *op);
/// Use the specified object to parse this ops custom assembly format. /// Use the specified object to parse this ops custom assembly format.
bool (&parseAssembly)(OpAsmParser *parser, OperationState *result); ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result);
/// This hook implements the AsmPrinter for this operation. /// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(Operation *op, OpAsmPrinter *p); void (&printAssembly)(Operation *op, OpAsmPrinter *p);
@ -150,7 +151,7 @@ private:
AbstractOperation( AbstractOperation(
StringRef name, Dialect &dialect, OperationProperties opProperties, StringRef name, Dialect &dialect, OperationProperties opProperties,
bool (&isClassFor)(Operation *op), bool (&isClassFor)(Operation *op),
bool (&parseAssembly)(OpAsmParser *parser, OperationState *result), ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
void (&printAssembly)(Operation *op, OpAsmPrinter *p), void (&printAssembly)(Operation *op, OpAsmPrinter *p),
LogicalResult (&verifyInvariants)(Operation *op), LogicalResult (&verifyInvariants)(Operation *op),
LogicalResult (&constantFoldHook)(Operation *op, LogicalResult (&constantFoldHook)(Operation *op,

View File

@ -43,7 +43,7 @@ public:
static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; } static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
static void build(Builder *b, OperationState *result, Type type, Value *size); static void build(Builder *b, OperationState *result, Type type, Value *size);
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
// Op-specific functionality. // Op-specific functionality.
@ -67,7 +67,7 @@ public:
static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; } static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
static void build(Builder *b, OperationState *result, Value *buffer); static void build(Builder *b, OperationState *result, Value *buffer);
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
// Op-specific functionality. // Op-specific functionality.
@ -94,7 +94,7 @@ public:
static void build(Builder *b, OperationState *result, Value *min, Value *max, static void build(Builder *b, OperationState *result, Value *min, Value *max,
Value *step); Value *step);
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
// Op-specific functionality. // Op-specific functionality.
@ -156,7 +156,8 @@ public:
static void build(mlir::Builder *b, mlir::OperationState *result, static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings); mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
// Op-specific functionality. // Op-specific functionality.
@ -207,7 +208,8 @@ public:
mlir::Value *buffer, mlir::Value *buffer,
llvm::ArrayRef<mlir::Value *> indexings); llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify(); mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); static ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p); void print(mlir::OpAsmPrinter *p);
// Op-specific functionality. // Op-specific functionality.

View File

@ -80,7 +80,7 @@ public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<Value *> operands = {}); MemRefType memrefType, ArrayRef<Value *> operands = {});
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context); MLIRContext *context);
@ -108,7 +108,7 @@ public:
ArrayRef<Value *> operands = {}); ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
/// Return the block this branch jumps to. /// Return the block this branch jumps to.
@ -149,7 +149,7 @@ public:
operand_iterator arg_operand_end() { return operand_end(); } operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
}; };
@ -183,7 +183,7 @@ public:
operand_iterator arg_operand_end() { return operand_end(); } operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -249,7 +249,7 @@ public:
static void build(Builder *builder, OperationState *result, CmpIPredicate, static void build(Builder *builder, OperationState *result, CmpIPredicate,
Value *lhs, Value *rhs); Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context); Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
@ -324,7 +324,7 @@ public:
static void build(Builder *builder, OperationState *result, CmpFPredicate, static void build(Builder *builder, OperationState *result, CmpFPredicate,
Value *lhs, Value *rhs); Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context); Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
@ -362,7 +362,7 @@ public:
Block *falseDest, ArrayRef<Value *> falseOperands); Block *falseDest, ArrayRef<Value *> falseOperands);
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
@ -521,7 +521,7 @@ public:
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, Value *memref); static void build(Builder *builder, OperationState *result, Value *memref);
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context); MLIRContext *context);
@ -553,7 +553,7 @@ public:
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
}; };
@ -682,7 +682,7 @@ public:
} }
static StringRef getOperationName() { return "std.dma_start"; } static StringRef getOperationName() { return "std.dma_start"; }
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
@ -748,7 +748,7 @@ public:
// Returns the number of elements transferred in the associated DMA operation. // Returns the number of elements transferred in the associated DMA operation.
Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); } Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context); MLIRContext *context);
@ -785,7 +785,7 @@ public:
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context); Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
}; };
@ -821,7 +821,7 @@ public:
static StringRef getOperationName() { return "std.load"; } static StringRef getOperationName() { return "std.load"; }
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context); MLIRContext *context);
@ -881,7 +881,7 @@ public:
ArrayRef<Value *> results = {}); ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
}; };
@ -906,7 +906,7 @@ public:
static StringRef getOperationName() { return "std.select"; } static StringRef getOperationName() { return "std.select"; }
static void build(Builder *builder, OperationState *result, Value *condition, static void build(Builder *builder, OperationState *result, Value *condition,
Value *trueValue, Value *falseValue); Value *trueValue, Value *falseValue);
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
@ -953,7 +953,7 @@ public:
static StringRef getOperationName() { return "std.store"; } static StringRef getOperationName() { return "std.store"; }
LogicalResult verify(); LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results, static void getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -994,9 +994,9 @@ void printDimAndSymbolList(Operation::operand_iterator begin,
OpAsmPrinter *p); OpAsmPrinter *p);
/// Parses dimension and symbol list and returns true if parsing failed. /// Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser, ParseResult parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands, SmallVector<Value *, 4> &operands,
unsigned &numDims); unsigned &numDims);
} // end namespace mlir } // end namespace mlir

View File

@ -115,7 +115,7 @@ public:
Optional<Value *> getPaddingValue(); Optional<Value *> getPaddingValue();
AffineMap getPermutationMap(); AffineMap getPermutationMap();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
}; };
@ -177,7 +177,7 @@ public:
operand_range getIndices(); operand_range getIndices();
AffineMap getPermutationMap(); AffineMap getPermutationMap();
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
}; };
@ -200,7 +200,7 @@ public:
static StringRef getOperationName() { return "vector.type_cast"; } static StringRef getOperationName() { return "vector.type_cast"; }
static void build(Builder *builder, OperationState *result, Value *srcVector, static void build(Builder *builder, OperationState *result, Value *srcVector,
Type dstType); Type dstType);
static bool parse(OpAsmParser *parser, OperationState *result); static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p); void print(OpAsmPrinter *p);
LogicalResult verify(); LogicalResult verify();
}; };

View File

@ -131,7 +131,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
result->addAttribute("map", builder->getAffineMapAttr(map)); result->addAttribute("map", builder->getAffineMapAttr(map));
} }
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder(); auto &builder = parser->getBuilder();
auto affineIntTy = builder.getIndexType(); auto affineIntTy = builder.getIndexType();
@ -140,7 +140,7 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->parseAttribute(mapAttr, "map", result->attributes) || if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims) || parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes)) parser->parseOptionalAttributeDict(result->attributes))
return true; return failure();
auto map = mapAttr.getValue(); auto map = mapAttr.getValue();
if (map.getNumDims() != numDims || if (map.getNumDims() != numDims ||
@ -150,7 +150,7 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
} }
result->types.append(map.getNumResults(), affineIntTy); result->types.append(map.getNumResults(), affineIntTy);
return false; return success();
} }
void AffineApplyOp::print(OpAsmPrinter *p) { void AffineApplyOp::print(OpAsmPrinter *p) {
@ -801,10 +801,12 @@ LogicalResult AffineForOp::verify() {
} }
/// Parse a for operation loop bounds. /// Parse a for operation loop bounds.
static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) { static ParseResult parseBound(bool isLower, OperationState *result,
OpAsmParser *p) {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results. // the map has multiple results.
bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min"); bool failedToParsedMinMax =
failed(p->parseOptionalKeyword(isLower ? "max" : "min"));
auto &builder = p->getBuilder(); auto &builder = p->getBuilder();
auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
@ -813,7 +815,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
// Parse ssa-id as identity map. // Parse ssa-id as identity map.
SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
if (p->parseOperandList(boundOpInfos)) if (p->parseOperandList(boundOpInfos))
return true; return failure();
if (!boundOpInfos.empty()) { if (!boundOpInfos.empty()) {
// Check that only one operand was parsed. // Check that only one operand was parsed.
@ -825,14 +827,14 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
// Currently it is 'use of value ... expects different type than prior uses' // Currently it is 'use of value ... expects different type than prior uses'
if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
result->operands)) result->operands))
return true; return failure();
// Create an identity map using symbol id. This representation is optimized // Create an identity map using symbol id. This representation is optimized
// for storage. Analysis passes may expand it into a multi-dimensional map // for storage. Analysis passes may expand it into a multi-dimensional map
// if desired. // if desired.
AffineMap map = builder.getSymbolIdentityMap(); AffineMap map = builder.getSymbolIdentityMap();
result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
return false; return success();
} }
// Get the attribute location. // Get the attribute location.
@ -842,14 +844,14 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
Attribute boundAttr; Attribute boundAttr;
if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
result->attributes)) result->attributes))
return true; return failure();
// Parse full form - affine map followed by dim and symbol list. // Parse full form - affine map followed by dim and symbol list.
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
unsigned currentNumOperands = result->operands.size(); unsigned currentNumOperands = result->operands.size();
unsigned numDims; unsigned numDims;
if (parseDimAndSymbolList(p, result->operands, numDims)) if (parseDimAndSymbolList(p, result->operands, numDims))
return true; return failure();
auto map = affineMapAttr.getValue(); auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims) if (map.getNumDims() != numDims)
@ -874,7 +876,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
return p->emitError(attrLoc, "upper loop bound affine map with multiple " return p->emitError(attrLoc, "upper loop bound affine map with multiple "
"results requires 'min' prefix"); "results requires 'min' prefix");
} }
return false; return success();
} }
// Parse custom assembly form. // Parse custom assembly form.
@ -883,7 +885,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
result->addAttribute( result->addAttribute(
boundAttrName, builder.getAffineMapAttr( boundAttrName, builder.getAffineMapAttr(
builder.getConstantAffineMap(integerAttr.getInt()))); builder.getConstantAffineMap(integerAttr.getInt())));
return false; return success();
} }
return p->emitError( return p->emitError(
@ -891,18 +893,18 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
"expected valid affine map representation for loop bounds"); "expected valid affine map representation for loop bounds");
} }
bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder(); auto &builder = parser->getBuilder();
OpAsmParser::OperandType inductionVariable; OpAsmParser::OperandType inductionVariable;
// Parse the induction variable followed by '='. // Parse the induction variable followed by '='.
if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
return true; return failure();
// Parse loop bounds. // Parse loop bounds.
if (parseBound(/*isLower=*/true, result, parser) || if (parseBound(/*isLower=*/true, result, parser) ||
parser->parseKeyword("to", " between bounds") || parser->parseKeyword("to", " between bounds") ||
parseBound(/*isLower=*/false, result, parser)) parseBound(/*isLower=*/false, result, parser))
return true; return failure();
// Parse the optional loop step, we default to 1 if one is not present. // Parse the optional loop step, we default to 1 if one is not present.
if (parser->parseOptionalKeyword("step")) { if (parser->parseOptionalKeyword("step")) {
@ -915,7 +917,7 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->getCurrentLocation(&stepLoc) || if (parser->getCurrentLocation(&stepLoc) ||
parser->parseAttribute(stepAttr, builder.getIndexType(), parser->parseAttribute(stepAttr, builder.getIndexType(),
getStepAttrName().data(), result->attributes)) getStepAttrName().data(), result->attributes))
return true; return failure();
if (stepAttr.getValue().getSExtValue() < 0) if (stepAttr.getValue().getSExtValue() < 0)
return parser->emitError( return parser->emitError(
@ -926,17 +928,17 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the body region. // Parse the body region.
Region *body = result->addRegion(); Region *body = result->addRegion();
if (parser->parseRegion(*body, inductionVariable, builder.getIndexType())) if (parser->parseRegion(*body, inductionVariable, builder.getIndexType()))
return true; return failure();
ensureAffineTerminator(*body, builder, result->location); ensureAffineTerminator(*body, builder, result->location);
// Parse the optional attribute list. // Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes)) if (parser->parseOptionalAttributeDict(result->attributes))
return true; return failure();
// Set the operands list as resizable so that we can freely modify the bounds. // Set the operands list as resizable so that we can freely modify the bounds.
result->setOperandListToResizable(); result->setOperandListToResizable();
return false; return success();
} }
static void printBound(AffineMapAttr boundMap, static void printBound(AffineMapAttr boundMap,
@ -1253,14 +1255,14 @@ LogicalResult AffineIfOp::verify() {
return success(); return success();
} }
bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the condition attribute set. // Parse the condition attribute set.
IntegerSetAttr conditionAttr; IntegerSetAttr conditionAttr;
unsigned numDims; unsigned numDims;
if (parser->parseAttribute(conditionAttr, getConditionAttrName(), if (parser->parseAttribute(conditionAttr, getConditionAttrName(),
result->attributes) || result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims)) parseDimAndSymbolList(parser, result->operands, numDims))
return true; return failure();
// Verify the condition operands. // Verify the condition operands.
auto set = conditionAttr.getValue(); auto set = conditionAttr.getValue();
@ -1281,21 +1283,21 @@ bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the 'then' region. // Parse the 'then' region.
if (parser->parseRegion(*thenRegion, {}, {})) if (parser->parseRegion(*thenRegion, {}, {}))
return true; return failure();
ensureAffineTerminator(*thenRegion, parser->getBuilder(), result->location); ensureAffineTerminator(*thenRegion, parser->getBuilder(), result->location);
// If we find an 'else' keyword then parse the 'else' region. // If we find an 'else' keyword then parse the 'else' region.
if (!parser->parseOptionalKeyword("else")) { if (!parser->parseOptionalKeyword("else")) {
if (parser->parseRegion(*elseRegion, {}, {})) if (parser->parseRegion(*elseRegion, {}, {}))
return true; return failure();
ensureAffineTerminator(*elseRegion, parser->getBuilder(), result->location); ensureAffineTerminator(*elseRegion, parser->getBuilder(), result->location);
} }
// Parse the optional attribute list. // Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes)) if (parser->parseOptionalAttributeDict(result->attributes))
return true; return failure();
return false; return success();
} }
void AffineIfOp::print(OpAsmPrinter *p) { void AffineIfOp::print(OpAsmPrinter *p) {

View File

@ -159,7 +159,7 @@ void LaunchOp::print(OpAsmPrinter *p) {
// where %region_arg are percent-identifiers for the region arguments to be // where %region_arg are percent-identifiers for the region arguments to be
// introduced futher (SSA defs), and %operand are percent-identifiers for the // introduced futher (SSA defs), and %operand are percent-identifiers for the
// SSA value uses. // SSA value uses.
static bool static ParseResult
parseSizeAssignment(OpAsmParser *parser, parseSizeAssignment(OpAsmParser *parser,
MutableArrayRef<OpAsmParser::OperandType> sizes, MutableArrayRef<OpAsmParser::OperandType> sizes,
MutableArrayRef<OpAsmParser::OperandType> regionSizes, MutableArrayRef<OpAsmParser::OperandType> regionSizes,
@ -169,14 +169,14 @@ parseSizeAssignment(OpAsmParser *parser,
parser->parseComma() || parser->parseRegionArgument(indices[2]) || parser->parseComma() || parser->parseRegionArgument(indices[2]) ||
parser->parseRParen() || parser->parseKeyword("in") || parser->parseRParen() || parser->parseKeyword("in") ||
parser->parseLParen()) parser->parseLParen())
return true; return failure();
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
if (i != 0 && parser->parseComma()) if (i != 0 && parser->parseComma())
return true; return failure();
if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() || if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() ||
parser->parseOperand(sizes[i])) parser->parseOperand(sizes[i]))
return true; return failure();
} }
return parser->parseRParen(); return parser->parseRParen();
@ -188,7 +188,7 @@ parseSizeAssignment(OpAsmParser *parser,
// (`args` ssa-reassignment `:` type-list)? // (`args` ssa-reassignment `:` type-list)?
// region attr-dict? // region attr-dict?
// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
// Sizes of the grid and block. // Sizes of the grid and block.
SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
kNumConfigOperands); kNumConfigOperands);
@ -217,7 +217,7 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
regionArgsRef.slice(3, 3)) || regionArgsRef.slice(3, 3)) ||
parser->resolveOperands(sizes, parser->getBuilder().getIndexType(), parser->resolveOperands(sizes, parser->getBuilder().getIndexType(),
result->operands)) result->operands))
return true; return failure();
// If kernel argument renaming segment is present, parse it. When present, // If kernel argument renaming segment is present, parse it. When present,
// the segment should have at least one element. If this segment is present, // the segment should have at least one element. If this segment is present,
@ -232,20 +232,20 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->getCurrentLocation(&argsLoc) || parser->parseLParen() || if (parser->getCurrentLocation(&argsLoc) || parser->parseLParen() ||
parser->parseRegionArgument(regionArgs.back()) || parser->parseRegionArgument(regionArgs.back()) ||
parser->parseEqual() || parser->parseOperand(dataOperands.back())) parser->parseEqual() || parser->parseOperand(dataOperands.back()))
return true; return failure();
while (!parser->parseOptionalComma()) { while (!parser->parseOptionalComma()) {
regionArgs.push_back({}); regionArgs.push_back({});
dataOperands.push_back({}); dataOperands.push_back({});
if (parser->parseRegionArgument(regionArgs.back()) || if (parser->parseRegionArgument(regionArgs.back()) ||
parser->parseEqual() || parser->parseOperand(dataOperands.back())) parser->parseEqual() || parser->parseOperand(dataOperands.back()))
return true; return failure();
} }
if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) || if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) ||
parser->resolveOperands(dataOperands, dataTypes, argsLoc, parser->resolveOperands(dataOperands, dataTypes, argsLoc,
result->operands)) result->operands))
return true; return failure();
} }
// Introduce the body region and parse it. The region has // Introduce the body region and parse it. The region has
@ -255,11 +255,10 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
Type index = parser->getBuilder().getIndexType(); Type index = parser->getBuilder().getIndexType();
dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
Region *body = result->addRegion(); Region *body = result->addRegion();
return parser->parseRegion(*body, regionArgs, dataTypes) || return failure(parser->parseRegion(*body, regionArgs, dataTypes) ||
parser->parseOptionalAttributeDict(result->attributes); parser->parseOptionalAttributeDict(result->attributes));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LaunchFuncOp // LaunchFuncOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -640,7 +640,7 @@ Operation *Operation::clone(MLIRContext *context) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// The fallback for the parser is to reject the custom assembly form. // The fallback for the parser is to reject the custom assembly form.
bool OpState::parse(OpAsmParser *parser, OperationState *result) { ParseResult OpState::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), "has no custom assembly form"); return parser->emitError(parser->getNameLoc(), "has no custom assembly form");
} }
@ -948,14 +948,14 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
result->types.push_back(lhs->getType()); result->types.push_back(lhs->getType());
} }
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { ParseResult impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops; SmallVector<OpAsmParser::OperandType, 2> ops;
Type type; Type type;
return parser->parseOperandList(ops, 2) || return failure(parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands) || parser->resolveOperands(ops, type, result->operands) ||
parser->addTypeToList(type, result->types); parser->addTypeToList(type, result->types));
} }
void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) { void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) {
@ -988,13 +988,14 @@ void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
result->addTypes(destType); result->addTypes(destType);
} }
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { ParseResult impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo; OpAsmParser::OperandType srcInfo;
Type srcType, dstType; Type srcType, dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || return failure(parser->parseOperand(srcInfo) ||
parser->resolveOperand(srcInfo, srcType, result->operands) || parser->parseColonType(srcType) ||
parser->parseKeywordType("to", dstType) || parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->addTypeToList(dstType, result->types); parser->parseKeywordType("to", dstType) ||
parser->addTypeToList(dstType, result->types));
} }
void impl::printCastOp(Operation *op, OpAsmPrinter *p) { void impl::printCastOp(Operation *op, OpAsmPrinter *p) {

View File

@ -126,7 +126,7 @@ static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
// attribute-dict? `:` type // attribute-dict? `:` type
static bool parseICmpOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
Builder &builder = parser->getBuilder(); Builder &builder = parser->getBuilder();
Attribute predicate; Attribute predicate;
@ -142,7 +142,7 @@ static bool parseICmpOp(OpAsmParser *parser, OperationState *result) {
parser->parseType(type) || parser->parseType(type) ||
parser->resolveOperand(lhs, type, result->operands) || parser->resolveOperand(lhs, type, result->operands) ||
parser->resolveOperand(rhs, type, result->operands)) parser->resolveOperand(rhs, type, result->operands))
return true; return failure();
// Replace the string attribute `predicate` with an integer attribute. // Replace the string attribute `predicate` with an integer attribute.
auto predicateStr = predicate.dyn_cast<StringAttr>(); auto predicateStr = predicate.dyn_cast<StringAttr>();
@ -173,7 +173,7 @@ static bool parseICmpOp(OpAsmParser *parser, OperationState *result) {
result->attributes = attrs; result->attributes = attrs;
result->addTypes({resultType}); result->addTypes({resultType});
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -195,7 +195,7 @@ static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
// `:` type `,` type // `:` type `,` type
static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arraySize; OpAsmParser::OperandType arraySize;
Type type, elemType; Type type, elemType;
@ -204,7 +204,7 @@ static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) {
parser->parseType(elemType) || parser->parseType(elemType) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return true; return failure();
// Extract the result type from the trailing function type. // Extract the result type from the trailing function type.
auto funcType = type.dyn_cast<FunctionType>(); auto funcType = type.dyn_cast<FunctionType>();
@ -215,11 +215,11 @@ static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) {
"expected trailing function type with one argument and one result"); "expected trailing function type with one argument and one result");
if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands)) if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands))
return true; return failure();
result->attributes = attrs; result->attributes = attrs;
result->addTypes({funcType.getResult(0)}); result->addTypes({funcType.getResult(0)});
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -242,7 +242,7 @@ static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]` // <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
// attribute-dict? `:` type // attribute-dict? `:` type
static bool parseGEPOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType base; OpAsmParser::OperandType base;
SmallVector<OpAsmParser::OperandType, 8> indices; SmallVector<OpAsmParser::OperandType, 8> indices;
@ -253,7 +253,7 @@ static bool parseGEPOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) || OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return true; return failure();
// Deconstruct the trailing function type to extract the types of the base // Deconstruct the trailing function type to extract the types of the base
// pointer and result (same type) and the types of the indices. // pointer and result (same type) and the types of the indices.
@ -267,11 +267,11 @@ static bool parseGEPOp(OpAsmParser *parser, OperationState *result) {
if (parser->resolveOperand(base, funcType.getInput(0), result->operands) || if (parser->resolveOperand(base, funcType.getInput(0), result->operands) ||
parser->resolveOperands(indices, funcType.getInputs().drop_front(), parser->resolveOperands(indices, funcType.getInputs().drop_front(),
parser->getNameLoc(), result->operands)) parser->getNameLoc(), result->operands))
return true; return failure();
result->attributes = attrs; result->attributes = attrs;
result->addTypes(funcType.getResults()); result->addTypes(funcType.getResults());
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -302,7 +302,7 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
} }
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type // <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
static bool parseLoadOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr; OpAsmParser::OperandType addr;
Type type; Type type;
@ -312,13 +312,13 @@ static bool parseLoadOp(OpAsmParser *parser, OperationState *result) {
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type) || parser->parseType(type) ||
parser->resolveOperand(addr, type, result->operands)) parser->resolveOperand(addr, type, result->operands))
return true; return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
result->attributes = attrs; result->attributes = attrs;
result->addTypes(elemTy); result->addTypes(elemTy);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -332,7 +332,7 @@ static void printStoreOp(OpAsmPrinter *p, StoreOp &op) {
} }
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type // <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
static bool parseStoreOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr, value; OpAsmParser::OperandType addr, value;
Type type; Type type;
@ -342,18 +342,18 @@ static bool parseStoreOp(OpAsmParser *parser, OperationState *result) {
parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type)) parser->parseType(type))
return true; return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
if (!elemTy) if (!elemTy)
return true; return failure();
if (parser->resolveOperand(value, elemTy, result->operands) || if (parser->resolveOperand(value, elemTy, result->operands) ||
parser->resolveOperand(addr, type, result->operands)) parser->resolveOperand(addr, type, result->operands))
return true; return failure();
result->attributes = attrs; result->attributes = attrs;
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -367,7 +367,7 @@ static void printBitcastOp(OpAsmPrinter *p, BitcastOp &op) {
} }
// <operation> ::= `llvm.bitcast` ssa-use attribute-dict? `:` type `to` type // <operation> ::= `llvm.bitcast` ssa-use attribute-dict? `:` type `to` type
static bool parseBitcastOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseBitcastOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arg; OpAsmParser::OperandType arg;
Type sourceType, type; Type sourceType, type;
@ -376,11 +376,11 @@ static bool parseBitcastOp(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(sourceType) || parser->parseKeyword("to") || parser->parseColonType(sourceType) || parser->parseKeyword("to") ||
parser->parseType(type) || parser->parseType(type) ||
parser->resolveOperand(arg, sourceType, result->operands)) parser->resolveOperand(arg, sourceType, result->operands))
return true; return failure();
result->attributes = attrs; result->attributes = attrs;
result->addTypes(type); result->addTypes(type);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -438,7 +438,7 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
// attribute-dict? `:` function-type // attribute-dict? `:` function-type
static bool parseCallOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
SmallVector<OpAsmParser::OperandType, 8> operands; SmallVector<OpAsmParser::OperandType, 8> operands;
Type type; Type type;
@ -450,19 +450,19 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
// direct call, there will be no operands and the parser will stop at the // direct call, there will be no operands and the parser will stop at the
// function identifier without complaining. // function identifier without complaining.
if (parser->parseOperandList(operands)) if (parser->parseOperandList(operands))
return true; return failure();
bool isDirect = operands.empty(); bool isDirect = operands.empty();
// Optionally parse a function identifier. // Optionally parse a function identifier.
if (isDirect) if (isDirect)
if (parser->parseFunctionName(calleeName, calleeLoc)) if (parser->parseFunctionName(calleeName, calleeLoc))
return true; return failure();
if (parser->parseOperandList(operands, /*requiredOperandCount=*/-1, if (parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) || OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return true; return failure();
auto funcType = type.dyn_cast<FunctionType>(); auto funcType = type.dyn_cast<FunctionType>();
if (!funcType) if (!funcType)
@ -471,14 +471,14 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
// Add the direct callee as an Op attribute. // Add the direct callee as an Op attribute.
Function *func; Function *func;
if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func)) if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func))
return true; return failure();
auto funcAttr = parser->getBuilder().getFunctionAttr(func); auto funcAttr = parser->getBuilder().getFunctionAttr(func);
attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr)); attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr));
// Make sure types match. // Make sure types match.
if (parser->resolveOperands(operands, funcType.getInputs(), if (parser->resolveOperands(operands, funcType.getInputs(),
parser->getNameLoc(), result->operands)) parser->getNameLoc(), result->operands))
return true; return failure();
result->addTypes(funcType.getResults()); result->addTypes(funcType.getResults());
} else { } else {
// Construct the LLVM IR Dialect function type that the first operand // Construct the LLVM IR Dialect function type that the first operand
@ -528,13 +528,13 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
result->operands) || result->operands) ||
parser->resolveOperands(funcArguments, funcType.getInputs(), parser->resolveOperands(funcArguments, funcType.getInputs(),
parser->getNameLoc(), result->operands)) parser->getNameLoc(), result->operands))
return true; return failure();
result->addTypes(wrappedResultType); result->addTypes(wrappedResultType);
} }
result->attributes = attrs; result->attributes = attrs;
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -601,7 +601,8 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
// <operation> ::= `llvm.extractvalue` ssa-use // <operation> ::= `llvm.extractvalue` ssa-use
// `[` integer-literal (`,` integer-literal)* `]` // `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type // attribute-dict? `:` type
static bool parseExtractValueOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseExtractValueOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType container; OpAsmParser::OperandType container;
Type containerType; Type containerType;
@ -615,16 +616,16 @@ static bool parseExtractValueOp(OpAsmParser *parser, OperationState *result) {
parser->getCurrentLocation(&trailingTypeLoc) || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(containerType) || parser->parseType(containerType) ||
parser->resolveOperand(container, containerType, result->operands)) parser->resolveOperand(container, containerType, result->operands))
return true; return failure();
auto elementType = getInsertExtractValueElementType( auto elementType = getInsertExtractValueElementType(
parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
if (!elementType) if (!elementType)
return true; return failure();
result->attributes = attrs; result->attributes = attrs;
result->addTypes(elementType); result->addTypes(elementType);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -641,7 +642,8 @@ static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) {
// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
// `[` integer-literal (`,` integer-literal)* `]` // `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type // attribute-dict? `:` type
static bool parseInsertValueOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseInsertValueOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType container, value; OpAsmParser::OperandType container, value;
Type containerType; Type containerType;
Attribute positionAttr; Attribute positionAttr;
@ -654,19 +656,19 @@ static bool parseInsertValueOp(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(containerType)) parser->parseType(containerType))
return true; return failure();
auto valueType = getInsertExtractValueElementType( auto valueType = getInsertExtractValueElementType(
parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
if (!valueType) if (!valueType)
return true; return failure();
if (parser->resolveOperand(container, containerType, result->operands) || if (parser->resolveOperand(container, containerType, result->operands) ||
parser->resolveOperand(value, valueType, result->operands)) parser->resolveOperand(value, valueType, result->operands))
return true; return failure();
result->addTypes(containerType); result->addTypes(containerType);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -682,7 +684,7 @@ static void printSelectOp(OpAsmPrinter *p, SelectOp &op) {
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use // <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
// attribute-dict? `:` type, type // attribute-dict? `:` type, type
static bool parseSelectOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType condition, trueValue, falseValue; OpAsmParser::OperandType condition, trueValue, falseValue;
Type conditionType, argType; Type conditionType, argType;
@ -692,15 +694,15 @@ static bool parseSelectOp(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(conditionType) || parser->parseComma() || parser->parseColonType(conditionType) || parser->parseComma() ||
parser->parseType(argType)) parser->parseType(argType))
return true; return failure();
if (parser->resolveOperand(condition, conditionType, result->operands) || if (parser->resolveOperand(condition, conditionType, result->operands) ||
parser->resolveOperand(trueValue, argType, result->operands) || parser->resolveOperand(trueValue, argType, result->operands) ||
parser->resolveOperand(falseValue, argType, result->operands)) parser->resolveOperand(falseValue, argType, result->operands))
return true; return failure();
result->addTypes(argType); result->addTypes(argType);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -715,15 +717,15 @@ static void printBrOp(OpAsmPrinter *p, BrOp &op) {
// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? // <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
// attribute-dict? // attribute-dict?
static bool parseBrOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) {
Block *dest; Block *dest;
SmallVector<Value *, 4> operands; SmallVector<Value *, 4> operands;
if (parser->parseSuccessorAndUseList(dest, operands) || if (parser->parseSuccessorAndUseList(dest, operands) ||
parser->parseOptionalAttributeDict(result->attributes)) parser->parseOptionalAttributeDict(result->attributes))
return true; return failure();
result->addSuccessor(dest, operands); result->addSuccessor(dest, operands);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -741,7 +743,7 @@ static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) {
// <operation> ::= `llvm.cond_br` ssa-use `,` // <operation> ::= `llvm.cond_br` ssa-use `,`
// bb-id (`[` ssa-use-and-type-list `]`)? `,` // bb-id (`[` ssa-use-and-type-list `]`)? `,`
// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? // bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
static bool parseCondBrOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
Block *trueDest; Block *trueDest;
Block *falseDest; Block *falseDest;
SmallVector<Value *, 4> trueOperands; SmallVector<Value *, 4> trueOperands;
@ -760,11 +762,11 @@ static bool parseCondBrOp(OpAsmParser *parser, OperationState *result) {
parser->parseSuccessorAndUseList(falseDest, falseOperands) || parser->parseSuccessorAndUseList(falseDest, falseOperands) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->resolveOperand(condition, i1Type, result->operands)) parser->resolveOperand(condition, i1Type, result->operands))
return true; return failure();
result->addSuccessor(trueDest, trueOperands); result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands); result->addSuccessor(falseDest, falseOperands);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -784,20 +786,20 @@ static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) {
// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
// type-list-no-parens // type-list-no-parens
static bool parseReturnOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 1> operands; SmallVector<OpAsmParser::OperandType, 1> operands;
Type type; Type type;
if (parser->parseOperandList(operands) || if (parser->parseOperandList(operands) ||
parser->parseOptionalAttributeDict(result->attributes)) parser->parseOptionalAttributeDict(result->attributes))
return true; return failure();
if (operands.empty()) if (operands.empty())
return false; return success();
if (parser->parseColonType(type) || if (parser->parseColonType(type) ||
parser->resolveOperand(operands[0], type, result->operands)) parser->resolveOperand(operands[0], type, result->operands))
return true; return failure();
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -811,15 +813,15 @@ static void printUndefOp(OpAsmPrinter *p, UndefOp &op) {
} }
// <operation> ::= `llvm.undef` attribute-dict? : type // <operation> ::= `llvm.undef` attribute-dict? : type
static bool parseUndefOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
Type type; Type type;
if (parser->parseOptionalAttributeDict(result->attributes) || if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type)) parser->parseColonType(type))
return true; return failure();
result->addTypes(type); result->addTypes(type);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -845,7 +847,8 @@ static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
} }
// <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type // <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type
static bool parseConstantOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr; Attribute valueAttr;
Type type; Type type;
@ -854,10 +857,10 @@ static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
parser->parseRParen() || parser->parseRParen() ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type)) parser->parseColonType(type))
return true; return failure();
result->addTypes(type); result->addTypes(type);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -53,15 +53,15 @@ static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) {
} }
// <operation> ::= `llvm.nvvm.XYZ` : type // <operation> ::= `llvm.nvvm.XYZ` : type
static bool parseNVVMSpecialRegisterOp(OpAsmParser *parser, static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser,
OperationState *result) { OperationState *result) {
Type type; Type type;
if (parser->parseOptionalAttributeDict(result->attributes) || if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type)) parser->parseColonType(type))
return true; return failure();
result->addTypes(type); result->addTypes(type);
return false; return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -58,18 +58,19 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *size() << " : " << getType(); *p << getOperationName() << " " << *size() << " : " << getType();
} }
bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo; OpAsmParser::OperandType sizeInfo;
BufferType bufferType; BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType(); auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType)) if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return true; return failure();
if (bufferType.getElementType() != parser->getBuilder().getF32Type()) if (bufferType.getElementType() != parser->getBuilder().getF32Type())
return parser->emitError( return parser->emitError(
parser->getNameLoc(), parser->getNameLoc(),
"Only buffer<f32> supported until mlir::Parser pieces are exposed"); "Only buffer<f32> supported until mlir::Parser pieces are exposed");
return parser->resolveOperands(sizeInfo, indexTy, result->operands) || return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types); parser->addTypeToList(bufferType, result->types));
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -95,11 +96,13 @@ void mlir::BufferDeallocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType(); *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
} }
bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult mlir::BufferDeallocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo; OpAsmParser::OperandType sizeInfo;
BufferType bufferType; BufferType bufferType;
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || return failure(
parser->resolveOperands(sizeInfo, bufferType, result->operands); parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// RangeOp // RangeOp
@ -131,15 +134,16 @@ void mlir::RangeOp::print(OpAsmPrinter *p) {
<< " : " << getType(); << " : " << getType();
} }
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3); SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type; RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() || return failure(
parser->parseOperand(rangeInfo[1]) || parser->parseColon() || parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) || parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->addTypeToList(type, result->types); parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types));
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -189,7 +193,7 @@ LogicalResult mlir::SliceOp::verify() {
return success(); return success();
} }
bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType baseInfo; OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo; SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types; SmallVector<Type, 8> types;
@ -198,7 +202,7 @@ bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) || OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types)) parser->parseColonTypeList(types))
return true; return failure();
if (types.size() != 2 + indexingsInfo.size()) if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -221,12 +225,13 @@ bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"expected " + Twine(baseViewType.getRank()) + "expected " + Twine(baseViewType.getRank()) +
" indexing types"); " indexing types");
return parser->resolveOperand(baseInfo, baseViewType, result->operands) || return failure(
(!indexingsInfo.empty() && parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
parser->resolveOperands(indexingsInfo, indexingTypes, (!indexingsInfo.empty() &&
indexingsInfo.front().location, parser->resolveOperands(indexingsInfo, indexingTypes,
result->operands)) || indexingsInfo.front().location,
parser->addTypeToList(viewType, result->types); result->operands)) ||
parser->addTypeToList(viewType, result->types));
} }
// A SliceOp prints as: // A SliceOp prints as:
@ -306,7 +311,7 @@ LogicalResult mlir::ViewOp::verify() {
return success(); return success();
} }
bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo; OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo; SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type; Type type;
@ -315,7 +320,7 @@ bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) || OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type)) parser->parseColonType(type))
return true; return failure();
ViewType viewType = type.dyn_cast<ViewType>(); ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType) if (!viewType)
@ -324,15 +329,15 @@ bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"expected" + Twine(viewType.getRank()) + "expected" + Twine(viewType.getRank()) +
" range indexings"); " range indexings");
return parser->resolveOperand( return failure(
bufferInfo, parser->resolveOperand(
BufferType::get(type.getContext(), viewType.getElementType()), bufferInfo,
result->operands) || BufferType::get(type.getContext(), viewType.getElementType()),
(!indexingsInfo.empty() && result->operands) ||
parser->resolveOperands(indexingsInfo, (!indexingsInfo.empty() &&
RangeType::get(type.getContext()), parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) || result->operands)) ||
parser->addTypeToList(viewType, result->types); parser->addTypeToList(viewType, result->types));
} }
// A ViewOp prints as: // A ViewOp prints as:
@ -354,9 +359,9 @@ void mlir::ViewOp::print(OpAsmPrinter *p) {
namespace mlir { namespace mlir {
namespace impl { namespace impl {
void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op); void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op);
bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result); ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op); void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op);
bool parseBufferSizeOp(OpAsmParser *parser, OperationState *result); ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
} // namespace impl } // namespace impl
/// Buffer size prints as: /// Buffer size prints as:
@ -372,16 +377,16 @@ void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) {
*p << " : " << op->getOperand(0)->getType(); *p << " : " << op->getOperand(0)->getType();
} }
bool mlir::impl::parseBufferSizeOp(OpAsmParser *parser, ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) { OperationState *result) {
OpAsmParser::OperandType op; OpAsmParser::OperandType op;
Type type; Type type;
return parser->parseOperand(op) || return failure(parser->parseOperand(op) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperand(op, type, result->operands) || parser->resolveOperand(op, type, result->operands) ||
parser->addTypeToList(parser->getBuilder().getIndexType(), parser->addTypeToList(parser->getBuilder().getIndexType(),
result->types); result->types));
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
@ -415,15 +420,16 @@ void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
[&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
} }
bool mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser, ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) { OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops; SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types; SmallVector<Type, 3> types;
return parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) || return failure(
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) ||
parser->parseColonTypeList(types) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->resolveOperands(ops, types, parser->getNameLoc(), parser->parseColonTypeList(types) ||
result->operands); parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands));
} }
// Ideally this should all be Tablegen'd but there is no good story for // Ideally this should all be Tablegen'd but there is no good story for

View File

@ -45,17 +45,6 @@ using llvm::MemoryBuffer;
using llvm::SMLoc; using llvm::SMLoc;
using llvm::SourceMgr; using llvm::SourceMgr;
/// Simple wrapper class around LogicalResult that allows for explicit
/// conversion to bool. This allows for the parser to chain together parse rules
/// without the clutter of "failed/succeeded".
class ParseResult : public LogicalResult {
public:
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
/// Failure is true in a boolean context.
explicit operator bool() const { return failed(*this); }
};
namespace { namespace {
class Parser; class Parser;
@ -2266,8 +2255,8 @@ public:
ParseResult parseFunctionBody(bool hadNamedArguments); ParseResult parseFunctionBody(bool hadNamedArguments);
/// Parse a single operation successor and it's operand list. /// Parse a single operation successor and it's operand list.
bool parseSuccessorAndUseList(Block *&dest, ParseResult parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands); SmallVectorImpl<Value *> &operands);
/// Parse a comma-separated list of operation successors in brackets. /// Parse a comma-separated list of operation successors in brackets.
ParseResult ParseResult
@ -2809,11 +2798,12 @@ Block *FunctionParser::defineBlockNamed(StringRef name, SMLoc loc,
/// successor ::= block-id branch-use-list? /// successor ::= block-id branch-use-list?
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)` /// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
/// ///
bool FunctionParser::parseSuccessorAndUseList( ParseResult
Block *&dest, SmallVectorImpl<Value *> &operands) { FunctionParser::parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) {
// Verify branch is identifier and get the matching block. // Verify branch is identifier and get the matching block.
if (!getToken().is(Token::caret_identifier)) if (!getToken().is(Token::caret_identifier))
return emitError("expected block name"), true; return emitError("expected block name");
dest = getBlockNamed(getTokenSpelling(), getToken().getLoc()); dest = getBlockNamed(getTokenSpelling(), getToken().getLoc());
consumeToken(); consumeToken();
@ -2821,10 +2811,10 @@ bool FunctionParser::parseSuccessorAndUseList(
if (consumeIf(Token::l_paren) && if (consumeIf(Token::l_paren) &&
(parseOptionalSSAUseAndTypeList(operands) || (parseOptionalSSAUseAndTypeList(operands) ||
parseToken(Token::r_paren, "expected ')' to close argument list"))) { parseToken(Token::r_paren, "expected ')' to close argument list"))) {
return true; return failure();
} }
return false; return success();
} }
/// Parse a comma-separated list of operation successors in brackets. /// Parse a comma-separated list of operation successors in brackets.
@ -2840,10 +2830,10 @@ ParseResult FunctionParser::parseSuccessors(
auto parseElt = [this, &destinations, &operands]() { auto parseElt = [this, &destinations, &operands]() {
Block *dest; Block *dest;
SmallVector<Value *, 4> destOperands; SmallVector<Value *, 4> destOperands;
bool r = parseSuccessorAndUseList(dest, destOperands); auto res = parseSuccessorAndUseList(dest, destOperands);
destinations.push_back(dest); destinations.push_back(dest);
operands.push_back(destOperands); operands.push_back(destOperands);
return r ? failure() : success(); return res;
}; };
return parseCommaSeparatedListUntil(Token::r_square, parseElt, return parseCommaSeparatedListUntil(Token::r_square, parseElt,
/*allowEmptyList=*/false); /*allowEmptyList=*/false);
@ -3105,10 +3095,10 @@ public:
CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser) CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
: nameLoc(nameLoc), opName(opName), parser(parser) {} : nameLoc(nameLoc), opName(opName), parser(parser) {}
bool parseOperation(const AbstractOperation *opDefinition, ParseResult parseOperation(const AbstractOperation *opDefinition,
OperationState *opState) { OperationState *opState) {
if (opDefinition->parseAssembly(this, opState)) if (opDefinition->parseAssembly(this, opState))
return true; return failure();
// Check that none of the operands of the current operation reference an // Check that none of the operands of the current operation reference an
// entry block argument for any of the region. // entry block argument for any of the region.
@ -3116,53 +3106,53 @@ public:
if (llvm::is_contained(opState->operands, entryArg)) if (llvm::is_contained(opState->operands, entryArg))
return emitError(nameLoc, "operand use before it's defined"); return emitError(nameLoc, "operand use before it's defined");
return false; return success();
} }
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// High level parsing methods. // High level parsing methods.
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
bool getCurrentLocation(llvm::SMLoc *loc) override { ParseResult getCurrentLocation(llvm::SMLoc *loc) override {
*loc = parser.getToken().getLoc(); *loc = parser.getToken().getLoc();
return false; return success();
} }
bool parseComma() override { ParseResult parseComma() override {
return failed(parser.parseToken(Token::comma, "expected ','")); return parser.parseToken(Token::comma, "expected ','");
} }
bool parseColon() override { ParseResult parseColon() override {
return failed(parser.parseToken(Token::colon, "expected ':'")); return parser.parseToken(Token::colon, "expected ':'");
} }
bool parseEqual() override { ParseResult parseEqual() override {
return failed(parser.parseToken(Token::equal, "expected '='")); return parser.parseToken(Token::equal, "expected '='");
} }
bool parseType(Type &result) override { ParseResult parseType(Type &result) override {
return !(result = parser.parseType()); return failure(!(result = parser.parseType()));
} }
bool parseColonType(Type &result) override { ParseResult parseColonType(Type &result) override {
return parser.parseToken(Token::colon, "expected ':'") || return failure(parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType()); !(result = parser.parseType()));
} }
bool parseColonTypeList(SmallVectorImpl<Type> &result) override { ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'")) if (parser.parseToken(Token::colon, "expected ':'"))
return true; return failure();
do { do {
if (auto type = parser.parseType()) if (auto type = parser.parseType())
result.push_back(type); result.push_back(type);
else else
return true; return failure();
} while (parser.consumeIf(Token::comma)); } while (parser.consumeIf(Token::comma));
return false; return success();
} }
bool parseTrailingOperandList(SmallVectorImpl<OperandType> &result, ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount, int requiredOperandCount,
Delimiter delimiter) override { Delimiter delimiter) override {
if (parser.getToken().is(Token::comma)) { if (parser.getToken().is(Token::comma)) {
parseComma(); parseComma();
return parseOperandList(result, requiredOperandCount, delimiter); return parseOperandList(result, requiredOperandCount, delimiter);
@ -3170,101 +3160,105 @@ public:
if (requiredOperandCount != -1) if (requiredOperandCount != -1)
return emitError(parser.getToken().getLoc(), return emitError(parser.getToken().getLoc(),
"expected " + Twine(requiredOperandCount) + " operands"); "expected " + Twine(requiredOperandCount) + " operands");
return false; return success();
} }
bool parseOptionalComma() override { return !parser.consumeIf(Token::comma); } ParseResult parseOptionalComma() override {
return success(parser.consumeIf(Token::comma));
}
/// Parse an optional keyword. /// Parse an optional keyword.
bool parseOptionalKeyword(const char *keyword) override { ParseResult parseOptionalKeyword(const char *keyword) override {
// Check that the current token is a bare identifier or keyword. // Check that the current token is a bare identifier or keyword.
if (parser.getToken().isNot(Token::bare_identifier) && if (parser.getToken().isNot(Token::bare_identifier) &&
!parser.getToken().isKeyword()) !parser.getToken().isKeyword())
return true; return failure();
if (parser.getTokenSpelling() == keyword) { if (parser.getTokenSpelling() == keyword) {
parser.consumeToken(); parser.consumeToken();
return false; return success();
} }
return true; return failure();
} }
/// Parse an arbitrary attribute of a given type and return it in result. This /// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified /// also adds the attribute to the specified attribute list with the specified
/// name. /// name.
bool parseAttribute(Attribute &result, Type type, StringRef attrName, ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) override { SmallVectorImpl<NamedAttribute> &attrs) override {
result = parser.parseAttribute(type); result = parser.parseAttribute(type);
if (!result) if (!result)
return true; return failure();
attrs.push_back(parser.builder.getNamedAttr(attrName, result)); attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return false; return success();
} }
/// Parse an arbitrary attribute and return it in result. This also adds /// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name. /// the attribute to the specified attribute list with the specified name.
bool parseAttribute(Attribute &result, StringRef attrName, ParseResult parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) override { SmallVectorImpl<NamedAttribute> &attrs) override {
return parseAttribute(result, Type(), attrName, attrs); return parseAttribute(result, Type(), attrName, attrs);
} }
/// If a named attribute list is present, parse is into result. /// If a named attribute list is present, parse is into result.
bool ParseResult
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override { parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override {
if (parser.getToken().isNot(Token::l_brace)) if (parser.getToken().isNot(Token::l_brace))
return false; return success();
return failed(parser.parseAttributeDict(result)); return parser.parseAttributeDict(result);
} }
/// Parse a function name like '@foo' and return the name in a form that can /// Parse a function name like '@foo' and return the name in a form that can
/// be passed to resolveFunctionName when a function type is available. /// be passed to resolveFunctionName when a function type is available.
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) { virtual ParseResult parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
if (parseOptionalFunctionName(result, loc)) if (parseOptionalFunctionName(result, loc))
return emitError(loc, "expected function name"); return emitError(loc, "expected function name");
return false; return success();
} }
/// Parse a function name like '@foo` if present and return the name without /// Parse a function name like '@foo` if present and return the name without
/// the sigil in `result`. Return true if the next token is not a function /// the sigil in `result`. Return true if the next token is not a function
/// name and keep `result` unchanged. /// name and keep `result` unchanged.
bool parseOptionalFunctionName(StringRef &result, llvm::SMLoc &loc) override { ParseResult parseOptionalFunctionName(StringRef &result,
llvm::SMLoc &loc) override {
loc = parser.getToken().getLoc(); loc = parser.getToken().getLoc();
if (parser.getToken().isNot(Token::at_identifier)) if (parser.getToken().isNot(Token::at_identifier))
return true; return failure();
result = parser.getTokenSpelling(); result = parser.getTokenSpelling();
parser.consumeToken(Token::at_identifier); parser.consumeToken(Token::at_identifier);
return false; return success();
} }
bool parseOperand(OperandType &result) override { ParseResult parseOperand(OperandType &result) override {
FunctionParser::SSAUseInfo useInfo; FunctionParser::SSAUseInfo useInfo;
if (parser.parseSSAUse(useInfo)) if (parser.parseSSAUse(useInfo))
return true; return failure();
result = {useInfo.loc, useInfo.name, useInfo.number}; result = {useInfo.loc, useInfo.name, useInfo.number};
return false; return success();
} }
bool parseSuccessorAndUseList(Block *&dest, ParseResult
SmallVectorImpl<Value *> &operands) override { parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) override {
// Defer successor parsing to the function parsers. // Defer successor parsing to the function parsers.
return parser.parseSuccessorAndUseList(dest, operands); return parser.parseSuccessorAndUseList(dest, operands);
} }
bool parseLParen() override { ParseResult parseLParen() override {
return failed(parser.parseToken(Token::l_paren, "expected '('")); return parser.parseToken(Token::l_paren, "expected '('");
} }
bool parseRParen() override { ParseResult parseRParen() override {
return failed(parser.parseToken(Token::r_paren, "expected ')'")); return parser.parseToken(Token::r_paren, "expected ')'");
} }
bool parseOperandList(SmallVectorImpl<OperandType> &result, ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1, int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) override { Delimiter delimiter = Delimiter::None) override {
auto startLoc = parser.getToken().getLoc(); auto startLoc = parser.getToken().getLoc();
// Handle delimiters. // Handle delimiters.
@ -3284,19 +3278,19 @@ public:
return emitError(startLoc, "invalid operand"); return emitError(startLoc, "invalid operand");
case Delimiter::OptionalParen: case Delimiter::OptionalParen:
if (parser.getToken().isNot(Token::l_paren)) if (parser.getToken().isNot(Token::l_paren))
return false; return success();
LLVM_FALLTHROUGH; LLVM_FALLTHROUGH;
case Delimiter::Paren: case Delimiter::Paren:
if (parser.parseToken(Token::l_paren, "expected '(' in operand list")) if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
return true; return failure();
break; break;
case Delimiter::OptionalSquare: case Delimiter::OptionalSquare:
if (parser.getToken().isNot(Token::l_square)) if (parser.getToken().isNot(Token::l_square))
return false; return success();
LLVM_FALLTHROUGH; LLVM_FALLTHROUGH;
case Delimiter::Square: case Delimiter::Square:
if (parser.parseToken(Token::l_square, "expected '[' in operand list")) if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
return true; return failure();
break; break;
} }
@ -3305,7 +3299,7 @@ public:
do { do {
OperandType operand; OperandType operand;
if (parseOperand(operand)) if (parseOperand(operand))
return true; return failure();
result.push_back(operand); result.push_back(operand);
} while (parser.consumeIf(Token::comma)); } while (parser.consumeIf(Token::comma));
} }
@ -3318,32 +3312,32 @@ public:
case Delimiter::OptionalParen: case Delimiter::OptionalParen:
case Delimiter::Paren: case Delimiter::Paren:
if (parser.parseToken(Token::r_paren, "expected ')' in operand list")) if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
return true; return failure();
break; break;
case Delimiter::OptionalSquare: case Delimiter::OptionalSquare:
case Delimiter::Square: case Delimiter::Square:
if (parser.parseToken(Token::r_square, "expected ']' in operand list")) if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
return true; return failure();
break; break;
} }
if (requiredOperandCount != -1 && result.size() != requiredOperandCount) if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
return emitError(startLoc, return emitError(startLoc,
"expected " + Twine(requiredOperandCount) + " operands"); "expected " + Twine(requiredOperandCount) + " operands");
return false; return success();
} }
/// Resolve a parse function name and a type into a function reference. /// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType type, virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) { llvm::SMLoc loc, Function *&result) {
result = parser.resolveFunctionReference(name, loc, type); result = parser.resolveFunctionReference(name, loc, type);
return result == nullptr; return failure(result == nullptr);
} }
/// Parse a region that takes `arguments` of `argTypes` types. This /// Parse a region that takes `arguments` of `argTypes` types. This
/// effectively defines the SSA values of `arguments` and assignes their type. /// effectively defines the SSA values of `arguments` and assignes their type.
bool parseRegion(Region &region, ArrayRef<OperandType> arguments, ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) override { ArrayRef<Type> argTypes) override {
assert(arguments.size() == argTypes.size() && assert(arguments.size() == argTypes.size() &&
"mismatching number of arguments and types"); "mismatching number of arguments and types");
@ -3359,26 +3353,26 @@ public:
// references to region arguments. // references to region arguments.
Value *value = parser.resolveSSAUse(operandInfo, type); Value *value = parser.resolveSSAUse(operandInfo, type);
if (!value) if (!value)
return true; return failure();
parsedRegionEntryArgumentPlaceholders.emplace_back(value); parsedRegionEntryArgumentPlaceholders.emplace_back(value);
} }
return failed(parser.parseOperationRegion(region, regionArguments)); return parser.parseOperationRegion(region, regionArguments);
} }
/// Parse a region argument. Region arguments define new values, so this also /// Parse a region argument. Region arguments define new values, so this also
/// checks if the values with the same name has not been defined yet. The /// checks if the values with the same name has not been defined yet. The
/// type of the argument will be resolved later by a call to `parseRegion`. /// type of the argument will be resolved later by a call to `parseRegion`.
bool parseRegionArgument(OperandType &argument) { ParseResult parseRegionArgument(OperandType &argument) {
// Use parseOperand to fill in the OperandType structure. // Use parseOperand to fill in the OperandType structure.
if (parseOperand(argument)) if (parseOperand(argument))
return true; return failure();
if (auto defLoc = parser.getDefinitionLoc(argument.name, argument.number)) { if (auto defLoc = parser.getDefinitionLoc(argument.name, argument.number)) {
parser.emitError(argument.location, parser.emitError(argument.location,
"redefinition of SSA value '" + argument.name + "'"); "redefinition of SSA value '" + argument.name + "'");
return parser.emitError(*defLoc, "previously defined here"), true; return parser.emitError(*defLoc, "previously defined here");
} }
return false; return success();
} }
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
@ -3389,22 +3383,22 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; } llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(const OperandType &operand, Type type, ParseResult resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<Value *> &result) override { SmallVectorImpl<Value *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number, FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location}; operand.location};
if (auto *value = parser.resolveSSAUse(operandInfo, type)) { if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
result.push_back(value); result.push_back(value);
return false; return success();
} }
return true; return failure();
} }
/// Emit a diagnostic at the specified location and return true. /// Emit a diagnostic at the specified location and return failure.
bool emitError(llvm::SMLoc loc, const Twine &message) override { ParseResult emitError(llvm::SMLoc loc, const Twine &message) override {
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
emittedError = true; emittedError = true;
return true; return parser.emitError(loc,
"custom op '" + Twine(opName) + "' " + message);
} }
bool didEmitError() const { return emittedError; } bool didEmitError() const { return emittedError; }

View File

@ -87,12 +87,12 @@ void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
// Parses dimension and symbol list, and sets 'numDims' to the number of // Parses dimension and symbol list, and sets 'numDims' to the number of
// dimension operands parsed. // dimension operands parsed.
// Returns 'false' on success and 'true' on error. // Returns 'false' on success and 'true' on error.
bool mlir::parseDimAndSymbolList(OpAsmParser *parser, ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands, SmallVector<Value *, 4> &operands,
unsigned &numDims) { unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos; SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren)) if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true; return failure();
// Store number of dimensions for validation by caller. // Store number of dimensions for validation by caller.
numDims = opInfos.size(); numDims = opInfos.size();
@ -101,8 +101,8 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
if (parser->parseOperandList(opInfos, -1, if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) || OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands)) parser->resolveOperands(opInfos, affineIntTy, operands))
return true; return failure();
return false; return success();
} }
/// Matches a ConstantIndexOp. /// Matches a ConstantIndexOp.
@ -223,7 +223,7 @@ void AllocOp::print(OpAsmPrinter *p) {
*p << " : " << type; *p << " : " << type;
} }
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MemRefType type; MemRefType type;
// Parse the dimension operands and optional symbol operands, followed by a // Parse the dimension operands and optional symbol operands, followed by a
@ -232,7 +232,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type)) parser->parseColonType(type))
return true; return failure();
// Check numDynamicDims against number of question marks in memref type. // Check numDynamicDims against number of question marks in memref type.
// Note: this check remains here (instead of in verify()), because the // Note: this check remains here (instead of in verify()), because the
@ -246,7 +246,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
"dynamic dimension count"); "dynamic dimension count");
} }
result->types.push_back(type); result->types.push_back(type);
return false; return success();
} }
LogicalResult AllocOp::verify() { LogicalResult AllocOp::verify() {
@ -385,13 +385,13 @@ void BranchOp::build(Builder *builder, OperationState *result, Block *dest,
result->addSuccessor(dest, operands); result->addSuccessor(dest, operands);
} }
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) {
Block *dest; Block *dest;
SmallVector<Value *, 4> destOperands; SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands)) if (parser->parseSuccessorAndUseList(dest, destOperands))
return true; return failure();
result->addSuccessor(dest, destOperands); result->addSuccessor(dest, destOperands);
return false; return success();
} }
void BranchOp::print(OpAsmPrinter *p) { void BranchOp::print(OpAsmPrinter *p) {
@ -420,7 +420,7 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
result->addTypes(callee->getType().getResults()); result->addTypes(callee->getType().getResults());
} }
bool CallOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName; StringRef calleeName;
llvm::SMLoc calleeLoc; llvm::SMLoc calleeLoc;
FunctionType calleeType; FunctionType calleeType;
@ -435,10 +435,10 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
parser->addTypesToList(calleeType.getResults(), result->types) || parser->addTypesToList(calleeType.getResults(), result->types) ||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands)) result->operands))
return true; return failure();
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee)); result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
return false; return success();
} }
void CallOp::print(OpAsmPrinter *p) { void CallOp::print(OpAsmPrinter *p) {
@ -517,21 +517,22 @@ void CallIndirectOp::build(Builder *builder, OperationState *result,
result->addTypes(fnType.getResults()); result->addTypes(fnType.getResults());
} }
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
FunctionType calleeType; FunctionType calleeType;
OpAsmParser::OperandType callee; OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc; llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands; SmallVector<OpAsmParser::OperandType, 4> operands;
return parser->parseOperand(callee) || return failure(
parser->getCurrentLocation(&operandsLoc) || parser->parseOperand(callee) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1, parser->getCurrentLocation(&operandsLoc) ||
OpAsmParser::Delimiter::Paren) || parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
parser->parseOptionalAttributeDict(result->attributes) || OpAsmParser::Delimiter::Paren) ||
parser->parseColonType(calleeType) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->resolveOperand(callee, calleeType, result->operands) || parser->parseColonType(calleeType) ||
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, parser->resolveOperand(callee, calleeType, result->operands) ||
result->operands) || parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
parser->addTypesToList(calleeType.getResults(), result->types); result->operands) ||
parser->addTypesToList(calleeType.getResults(), result->types));
} }
void CallIndirectOp::print(OpAsmPrinter *p) { void CallIndirectOp::print(OpAsmPrinter *p) {
@ -678,7 +679,7 @@ void CmpIOp::build(Builder *build, OperationState *result,
build->getI64IntegerAttr(static_cast<int64_t>(predicate))); build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
} }
bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops; SmallVector<OpAsmParser::OperandType, 2> ops;
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
Attribute predicateNameAttr; Attribute predicateNameAttr;
@ -689,7 +690,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(attrs) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands)) parser->resolveOperands(ops, type, result->operands))
return true; return failure();
if (!predicateNameAttr.isa<StringAttr>()) if (!predicateNameAttr.isa<StringAttr>())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -713,7 +714,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
result->attributes = attrs; result->attributes = attrs;
result->addTypes({i1Type}); result->addTypes({i1Type});
return false; return success();
} }
void CmpIOp::print(OpAsmPrinter *p) { void CmpIOp::print(OpAsmPrinter *p) {
@ -856,7 +857,7 @@ void CmpFOp::build(Builder *build, OperationState *result,
build->getI64IntegerAttr(static_cast<int64_t>(predicate))); build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
} }
bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops; SmallVector<OpAsmParser::OperandType, 2> ops;
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
Attribute predicateNameAttr; Attribute predicateNameAttr;
@ -867,7 +868,7 @@ bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(attrs) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands)) parser->resolveOperands(ops, type, result->operands))
return true; return failure();
if (!predicateNameAttr.isa<StringAttr>()) if (!predicateNameAttr.isa<StringAttr>())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -891,7 +892,7 @@ bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
result->attributes = attrs; result->attributes = attrs;
result->addTypes({i1Type}); result->addTypes({i1Type});
return false; return success();
} }
void CmpFOp::print(OpAsmPrinter *p) { void CmpFOp::print(OpAsmPrinter *p) {
@ -1044,7 +1045,7 @@ void CondBranchOp::build(Builder *builder, OperationState *result,
result->addSuccessor(falseDest, falseOperands); result->addSuccessor(falseDest, falseOperands);
} }
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<Value *, 4> destOperands; SmallVector<Value *, 4> destOperands;
Block *dest; Block *dest;
OpAsmParser::OperandType condInfo; OpAsmParser::OperandType condInfo;
@ -1059,18 +1060,17 @@ bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the true successor. // Parse the true successor.
if (parser->parseSuccessorAndUseList(dest, destOperands)) if (parser->parseSuccessorAndUseList(dest, destOperands))
return true; return failure();
result->addSuccessor(dest, destOperands); result->addSuccessor(dest, destOperands);
// Parse the false successor. // Parse the false successor.
destOperands.clear(); destOperands.clear();
if (parser->parseComma() || if (parser->parseComma() ||
parser->parseSuccessorAndUseList(dest, destOperands)) parser->parseSuccessorAndUseList(dest, destOperands))
return true; return failure();
result->addSuccessor(dest, destOperands); result->addSuccessor(dest, destOperands);
// Return false on success. return success();
return false;
} }
void CondBranchOp::print(OpAsmPrinter *p) { void CondBranchOp::print(OpAsmPrinter *p) {
@ -1132,13 +1132,14 @@ static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
*p << " : " << op.getType(); *p << " : " << op.getType();
} }
static bool parseConstantOp(OpAsmParser *parser, OperationState *result) { static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr; Attribute valueAttr;
Type type; Type type;
if (parser->parseOptionalAttributeDict(result->attributes) || if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseAttribute(valueAttr, "value", result->attributes)) parser->parseAttribute(valueAttr, "value", result->attributes))
return true; return failure();
// 'constant' taking a function reference doesn't get a redundant type // 'constant' taking a function reference doesn't get a redundant type
// specifier. The attribute itself carries it. // specifier. The attribute itself carries it.
@ -1150,7 +1151,7 @@ static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
} else if (auto fpAttr = valueAttr.dyn_cast<FloatAttr>()) { } else if (auto fpAttr = valueAttr.dyn_cast<FloatAttr>()) {
type = fpAttr.getType(); type = fpAttr.getType();
} else if (parser->parseColonType(type)) { } else if (parser->parseColonType(type)) {
return true; return failure();
} }
return parser->addTypeToList(type, result->types); return parser->addTypeToList(type, result->types);
} }
@ -1298,12 +1299,13 @@ void DeallocOp::print(OpAsmPrinter *p) {
*p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType(); *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
} }
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
MemRefType type; MemRefType type;
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) || return failure(parser->parseOperand(memrefInfo) ||
parser->resolveOperand(memrefInfo, type, result->operands); parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands));
} }
LogicalResult DeallocOp::verify() { LogicalResult DeallocOp::verify() {
@ -1338,19 +1340,19 @@ void DimOp::print(OpAsmPrinter *p) {
*p << " : " << getOperand()->getType(); *p << " : " << getOperand()->getType();
} }
bool DimOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo; OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr; IntegerAttr indexAttr;
Type type; Type type;
Type indexType = parser->getBuilder().getIndexType(); Type indexType = parser->getBuilder().getIndexType();
return parser->parseOperand(operandInfo) || parser->parseComma() || return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, indexType, "index", parser->parseAttribute(indexAttr, indexType, "index",
result->attributes) || result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, result->operands) || parser->resolveOperand(operandInfo, type, result->operands) ||
parser->addTypeToList(indexType, result->types); parser->addTypeToList(indexType, result->types));
} }
LogicalResult DimOp::verify() { LogicalResult DimOp::verify() {
@ -1491,7 +1493,7 @@ void DmaStartOp::print(OpAsmPrinter *p) {
// memref<1024 x f32, 2>, // memref<1024 x f32, 2>,
// memref<1 x i32> // memref<1 x i32>
// //
bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcMemRefInfo; OpAsmParser::OperandType srcMemRefInfo;
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
OpAsmParser::OperandType dstMemRefInfo; OpAsmParser::OperandType dstMemRefInfo;
@ -1518,11 +1520,11 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseComma() || parser->parseOperand(tagMemrefInfo) || parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
parser->parseOperandList(tagIndexInfos, -1, parser->parseOperandList(tagIndexInfos, -1,
OpAsmParser::Delimiter::Square)) OpAsmParser::Delimiter::Square))
return true; return failure();
// Parse optional stride and elements per stride. // Parse optional stride and elements per stride.
if (parser->parseTrailingOperandList(strideInfo)) { if (parser->parseTrailingOperandList(strideInfo)) {
return true; return failure();
} }
if (!strideInfo.empty() && strideInfo.size() != 2) { if (!strideInfo.empty() && strideInfo.size() != 2) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -1531,7 +1533,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
bool isStrided = strideInfo.size() == 2; bool isStrided = strideInfo.size() == 2;
if (parser->parseColonTypeList(types)) if (parser->parseColonTypeList(types))
return true; return failure();
if (types.size() != 3) if (types.size() != 3)
return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
@ -1545,7 +1547,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
// tag indices should be index. // tag indices should be index.
parser->resolveOperands(tagIndexInfos, indexType, result->operands)) parser->resolveOperands(tagIndexInfos, indexType, result->operands))
return true; return failure();
if (!types[0].isa<MemRefType>()) if (!types[0].isa<MemRefType>())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -1562,7 +1564,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
if (isStrided) { if (isStrided) {
if (parser->resolveOperand(strideInfo[0], indexType, result->operands) || if (parser->resolveOperand(strideInfo[0], indexType, result->operands) ||
parser->resolveOperand(strideInfo[1], indexType, result->operands)) parser->resolveOperand(strideInfo[1], indexType, result->operands))
return true; return failure();
} }
// Check that source/destination index list size matches associated rank. // Check that source/destination index list size matches associated rank.
@ -1575,7 +1577,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count"); "tag memref rank not equal to indices count");
return false; return success();
} }
LogicalResult DmaStartOp::verify() { LogicalResult DmaStartOp::verify() {
@ -1628,7 +1630,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) {
// Eg: // Eg:
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
// //
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo; OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type type; Type type;
@ -1644,7 +1646,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(tagMemrefInfo, type, result->operands) || parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
parser->resolveOperands(tagIndexInfos, indexType, result->operands) || parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
parser->resolveOperand(numElementsInfo, indexType, result->operands)) parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true; return failure();
if (!type.isa<MemRefType>()) if (!type.isa<MemRefType>())
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
@ -1654,7 +1656,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count"); "tag memref rank not equal to indices count");
return false; return success();
} }
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -1684,20 +1686,21 @@ void ExtractElementOp::print(OpAsmPrinter *p) {
*p << " : " << getAggregate()->getType(); *p << " : " << getAggregate()->getType();
} }
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult ExtractElementOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType aggregateInfo; OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
VectorOrTensorType type; VectorOrTensorType type;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(aggregateInfo) || return failure(
parser->parseOperandList(indexInfo, -1, parser->parseOperand(aggregateInfo) ||
OpAsmParser::Delimiter::Square) || parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) || parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types); parser->addTypeToList(type.getElementType(), result->types));
} }
LogicalResult ExtractElementOp::verify() { LogicalResult ExtractElementOp::verify() {
@ -1771,20 +1774,20 @@ void LoadOp::print(OpAsmPrinter *p) {
*p << " : " << getMemRefType(); *p << " : " << getMemRefType();
} }
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType type; MemRefType type;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(memrefInfo) || return failure(
parser->parseOperandList(indexInfo, -1, parser->parseOperand(memrefInfo) ||
OpAsmParser::Delimiter::Square) || parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) || parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) || parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types); parser->addTypeToList(type.getElementType(), result->types));
} }
LogicalResult LoadOp::verify() { LogicalResult LoadOp::verify() {
@ -1963,13 +1966,14 @@ void ReturnOp::build(Builder *builder, OperationState *result,
result->addOperands(results); result->addOperands(results);
} }
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo; SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types; SmallVector<Type, 2> types;
llvm::SMLoc loc; llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) || return failure(parser->getCurrentLocation(&loc) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) || parser->parseOperandList(opInfo) ||
parser->resolveOperands(opInfo, types, loc, result->operands); (!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands));
} }
void ReturnOp::print(OpAsmPrinter *p) { void ReturnOp::print(OpAsmPrinter *p) {
@ -2012,7 +2016,7 @@ void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
result->addTypes(trueValue->getType()); result->addTypes(trueValue->getType());
} }
bool SelectOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops; SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<NamedAttribute, 4> attrs; SmallVector<NamedAttribute, 4> attrs;
Type type; Type type;
@ -2020,7 +2024,7 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->parseOperandList(ops, 3) || if (parser->parseOperandList(ops, 3) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type)) parser->parseColonType(type))
return true; return failure();
auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type);
if (!i1Type) if (!i1Type)
@ -2028,9 +2032,9 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
"expected type with valid i1 shape"); "expected type with valid i1 shape");
SmallVector<Type, 3> types = {i1Type, type, type}; SmallVector<Type, 3> types = {i1Type, type, type};
return parser->resolveOperands(ops, types, parser->getNameLoc(), return failure(parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands) || result->operands) ||
parser->addTypeToList(type, result->types); parser->addTypeToList(type, result->types));
} }
void SelectOp::print(OpAsmPrinter *p) { void SelectOp::print(OpAsmPrinter *p) {
@ -2090,23 +2094,23 @@ void StoreOp::print(OpAsmPrinter *p) {
*p << " : " << getMemRefType(); *p << " : " << getMemRefType();
} }
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType memrefType; MemRefType memrefType;
auto affineIntTy = parser->getBuilder().getIndexType(); auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(storeValueInfo) || parser->parseComma() || return failure(
parser->parseOperand(memrefInfo) || parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperandList(indexInfo, -1, parser->parseOperand(memrefInfo) ||
OpAsmParser::Delimiter::Square) || parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) || parser->parseColonType(memrefType) ||
parser->resolveOperand(storeValueInfo, memrefType.getElementType(), parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) || result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) || parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands); parser->resolveOperands(indexInfo, affineIntTy, result->operands));
} }
LogicalResult StoreOp::verify() { LogicalResult StoreOp::verify() {

View File

@ -120,7 +120,8 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) {
*p << ", " << getResultType(); *p << ", " << getResultType();
} }
bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo; SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<OpAsmParser::OperandType, 8> paddingInfo; SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
@ -133,7 +134,7 @@ bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Paren) || OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types)) parser->parseColonTypeList(types))
return true; return failure();
// Resolution. // Resolution.
if (types.size() != 2) if (types.size() != 2)
@ -160,12 +161,12 @@ bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
paddingType = vectorType.getElementType(); paddingType = vectorType.getElementType();
} }
auto indexType = parser->getBuilder().getIndexType(); auto indexType = parser->getBuilder().getIndexType();
return parser->resolveOperand(memrefInfo, memrefType, result->operands) || return failure(
parser->resolveOperands(indexInfo, indexType, result->operands) || parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
(hasOptionalPaddingValue && parser->resolveOperands(indexInfo, indexType, result->operands) ||
parser->resolveOperand(paddingInfo[0], paddingType, (hasOptionalPaddingValue &&
result->operands)) || parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) ||
parser->addTypeToList(vectorType, result->types); parser->addTypeToList(vectorType, result->types));
} }
LogicalResult VectorTransferReadOp::verify() { LogicalResult VectorTransferReadOp::verify() {
@ -286,7 +287,8 @@ void VectorTransferWriteOp::print(OpAsmPrinter *p) {
p->printType(getMemRefType()); p->printType(getMemRefType());
} }
bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo; OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
@ -297,7 +299,7 @@ bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) || parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types)) parser->parseColonTypeList(types))
return true; return failure();
if (types.size() != 2) if (types.size() != 2)
return parser->emitError(parser->getNameLoc(), "expected 2 types"); return parser->emitError(parser->getNameLoc(), "expected 2 types");
@ -308,10 +310,10 @@ bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
if (!memrefType) if (!memrefType)
return parser->emitError(parser->getNameLoc(), "memRef type expected"); return parser->emitError(parser->getNameLoc(), "memRef type expected");
return parser->resolveOperands(storeValueInfo, vectorType, return failure(
result->operands) || parser->resolveOperands(storeValueInfo, vectorType, result->operands) ||
parser->resolveOperands(memrefInfo, memrefType, result->operands) || parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands); parser->resolveOperands(indexInfo, indexType, result->operands));
} }
LogicalResult VectorTransferWriteOp::verify() { LogicalResult VectorTransferWriteOp::verify() {
@ -390,15 +392,16 @@ void VectorTypeCastOp::build(Builder *builder, OperationState *result,
result->addTypes(dstType); result->addTypes(dstType);
} }
bool VectorTypeCastOp::parse(OpAsmParser *parser, OperationState *result) { ParseResult VectorTypeCastOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType operand; OpAsmParser::OperandType operand;
Type srcType, dstType; Type srcType, dstType;
return parser->parseOperand(operand) || return failure(parser->parseOperand(operand) ||
parser->parseOptionalAttributeDict(result->attributes) || parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(srcType) || parser->parseComma() || parser->parseColonType(srcType) || parser->parseComma() ||
parser->parseType(dstType) || parser->parseType(dstType) ||
parser->addTypeToList(dstType, result->types) || parser->addTypeToList(dstType, result->types) ||
parser->resolveOperand(operand, srcType, result->operands); parser->resolveOperand(operand, srcType, result->operands));
} }
void VectorTypeCastOp::print(OpAsmPrinter *p) { void VectorTypeCastOp::print(OpAsmPrinter *p) {

View File

@ -51,7 +51,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
// CHECK: static void build(Value *val); // CHECK: static void build(Value *val);
// CHECK: static void build(Builder *, OperationState *tblgen_state, Type r, ArrayRef<Type> s, Value *a, ArrayRef<Value *> b, IntegerAttr attr1, /*optional*/FloatAttr attr2); // CHECK: static void build(Builder *, OperationState *tblgen_state, Type r, ArrayRef<Type> s, Value *a, ArrayRef<Value *> b, IntegerAttr attr1, /*optional*/FloatAttr attr2);
// CHECK: static void build(Builder *, OperationState *tblgen_state, ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes); // CHECK: static void build(Builder *, OperationState *tblgen_state, ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes);
// CHECK: static bool parse(OpAsmParser *parser, OperationState *result); // CHECK: static ParseResult parse(OpAsmParser *parser, OperationState *result);
// CHECK: void print(OpAsmPrinter *p); // CHECK: void print(OpAsmPrinter *p);
// CHECK: LogicalResult verify(); // CHECK: LogicalResult verify();
// CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); // CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context);

View File

@ -863,7 +863,7 @@ void OpEmitter::genParser() {
return; return;
auto &method = opClass.newMethod( auto &method = opClass.newMethod(
"bool", "parse", "OpAsmParser *parser, OperationState *result", "ParseResult", "parse", "OpAsmParser *parser, OperationState *result",
OpMethod::MP_Static); OpMethod::MP_Static);
auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
method.body() << " " << parser; method.body() << " " << parser;