NFC: Make ParseResult public and update the OpAsmParser(and thus all of the custom operation parsers) to use it instead of bool.

--

PiperOrigin-RevId: 246955523
This commit is contained in:
River Riddle 2019-05-06 22:01:31 -07:00 committed by Mehdi Amini
parent 6ccf90147c
commit b7dc252683
30 changed files with 591 additions and 543 deletions

View File

@ -41,7 +41,8 @@ public:
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *min, mlir::Value *max, mlir::Value *step);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////

View File

@ -40,7 +40,8 @@ public:
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *view, mlir::Value *indexing, unsigned dim);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////

View File

@ -43,7 +43,8 @@ public:
mlir::Value *memRef,
llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////

View File

@ -48,15 +48,17 @@ mlir::LogicalResult linalg::RangeOp::verify() {
return mlir::success();
}
bool linalg::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult linalg::RangeOp::parse(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto indexTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, indexTy, result->operands) ||
parser->addTypeToList(type, result->types);
return failure(
parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, indexTy, result->operands) ||
parser->addTypeToList(type, result->types));
}
// A RangeOp prints as:

View File

@ -74,7 +74,8 @@ mlir::LogicalResult linalg::SliceOp::verify() {
return mlir::success();
}
bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult linalg::SliceOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType viewInfo;
SmallVector<OpAsmParser::OperandType, 1> indexingInfo;
SmallVector<Type, 8> types;
@ -83,7 +84,7 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return true;
return failure();
if (indexingInfo.size() != 1)
return parser->emitError(parser->getNameLoc(), "expected 1 indexing type");
@ -107,10 +108,10 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
ViewType resultViewType =
ViewType::get(viewType.getContext(), viewType.getElementType(), rank);
return parser->resolveOperand(viewInfo, viewType, result->operands) ||
parser->resolveOperands(indexingInfo[0], types.back(),
result->operands) ||
parser->addTypeToList(resultViewType, result->types);
return failure(parser->resolveOperand(viewInfo, viewType, result->operands) ||
parser->resolveOperands(indexingInfo[0], types.back(),
result->operands) ||
parser->addTypeToList(resultViewType, result->types));
}
// A SliceOp prints as:

View File

@ -89,7 +89,7 @@ LogicalResult linalg::ViewOp::verify() {
return success();
}
bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memRefInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types;
@ -98,7 +98,7 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return true;
return failure();
if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
@ -120,12 +120,13 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(memRefType.getRank()) +
" indexing types");
return parser->resolveOperand(memRefInfo, memRefType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types);
return failure(
parser->resolveOperand(memRefInfo, memRefType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A ViewOp prints as:

View File

@ -83,8 +83,9 @@ mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
}
template <class ConcreteOp>
bool linalg::TensorContractionBase<ConcreteOp>::parse(
mlir::OpAsmParser *parser, mlir::OperationState *result) {
mlir::ParseResult
linalg::TensorContractionBase<ConcreteOp>::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
}

View File

@ -38,7 +38,8 @@ protected:
//////////////////////////////////////////////////////////////////////////////
/// Generic implementation of hooks that should be called from `ConcreteType`s
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
public:
@ -118,7 +119,8 @@ public:
return build(b, result, {A, B, C});
}
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////
@ -179,7 +181,8 @@ public:
return build(b, result, {A, B, C});
}
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////
@ -240,7 +243,8 @@ public:
return build(b, result, {A, B, C});
}
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////

View File

@ -58,8 +58,8 @@ LogicalResult linalg::DotOp::verify() {
}
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::DotOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
ParseResult linalg::DotOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
return TensorContractionBaseType::parse(parser, result);
}
@ -92,8 +92,8 @@ LogicalResult linalg::MatvecOp::verify() {
}
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
ParseResult linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
return TensorContractionBaseType::parse(parser, result);
}
@ -123,8 +123,8 @@ LogicalResult linalg::MatmulOp::verify() {
}
// Parsing of the linalg dialect is not supported in this tutorial.
bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
ParseResult linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
mlir::OperationState *result) {
return TensorContractionBaseType::parse(parser, result);
}

View File

@ -41,7 +41,8 @@ public:
mlir::Value *view,
mlir::ArrayRef<mlir::Value *> indices = {});
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////
@ -71,7 +72,8 @@ public:
mlir::Value *valueToStore, mlir::Value *view,
mlir::ArrayRef<mlir::Value *> indices = {});
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static mlir::ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
//////////////////////////////////////////////////////////////////////////////

View File

