diff --git a/mlir/include/mlir/IR/StandardOps.h b/mlir/include/mlir/IR/StandardOps.h index 4c71b7edd4ef..99cbc59e84d5 100644 --- a/mlir/include/mlir/IR/StandardOps.h +++ b/mlir/include/mlir/IR/StandardOps.h @@ -299,7 +299,6 @@ private: explicit StoreOp(const Operation *state) : OpBase(state) {} }; -/// TODO: change comment. /// The "return" operation represents a return statement of an ML function. /// The operation takes variable number of operands and produces no results. /// The operand number and types must match the signature of the ML function diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp index c3f815f03fb2..e78290e929d1 100644 --- a/mlir/lib/IR/StandardOps.cpp +++ b/mlir/lib/IR/StandardOps.cpp @@ -369,7 +369,6 @@ const char *LoadOp::verify() const { bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) { SmallVector opInfo; SmallVector types; - SmallVector operands; return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) || (!opInfo.empty() && parser->parseColonTypeList(types)) || @@ -391,17 +390,15 @@ void ReturnOp::print(OpAsmPrinter *p) const { const char *ReturnOp::verify() const { // ReturnOp must be part of an ML function. if (auto *stmt = dyn_cast(getOperation())) { - StmtBlock *block = stmt->getBlock(); - - if (!block || !isa(block) || - &cast(block)->back() != stmt) + MLFunction *func = dyn_cast_or_null(stmt->getBlock()); + if (!func || &func->back() != stmt) return "must be the last statement in the ML function"; // Return success. Checking that operand types match those in the function // signature is performed in the ML function verifier. return nullptr; } - return "cannot occur in a CFG function."; + return "cannot occur in a CFG function"; } //===----------------------------------------------------------------------===// @@ -470,6 +467,6 @@ const char *StoreOp::verify() const { /// Install the standard operations in the specified operation set. void mlir::registerStandardOperations(OperationSet &opSet) { opSet.addOperations( + ReturnOp, StoreOp>( /*prefix=*/""); }