[AsmParser] Rework logic around "region argument parsing"

The asm parser had a notional distinction between parsing an
operand (like "%foo" or "%4#3") and parsing a region argument
(which isn't supposed to allow a result number like #3).

Unfortunately the implementation has two problems:

1) It didn't actually check for the result number and reject
   it.  parseRegionArgument and parseOperand were identical.
2) It had a lot of machinery built up around it that paralleled
   operand parsing.  This also was functionally identical, but
   also had some subtle differences (e.g. the parseOptional
   stuff had a different result type).

I thought about just removing all of this, but decided that the
missing error checking was important, so I reimplemented it with
a `allowResultNumber` flag on parseOperand.  This keeps the
codepaths unified and adds the missing error checks.

Differential Revision: https://reviews.llvm.org/D124470
This commit is contained in:
Chris Lattner 2022-04-26 12:03:03 -07:00
parent 6c81b57237
commit 5dedf911de
15 changed files with 93 additions and 107 deletions

View File

@ -1563,7 +1563,8 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) ||
if (parser.parseLParen() ||
parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return mlir::failure();
@ -1581,8 +1582,9 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput;
if (parser.parseKeyword("and") || parser.parseLParen() ||
parser.parseRegionArgument(iterateVar) || parser.parseEqual() ||
parser.parseOperand(iterateInput) || parser.parseRParen() ||
parser.parseOperand(iterateVar, /*allowResultNumber=*/false) ||
parser.parseEqual() || parser.parseOperand(iterateInput) ||
parser.parseRParen() ||
parser.resolveOperand(iterateInput, i1Type, result.operands))
return mlir::failure();
@ -1876,7 +1878,8 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
auto &builder = parser.getBuilder();
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return mlir::failure();
// Parse loop bounds.

View File

@ -584,7 +584,7 @@ public:
}
/// These are the supported delimiters around operand lists and region
/// argument lists, used by parseOperandList and parseRegionArgumentList.
/// argument lists, used by parseOperandList.
enum class Delimiter {
/// Zero or more operands with no delimiters.
None,
@ -1110,22 +1110,27 @@ public:
Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
Optional<FunctionType> parsedFnType = llvm::None) = 0;
/// Parse a single operand.
virtual ParseResult parseOperand(UnresolvedOperand &result) = 0;
/// Parse a single SSA value operand name along with a result number if
/// `allowResultNumber` is true.
virtual ParseResult parseOperand(UnresolvedOperand &result,
bool allowResultNumber = true) = 0;
/// Parse a single operand if present.
virtual OptionalParseResult
parseOptionalOperand(UnresolvedOperand &result) = 0;
parseOptionalOperand(UnresolvedOperand &result,
bool allowResultNumber = true) = 0;
/// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimiter, and an optional required operand count.
virtual ParseResult
parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
virtual ParseResult parseOperandList(
SmallVectorImpl<UnresolvedOperand> &result, int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None, bool allowResultNumber = true) = 0;
ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
Delimiter delimiter) {
return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
Delimiter delimiter,
bool allowResultNumber = true) {
return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter,
allowResultNumber);
}
/// Parse zero or more trailing SSA comma-separated trailing operand
@ -1243,29 +1248,6 @@ public:
ArrayRef<Type> argTypes = {},
bool enableNameShadowing = false) = 0;
/// Parse a region argument, this argument is resolved when calling
/// 'parseRegion'.
virtual ParseResult parseRegionArgument(UnresolvedOperand &argument) = 0;
/// Parse zero or more region arguments with a specified surrounding
/// delimiter, and an optional required argument count. Region arguments
/// define new values; so this also checks if values with the same names have
/// not been defined yet.
virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result,
Delimiter delimiter) {
return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
delimiter);
}
/// Parse a region argument if present.
virtual ParseResult
parseOptionalRegionArgument(UnresolvedOperand &argument) = 0;
//===--------------------------------------------------------------------===//
// Successor Parsing
//===--------------------------------------------------------------------===//

View File

@ -1433,7 +1433,8 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand inductionVariable;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return failure();
// Parse loop bounds.
@ -3527,8 +3528,8 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false) ||
parser.parseEqual() ||
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
parser.parseKeyword("to") ||

View File