@ -49,9 +49,9 @@ void linalg::LoadOp::print(OpAsmPrinter *p) {
*p << " : " << getViewType();
}
bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
return false;
return success();
}
LogicalResult linalg::LoadOp::verify() {
@ -101,9 +101,10 @@ void linalg::StoreOp::print(OpAsmPrinter *p) {
*p << " : " << getViewType();
}
bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult linalg::StoreOp::parse(OpAsmParser *parser,
OperationState *result) {
assert(false && "NYI");
return false;
return success();
}
LogicalResult linalg::StoreOp::verify() {

View File

@ -80,7 +80,7 @@ public:
static StringRef getOperationName() { return "affine.apply"; }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
@ -130,7 +130,7 @@ public:
static void build(Builder *builder, OperationState *result, int64_t lb,
int64_t ub, int64_t step = 1);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -326,7 +326,7 @@ public:
Region &getElseBlocks();
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
};

View File

@ -80,7 +80,7 @@ public:
/// Custom syntax support.
void print(OpAsmPrinter *p);
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
static StringRef getOperationName() { return "gpu.launch"; }

View File

@ -54,6 +54,18 @@ template <typename OpType> struct IsSingleResult {
OpType *, OpTrait::OneResult<typename OpType::ConcreteOpType> *>::value;
};
/// This class represents success/failure for operation parsing. It is
/// essentially a simple wrapper class around LogicalResult that allows for
/// explicit conversion to bool. This allows for the parser to chain together
/// parse rules without the clutter of "failed/succeeded".
class ParseResult : public LogicalResult {
public:
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
/// Failure is true in a boolean context.
explicit operator bool() const { return failed(*this); }
};
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
@ -132,10 +144,9 @@ protected:
LogicalResult verify() { return success(); }
/// Unless overridden, the custom assembly form of an op is always rejected.
/// Op implementations should implement this to return true on failure.
/// On success, they should return false and fill in result with the fields to
/// use.
static bool parse(OpAsmParser *parser, OperationState *result);
/// Op implementations should implement this to return failure.
/// On success, they should fill in result with the fields to use.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
// The fallback for the printer is to print it the generic assembly form.
void print(OpAsmPrinter *p);
@ -768,9 +779,10 @@ public:
/// This is the hook used by the AsmParser to parse the custom form of this
/// op from an .mlir file. Op implementations should provide a parse method,
/// which returns boolean true on failure. On success, they should return
/// false and fill in result with the fields to use.
static bool parseAssembly(OpAsmParser *parser, OperationState *result) {
/// which returns failure. On success, they should return fill in result with
/// the fields to use.
static ParseResult parseAssembly(OpAsmParser *parser,
OperationState *result) {
return ConcreteType::parse(parser, result);
}
@ -854,7 +866,7 @@ private:
namespace impl {
void buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
Value *rhs);
bool parseBinaryOp(OpAsmParser *parser, OperationState *result);
ParseResult parseBinaryOp(OpAsmParser *parser, OperationState *result);
// Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly
// form.
@ -866,7 +878,7 @@ void printBinaryOp(Operation *op, OpAsmPrinter *p);
namespace impl {
void buildCastOp(Builder *builder, OperationState *result, Value *source,
Type destType);
bool parseCastOp(OpAsmParser *parser, OperationState *result);
ParseResult parseCastOp(OpAsmParser *parser, OperationState *result);
void printCastOp(Operation *op, OpAsmPrinter *p);
Value *foldCastOp(Operation *op);
} // namespace impl
@ -888,7 +900,7 @@ public:
Type destType) {
impl::buildCastOp(builder, result, source, destType);
}
static bool parse(OpAsmParser *parser, OperationState *result) {
static ParseResult parse(OpAsmParser *parser, OperationState *result) {
return impl::parseCastOp(parser, result);
}
void print(OpAsmPrinter *p) {

View File

@ -148,134 +148,138 @@ public:
// High level parsing methods.
//===--------------------------------------------------------------------===//
// These emit an error and return true on failure, or return false on success.
// These emit an error and return failure or success.
// This allows these to be chained together into a linear sequence of ||
// expressions in many cases.
/// Get the location of the next token and store it into the argument. This
/// always succeeds.
virtual bool getCurrentLocation(llvm::SMLoc *loc) = 0;
virtual ParseResult getCurrentLocation(llvm::SMLoc *loc) = 0;
/// This parses... a comma!
virtual bool parseComma() = 0;
virtual ParseResult parseComma() = 0;
/// Parses a comma if present.
virtual bool parseOptionalComma() = 0;
virtual ParseResult parseOptionalComma() = 0;
/// Parse a `:` token.
virtual bool parseColon() = 0;
virtual ParseResult parseColon() = 0;
/// Parse a '(' token.
virtual bool parseLParen() = 0;
virtual ParseResult parseLParen() = 0;
/// Parse a ')' token.
virtual bool parseRParen() = 0;
virtual ParseResult parseRParen() = 0;
/// This parses an equal(=) token!
virtual bool parseEqual() = 0;
virtual ParseResult parseEqual() = 0;
/// Parse a type.
virtual bool parseType(Type &result) = 0;
virtual ParseResult parseType(Type &result) = 0;
/// Parse a colon followed by a type.
virtual bool parseColonType(Type &result) = 0;
virtual ParseResult parseColonType(Type &result) = 0;
/// Parse a type of a specific kind, e.g. a FunctionType.
template <typename TypeType> bool parseColonType(TypeType &result) {
template <typename TypeType> ParseResult parseColonType(TypeType &result) {
llvm::SMLoc loc;
getCurrentLocation(&loc);
// Parse any kind of type.
Type type;
if (parseColonType(type))
return true;
return failure();
// Check for the right kind of attribute.
result = type.dyn_cast<TypeType>();
if (!result)
return emitError(loc, "invalid kind of type specified");
return false;
return success();
}
/// Parse a colon followed by a type list, which must have at least one type.
virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
/// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type &result) {
return parseKeyword(keyword) || parseType(result);
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
}
/// Parse a keyword.
bool parseKeyword(const char *keyword, const Twine &msg = "") {
ParseResult parseKeyword(const char *keyword, const Twine &msg = "") {
if (parseOptionalKeyword(keyword))
return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg);
return false;
return success();
}
/// If a keyword is present, then parse it.
virtual bool parseOptionalKeyword(const char *keyword) = 0;
virtual ParseResult parseOptionalKeyword(const char *keyword) = 0;
/// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators.
bool addTypeToList(Type type, SmallVectorImpl<Type> &result) {
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
result.push_back(type);
return false;
return success();
}
/// Add the specified types to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators.
bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) {
/// success. This is a helper designed to allow parse methods to be simple
/// and chain through || operators.
ParseResult addTypesToList(ArrayRef<Type> types,
SmallVectorImpl<Type> &result) {
result.append(types.begin(), types.end());
return false;
return success();
}
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name.
virtual bool parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
virtual ParseResult
parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
virtual bool parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
virtual ParseResult
parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
bool parseAttribute(AttrType &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) {
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) {
llvm::SMLoc loc;
getCurrentLocation(&loc);
// Parse any kind of attribute.
Attribute attr;
if (parseAttribute(attr, type, attrName, attrs))
return true;
return failure();
// Check for the right kind of attribute.
result = attr.dyn_cast<AttrType>();
if (!result)
return emitError(loc, "invalid kind of constant specified");
return false;
return success();
}
/// If a named attribute dictionary is present, parse it into result.
virtual bool
virtual ParseResult
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) = 0;
/// Parse a function name like '@foo' and return the name in a form that can
/// be passed to resolveFunctionName when a function type is available.
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) = 0;
virtual ParseResult parseFunctionName(StringRef &result,
llvm::SMLoc &loc) = 0;
/// Parse a function name like '@foo` if present and return the name without
/// the sigil in `result`. Return true if the next token is not a function
/// name and keep `result` unchanged.
virtual bool parseOptionalFunctionName(StringRef &result,
llvm::SMLoc &loc) = 0;
virtual ParseResult parseOptionalFunctionName(StringRef &result,
llvm::SMLoc &loc) = 0;
/// This is the representation of an operand reference.
struct OperandType {
@ -285,11 +289,12 @@ public:
};
/// Parse a single operand.
virtual bool parseOperand(OperandType &result) = 0;
virtual ParseResult parseOperand(OperandType &result) = 0;
/// Parse a single operation successor and it's operand list.
virtual bool parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) = 0;
virtual ParseResult
parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) = 0;
/// These are the supported delimiters around operand lists, used by
/// parseOperandList.
@ -308,14 +313,15 @@ public:
/// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimiter, and an optional required operand count.
virtual bool parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
virtual ParseResult
parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
/// Parse zero or more trailing SSA comma-separated trailing operand
/// references with a specified surrounding delimiter, and an optional
/// required operand count. A leading comma is expected before the operands.
virtual bool
virtual ParseResult
parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
@ -323,12 +329,13 @@ public:
/// Parses a region. Any parsed blocks are appended to "region" and must be
/// moved to the op regions after the op is created. The first block of the
/// region takes "arguments" of types "argTypes".
virtual bool parseRegion(Region &region, ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) = 0;
virtual ParseResult parseRegion(Region &region,
ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) = 0;
/// Parse a region argument. Region arguments define new values, so this also
/// checks if the values with the same name has not been defined yet.
virtual bool parseRegionArgument(OperandType &argument) = 0;
virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
@ -341,46 +348,45 @@ public:
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;
/// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure.
virtual bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<Value *> &result) = 0;
/// Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<Value *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning
/// true on failure, or appending the results to the list on success.
/// This method should be used when all operands have the same type.
virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<Value *> &result) {
/// Resolve a list of operands to SSA values, emitting an error on failure, or
/// appending the results to the list on success. This method should be used
/// when all operands have the same type.
virtual ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
SmallVectorImpl<Value *> &result) {
for (auto elt : operands)
if (resolveOperand(elt, type, result))
return true;
return false;
return failure();
return success();
}
/// Resolve a list of operands and a list of operand types to SSA values,
/// emitting an error and returning true on failure, or appending the results
/// emitting an error and returning failure, or appending the results
/// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<Value *> &result) {
virtual ParseResult resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type> types, llvm::SMLoc loc,
SmallVectorImpl<Value *> &result) {
if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) +
" operands present, but expected " +
Twine(types.size()));
for (unsigned i = 0, e = operands.size(); i != e; ++i) {
for (unsigned i = 0, e = operands.size(); i != e; ++i)
if (resolveOperand(operands[i], types[i], result))
return true;
}
return false;
return failure();
return success();
}
/// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) = 0;
virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc,
Function *&result) = 0;
/// Emit a diagnostic at the specified location and return true.
virtual bool emitError(llvm::SMLoc loc, const Twine &message) = 0;
/// Emit a diagnostic at the specified location and return failure.
virtual ParseResult emitError(llvm::SMLoc loc, const Twine &message) = 0;
};
} // end namespace mlir

View File