@ -489,7 +489,8 @@ static ParseResult parseSwitchOpCases(
parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRegionArgumentList(defaultOperands) ||
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
@ -509,7 +510,8 @@ static ParseResult parseSwitchOpCases(
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();

View File

@ -539,8 +539,8 @@ parseSizeAssignment(OpAsmParser &parser,
MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
assert(indices.size() == 3 && "space for three indices expected");
SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3,
OpAsmParser::Delimiter::Paren) ||
if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false) ||
parser.parseKeyword("in") || parser.parseLParen())
return failure();
std::move(args.begin(), args.end(), indices.begin());
@ -548,8 +548,8 @@ parseSizeAssignment(OpAsmParser &parser,
for (int i = 0; i < 3; ++i) {
if (i != 0 && parser.parseComma())
return failure();
if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() ||
parser.parseOperand(sizes[i]))
if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
parser.parseEqual() || parser.parseOperand(sizes[i]))
return failure();
}
@ -869,7 +869,8 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
OpAsmParser::UnresolvedOperand arg;
Type type;
if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
parser.parseColonType(type))
return failure();
args.push_back(arg);

View File

@ -332,7 +332,8 @@ static ParseResult parseSwitchOpCases(
if (parser.parseColon() || parser.parseSuccessor(destination))
return failure();
if (!parser.parseOptionalLParen()) {
if (parser.parseRegionArgumentList(operands) ||
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(operandTypes) || parser.parseRParen())
return failure();
}

View File

@ -70,7 +70,8 @@ parseOperandList(OpAsmParser &parser, StringRef keyword,
OpAsmParser::UnresolvedOperand arg;
Type type;
if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
parser.parseColonType(type))
return failure();
args.push_back(arg);

View File

@ -524,8 +524,8 @@ parseWsLoopControl(OpAsmParser &parser, Region &region,
SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::UnresolvedOperand> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false))
return failure();
size_t numIVs = ivs.size();
@ -587,8 +587,8 @@ void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::UnresolvedOperand> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false))
return failure();
int numIVs = static_cast<int>(ivs.size());
Type loopVarType;

View File

@ -103,7 +103,7 @@ ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the loop variable followed by type.
OpAsmParser::UnresolvedOperand loopVariable;
Type loopVariableType;
if (parser.parseRegionArgument(loopVariable) ||
if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) ||
parser.parseColonType(loopVariableType))
return failure();

View File

@ -401,7 +401,8 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return failure();
// Parse loop bounds.
@ -1975,8 +1976,8 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false))
return failure();
// Parse loop bounds.

View File