@ -42,6 +42,7 @@ struct OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
class ParseResult;
class Pattern;
class Region;
class RewritePattern;
@ -85,7 +86,7 @@ public:
bool (&isClassFor)(Operation *op);
/// Use the specified object to parse this ops custom assembly format.
bool (&parseAssembly)(OpAsmParser *parser, OperationState *result);
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result);
/// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(Operation *op, OpAsmPrinter *p);
@ -150,7 +151,7 @@ private:
AbstractOperation(
StringRef name, Dialect &dialect, OperationProperties opProperties,
bool (&isClassFor)(Operation *op),
bool (&parseAssembly)(OpAsmParser *parser, OperationState *result),
ParseResult (&parseAssembly)(OpAsmParser *parser, OperationState *result),
void (&printAssembly)(Operation *op, OpAsmPrinter *p),
LogicalResult (&verifyInvariants)(Operation *op),
LogicalResult (&constantFoldHook)(Operation *op,

View File

@ -43,7 +43,7 @@ public:
static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; }
static void build(Builder *b, OperationState *result, Type type, Value *size);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
@ -67,7 +67,7 @@ public:
static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; }
static void build(Builder *b, OperationState *result, Value *buffer);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
@ -94,7 +94,7 @@ public:
static void build(Builder *b, OperationState *result, Value *min, Value *max,
Value *step);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
// Op-specific functionality.
@ -156,7 +156,8 @@ public:
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
// Op-specific functionality.
@ -207,7 +208,8 @@ public:
mlir::Value *buffer,
llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
static ParseResult parse(mlir::OpAsmParser *parser,
mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
// Op-specific functionality.

View File

@ -80,7 +80,7 @@ public:
static void build(Builder *builder, OperationState *result,
MemRefType memrefType, ArrayRef<Value *> operands = {});
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
@ -108,7 +108,7 @@ public:
ArrayRef<Value *> operands = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
/// Return the block this branch jumps to.
@ -149,7 +149,7 @@ public:
operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
@ -183,7 +183,7 @@ public:
operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -249,7 +249,7 @@ public:
static void build(Builder *builder, OperationState *result, CmpIPredicate,
Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
@ -324,7 +324,7 @@ public:
static void build(Builder *builder, OperationState *result, CmpFPredicate,
Value *lhs, Value *rhs);
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
@ -362,7 +362,7 @@ public:
Block *falseDest, ArrayRef<Value *> falseOperands);
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
@ -521,7 +521,7 @@ public:
// Hooks to customize behavior of this op.
static void build(Builder *builder, OperationState *result, Value *memref);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
@ -553,7 +553,7 @@ public:
// Hooks to customize behavior of this op.
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
};
@ -682,7 +682,7 @@ public:
}
static StringRef getOperationName() { return "std.dma_start"; }
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
@ -748,7 +748,7 @@ public:
// Returns the number of elements transferred in the associated DMA operation.
Value *getNumElements() { return getOperand(1 + getTagMemRefRank()); }
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
@ -785,7 +785,7 @@ public:
// Hooks to customize behavior of this op.
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
Attribute constantFold(ArrayRef<Attribute> operands, MLIRContext *context);
};
@ -821,7 +821,7 @@ public:
static StringRef getOperationName() { return "std.load"; }
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
@ -881,7 +881,7 @@ public:
ArrayRef<Value *> results = {});
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
@ -906,7 +906,7 @@ public:
static StringRef getOperationName() { return "std.select"; }
static void build(Builder *builder, OperationState *result, Value *condition,
Value *trueValue, Value *falseValue);
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
@ -953,7 +953,7 @@ public:
static StringRef getOperationName() { return "std.store"; }
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -994,9 +994,9 @@ void printDimAndSymbolList(Operation::operand_iterator begin,
OpAsmPrinter *p);
/// Parses dimension and symbol list and returns true if parsing failed.
bool parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
ParseResult parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims);
} // end namespace mlir

View File

@ -115,7 +115,7 @@ public:
Optional<Value *> getPaddingValue();
AffineMap getPermutationMap();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
@ -177,7 +177,7 @@ public:
operand_range getIndices();
AffineMap getPermutationMap();
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
@ -200,7 +200,7 @@ public:
static StringRef getOperationName() { return "vector.type_cast"; }
static void build(Builder *builder, OperationState *result, Value *srcVector,
Type dstType);
static bool parse(OpAsmParser *parser, OperationState *result);
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};

View File

@ -131,7 +131,7 @@ void AffineApplyOp::build(Builder *builder, OperationState *result,
result->addAttribute("map", builder->getAffineMapAttr(map));
}
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto affineIntTy = builder.getIndexType();
@ -140,7 +140,7 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
return failure();
auto map = mapAttr.getValue();
if (map.getNumDims() != numDims ||
@ -150,7 +150,7 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
}
result->types.append(map.getNumResults(), affineIntTy);
return false;
return success();
}
void AffineApplyOp::print(OpAsmPrinter *p) {
@ -801,10 +801,12 @@ LogicalResult AffineForOp::verify() {
}
/// Parse a for operation loop bounds.
static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
static ParseResult parseBound(bool isLower, OperationState *result,
OpAsmParser *p) {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results.
bool failedToParsedMinMax = p->parseOptionalKeyword(isLower ? "max" : "min");
bool failedToParsedMinMax =
failed(p->parseOptionalKeyword(isLower ? "max" : "min"));
auto &builder = p->getBuilder();
auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
@ -813,7 +815,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
// Parse ssa-id as identity map.
SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
if (p->parseOperandList(boundOpInfos))
return true;
return failure();
if (!boundOpInfos.empty()) {
// Check that only one operand was parsed.
@ -825,14 +827,14 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
// Currently it is 'use of value ... expects different type than prior uses'
if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
result->operands))
return true;
return failure();
// Create an identity map using symbol id. This representation is optimized
// for storage. Analysis passes may expand it into a multi-dimensional map
// if desired.
AffineMap map = builder.getSymbolIdentityMap();
result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
return false;
return success();
}
// Get the attribute location.
@ -842,14 +844,14 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
Attribute boundAttr;
if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
result->attributes))
return true;
return failure();
// Parse full form - affine map followed by dim and symbol list.
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
unsigned currentNumOperands = result->operands.size();
unsigned numDims;
if (parseDimAndSymbolList(p, result->operands, numDims))
return true;
return failure();
auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims)
@ -874,7 +876,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
return p->emitError(attrLoc, "upper loop bound affine map with multiple "
"results requires 'min' prefix");
}
return false;
return success();
}
// Parse custom assembly form.
@ -883,7 +885,7 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
result->addAttribute(
boundAttrName, builder.getAffineMapAttr(
builder.getConstantAffineMap(integerAttr.getInt())));
return false;
return success();
}
return p->emitError(
@ -891,18 +893,18 @@ static bool parseBound(bool isLower, OperationState *result, OpAsmParser *p) {
"expected valid affine map representation for loop bounds");
}
bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
OpAsmParser::OperandType inductionVariable;
// Parse the induction variable followed by '='.
if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
return true;
return failure();
// Parse loop bounds.
if (parseBound(/*isLower=*/true, result, parser) ||
parser->parseKeyword("to", " between bounds") ||
parseBound(/*isLower=*/false, result, parser))
return true;
return failure();
// Parse the optional loop step, we default to 1 if one is not present.
if (parser->parseOptionalKeyword("step")) {
@ -915,7 +917,7 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->getCurrentLocation(&stepLoc) ||
parser->parseAttribute(stepAttr, builder.getIndexType(),
getStepAttrName().data(), result->attributes))
return true;
return failure();
if (stepAttr.getValue().getSExtValue() < 0)
return parser->emitError(
@ -926,17 +928,17 @@ bool AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the body region.
Region *body = result->addRegion();
if (parser->parseRegion(*body, inductionVariable, builder.getIndexType()))
return true;
return failure();
ensureAffineTerminator(*body, builder, result->location);
// Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes))
return true;
return failure();
// Set the operands list as resizable so that we can freely modify the bounds.
result->setOperandListToResizable();
return false;
return success();
}
static void printBound(AffineMapAttr boundMap,
@ -1253,14 +1255,14 @@ LogicalResult AffineIfOp::verify() {
return success();
}
bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the condition attribute set.
IntegerSetAttr conditionAttr;
unsigned numDims;
if (parser->parseAttribute(conditionAttr, getConditionAttrName(),
result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims))
return true;
return failure();
// Verify the condition operands.
auto set = conditionAttr.getValue();
@ -1281,21 +1283,21 @@ bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the 'then' region.
if (parser->parseRegion(*thenRegion, {}, {}))
return true;
return failure();
ensureAffineTerminator(*thenRegion, parser->getBuilder(), result->location);
// If we find an 'else' keyword then parse the 'else' region.
if (!parser->parseOptionalKeyword("else")) {
if (parser->parseRegion(*elseRegion, {}, {}))
return true;
return failure();
ensureAffineTerminator(*elseRegion, parser->getBuilder(), result->location);
}
// Parse the optional attribute list.
if (parser->parseOptionalAttributeDict(result->attributes))
return true;
return failure();
return false;
return success();
}
void AffineIfOp::print(OpAsmPrinter *p) {

View File

@ -159,7 +159,7 @@ void LaunchOp::print(OpAsmPrinter *p) {
// where %region_arg are percent-identifiers for the region arguments to be
// introduced futher (SSA defs), and %operand are percent-identifiers for the
// SSA value uses.
static bool
static ParseResult
parseSizeAssignment(OpAsmParser *parser,
MutableArrayRef<OpAsmParser::OperandType> sizes,
MutableArrayRef<OpAsmParser::OperandType> regionSizes,
@ -169,14 +169,14 @@ parseSizeAssignment(OpAsmParser *parser,
parser->parseComma() || parser->parseRegionArgument(indices[2]) ||
parser->parseRParen() || parser->parseKeyword("in") ||
parser->parseLParen())
return true;
return failure();
for (int i = 0; i < 3; ++i) {
if (i != 0 && parser->parseComma())
return true;
return failure();
if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() ||
parser->parseOperand(sizes[i]))
return true;
return failure();
}
return parser->parseRParen();
@ -188,7 +188,7 @@ parseSizeAssignment(OpAsmParser *parser,
// (`args` ssa-reassignment `:` type-list)?
// region attr-dict?
// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
// Sizes of the grid and block.
SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
kNumConfigOperands);
@ -217,7 +217,7 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
regionArgsRef.slice(3, 3)) ||
parser->resolveOperands(sizes, parser->getBuilder().getIndexType(),
result->operands))
return true;
return failure();
// If kernel argument renaming segment is present, parse it. When present,
// the segment should have at least one element. If this segment is present,
@ -232,20 +232,20 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->getCurrentLocation(&argsLoc) || parser->parseLParen() ||
parser->parseRegionArgument(regionArgs.back()) ||
parser->parseEqual() || parser->parseOperand(dataOperands.back()))
return true;
return failure();
while (!parser->parseOptionalComma()) {
regionArgs.push_back({});
dataOperands.push_back({});
if (parser->parseRegionArgument(regionArgs.back()) ||
parser->parseEqual() || parser->parseOperand(dataOperands.back()))
return true;
return failure();
}
if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) ||
parser->resolveOperands(dataOperands, dataTypes, argsLoc,
result->operands))
return true;
return failure();
}
// Introduce the body region and parse it. The region has
@ -255,11 +255,10 @@ bool LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
Type index = parser->getBuilder().getIndexType();
dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
Region *body = result->addRegion();
return parser->parseRegion(*body, regionArgs, dataTypes) ||
parser->parseOptionalAttributeDict(result->attributes);
return failure(parser->parseRegion(*body, regionArgs, dataTypes) ||
parser->parseOptionalAttributeDict(result->attributes));
}
//===----------------------------------------------------------------------===//
// LaunchFuncOp
//===----------------------------------------------------------------------===//

View File

@ -640,7 +640,7 @@ Operation *Operation::clone(MLIRContext *context) {
//===----------------------------------------------------------------------===//
// The fallback for the parser is to reject the custom assembly form.
bool OpState::parse(OpAsmParser *parser, OperationState *result) {
ParseResult OpState::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(), "has no custom assembly form");
}
@ -948,14 +948,14 @@ void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
result->types.push_back(lhs->getType());
}
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
ParseResult impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
return parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands) ||
parser->addTypeToList(type, result->types);
return failure(parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands) ||
parser->addTypeToList(type, result->types));
}
void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) {
@ -988,13 +988,14 @@ void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
result->addTypes(destType);
}
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
ParseResult impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo;
Type srcType, dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) ||
parser->addTypeToList(dstType, result->types);
return failure(parser->parseOperand(srcInfo) ||
parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) ||
parser->addTypeToList(dstType, result->types));
}
void impl::printCastOp(Operation *op, OpAsmPrinter *p) {

View File

@ -126,7 +126,7 @@ static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
// attribute-dict? `:` type
static bool parseICmpOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseICmpOp(OpAsmParser *parser, OperationState *result) {
Builder &builder = parser->getBuilder();
Attribute predicate;
@ -142,7 +142,7 @@ static bool parseICmpOp(OpAsmParser *parser, OperationState *result) {
parser->parseType(type) ||
parser->resolveOperand(lhs, type, result->operands) ||
parser->resolveOperand(rhs, type, result->operands))
return true;
return failure();
// Replace the string attribute `predicate` with an integer attribute.
auto predicateStr = predicate.dyn_cast<StringAttr>();
@ -173,7 +173,7 @@ static bool parseICmpOp(OpAsmParser *parser, OperationState *result) {
result->attributes = attrs;
result->addTypes({resultType});
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -195,7 +195,7 @@ static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
// `:` type `,` type
static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arraySize;
Type type, elemType;
@ -204,7 +204,7 @@ static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) {
parser->parseType(elemType) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return true;
return failure();
// Extract the result type from the trailing function type.
auto funcType = type.dyn_cast<FunctionType>();
@ -215,11 +215,11 @@ static bool parseAllocaOp(OpAsmParser *parser, OperationState *result) {
"expected trailing function type with one argument and one result");
if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands))
return true;
return failure();
result->attributes = attrs;
result->addTypes({funcType.getResult(0)});
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -242,7 +242,7 @@ static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
// attribute-dict? `:` type
static bool parseGEPOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType base;
SmallVector<OpAsmParser::OperandType, 8> indices;
@ -253,7 +253,7 @@ static bool parseGEPOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return true;
return failure();
// Deconstruct the trailing function type to extract the types of the base
// pointer and result (same type) and the types of the indices.
@ -267,11 +267,11 @@ static bool parseGEPOp(OpAsmParser *parser, OperationState *result) {
if (parser->resolveOperand(base, funcType.getInput(0), result->operands) ||
parser->resolveOperands(indices, funcType.getInputs().drop_front(),
parser->getNameLoc(), result->operands))
return true;
return failure();
result->attributes = attrs;
result->addTypes(funcType.getResults());
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -302,7 +302,7 @@ static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
}
// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
static bool parseLoadOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr;
Type type;
@ -312,13 +312,13 @@ static bool parseLoadOp(OpAsmParser *parser, OperationState *result) {
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type) ||
parser->resolveOperand(addr, type, result->operands))
return true;
return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
result->attributes = attrs;
result->addTypes(elemTy);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -332,7 +332,7 @@ static void printStoreOp(OpAsmPrinter *p, StoreOp &op) {
}
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
static bool parseStoreOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType addr, value;
Type type;
@ -342,18 +342,18 @@ static bool parseStoreOp(OpAsmParser *parser, OperationState *result) {
parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(type))
return true;
return failure();
Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
if (!elemTy)
return true;
return failure();
if (parser->resolveOperand(value, elemTy, result->operands) ||
parser->resolveOperand(addr, type, result->operands))
return true;
return failure();
result->attributes = attrs;
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -367,7 +367,7 @@ static void printBitcastOp(OpAsmPrinter *p, BitcastOp &op) {
}
// <operation> ::= `llvm.bitcast` ssa-use attribute-dict? `:` type `to` type
static bool parseBitcastOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseBitcastOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType arg;
Type sourceType, type;
@ -376,11 +376,11 @@ static bool parseBitcastOp(OpAsmParser *parser, OperationState *result) {
parser->parseColonType(sourceType) || parser->parseKeyword("to") ||
parser->parseType(type) ||
parser->resolveOperand(arg, sourceType, result->operands))
return true;
return failure();
result->attributes = attrs;
result->addTypes(type);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -438,7 +438,7 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
// attribute-dict? `:` function-type
static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
SmallVector<OpAsmParser::OperandType, 8> operands;
Type type;
@ -450,19 +450,19 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
// direct call, there will be no operands and the parser will stop at the
// function identifier without complaining.
if (parser->parseOperandList(operands))
return true;
return failure();
bool isDirect = operands.empty();
// Optionally parse a function identifier.
if (isDirect)
if (parser->parseFunctionName(calleeName, calleeLoc))
return true;
return failure();
if (parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
return true;
return failure();
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType)
@ -471,14 +471,14 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
// Add the direct callee as an Op attribute.
Function *func;
if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func))
return true;
return failure();
auto funcAttr = parser->getBuilder().getFunctionAttr(func);
attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr));
// Make sure types match.
if (parser->resolveOperands(operands, funcType.getInputs(),
parser->getNameLoc(), result->operands))
return true;
return failure();
result->addTypes(funcType.getResults());
} else {
// Construct the LLVM IR Dialect function type that the first operand
@ -528,13 +528,13 @@ static bool parseCallOp(OpAsmParser *parser, OperationState *result) {
result->operands) ||
parser->resolveOperands(funcArguments, funcType.getInputs(),
parser->getNameLoc(), result->operands))
return true;
return failure();
result->addTypes(wrappedResultType);
}
result->attributes = attrs;
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -601,7 +601,8 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
// <operation> ::= `llvm.extractvalue` ssa-use
// `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type
static bool parseExtractValueOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseExtractValueOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<NamedAttribute, 4> attrs;
OpAsmParser::OperandType container;
Type containerType;
@ -615,16 +616,16 @@ static bool parseExtractValueOp(OpAsmParser *parser, OperationState *result) {
parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(containerType) ||
parser->resolveOperand(container, containerType, result->operands))
return true;
return failure();
auto elementType = getInsertExtractValueElementType(
parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
if (!elementType)
return true;
return failure();
result->attributes = attrs;
result->addTypes(elementType);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -641,7 +642,8 @@ static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) {
// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
// `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type
static bool parseInsertValueOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseInsertValueOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType container, value;
Type containerType;
Attribute positionAttr;
@ -654,19 +656,19 @@ static bool parseInsertValueOp(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
parser->parseType(containerType))
return true;
return failure();
auto valueType = getInsertExtractValueElementType(
parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
if (!valueType)
return true;
return failure();
if (parser->resolveOperand(container, containerType, result->operands) ||
parser->resolveOperand(value, valueType, result->operands))
return true;
return failure();
result->addTypes(containerType);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -682,7 +684,7 @@ static void printSelectOp(OpAsmPrinter *p, SelectOp &op) {
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
// attribute-dict? `:` type, type
static bool parseSelectOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType condition, trueValue, falseValue;
Type conditionType, argType;
@ -692,15 +694,15 @@ static bool parseSelectOp(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(conditionType) || parser->parseComma() ||
parser->parseType(argType))
return true;
return failure();
if (parser->resolveOperand(condition, conditionType, result->operands) ||
parser->resolveOperand(trueValue, argType, result->operands) ||
parser->resolveOperand(falseValue, argType, result->operands))
return true;
return failure();
result->addTypes(argType);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -715,15 +717,15 @@ static void printBrOp(OpAsmPrinter *p, BrOp &op) {
// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
// attribute-dict?
static bool parseBrOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) {
Block *dest;
SmallVector<Value *, 4> operands;
if (parser->parseSuccessorAndUseList(dest, operands) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
return failure();
result->addSuccessor(dest, operands);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -741,7 +743,7 @@ static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) {
// <operation> ::= `llvm.cond_br` ssa-use `,`
// bb-id (`[` ssa-use-and-type-list `]`)? `,`
// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
static bool parseCondBrOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
Block *trueDest;
Block *falseDest;
SmallVector<Value *, 4> trueOperands;
@ -760,11 +762,11 @@ static bool parseCondBrOp(OpAsmParser *parser, OperationState *result) {
parser->parseSuccessorAndUseList(falseDest, falseOperands) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->resolveOperand(condition, i1Type, result->operands))
return true;
return failure();
result->addSuccessor(trueDest, trueOperands);
result->addSuccessor(falseDest, falseOperands);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -784,20 +786,20 @@ static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) {
// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
// type-list-no-parens
static bool parseReturnOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 1> operands;
Type type;
if (parser->parseOperandList(operands) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
return failure();
if (operands.empty())
return false;
return success();
if (parser->parseColonType(type) ||
parser->resolveOperand(operands[0], type, result->operands))
return true;
return false;
return failure();
return success();
}
//===----------------------------------------------------------------------===//
@ -811,15 +813,15 @@ static void printUndefOp(OpAsmPrinter *p, UndefOp &op) {
}
// <operation> ::= `llvm.undef` attribute-dict? : type
static bool parseUndefOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
return failure();
result->addTypes(type);
return false;
return success();
}
//===----------------------------------------------------------------------===//
@ -845,7 +847,8 @@ static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
}
// <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type
static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr;
Type type;
@ -854,10 +857,10 @@ static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
parser->parseRParen() ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
return failure();
result->addTypes(type);
return false;
return success();
}
//===----------------------------------------------------------------------===//

View File

@ -53,15 +53,15 @@ static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) {
}
// <operation> ::= `llvm.nvvm.XYZ` : type
static bool parseNVVMSpecialRegisterOp(OpAsmParser *parser,
OperationState *result) {
static ParseResult parseNVVMSpecialRegisterOp(OpAsmParser *parser,
OperationState *result) {
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
return failure();
result->addTypes(type);
return false;
return success();
}
//===----------------------------------------------------------------------===//

View File

@ -58,18 +58,19 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *size() << " : " << getType();
}
bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult mlir::BufferAllocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return true;
return failure();
if (bufferType.getElementType() != parser->getBuilder().getF32Type())
return parser->emitError(
parser->getNameLoc(),
"Only buffer<f32> supported until mlir::Parser pieces are exposed");
return parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types);
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
//////////////////////////////////////////////////////////////////////////////
@ -95,11 +96,13 @@ void mlir::BufferDeallocOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBuffer() << " : " << getBufferType();
}
bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult mlir::BufferDeallocOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands);
return failure(
parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands));
}
//////////////////////////////////////////////////////////////////////////////
// RangeOp
@ -131,15 +134,16 @@ void mlir::RangeOp::print(OpAsmPrinter *p) {
<< " : " << getType();
}
bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
return failure(
parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types));
}
//////////////////////////////////////////////////////////////////////////////
@ -189,7 +193,7 @@ LogicalResult mlir::SliceOp::verify() {
return success();
}
bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types;
@ -198,7 +202,7 @@ bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return true;
return failure();
if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
@ -221,12 +225,13 @@ bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(baseViewType.getRank()) +
" indexing types");
return parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types);
return failure(
parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A SliceOp prints as:
@ -306,7 +311,7 @@ LogicalResult mlir::ViewOp::verify() {
return success();
}
bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
@ -315,7 +320,7 @@ bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
return failure();
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
@ -324,15 +329,15 @@ bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"expected" + Twine(viewType.getRank()) +
" range indexings");
return parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo,
RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types);
return failure(
parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}
// A ViewOp prints as:
@ -354,9 +359,9 @@ void mlir::ViewOp::print(OpAsmPrinter *p) {
namespace mlir {
namespace impl {
void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op);
bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
void printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op);
bool parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result);
} // namespace impl
/// Buffer size prints as:
@ -372,16 +377,16 @@ void mlir::impl::printBufferSizeOp(mlir::OpAsmPrinter *p, Operation *op) {
*p << " : " << op->getOperand(0)->getType();
}
bool mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
ParseResult mlir::impl::parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
Type type;
return parser->parseOperand(op) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(op, type, result->operands) ||
parser->addTypeToList(parser->getBuilder().getIndexType(),
result->types);
return failure(parser->parseOperand(op) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(op, type, result->operands) ||
parser->addTypeToList(parser->getBuilder().getIndexType(),
result->types));
}
#define GET_OP_CLASSES
@ -415,15 +420,16 @@ void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
[&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
}
bool mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) {
ParseResult mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<Type, 3> types;
return parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types) ||
parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands);
return failure(
parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types) ||
parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands));
}
// Ideally this should all be Tablegen'd but there is no good story for