@ -4698,7 +4698,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand laneId;
// Parse predicate operand.
if (parser.parseLParen() || parser.parseRegionArgument(laneId) ||
if (parser.parseLParen() ||
parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
parser.parseRParen())
return failure();

View File

@ -30,8 +30,12 @@ ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
// Parse argument name if present.
OpAsmParser::UnresolvedOperand argument;
Type argumentType;
if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
!argument.name.empty()) {
auto hadSSAValue = parser.parseOptionalOperand(argument,
/*allowResultNumber=*/false);
if (hadSSAValue.hasValue()) {
if (failed(hadSSAValue.getValue()))
return failure(); // Argument was present but malformed.
// Reject this if the preceding argument was missing a name.
if (argNames.empty() && !argTypes.empty())
return parser.emitError(loc, "expected type instead of SSA identifier");

View File

@ -268,8 +268,10 @@ public:
ParseResult
parseOptionalSSAUseList(SmallVectorImpl<UnresolvedOperand> &results);
/// Parse a single SSA use into 'result'.
ParseResult parseSSAUse(UnresolvedOperand &result);
/// Parse a single SSA use into 'result'. If 'allowResultNumber' is true then
/// we allow #42 syntax.
ParseResult parseSSAUse(UnresolvedOperand &result,
bool allowResultNumber = true);
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
@ -699,7 +701,8 @@ ParseResult OperationParser::parseOptionalSSAUseList(
///
/// ssa-use ::= ssa-id
///
ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result) {
ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result,
bool allowResultNumber) {
result.name = getTokenSpelling();
result.number = 0;
result.location = getToken().getLoc();
@ -708,6 +711,9 @@ ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result) {
// If we have an attribute ID, it is a result number.
if (getToken().is(Token::hash_identifier)) {
if (!allowResultNumber)
return emitError("result number not allowed in argument list");
if (auto value = getToken().getHashIdentifierNumber())
result.number = value.getValue();
else
@ -1267,9 +1273,10 @@ public:
//===--------------------------------------------------------------------===//
/// Parse a single operand.
ParseResult parseOperand(UnresolvedOperand &result) override {
ParseResult parseOperand(UnresolvedOperand &result,
bool allowResultNumber = true) override {
OperationParser::UnresolvedOperand useInfo;
if (parser.parseSSAUse(useInfo))
if (parser.parseSSAUse(useInfo, allowResultNumber))
return failure();
result = {useInfo.location, useInfo.name, useInfo.number, {}};
@ -1279,9 +1286,11 @@ public:
}
/// Parse a single operand if present.
OptionalParseResult parseOptionalOperand(UnresolvedOperand &result) override {
OptionalParseResult
parseOptionalOperand(UnresolvedOperand &result,
bool allowResultNumber = true) override {
if (parser.getToken().is(Token::percent_identifier))
return parseOperand(result);
return parseOperand(result, allowResultNumber);
return llvm::None;
}
@ -1289,17 +1298,8 @@ public:
/// surrounding delimiter, and an optional required operand count.
ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) override {
return parseOperandOrRegionArgList(result, /*isOperandList=*/true,
requiredOperandCount, delimiter);
}
/// Parse zero or more SSA comma-separated operand or region arguments with
/// optional surrounding delimiter and required operand count.
ParseResult
parseOperandOrRegionArgList(SmallVectorImpl<UnresolvedOperand> &result,
bool isOperandList, int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) {
Delimiter delimiter = Delimiter::None,
bool allowResultNumber = true) override {
auto startLoc = parser.getToken().getLoc();
// The no-delimiter case has some special handling for better diagnostics.
@ -1322,8 +1322,7 @@ public:
auto parseOneOperand = [&]() -> ParseResult {
UnresolvedOperand operandOrArg;
if (isOperandList ? parseOperand(operandOrArg)
: parseRegionArgument(operandOrArg))
if (parseOperand(operandOrArg, allowResultNumber))
return failure();
result.push_back(operandOrArg);
return success();
@ -1472,28 +1471,6 @@ public:
return success();
}
/// Parse a region argument. The type of the argument will be resolved later
/// by a call to `parseRegion`.
ParseResult parseRegionArgument(UnresolvedOperand &argument) override {
return parseOperand(argument);
}
/// Parse a region argument if present.
ParseResult
parseOptionalRegionArgument(UnresolvedOperand &argument) override {
if (parser.getToken().isNot(Token::percent_identifier))
return success();
return parseRegionArgument(argument);
}
ParseResult
parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) override {
return parseOperandOrRegionArgList(result, /*isOperandList=*/false,
requiredOperandCount, delimiter);
}
//===--------------------------------------------------------------------===//
// Successor Parsing
//===--------------------------------------------------------------------===//
@ -1539,8 +1516,8 @@ public:
auto parseElt = [&]() -> ParseResult {
UnresolvedOperand regionArg, operand;
if (parseRegionArgument(regionArg) || parseEqual() ||
parseOperand(operand))
if (parseOperand(regionArg, /*allowResultNumber=*/false) ||
parseEqual() || parseOperand(operand))
return failure();
lhs.push_back(regionArg);
rhs.push_back(operand);
@ -1561,8 +1538,9 @@ public:
auto parseElt = [&]() -> ParseResult {
UnresolvedOperand regionArg, operand;
Type type;
if (parseRegionArgument(regionArg) || parseEqual() ||
parseOperand(operand) || parseColon() || parseType(type))
if (parseOperand(regionArg, /*allowResultNumber=*/false) ||
parseEqual() || parseOperand(operand) || parseColon() ||
parseType(type))
return failure();
lhs.push_back(regionArg);
rhs.push_back(operand);

View File

@ -380,3 +380,13 @@ func.func @affine_for_iter_args_mismatch(%buffer: memref<1024xf32>) -> f32 {
}
return %res : f32
}
// -----
func.func @result_number() {
// expected-error@+1 {{result number not allowed}}
affine.for %n0#0 = 0 to 7 {
}
return
}

View File

@ -932,7 +932,8 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
if (parser.parseRegionArgumentList(ivsInfo))
if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false))
return failure();
// Parse the body region.