View File

@ -45,17 +45,6 @@ using llvm::MemoryBuffer;
using llvm::SMLoc;
using llvm::SourceMgr;
/// Simple wrapper class around LogicalResult that allows for explicit
/// conversion to bool. This allows for the parser to chain together parse rules
/// without the clutter of "failed/succeeded".
class ParseResult : public LogicalResult {
public:
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
/// Failure is true in a boolean context.
explicit operator bool() const { return failed(*this); }
};
namespace {
class Parser;
@ -2266,8 +2255,8 @@ public:
ParseResult parseFunctionBody(bool hadNamedArguments);
/// Parse a single operation successor and it's operand list.
bool parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands);
ParseResult parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands);
/// Parse a comma-separated list of operation successors in brackets.
ParseResult
@ -2809,11 +2798,12 @@ Block *FunctionParser::defineBlockNamed(StringRef name, SMLoc loc,
/// successor ::= block-id branch-use-list?
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
///
bool FunctionParser::parseSuccessorAndUseList(
Block *&dest, SmallVectorImpl<Value *> &operands) {
ParseResult
FunctionParser::parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) {
// Verify branch is identifier and get the matching block.
if (!getToken().is(Token::caret_identifier))
return emitError("expected block name"), true;
return emitError("expected block name");
dest = getBlockNamed(getTokenSpelling(), getToken().getLoc());
consumeToken();
@ -2821,10 +2811,10 @@ bool FunctionParser::parseSuccessorAndUseList(
if (consumeIf(Token::l_paren) &&
(parseOptionalSSAUseAndTypeList(operands) ||
parseToken(Token::r_paren, "expected ')' to close argument list"))) {
return true;
return failure();
}
return false;
return success();
}
/// Parse a comma-separated list of operation successors in brackets.
@ -2840,10 +2830,10 @@ ParseResult FunctionParser::parseSuccessors(
auto parseElt = [this, &destinations, &operands]() {
Block *dest;
SmallVector<Value *, 4> destOperands;
bool r = parseSuccessorAndUseList(dest, destOperands);
auto res = parseSuccessorAndUseList(dest, destOperands);
destinations.push_back(dest);
operands.push_back(destOperands);
return r ? failure() : success();
return res;
};
return parseCommaSeparatedListUntil(Token::r_square, parseElt,
/*allowEmptyList=*/false);
@ -3105,10 +3095,10 @@ public:
CustomOpAsmParser(SMLoc nameLoc, StringRef opName, FunctionParser &parser)
: nameLoc(nameLoc), opName(opName), parser(parser) {}
bool parseOperation(const AbstractOperation *opDefinition,
OperationState *opState) {
ParseResult parseOperation(const AbstractOperation *opDefinition,
OperationState *opState) {
if (opDefinition->parseAssembly(this, opState))
return true;
return failure();
// Check that none of the operands of the current operation reference an
// entry block argument for any of the region.
@ -3116,53 +3106,53 @@ public:
if (llvm::is_contained(opState->operands, entryArg))
return emitError(nameLoc, "operand use before it's defined");
return false;
return success();
}
//===--------------------------------------------------------------------===//
// High level parsing methods.
//===--------------------------------------------------------------------===//
bool getCurrentLocation(llvm::SMLoc *loc) override {
ParseResult getCurrentLocation(llvm::SMLoc *loc) override {
*loc = parser.getToken().getLoc();
return false;
return success();
}
bool parseComma() override {
return failed(parser.parseToken(Token::comma, "expected ','"));
ParseResult parseComma() override {
return parser.parseToken(Token::comma, "expected ','");
}
bool parseColon() override {
return failed(parser.parseToken(Token::colon, "expected ':'"));
ParseResult parseColon() override {
return parser.parseToken(Token::colon, "expected ':'");
}
bool parseEqual() override {
return failed(parser.parseToken(Token::equal, "expected '='"));
ParseResult parseEqual() override {
return parser.parseToken(Token::equal, "expected '='");
}
bool parseType(Type &result) override {
return !(result = parser.parseType());
ParseResult parseType(Type &result) override {
return failure(!(result = parser.parseType()));
}
bool parseColonType(Type &result) override {
return parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType());
ParseResult parseColonType(Type &result) override {
return failure(parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType()));
}
bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'"))
return true;
return failure();
do {
if (auto type = parser.parseType())
result.push_back(type);
else
return true;
return failure();
} while (parser.consumeIf(Token::comma));
return false;
return success();
}
bool parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount,
Delimiter delimiter) override {
ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount,
Delimiter delimiter) override {
if (parser.getToken().is(Token::comma)) {
parseComma();
return parseOperandList(result, requiredOperandCount, delimiter);
@ -3170,101 +3160,105 @@ public:
if (requiredOperandCount != -1)
return emitError(parser.getToken().getLoc(),
"expected " + Twine(requiredOperandCount) + " operands");
return false;
return success();
}
bool parseOptionalComma() override { return !parser.consumeIf(Token::comma); }
ParseResult parseOptionalComma() override {
return success(parser.consumeIf(Token::comma));
}
/// Parse an optional keyword.
bool parseOptionalKeyword(const char *keyword) override {
ParseResult parseOptionalKeyword(const char *keyword) override {
// Check that the current token is a bare identifier or keyword.
if (parser.getToken().isNot(Token::bare_identifier) &&
!parser.getToken().isKeyword())
return true;
return failure();
if (parser.getTokenSpelling() == keyword) {
parser.consumeToken();
return false;
return success();
}
return true;
return failure();
}
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
bool parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
result = parser.parseAttribute(type);
if (!result)
return true;
return failure();
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return false;
return success();
}
/// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name.
bool parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
ParseResult parseAttribute(Attribute &result, StringRef attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
return parseAttribute(result, Type(), attrName, attrs);
}
/// If a named attribute list is present, parse is into result.
bool
ParseResult
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override {
if (parser.getToken().isNot(Token::l_brace))
return false;
return failed(parser.parseAttributeDict(result));
return success();
return parser.parseAttributeDict(result);
}
/// Parse a function name like '@foo' and return the name in a form that can
/// be passed to resolveFunctionName when a function type is available.
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
virtual ParseResult parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
if (parseOptionalFunctionName(result, loc))
return emitError(loc, "expected function name");
return false;
return success();
}
/// Parse a function name like '@foo` if present and return the name without
/// the sigil in `result`. Return true if the next token is not a function
/// name and keep `result` unchanged.
bool parseOptionalFunctionName(StringRef &result, llvm::SMLoc &loc) override {
ParseResult parseOptionalFunctionName(StringRef &result,
llvm::SMLoc &loc) override {
loc = parser.getToken().getLoc();
if (parser.getToken().isNot(Token::at_identifier))
return true;
return failure();
result = parser.getTokenSpelling();
parser.consumeToken(Token::at_identifier);
return false;
return success();
}
bool parseOperand(OperandType &result) override {
ParseResult parseOperand(OperandType &result) override {
FunctionParser::SSAUseInfo useInfo;
if (parser.parseSSAUse(useInfo))
return true;
return failure();
result = {useInfo.loc, useInfo.name, useInfo.number};
return false;
return success();
}
bool parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) override {
ParseResult
parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value *> &operands) override {
// Defer successor parsing to the function parsers.
return parser.parseSuccessorAndUseList(dest, operands);
}
bool parseLParen() override {
return failed(parser.parseToken(Token::l_paren, "expected '('"));
ParseResult parseLParen() override {
return parser.parseToken(Token::l_paren, "expected '('");
}
bool parseRParen() override {
return failed(parser.parseToken(Token::r_paren, "expected ')'"));
ParseResult parseRParen() override {
return parser.parseToken(Token::r_paren, "expected ')'");
}
bool parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) override {
ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) override {
auto startLoc = parser.getToken().getLoc();
// Handle delimiters.
@ -3284,19 +3278,19 @@ public:
return emitError(startLoc, "invalid operand");
case Delimiter::OptionalParen:
if (parser.getToken().isNot(Token::l_paren))
return false;
return success();
LLVM_FALLTHROUGH;
case Delimiter::Paren:
if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
return true;
return failure();
break;
case Delimiter::OptionalSquare:
if (parser.getToken().isNot(Token::l_square))
return false;
return success();
LLVM_FALLTHROUGH;
case Delimiter::Square:
if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
return true;
return failure();
break;
}
@ -3305,7 +3299,7 @@ public:
do {
OperandType operand;
if (parseOperand(operand))
return true;
return failure();
result.push_back(operand);
} while (parser.consumeIf(Token::comma));
}
@ -3318,32 +3312,32 @@ public:
case Delimiter::OptionalParen:
case Delimiter::Paren:
if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
return true;
return failure();
break;
case Delimiter::OptionalSquare:
case Delimiter::Square:
if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
return true;
return failure();
break;
}
if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
return emitError(startLoc,
"expected " + Twine(requiredOperandCount) + " operands");
return false;
return success();
}
/// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) {
virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) {
result = parser.resolveFunctionReference(name, loc, type);
return result == nullptr;
return failure(result == nullptr);
}
/// Parse a region that takes `arguments` of `argTypes` types. This
/// effectively defines the SSA values of `arguments` and assignes their type.
bool parseRegion(Region &region, ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) override {
ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) override {
assert(arguments.size() == argTypes.size() &&
"mismatching number of arguments and types");
@ -3359,26 +3353,26 @@ public:
// references to region arguments.
Value *value = parser.resolveSSAUse(operandInfo, type);
if (!value)
return true;
return failure();
parsedRegionEntryArgumentPlaceholders.emplace_back(value);
}
return failed(parser.parseOperationRegion(region, regionArguments));
return parser.parseOperationRegion(region, regionArguments);
}
/// Parse a region argument. Region arguments define new values, so this also
/// checks if the values with the same name has not been defined yet. The
/// type of the argument will be resolved later by a call to `parseRegion`.
bool parseRegionArgument(OperandType &argument) {
ParseResult parseRegionArgument(OperandType &argument) {
// Use parseOperand to fill in the OperandType structure.
if (parseOperand(argument))
return true;
return failure();
if (auto defLoc = parser.getDefinitionLoc(argument.name, argument.number)) {
parser.emitError(argument.location,
"redefinition of SSA value '" + argument.name + "'");
return parser.emitError(*defLoc, "previously defined here"), true;
return parser.emitError(*defLoc, "previously defined here");
}
return false;
return success();
}
//===--------------------------------------------------------------------===//
@ -3389,22 +3383,22 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<Value *> &result) override {
ParseResult resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<Value *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
result.push_back(value);
return false;
return success();
}
return true;
return failure();
}
/// Emit a diagnostic at the specified location and return true.
bool emitError(llvm::SMLoc loc, const Twine &message) override {
parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message);
/// Emit a diagnostic at the specified location and return failure.
ParseResult emitError(llvm::SMLoc loc, const Twine &message) override {
emittedError = true;
return true;
return parser.emitError(loc,
"custom op '" + Twine(opName) + "' " + message);
}
bool didEmitError() const { return emittedError; }

View File

@ -87,12 +87,12 @@ void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
// Parses dimension and symbol list, and sets 'numDims' to the number of
// dimension operands parsed.
// Returns 'false' on success and 'true' on error.
bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims) {
ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser,
SmallVector<Value *, 4> &operands,
unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
return true;
return failure();
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
@ -101,8 +101,8 @@ bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
return true;
return false;
return failure();
return success();
}
/// Matches a ConstantIndexOp.
@ -223,7 +223,7 @@ void AllocOp::print(OpAsmPrinter *p) {
*p << " : " << type;
}
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult AllocOp::parse(OpAsmParser *parser, OperationState *result) {
MemRefType type;
// Parse the dimension operands and optional symbol operands, followed by a
@ -232,7 +232,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
return failure();
// Check numDynamicDims against number of question marks in memref type.
// Note: this check remains here (instead of in verify()), because the
@ -246,7 +246,7 @@ bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
"dynamic dimension count");
}
result->types.push_back(type);
return false;
return success();
}
LogicalResult AllocOp::verify() {
@ -385,13 +385,13 @@ void BranchOp::build(Builder *builder, OperationState *result, Block *dest,
result->addSuccessor(dest, operands);
}
bool BranchOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult BranchOp::parse(OpAsmParser *parser, OperationState *result) {
Block *dest;
SmallVector<Value *, 4> destOperands;
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
return failure();
result->addSuccessor(dest, destOperands);
return false;
return success();
}
void BranchOp::print(OpAsmPrinter *p) {
@ -420,7 +420,7 @@ void CallOp::build(Builder *builder, OperationState *result, Function *callee,
result->addTypes(callee->getType().getResults());
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName;
llvm::SMLoc calleeLoc;
FunctionType calleeType;
@ -435,10 +435,10 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
parser->addTypesToList(calleeType.getResults(), result->types) ||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands))
return true;
return failure();
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
return false;
return success();
}
void CallOp::print(OpAsmPrinter *p) {
@ -517,21 +517,22 @@ void CallIndirectOp::build(Builder *builder, OperationState *result,
result->addTypes(fnType.getResults());
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
FunctionType calleeType;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands;
return parser->parseOperand(callee) ||
parser->getCurrentLocation(&operandsLoc) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) ||
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result->operands) ||
parser->addTypesToList(calleeType.getResults(), result->types);
return failure(
parser->parseOperand(callee) ||
parser->getCurrentLocation(&operandsLoc) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) ||
parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result->operands) ||
parser->addTypesToList(calleeType.getResults(), result->types));
}
void CallIndirectOp::print(OpAsmPrinter *p) {
@ -678,7 +679,7 @@ void CmpIOp::build(Builder *build, OperationState *result,
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
SmallVector<NamedAttribute, 4> attrs;
Attribute predicateNameAttr;
@ -689,7 +690,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands))
return true;
return failure();
if (!predicateNameAttr.isa<StringAttr>())
return parser->emitError(parser->getNameLoc(),
@ -713,7 +714,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
result->attributes = attrs;
result->addTypes({i1Type});
return false;
return success();
}
void CmpIOp::print(OpAsmPrinter *p) {
@ -856,7 +857,7 @@ void CmpFOp::build(Builder *build, OperationState *result,
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
SmallVector<NamedAttribute, 4> attrs;
Attribute predicateNameAttr;
@ -867,7 +868,7 @@ bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOptionalAttributeDict(attrs) ||
parser->parseColonType(type) ||
parser->resolveOperands(ops, type, result->operands))
return true;
return failure();
if (!predicateNameAttr.isa<StringAttr>())
return parser->emitError(parser->getNameLoc(),
@ -891,7 +892,7 @@ bool CmpFOp::parse(OpAsmParser *parser, OperationState *result) {
result->attributes = attrs;
result->addTypes({i1Type});
return false;
return success();
}
void CmpFOp::print(OpAsmPrinter *p) {
@ -1044,7 +1045,7 @@ void CondBranchOp::build(Builder *builder, OperationState *result,
result->addSuccessor(falseDest, falseOperands);
}
bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<Value *, 4> destOperands;
Block *dest;
OpAsmParser::OperandType condInfo;
@ -1059,18 +1060,17 @@ bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
// Parse the true successor.
if (parser->parseSuccessorAndUseList(dest, destOperands))
return true;
return failure();
result->addSuccessor(dest, destOperands);
// Parse the false successor.
destOperands.clear();
if (parser->parseComma() ||
parser->parseSuccessorAndUseList(dest, destOperands))
return true;
return failure();
result->addSuccessor(dest, destOperands);
// Return false on success.
return false;
return success();
}
void CondBranchOp::print(OpAsmPrinter *p) {
@ -1132,13 +1132,14 @@ static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
*p << " : " << op.getType();
}
static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr;
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseAttribute(valueAttr, "value", result->attributes))
return true;
return failure();
// 'constant' taking a function reference doesn't get a redundant type
// specifier. The attribute itself carries it.
@ -1150,7 +1151,7 @@ static bool parseConstantOp(OpAsmParser *parser, OperationState *result) {
} else if (auto fpAttr = valueAttr.dyn_cast<FloatAttr>()) {
type = fpAttr.getType();
} else if (parser->parseColonType(type)) {
return true;
return failure();
}
return parser->addTypeToList(type, result->types);
}
@ -1298,12 +1299,13 @@ void DeallocOp::print(OpAsmPrinter *p) {
*p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
}
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
MemRefType type;
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands);
return failure(parser->parseOperand(memrefInfo) ||
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands));
}
LogicalResult DeallocOp::verify() {
@ -1338,19 +1340,19 @@ void DimOp::print(OpAsmPrinter *p) {
*p << " : " << getOperand()->getType();
}
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
Type type;
Type indexType = parser->getBuilder().getIndexType();
return parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, indexType, "index",
result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, result->operands) ||
parser->addTypeToList(indexType, result->types);
return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, indexType, "index",
result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, result->operands) ||
parser->addTypeToList(indexType, result->types));
}
LogicalResult DimOp::verify() {
@ -1491,7 +1493,7 @@ void DmaStartOp::print(OpAsmPrinter *p) {
// memref<1024 x f32, 2>,
// memref<1 x i32>
//
bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcMemRefInfo;
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
OpAsmParser::OperandType dstMemRefInfo;
@ -1518,11 +1520,11 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
parser->parseOperandList(tagIndexInfos, -1,
OpAsmParser::Delimiter::Square))
return true;
return failure();
// Parse optional stride and elements per stride.
if (parser->parseTrailingOperandList(strideInfo)) {
return true;
return failure();
}
if (!strideInfo.empty() && strideInfo.size() != 2) {
return parser->emitError(parser->getNameLoc(),
@ -1531,7 +1533,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
bool isStrided = strideInfo.size() == 2;
if (parser->parseColonTypeList(types))
return true;
return failure();
if (types.size() != 3)
return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
@ -1545,7 +1547,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
// tag indices should be index.
parser->resolveOperands(tagIndexInfos, indexType, result->operands))
return true;
return failure();
if (!types[0].isa<MemRefType>())
return parser->emitError(parser->getNameLoc(),
@ -1562,7 +1564,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
if (isStrided) {
if (parser->resolveOperand(strideInfo[0], indexType, result->operands) ||
parser->resolveOperand(strideInfo[1], indexType, result->operands))
return true;
return failure();
}
// Check that source/destination index list size matches associated rank.
@ -1575,7 +1577,7 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
return false;
return success();
}
LogicalResult DmaStartOp::verify() {
@ -1628,7 +1630,7 @@ void DmaWaitOp::print(OpAsmPrinter *p) {
// Eg:
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
//
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type type;
@ -1644,7 +1646,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true;
return failure();
if (!type.isa<MemRefType>())
return parser->emitError(parser->getNameLoc(),
@ -1654,7 +1656,7 @@ bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
return false;
return success();
}
void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
@ -1684,20 +1686,21 @@ void ExtractElementOp::print(OpAsmPrinter *p) {
*p << " : " << getAggregate()->getType();
}
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult ExtractElementOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
VectorOrTensorType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(aggregateInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types);
return failure(
parser->parseOperand(aggregateInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types));
}
LogicalResult ExtractElementOp::verify() {
@ -1771,20 +1774,20 @@ void LoadOp::print(OpAsmPrinter *p) {
*p << " : " << getMemRefType();
}
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types);
return failure(
parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type.getElementType(), result->types));
}
LogicalResult LoadOp::verify() {
@ -1963,13 +1966,14 @@ void ReturnOp::build(Builder *builder, OperationState *result,
result->addOperands(results);
}
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands);
return failure(parser->getCurrentLocation(&loc) ||
parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, loc, result->operands));
}
void ReturnOp::print(OpAsmPrinter *p) {
@ -2012,7 +2016,7 @@ void SelectOp::build(Builder *builder, OperationState *result, Value *condition,
result->addTypes(trueValue->getType());
}
bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult SelectOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> ops;
SmallVector<NamedAttribute, 4> attrs;
Type type;
@ -2020,7 +2024,7 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
if (parser->parseOperandList(ops, 3) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
return failure();
auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type);
if (!i1Type)
@ -2028,9 +2032,9 @@ bool SelectOp::parse(OpAsmParser *parser, OperationState *result) {
"expected type with valid i1 shape");
SmallVector<Type, 3> types = {i1Type, type, type};
return parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands) ||
parser->addTypeToList(type, result->types);
return failure(parser->resolveOperands(ops, types, parser->getNameLoc(),
result->operands) ||
parser->addTypeToList(type, result->types));
}
void SelectOp::print(OpAsmPrinter *p) {
@ -2090,23 +2094,23 @@ void StoreOp::print(OpAsmPrinter *p) {
*p << " : " << getMemRefType();
}
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType memrefType;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) ||
parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
return failure(
parser->parseOperand(storeValueInfo) || parser->parseComma() ||
parser->parseOperand(memrefInfo) ||
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) ||
parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands));
}
LogicalResult StoreOp::verify() {

View File

@ -120,7 +120,8 @@ void VectorTransferReadOp::print(OpAsmPrinter *p) {
*p << ", " << getResultType();
}
bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
@ -133,7 +134,7 @@ bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return true;
return failure();
// Resolution.
if (types.size() != 2)
@ -160,12 +161,12 @@ bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
paddingType = vectorType.getElementType();
}
auto indexType = parser->getBuilder().getIndexType();
return parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands) ||
(hasOptionalPaddingValue &&
parser->resolveOperand(paddingInfo[0], paddingType,
result->operands)) ||
parser->addTypeToList(vectorType, result->types);
return failure(
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands) ||
(hasOptionalPaddingValue &&
parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) ||
parser->addTypeToList(vectorType, result->types));
}
LogicalResult VectorTransferReadOp::verify() {
@ -286,7 +287,8 @@ void VectorTransferWriteOp::print(OpAsmPrinter *p) {
p->printType(getMemRefType());
}
bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
@ -297,7 +299,7 @@ bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return true;
return failure();
if (types.size() != 2)
return parser->emitError(parser->getNameLoc(), "expected 2 types");
@ -308,10 +310,10 @@ bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
if (!memrefType)
return parser->emitError(parser->getNameLoc(), "memRef type expected");
return parser->resolveOperands(storeValueInfo, vectorType,
result->operands) ||
parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands);
return failure(
parser->resolveOperands(storeValueInfo, vectorType, result->operands) ||
parser->resolveOperands(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, indexType, result->operands));
}
LogicalResult VectorTransferWriteOp::verify() {
@ -390,15 +392,16 @@ void VectorTypeCastOp::build(Builder *builder, OperationState *result,
result->addTypes(dstType);
}
bool VectorTypeCastOp::parse(OpAsmParser *parser, OperationState *result) {
ParseResult VectorTypeCastOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType operand;
Type srcType, dstType;
return parser->parseOperand(operand) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(srcType) || parser->parseComma() ||
parser->parseType(dstType) ||
parser->addTypeToList(dstType, result->types) ||
parser->resolveOperand(operand, srcType, result->operands);
return failure(parser->parseOperand(operand) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(srcType) || parser->parseComma() ||
parser->parseType(dstType) ||
parser->addTypeToList(dstType, result->types) ||
parser->resolveOperand(operand, srcType, result->operands));
}
void VectorTypeCastOp::print(OpAsmPrinter *p) {

View File

@ -51,7 +51,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
// CHECK: static void build(Value *val);
// CHECK: static void build(Builder *, OperationState *tblgen_state, Type r, ArrayRef<Type> s, Value *a, ArrayRef<Value *> b, IntegerAttr attr1, /*optional*/FloatAttr attr2);
// CHECK: static void build(Builder *, OperationState *tblgen_state, ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes);
// CHECK: static bool parse(OpAsmParser *parser, OperationState *result);
// CHECK: static ParseResult parse(OpAsmParser *parser, OperationState *result);
// CHECK: void print(OpAsmPrinter *p);
// CHECK: LogicalResult verify();
// CHECK: static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context);

View File

@ -863,7 +863,7 @@ void OpEmitter::genParser() {
return;
auto &method = opClass.newMethod(
"bool", "parse", "OpAsmParser *parser, OperationState *result",
"ParseResult", "parse", "OpAsmParser *parser, OperationState *result",
OpMethod::MP_Static);
auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
method.body() << " " << parser;