forked from OSchip/llvm-project
[AsmParser] Introduce a new "Argument" abstraction + supporting logic
MLIR has a common pattern for "arguments" that uses syntax like `%x : i32 {attrs} loc("sourceloc")` which is implemented in adhoc ways throughout the codebase. The approach this uses is verbose (because it is implemented with parallel arrays) and inconsistent (e.g. lots of things drop source location info). Solve this by introducing OpAsmParser::Argument and make addRegion (which sets up BlockArguments for the region) take it. Convert the world to propagating this down. This means that we correctly capture and propagate source location information in a lot more cases (e.g. see the affine.for testcase example), and it also simplifies much code. Differential Revision: https://reviews.llvm.org/D124649
This commit is contained in:
parent
6e689cbaf4
commit
d85eb4e2d6
|
@ -1200,8 +1200,8 @@ mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser,
|
|||
result.addRegion();
|
||||
} else {
|
||||
// Parse the optional initializer body.
|
||||
auto parseResult = parser.parseOptionalRegion(
|
||||
*result.addRegion(), /*arguments=*/llvm::None, /*argTypes=*/llvm::None);
|
||||
auto parseResult =
|
||||
parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{});
|
||||
if (parseResult.hasValue() && mlir::failed(*parseResult))
|
||||
return mlir::failure();
|
||||
}
|
||||
|
@ -1562,9 +1562,9 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
|
|||
mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
|
||||
if (parser.parseLParen() ||
|
||||
parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
|
||||
mlir::OpAsmParser::Argument inductionVariable, iterateVar;
|
||||
mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput;
|
||||
if (parser.parseLParen() || parser.parseArgument(inductionVariable) ||
|
||||
parser.parseEqual())
|
||||
return mlir::failure();
|
||||
|
||||
|
@ -1577,22 +1577,18 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
|
|||
parser.resolveOperand(ub, indexType, result.operands) ||
|
||||
parser.parseKeyword("step") || parser.parseOperand(step) ||
|
||||
parser.parseRParen() ||
|
||||
parser.resolveOperand(step, indexType, result.operands))
|
||||
return mlir::failure();
|
||||
|
||||
mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput;
|
||||
if (parser.parseKeyword("and") || parser.parseLParen() ||
|
||||
parser.parseOperand(iterateVar, /*allowResultNumber=*/false) ||
|
||||
parser.parseEqual() || parser.parseOperand(iterateInput) ||
|
||||
parser.parseRParen() ||
|
||||
parser.resolveOperand(step, indexType, result.operands) ||
|
||||
parser.parseKeyword("and") || parser.parseLParen() ||
|
||||
parser.parseArgument(iterateVar) || parser.parseEqual() ||
|
||||
parser.parseOperand(iterateInput) || parser.parseRParen() ||
|
||||
parser.resolveOperand(iterateInput, i1Type, result.operands))
|
||||
return mlir::failure();
|
||||
|
||||
// Parse the initial iteration arguments.
|
||||
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs;
|
||||
auto prependCount = false;
|
||||
|
||||
// Induction variable.
|
||||
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
|
||||
regionArgs.push_back(inductionVariable);
|
||||
regionArgs.push_back(iterateVar);
|
||||
|
||||
|
@ -1652,7 +1648,10 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
|
|||
parser.getNameLoc(),
|
||||
"mismatch in number of loop-carried values and defined values");
|
||||
|
||||
if (parser.parseRegion(*body, regionArgs, argTypes))
|
||||
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
|
||||
regionArgs[i].type = argTypes[i];
|
||||
|
||||
if (parser.parseRegion(*body, regionArgs))
|
||||
return mlir::failure();
|
||||
|
||||
fir::IterWhileOp::ensureTerminator(*body, builder, result.location);
|
||||
|
@ -1876,10 +1875,10 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
|
|||
mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
|
||||
mlir::OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
|
||||
mlir::OpAsmParser::Argument inductionVariable;
|
||||
mlir::OpAsmParser::UnresolvedOperand lb, ub, step;
|
||||
// Parse the induction variable followed by '='.
|
||||
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
|
||||
parser.parseEqual())
|
||||
if (parser.parseArgument(inductionVariable) || parser.parseEqual())
|
||||
return mlir::failure();
|
||||
|
||||
// Parse loop bounds.
|
||||
|
@ -1896,7 +1895,8 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
|
|||
result.addAttribute("unordered", builder.getUnitAttr());
|
||||
|
||||
// Parse the optional initial iteration arguments.
|
||||
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs, operands;
|
||||
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
|
||||
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
|
||||
llvm::SmallVector<mlir::Type> argTypes;
|
||||
bool prependCount = false;
|
||||
regionArgs.push_back(inductionVariable);
|
||||
|
@ -1939,8 +1939,10 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
|
|||
return parser.emitError(
|
||||
parser.getNameLoc(),
|
||||
"mismatch in number of loop-carried values and defined values");
|
||||
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
|
||||
regionArgs[i].type = argTypes[i];
|
||||
|
||||
if (parser.parseRegion(*body, regionArgs, argTypes))
|
||||
if (parser.parseRegion(*body, regionArgs))
|
||||
return mlir::failure();
|
||||
|
||||
DoLoopOp::ensureTerminator(*body, builder, result.location);
|
||||
|
|
|
@ -41,8 +41,8 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
|
|||
ArrayRef<DictionaryAttr> argAttrs,
|
||||
ArrayRef<DictionaryAttr> resultAttrs);
|
||||
void addArgAndResultAttrs(Builder &builder, OperationState &result,
|
||||
ArrayRef<NamedAttrList> argAttrs,
|
||||
ArrayRef<NamedAttrList> resultAttrs);
|
||||
ArrayRef<OpAsmParser::Argument> argAttrs,
|
||||
ArrayRef<DictionaryAttr> resultAttrs);
|
||||
|
||||
/// Callback type for `parseFunctionOp`, the callback should produce the
|
||||
/// type that will be associated with a function-like operation from lists of
|
||||
|
@ -52,26 +52,20 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
|
|||
using FuncTypeBuilder = function_ref<Type(
|
||||
Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
|
||||
|
||||
/// Parses function arguments using `parser`. The `allowVariadic` argument
|
||||
/// indicates whether functions with variadic arguments are supported. The
|
||||
/// trailing arguments are populated by this function with names, types,
|
||||
/// attributes and locations of the arguments.
|
||||
ParseResult parseFunctionArgumentList(
|
||||
OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
|
||||
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
|
||||
bool &isVariadic);
|
||||
|
||||
/// Parses a function signature using `parser`. The `allowVariadic` argument
|
||||
/// indicates whether functions with variadic arguments are supported. The
|
||||
/// trailing arguments are populated by this function with names, types,
|
||||
/// attributes and locations of the arguments and those of the results.
|
||||
ParseResult parseFunctionSignature(
|
||||
OpAsmParser &parser, bool allowVariadic,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
|
||||
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
|
||||
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<NamedAttrList> &resultAttrs);
|
||||
ParseResult
|
||||
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
|
||||
SmallVectorImpl<OpAsmParser::Argument> &arguments,
|
||||
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<DictionaryAttr> &resultAttrs);
|
||||
|
||||
/// Get a function type corresponding to an array of arguments (which have
|
||||
/// types) and a set of result types.
|
||||
Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
|
||||
ArrayRef<Type> resultTypes);
|
||||
|
||||
/// Parser implementation for function-like operations. Uses
|
||||
/// `funcTypeBuilder` to construct the custom function type given lists of
|
||||
|
|
|
@ -633,14 +633,14 @@ public:
|
|||
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
|
||||
/// context parameter.
|
||||
template <typename T, typename... ParamsT>
|
||||
T getChecked(SMLoc loc, ParamsT &&... params) {
|
||||
T getChecked(SMLoc loc, ParamsT &&...params) {
|
||||
return T::getChecked([&] { return emitError(loc); },
|
||||
std::forward<ParamsT>(params)...);
|
||||
}
|
||||
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
|
||||
/// errors.
|
||||
template <typename T, typename... ParamsT>
|
||||
T getChecked(ParamsT &&... params) {
|
||||
T getChecked(ParamsT &&...params) {
|
||||
return T::getChecked([&] { return emitError(getNameLoc()); },
|
||||
std::forward<ParamsT>(params)...);
|
||||
}
|
||||
|
@ -1093,7 +1093,6 @@ public:
|
|||
SMLoc location; // Location of the token.
|
||||
StringRef name; // Value name, e.g. %42 or %abc
|
||||
unsigned number; // Number, e.g. 12 for an operand like %xyz#12
|
||||
Optional<Location> sourceLoc; // Source location specifier if present.
|
||||
};
|
||||
|
||||
/// Parse different components, viz., use-info of operand(s), successor(s),
|
||||
|
@ -1219,34 +1218,64 @@ public:
|
|||
SmallVectorImpl<UnresolvedOperand> &symbOperands,
|
||||
AffineExpr &expr) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Argument Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
struct Argument {
|
||||
UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
|
||||
Type type; // Type.
|
||||
DictionaryAttr attrs; // Attributes if present.
|
||||
Optional<Location> sourceLoc; // Source location specifier if present.
|
||||
};
|
||||
|
||||
/// Parse a single argument with the following syntax:
|
||||
///
|
||||
/// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
|
||||
///
|
||||
/// If `allowType` is false or `allowAttrs` are false then the respective
|
||||
/// parts of the grammar are not parsed.
|
||||
virtual ParseResult parseArgument(Argument &result, bool allowType = false,
|
||||
bool allowAttrs = false) = 0;
|
||||
|
||||
/// Parse a single argument if present.
|
||||
virtual OptionalParseResult
|
||||
parseOptionalArgument(Argument &result, bool allowType = false,
|
||||
bool allowAttrs = false) = 0;
|
||||
|
||||
/// Parse zero or more arguments with a specified surrounding delimiter.
|
||||
virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
|
||||
Delimiter delimiter = Delimiter::None,
|
||||
bool allowType = false,
|
||||
bool allowAttrs = false) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Region Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Parses a region. Any parsed blocks are appended to 'region' and must be
|
||||
/// moved to the op regions after the op is created. The first block of the
|
||||
/// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
|
||||
/// set to true, the argument names are allowed to shadow the names of other
|
||||
/// existing SSA values defined above the region scope. 'enableNameShadowing'
|
||||
/// can only be set to true for regions attached to operations that are
|
||||
/// 'IsolatedFromAbove'.
|
||||
/// region takes 'arguments'.
|
||||
///
|
||||
/// If 'enableNameShadowing' is set to true, the argument names are allowed to
|
||||
/// shadow the names of other existing SSA values defined above the region
|
||||
/// scope. 'enableNameShadowing' can only be set to true for regions attached
|
||||
/// to operations that are 'IsolatedFromAbove'.
|
||||
virtual ParseResult parseRegion(Region ®ion,
|
||||
ArrayRef<UnresolvedOperand> arguments = {},
|
||||
ArrayRef<Type> argTypes = {},
|
||||
ArrayRef<Argument> arguments = {},
|
||||
bool enableNameShadowing = false) = 0;
|
||||
|
||||
/// Parses a region if present.
|
||||
virtual OptionalParseResult parseOptionalRegion(
|
||||
Region ®ion, ArrayRef<UnresolvedOperand> arguments = {},
|
||||
ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
|
||||
virtual OptionalParseResult
|
||||
parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {},
|
||||
bool enableNameShadowing = false) = 0;
|
||||
|
||||
/// Parses a region if present. If the region is present, a new region is
|
||||
/// allocated and placed in `region`. If no region is present or on failure,
|
||||
/// `region` remains untouched.
|
||||
virtual OptionalParseResult
|
||||
parseOptionalRegion(std::unique_ptr<Region> ®ion,
|
||||
ArrayRef<UnresolvedOperand> arguments = {},
|
||||
ArrayRef<Type> argTypes = {},
|
||||
ArrayRef<Argument> arguments = {},
|
||||
bool enableNameShadowing = false) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -1269,7 +1298,7 @@ public:
|
|||
|
||||
/// Parse a list of assignments of the form
|
||||
/// (%x1 = %y1, %x2 = %y2, ...)
|
||||
ParseResult parseAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs,
|
||||
ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
|
||||
SmallVectorImpl<UnresolvedOperand> &rhs) {
|
||||
OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
|
||||
if (!result.hasValue())
|
||||
|
@ -1278,26 +1307,8 @@ public:
|
|||
}
|
||||
|
||||
virtual OptionalParseResult
|
||||
parseOptionalAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs,
|
||||
parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
|
||||
SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
|
||||
|
||||
/// Parse a list of assignments of the form
|
||||
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
|
||||
ParseResult
|
||||
parseAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
|
||||
SmallVectorImpl<UnresolvedOperand> &rhs,
|
||||
SmallVectorImpl<Type> &types) {
|
||||
OptionalParseResult result =
|
||||
parseOptionalAssignmentListWithTypes(lhs, rhs, types);
|
||||
if (!result.hasValue())
|
||||
return emitError(getCurrentLocation(), "expected '('");
|
||||
return result.getValue();
|
||||
}
|
||||
|
||||
virtual OptionalParseResult
|
||||
parseOptionalAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
|
||||
SmallVectorImpl<UnresolvedOperand> &rhs,
|
||||
SmallVectorImpl<Type> &types) = 0;
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -1339,7 +1350,6 @@ public:
|
|||
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1431,10 +1431,10 @@ static ParseResult parseBound(bool isLower, OperationState &result,
|
|||
|
||||
ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
OpAsmParser::UnresolvedOperand inductionVariable;
|
||||
OpAsmParser::Argument inductionVariable;
|
||||
inductionVariable.type = builder.getIndexType();
|
||||
// Parse the induction variable followed by '='.
|
||||
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
|
||||
parser.parseEqual())
|
||||
if (parser.parseArgument(inductionVariable) || parser.parseEqual())
|
||||
return failure();
|
||||
|
||||
// Parse loop bounds.
|
||||
|
@ -1463,8 +1463,10 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
// Parse the optional initial iteration arguments.
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
|
||||
SmallVector<Type, 4> argTypes;
|
||||
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
||||
|
||||
// Induction variable.
|
||||
regionArgs.push_back(inductionVariable);
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
|
||||
|
@ -1473,23 +1475,23 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
// Resolve input operands.
|
||||
for (auto operandType : llvm::zip(operands, result.types))
|
||||
if (parser.resolveOperand(std::get<0>(operandType),
|
||||
std::get<1>(operandType), result.operands))
|
||||
for (auto argOperandType :
|
||||
llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
|
||||
Type type = std::get<2>(argOperandType);
|
||||
std::get<0>(argOperandType).type = type;
|
||||
if (parser.resolveOperand(std::get<1>(argOperandType), type,
|
||||
result.operands))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
// Induction variable.
|
||||
Type indexType = builder.getIndexType();
|
||||
argTypes.push_back(indexType);
|
||||
// Loop carried variables.
|
||||
argTypes.append(result.types.begin(), result.types.end());
|
||||
|
||||
// Parse the body region.
|
||||
Region *body = result.addRegion();
|
||||
if (regionArgs.size() != argTypes.size())
|
||||
if (regionArgs.size() != result.types.size() + 1)
|
||||
return parser.emitError(
|
||||
parser.getNameLoc(),
|
||||
"mismatch between the number of loop-carried values and results");
|
||||
if (parser.parseRegion(*body, regionArgs, argTypes))
|
||||
if (parser.parseRegion(*body, regionArgs))
|
||||
return failure();
|
||||
|
||||
AffineForOp::ensureTerminator(*body, builder, result.location);
|
||||
|
@ -1548,7 +1550,8 @@ unsigned AffineForOp::getNumIterOperands() {
|
|||
|
||||
void AffineForOp::print(OpAsmPrinter &p) {
|
||||
p << ' ';
|
||||
p.printOperand(getBody()->getArgument(0));
|
||||
p.printRegionArgument(getBody()->getArgument(0), /*argAtrs=*/{},
|
||||
/*omitType=*/true);
|
||||
p << " = ";
|
||||
printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
|
||||
p << " to ";
|
||||
|
@ -3527,9 +3530,8 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
|
|||
OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
auto indexType = builder.getIndexType();
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
|
||||
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
|
||||
/*allowResultNumber=*/false) ||
|
||||
SmallVector<OpAsmParser::Argument, 4> ivs;
|
||||
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseEqual() ||
|
||||
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
|
||||
parser.parseKeyword("to") ||
|
||||
|
@ -3600,8 +3602,9 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
|
|||
|
||||
// Now parse the body.
|
||||
Region *body = result.addRegion();
|
||||
SmallVector<Type, 4> types(ivs.size(), indexType);
|
||||
if (parser.parseRegion(*body, ivs, types) ||
|
||||
for (auto &iv : ivs)
|
||||
iv.type = indexType;
|
||||
if (parser.parseRegion(*body, ivs) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -178,21 +178,19 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
// Parse async value operands (%value as %unwrapped : !async.value<!type>).
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> unwrappedArgs;
|
||||
SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
|
||||
SmallVector<Type, 4> valueTypes;
|
||||
SmallVector<Type, 4> unwrappedTypes;
|
||||
|
||||
// Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
|
||||
auto parseAsyncValueArg = [&]() -> ParseResult {
|
||||
if (parser.parseOperand(valueArgs.emplace_back()) ||
|
||||
parser.parseKeyword("as") ||
|
||||
parser.parseOperand(unwrappedArgs.emplace_back()) ||
|
||||
parser.parseArgument(unwrappedArgs.emplace_back()) ||
|
||||
parser.parseColonType(valueTypes.emplace_back()))
|
||||
return failure();
|
||||
|
||||
auto valueTy = valueTypes.back().dyn_cast<ValueType>();
|
||||
unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
|
||||
|
||||
unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
|
||||
return success();
|
||||
};
|
||||
|
||||
|
@ -227,12 +225,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
// Parse asynchronous region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
|
||||
/*argTypes=*/{unwrappedTypes},
|
||||
/*enableNameShadowing=*/false))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
|
||||
}
|
||||
|
||||
LogicalResult ExecuteOp::verifyRegions() {
|
||||
|
|
|
@ -622,8 +622,17 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
Type index = parser.getBuilder().getIndexType();
|
||||
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
|
||||
LaunchOp::kNumConfigRegionAttributes, index);
|
||||
|
||||
SmallVector<OpAsmParser::Argument> regionArguments;
|
||||
for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
|
||||
OpAsmParser::Argument arg;
|
||||
arg.ssaName = std::get<0>(ssaValueAndType);
|
||||
arg.type = std::get<1>(ssaValueAndType);
|
||||
regionArguments.push_back(arg);
|
||||
}
|
||||
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, regionArgs, dataTypes) ||
|
||||
if (parser.parseRegion(*body, regionArguments) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
|
@ -758,11 +767,16 @@ static ParseResult parseLaunchFuncOperands(
|
|||
SmallVectorImpl<Type> &argTypes) {
|
||||
if (parser.parseOptionalKeyword("args"))
|
||||
return success();
|
||||
SmallVector<NamedAttrList> argAttrs;
|
||||
bool isVariadic = false;
|
||||
return function_interface_impl::parseFunctionArgumentList(
|
||||
parser, /*allowAttributes=*/false,
|
||||
/*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic);
|
||||
|
||||
SmallVector<OpAsmParser::Argument> args;
|
||||
if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
|
||||
/*allowType=*/true))
|
||||
return failure();
|
||||
for (auto &arg : args) {
|
||||
argNames.push_back(arg.ssaName);
|
||||
argTypes.push_back(arg.type);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
|
||||
|
@ -779,8 +793,6 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
|
|||
printer << ")";
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShuffleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -852,32 +864,13 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
|
|||
/// keyword provided as argument.
|
||||
static ParseResult
|
||||
parseAttributions(OpAsmParser &parser, StringRef keyword,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args,
|
||||
SmallVectorImpl<Type> &argTypes) {
|
||||
SmallVectorImpl<OpAsmParser::Argument> &args) {
|
||||
// If we could not parse the keyword, just assume empty list and succeed.
|
||||
if (failed(parser.parseOptionalKeyword(keyword)))
|
||||
return success();
|
||||
|
||||
if (failed(parser.parseLParen()))
|
||||
return failure();
|
||||
|
||||
// Early exit for an empty list.
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
do {
|
||||
OpAsmParser::UnresolvedOperand arg;
|
||||
Type type;
|
||||
|
||||
if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
|
||||
parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
args.push_back(arg);
|
||||
argTypes.push_back(type);
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
return parser.parseRParen();
|
||||
return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
|
||||
/*allowType=*/true);
|
||||
}
|
||||
|
||||
/// Parses a GPU function.
|
||||
|
@ -886,10 +879,8 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
|
|||
/// (`->` function-result-list)? memory-attribution `kernel`?
|
||||
/// function-attributes? region
|
||||
ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
|
||||
SmallVector<NamedAttrList> argAttrs;
|
||||
SmallVector<NamedAttrList> resultAttrs;
|
||||
SmallVector<Type> argTypes;
|
||||
SmallVector<OpAsmParser::Argument> entryArgs;
|
||||
SmallVector<DictionaryAttr> resultAttrs;
|
||||
SmallVector<Type> resultTypes;
|
||||
bool isVariadic;
|
||||
|
||||
|
@ -901,34 +892,41 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
auto signatureLocation = parser.getCurrentLocation();
|
||||
if (failed(function_interface_impl::parseFunctionSignature(
|
||||
parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
|
||||
isVariadic, resultTypes, resultAttrs)))
|
||||
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
|
||||
resultAttrs)))
|
||||
return failure();
|
||||
|
||||
if (entryArgs.empty() && !argTypes.empty())
|
||||
if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
|
||||
return parser.emitError(signatureLocation)
|
||||
<< "gpu.func requires named arguments";
|
||||
|
||||
// Construct the function type. More types will be added to the region, but
|
||||
// not to the function type.
|
||||
Builder &builder = parser.getBuilder();
|
||||
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto &arg : entryArgs)
|
||||
argTypes.push_back(arg.type);
|
||||
auto type = builder.getFunctionType(argTypes, resultTypes);
|
||||
result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
|
||||
|
||||
function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
|
||||
resultAttrs);
|
||||
|
||||
// Parse workgroup memory attributions.
|
||||
if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
|
||||
entryArgs, argTypes)))
|
||||
entryArgs)))
|
||||
return failure();
|
||||
|
||||
// Store the number of operands we just parsed as the number of workgroup
|
||||
// memory attributions.
|
||||
unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs();
|
||||
unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
|
||||
result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
|
||||
builder.getI64IntegerAttr(numWorkgroupAttrs));
|
||||
|
||||
// Parse private memory attributions.
|
||||
if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
|
||||
entryArgs, argTypes)))
|
||||
if (failed(
|
||||
parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), entryArgs)))
|
||||
return failure();
|
||||
|
||||
// Parse the kernel attribute if present.
|
||||
|
@ -939,13 +937,11 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
// Parse attributes.
|
||||
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
||||
return failure();
|
||||
function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
|
||||
resultAttrs);
|
||||
|
||||
// Parse the region. If no argument names were provided, take all names
|
||||
// (including those of attributions) from the entry block.
|
||||
auto *body = result.addRegion();
|
||||
return parser.parseRegion(*body, entryArgs, argTypes);
|
||||
return parser.parseRegion(*body, entryArgs);
|
||||
}
|
||||
|
||||
static void printAttributions(OpAsmPrinter &p, StringRef keyword,
|
||||
|
@ -1078,16 +1074,14 @@ void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
|
|||
ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
|
||||
// If module attributes are present, parse them.
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
result.attributes) ||
|
||||
// If module attributes are present, parse them.
|
||||
parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
|
||||
// Parse the module body.
|
||||
auto *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, None, None))
|
||||
if (parser.parseRegion(*body, {}))
|
||||
return failure();
|
||||
|
||||
// Ensure that this module has a valid terminator.
|
||||
|
|
|
@ -2152,10 +2152,8 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
parser, result, LLVM::Linkage::External)));
|
||||
|
||||
StringAttr nameAttr;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
|
||||
SmallVector<NamedAttrList> argAttrs;
|
||||
SmallVector<NamedAttrList> resultAttrs;
|
||||
SmallVector<Type> argTypes;
|
||||
SmallVector<OpAsmParser::Argument> entryArgs;
|
||||
SmallVector<DictionaryAttr> resultAttrs;
|
||||
SmallVector<Type> resultTypes;
|
||||
bool isVariadic;
|
||||
|
||||
|
@ -2163,10 +2161,13 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||
result.attributes) ||
|
||||
function_interface_impl::parseFunctionSignature(
|
||||
parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
|
||||
isVariadic, resultTypes, resultAttrs))
|
||||
parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
|
||||
resultAttrs))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto &arg : entryArgs)
|
||||
argTypes.push_back(arg.type);
|
||||
auto type =
|
||||
buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
|
||||
function_interface_impl::VariadicFlag(isVariadic));
|
||||
|
@ -2178,11 +2179,11 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
|
||||
return failure();
|
||||
function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
|
||||
argAttrs, resultAttrs);
|
||||
entryArgs, resultAttrs);
|
||||
|
||||
auto *body = result.addRegion();
|
||||
OptionalParseResult parseResult = parser.parseOptionalRegion(
|
||||
*body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
|
||||
OptionalParseResult parseResult =
|
||||
parser.parseOptionalRegion(*body, entryArgs);
|
||||
return failure(parseResult.hasValue() && failed(*parseResult));
|
||||
}
|
||||
|
||||
|
|
|
@ -799,10 +799,8 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
failed(parser.parseOptionalAttrDict(result.attributes)))
|
||||
return failure();
|
||||
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 8> regionOperands;
|
||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||
SmallVector<Type, 8> operandTypes, regionTypes;
|
||||
if (parser.parseRegion(*region, regionOperands, regionTypes))
|
||||
if (parser.parseRegion(*region, {}))
|
||||
return failure();
|
||||
result.addRegion(std::move(region));
|
||||
|
||||
|
|
|
@ -275,7 +275,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
return failure();
|
||||
|
||||
// Parse the body region.
|
||||
if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
|
||||
return failure();
|
||||
AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
|
||||
result.location);
|
||||
|
@ -1215,7 +1215,7 @@ ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
|
|||
return failure();
|
||||
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
|
||||
if (parser.parseRegion(*body, {}) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
|
||||
|
|
|
@ -523,20 +523,16 @@ parseWsLoopControl(OpAsmParser &parser, Region ®ion,
|
|||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
|
||||
SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
|
||||
// Parse an opening `(` followed by induction variables followed by `)`
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> ivs;
|
||||
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
|
||||
/*allowResultNumber=*/false))
|
||||
return failure();
|
||||
|
||||
size_t numIVs = ivs.size();
|
||||
SmallVector<OpAsmParser::Argument> ivs;
|
||||
Type loopVarType;
|
||||
if (parser.parseColonType(loopVarType) ||
|
||||
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseColonType(loopVarType) ||
|
||||
// Parse loop bounds.
|
||||
parser.parseEqual() ||
|
||||
parser.parseOperandList(lowerBound, numIVs,
|
||||
parser.parseOperandList(lowerBound, ivs.size(),
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseKeyword("to") ||
|
||||
parser.parseOperandList(upperBound, numIVs,
|
||||
parser.parseOperandList(upperBound, ivs.size(),
|
||||
OpAsmParser::Delimiter::Paren))
|
||||
return failure();
|
||||
|
||||
|
@ -545,15 +541,14 @@ parseWsLoopControl(OpAsmParser &parser, Region ®ion,
|
|||
|
||||
// Parse step values.
|
||||
if (parser.parseKeyword("step") ||
|
||||
parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren))
|
||||
parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
|
||||
return failure();
|
||||
|
||||
// Now parse the body.
|
||||
loopVarTypes = SmallVector<Type>(numIVs, loopVarType);
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
|
||||
if (parser.parseRegion(region, blockArgs, loopVarTypes))
|
||||
return failure();
|
||||
return success();
|
||||
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
|
||||
for (auto &iv : ivs)
|
||||
iv.type = loopVarType;
|
||||
return parser.parseRegion(region, ivs);
|
||||
}
|
||||
|
||||
void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
|
||||
|
@ -582,33 +577,28 @@ void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
|
|||
/// clause ::= TODO
|
||||
ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Parse an opening `(` followed by induction variables followed by `)`
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> ivs;
|
||||
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
|
||||
/*allowResultNumber=*/false))
|
||||
return failure();
|
||||
int numIVs = static_cast<int>(ivs.size());
|
||||
SmallVector<OpAsmParser::Argument> ivs;
|
||||
Type loopVarType;
|
||||
if (parser.parseColonType(loopVarType))
|
||||
return failure();
|
||||
// Parse loop bounds.
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> lower;
|
||||
if (parser.parseEqual() ||
|
||||
parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(lower, loopVarType, result.operands))
|
||||
return failure();
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> upper;
|
||||
if (parser.parseKeyword("to") ||
|
||||
parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(upper, loopVarType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse step values.
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> steps;
|
||||
if (parser.parseKeyword("step") ||
|
||||
parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> lower, upper, steps;
|
||||
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseColonType(loopVarType) ||
|
||||
// Parse loop bounds.
|
||||
parser.parseEqual() ||
|
||||
parser.parseOperandList(lower, ivs.size(),
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(lower, loopVarType, result.operands) ||
|
||||
parser.parseKeyword("to") ||
|
||||
parser.parseOperandList(upper, ivs.size(),
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(upper, loopVarType, result.operands) ||
|
||||
// Parse step values.
|
||||
parser.parseKeyword("step") ||
|
||||
parser.parseOperandList(steps, ivs.size(),
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.resolveOperands(steps, loopVarType, result.operands))
|
||||
return failure();
|
||||
|
||||
int numIVs = static_cast<int>(ivs.size());
|
||||
SmallVector<int> segments{numIVs, numIVs, numIVs};
|
||||
// TODO: Add parseClauses() when we support clauses
|
||||
result.addAttribute("operand_segment_sizes",
|
||||
|
@ -616,11 +606,9 @@ ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
// Now parse the body.
|
||||
Region *body = result.addRegion();
|
||||
SmallVector<Type> ivTypes(numIVs, loopVarType);
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
|
||||
if (parser.parseRegion(*body, blockArgs, ivTypes))
|
||||
return failure();
|
||||
return success();
|
||||
for (auto &iv : ivs)
|
||||
iv.type = loopVarType;
|
||||
return parser.parseRegion(*body, ivs);
|
||||
}
|
||||
|
||||
void SimdLoopOp::print(OpAsmPrinter &p) {
|
||||
|
|
|
@ -101,41 +101,29 @@ void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||
|
||||
ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Parse the loop variable followed by type.
|
||||
OpAsmParser::UnresolvedOperand loopVariable;
|
||||
Type loopVariableType;
|
||||
if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) ||
|
||||
parser.parseColonType(loopVariableType))
|
||||
return failure();
|
||||
|
||||
// Parse the "in" keyword.
|
||||
if (parser.parseKeyword("in", " after loop variable"))
|
||||
return failure();
|
||||
|
||||
// Parse the operand (value range).
|
||||
OpAsmParser::Argument loopVariable;
|
||||
OpAsmParser::UnresolvedOperand operandInfo;
|
||||
if (parser.parseOperand(operandInfo))
|
||||
if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
|
||||
parser.parseKeyword("in", " after loop variable") ||
|
||||
// Parse the operand (value range).
|
||||
parser.parseOperand(operandInfo))
|
||||
return failure();
|
||||
|
||||
// Resolve the operand.
|
||||
Type rangeType = pdl::RangeType::get(loopVariableType);
|
||||
Type rangeType = pdl::RangeType::get(loopVariable.type);
|
||||
if (parser.resolveOperand(operandInfo, rangeType, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse the body region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
|
||||
return failure();
|
||||
|
||||
// Parse the attribute dictionary.
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
// Parse the successor.
|
||||
Block *successor;
|
||||
if (parser.parseArrow() || parser.parseSuccessor(successor))
|
||||
if (parser.parseRegion(*body, loopVariable) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
// Parse the successor.
|
||||
parser.parseArrow() || parser.parseSuccessor(successor))
|
||||
return failure();
|
||||
result.addSuccessors(successor);
|
||||
|
||||
result.addSuccessors(successor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -399,15 +399,16 @@ void ForOp::print(OpAsmPrinter &p) {
|
|||
|
||||
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
|
||||
// Parse the induction variable followed by '='.
|
||||
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
|
||||
parser.parseEqual())
|
||||
return failure();
|
||||
|
||||
// Parse loop bounds.
|
||||
Type indexType = builder.getIndexType();
|
||||
if (parser.parseOperand(lb) ||
|
||||
|
||||
OpAsmParser::Argument inductionVariable;
|
||||
inductionVariable.type = indexType;
|
||||
OpAsmParser::UnresolvedOperand lb, ub, step;
|
||||
|
||||
// Parse the induction variable followed by '='.
|
||||
if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
|
||||
// Parse loop bounds.
|
||||
parser.parseOperand(lb) ||
|
||||
parser.resolveOperand(lb, indexType, result.operands) ||
|
||||
parser.parseKeyword("to") || parser.parseOperand(ub) ||
|
||||
parser.resolveOperand(ub, indexType, result.operands) ||
|
||||
|
@ -416,8 +417,8 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
return failure();
|
||||
|
||||
// Parse the optional initial iteration arguments.
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
|
||||
SmallVector<Type, 4> argTypes;
|
||||
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
||||
regionArgs.push_back(inductionVariable);
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
|
||||
|
@ -425,24 +426,26 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
if (parser.parseAssignmentList(regionArgs, operands) ||
|
||||
parser.parseArrowTypeList(result.types))
|
||||
return failure();
|
||||
|
||||
// Resolve input operands.
|
||||
for (auto operandType : llvm::zip(operands, result.types))
|
||||
if (parser.resolveOperand(std::get<0>(operandType),
|
||||
std::get<1>(operandType), result.operands))
|
||||
for (auto argOperandType :
|
||||
llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
|
||||
Type type = std::get<2>(argOperandType);
|
||||
std::get<0>(argOperandType).type = type;
|
||||
if (parser.resolveOperand(std::get<1>(argOperandType), type,
|
||||
result.operands))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
// Induction variable.
|
||||
argTypes.push_back(indexType);
|
||||
// Loop carried variables
|
||||
argTypes.append(result.types.begin(), result.types.end());
|
||||
// Parse the body region.
|
||||
Region *body = result.addRegion();
|
||||
if (regionArgs.size() != argTypes.size())
|
||||
|
||||
if (regionArgs.size() != result.types.size() + 1)
|
||||
return parser.emitError(
|
||||
parser.getNameLoc(),
|
||||
"mismatch in number of loop-carried values and defined values");
|
||||
|
||||
if (parser.parseRegion(*body, regionArgs, argTypes))
|
||||
// Parse the body region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, regionArgs))
|
||||
return failure();
|
||||
|
||||
ForOp::ensureTerminator(*body, builder, result.location);
|
||||
|
@ -1975,9 +1978,8 @@ LogicalResult ParallelOp::verify() {
|
|||
ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
auto &builder = parser.getBuilder();
|
||||
// Parse an opening `(` followed by induction variables followed by `)`
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
|
||||
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
|
||||
/*allowResultNumber=*/false))
|
||||
SmallVector<OpAsmParser::Argument, 4> ivs;
|
||||
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
|
||||
return failure();
|
||||
|
||||
// Parse loop bounds.
|
||||
|
@ -2016,8 +2018,9 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
// Now parse the body.
|
||||
Region *body = result.addRegion();
|
||||
SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
|
||||
if (parser.parseRegion(*body, ivs, types))
|
||||
for (auto &iv : ivs)
|
||||
iv.type = builder.getIndexType();
|
||||
if (parser.parseRegion(*body, ivs))
|
||||
return failure();
|
||||
|
||||
// Set `operand_segment_sizes` attribute.
|
||||
|
@ -2370,7 +2373,8 @@ void WhileOp::getSuccessorRegions(Optional<unsigned> index,
|
|||
/// assignment-list ::= assignment | assignment `,` assignment-list
|
||||
/// assignment ::= ssa-value `=` ssa-value
|
||||
ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
|
||||
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
||||
Region *before = result.addRegion();
|
||||
Region *after = result.addRegion();
|
||||
|
||||
|
@ -2399,10 +2403,13 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
result.operands)))
|
||||
return failure();
|
||||
|
||||
return failure(
|
||||
parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
|
||||
parser.parseKeyword("do") || parser.parseRegion(*after) ||
|
||||
parser.parseOptionalAttrDictWithKeyword(result.attributes));
|
||||
// Propagate the types into the region arguments.
|
||||
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
|
||||
regionArgs[i].type = functionType.getInput(i);
|
||||
|
||||
return failure(parser.parseRegion(*before, regionArgs) ||
|
||||
parser.parseKeyword("do") || parser.parseRegion(*after) ||
|
||||
parser.parseOptionalAttrDictWithKeyword(result.attributes));
|
||||
}
|
||||
|
||||
/// Prints a `while` op.
|
||||
|
|
|
@ -2193,10 +2193,8 @@ LogicalResult spirv::UConvertOp::verify() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
|
||||
SmallVector<NamedAttrList> argAttrs;
|
||||
SmallVector<NamedAttrList> resultAttrs;
|
||||
SmallVector<Type> argTypes;
|
||||
SmallVector<OpAsmParser::Argument> entryArgs;
|
||||
SmallVector<DictionaryAttr> resultAttrs;
|
||||
SmallVector<Type> resultTypes;
|
||||
auto &builder = parser.getBuilder();
|
||||
|
||||
|
@ -2209,10 +2207,13 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
|
|||
// Parse the function signature.
|
||||
bool isVariadic = false;
|
||||
if (function_interface_impl::parseFunctionSignature(
|
||||
parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
|
||||
isVariadic, resultTypes, resultAttrs))
|
||||
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
|
||||
resultAttrs))
|
||||
return failure();
|
||||
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto &arg : entryArgs)
|
||||
argTypes.push_back(arg.type);
|
||||
auto fnType = builder.getFunctionType(argTypes, resultTypes);
|
||||
state.addAttribute(FunctionOpInterface::getTypeAttrName(),
|
||||
TypeAttr::get(fnType));
|
||||
|
@ -2227,15 +2228,13 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
|
|||
return failure();
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
assert(argAttrs.size() == argTypes.size());
|
||||
assert(resultAttrs.size() == resultTypes.size());
|
||||
function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
|
||||
function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs,
|
||||
resultAttrs);
|
||||
|
||||
// Parse the optional function body.
|
||||
auto *body = state.addRegion();
|
||||
OptionalParseResult result = parser.parseOptionalRegion(
|
||||
*body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
|
||||
OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs);
|
||||
return failure(result.hasValue() && failed(*result));
|
||||
}
|
||||
|
||||
|
|
|
@ -13,83 +13,61 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
|
||||
OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
|
||||
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
|
||||
bool &isVariadic) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
static ParseResult
|
||||
parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
|
||||
SmallVectorImpl<OpAsmParser::Argument> &arguments,
|
||||
bool &isVariadic) {
|
||||
|
||||
// The argument list either has to consistently have ssa-id's followed by
|
||||
// types, or just be a type list. It isn't ok to sometimes have SSA ID's and
|
||||
// sometimes not.
|
||||
auto parseArgument = [&]() -> ParseResult {
|
||||
SMLoc loc = parser.getCurrentLocation();
|
||||
|
||||
// Parse argument name if present.
|
||||
OpAsmParser::UnresolvedOperand argument;
|
||||
Type argumentType;
|
||||
auto hadSSAValue = parser.parseOptionalOperand(argument,
|
||||
/*allowResultNumber=*/false);
|
||||
if (hadSSAValue.hasValue()) {
|
||||
if (failed(hadSSAValue.getValue()))
|
||||
return failure(); // Argument was present but malformed.
|
||||
|
||||
// Reject this if the preceding argument was missing a name.
|
||||
if (argNames.empty() && !argTypes.empty())
|
||||
return parser.emitError(loc, "expected type instead of SSA identifier");
|
||||
|
||||
// Parse required type.
|
||||
if (parser.parseColonType(argumentType))
|
||||
return failure();
|
||||
} else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
|
||||
isVariadic = true;
|
||||
return success();
|
||||
} else if (!argNames.empty()) {
|
||||
// Reject this if the preceding argument had a name.
|
||||
return parser.emitError(loc, "expected SSA identifier");
|
||||
} else if (parser.parseType(argumentType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Add the argument type.
|
||||
argTypes.push_back(argumentType);
|
||||
|
||||
// Parse any argument attributes and source location information.
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs) ||
|
||||
parser.parseOptionalLocationSpecifier(argument.sourceLoc))
|
||||
return failure();
|
||||
|
||||
if (!allowAttributes && !attrs.empty())
|
||||
return parser.emitError(loc, "expected arguments without attributes");
|
||||
argAttrs.push_back(attrs);
|
||||
|
||||
// If we had an argument name, then remember the parsed argument.
|
||||
if (!argument.name.empty())
|
||||
argNames.push_back(argument);
|
||||
return success();
|
||||
};
|
||||
|
||||
// Parse the function arguments.
|
||||
// Parse the function arguments. The argument list either has to consistently
|
||||
// have ssa-id's followed by types, or just be a type list. It isn't ok to
|
||||
// sometimes have SSA ID's and sometimes not.
|
||||
isVariadic = false;
|
||||
if (failed(parser.parseOptionalRParen())) {
|
||||
do {
|
||||
unsigned numTypedArguments = argTypes.size();
|
||||
if (parseArgument())
|
||||
return failure();
|
||||
|
||||
SMLoc loc = parser.getCurrentLocation();
|
||||
if (argTypes.size() == numTypedArguments &&
|
||||
succeeded(parser.parseOptionalComma()))
|
||||
return parser.emitError(
|
||||
loc, "variadic arguments must be in the end of the argument list");
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
parser.parseRParen();
|
||||
}
|
||||
return parser.parseCommaSeparatedList(
|
||||
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
|
||||
// Ellipsis must be at end of the list.
|
||||
if (isVariadic)
|
||||
return parser.emitError(
|
||||
parser.getCurrentLocation(),
|
||||
"variadic arguments must be in the end of the argument list");
|
||||
|
||||
return success();
|
||||
// Handle ellipsis as a special case.
|
||||
if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
|
||||
// This is a variadic designator.
|
||||
isVariadic = true;
|
||||
return success(); // Stop parsing arguments.
|
||||
}
|
||||
// Parse argument name if present.
|
||||
OpAsmParser::Argument argument;
|
||||
auto argPresent = parser.parseOptionalArgument(
|
||||
argument, /*allowType=*/true, /*allowAttrs=*/true);
|
||||
if (argPresent.hasValue()) {
|
||||
if (failed(argPresent.getValue()))
|
||||
return failure(); // Present but malformed.
|
||||
|
||||
// Reject this if the preceding argument was missing a name.
|
||||
if (!arguments.empty() && arguments.back().ssaName.name.empty())
|
||||
return parser.emitError(argument.ssaName.location,
|
||||
"expected type instead of SSA identifier");
|
||||
|
||||
} else {
|
||||
argument.ssaName.location = parser.getCurrentLocation();
|
||||
// Otherwise we just have a type list without SSA names. Reject
|
||||
// this if the preceding argument had a name.
|
||||
if (!arguments.empty() && !arguments.back().ssaName.name.empty())
|
||||
return parser.emitError(argument.ssaName.location,
|
||||
"expected SSA identifier");
|
||||
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseType(argument.type) ||
|
||||
parser.parseOptionalAttrDict(attrs) ||
|
||||
parser.parseOptionalLocationSpecifier(argument.sourceLoc))
|
||||
return failure();
|
||||
argument.attrs = attrs.getDictionary(parser.getContext());
|
||||
}
|
||||
arguments.push_back(argument);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
/// Parse a function result list.
|
||||
|
@ -103,7 +81,7 @@ ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
|
|||
///
|
||||
static ParseResult
|
||||
parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<NamedAttrList> &resultAttrs) {
|
||||
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
|
||||
if (failed(parser.parseOptionalLParen())) {
|
||||
// We already know that there is no `(`, so parse a type.
|
||||
// Because there is no `(`, it cannot be a function type.
|
||||
|
@ -120,83 +98,74 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
|
|||
return success();
|
||||
|
||||
// Parse individual function results.
|
||||
do {
|
||||
resultTypes.emplace_back();
|
||||
resultAttrs.emplace_back();
|
||||
if (parser.parseType(resultTypes.back()) ||
|
||||
parser.parseOptionalAttrDict(resultAttrs.back())) {
|
||||
return failure();
|
||||
}
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
|
||||
resultTypes.emplace_back();
|
||||
resultAttrs.emplace_back();
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseType(resultTypes.back()) ||
|
||||
parser.parseOptionalAttrDict(attrs))
|
||||
return failure();
|
||||
resultAttrs.back() = attrs.getDictionary(parser.getContext());
|
||||
return success();
|
||||
}))
|
||||
return failure();
|
||||
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
ParseResult mlir::function_interface_impl::parseFunctionSignature(
|
||||
OpAsmParser &parser, bool allowVariadic,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
|
||||
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
|
||||
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<NamedAttrList> &resultAttrs) {
|
||||
bool allowArgAttrs = true;
|
||||
if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
|
||||
argTypes, argAttrs, isVariadic))
|
||||
SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
|
||||
SmallVectorImpl<Type> &resultTypes,
|
||||
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
|
||||
if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalArrow()))
|
||||
return parseFunctionResultList(parser, resultTypes, resultAttrs);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Implementation of `addArgAndResultAttrs` that is attribute list type
|
||||
/// agnostic.
|
||||
template <typename AttrListT, typename AttrArrayBuildFnT>
|
||||
static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
|
||||
ArrayRef<AttrListT> argAttrs,
|
||||
ArrayRef<AttrListT> resultAttrs,
|
||||
AttrArrayBuildFnT &&buildAttrArrayFn) {
|
||||
auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
|
||||
ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
|
||||
result.addAttribute(function_interface_impl::getArgDictAttrName(),
|
||||
attrDicts);
|
||||
}
|
||||
// Add the attributes to the function results.
|
||||
if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
|
||||
ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
|
||||
result.addAttribute(function_interface_impl::getResultDictAttrName(),
|
||||
attrDicts);
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::function_interface_impl::addArgAndResultAttrs(
|
||||
Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
|
||||
ArrayRef<DictionaryAttr> resultAttrs) {
|
||||
auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
|
||||
return ArrayRef<Attribute>(attrs.data(), attrs.size());
|
||||
auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
|
||||
return attrs && !attrs.empty();
|
||||
};
|
||||
addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
|
||||
// Convert the specified array of dictionary attrs (which may have null
|
||||
// entries) to an ArrayAttr of dictionaries.
|
||||
auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
|
||||
SmallVector<Attribute> attrs;
|
||||
for (auto &dict : dictAttrs)
|
||||
attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
|
||||
return builder.getArrayAttr(attrs);
|
||||
};
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
|
||||
result.addAttribute(function_interface_impl::getArgDictAttrName(),
|
||||
getArrayAttr(argAttrs));
|
||||
|
||||
// Add the attributes to the function results.
|
||||
if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
|
||||
result.addAttribute(function_interface_impl::getResultDictAttrName(),
|
||||
getArrayAttr(resultAttrs));
|
||||
}
|
||||
|
||||
void mlir::function_interface_impl::addArgAndResultAttrs(
|
||||
Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
|
||||
ArrayRef<NamedAttrList> resultAttrs) {
|
||||
MLIRContext *context = builder.getContext();
|
||||
auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
|
||||
return llvm::to_vector<8>(
|
||||
llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
|
||||
return attrList.getDictionary(context);
|
||||
}));
|
||||
};
|
||||
addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
|
||||
Builder &builder, OperationState &result,
|
||||
ArrayRef<OpAsmParser::Argument> args,
|
||||
ArrayRef<DictionaryAttr> resultAttrs) {
|
||||
SmallVector<DictionaryAttr> argAttrs;
|
||||
for (const auto &arg : args)
|
||||
argAttrs.push_back(arg.attrs);
|
||||
addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
|
||||
}
|
||||
|
||||
ParseResult mlir::function_interface_impl::parseFunctionOp(
|
||||
OpAsmParser &parser, OperationState &result, bool allowVariadic,
|
||||
FuncTypeBuilder funcTypeBuilder) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
|
||||
SmallVector<NamedAttrList> argAttrs;
|
||||
SmallVector<NamedAttrList> resultAttrs;
|
||||
SmallVector<Type> argTypes;
|
||||
SmallVector<OpAsmParser::Argument> entryArgs;
|
||||
SmallVector<DictionaryAttr> resultAttrs;
|
||||
SmallVector<Type> resultTypes;
|
||||
auto &builder = parser.getBuilder();
|
||||
|
||||
|
@ -212,11 +181,15 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
|
|||
// Parse the function signature.
|
||||
SMLoc signatureLocation = parser.getCurrentLocation();
|
||||
bool isVariadic = false;
|
||||
if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
|
||||
argAttrs, isVariadic, resultTypes, resultAttrs))
|
||||
if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic,
|
||||
resultTypes, resultAttrs))
|
||||
return failure();
|
||||
|
||||
std::string errorMessage;
|
||||
SmallVector<Type> argTypes;
|
||||
argTypes.reserve(entryArgs.size());
|
||||
for (auto &arg : entryArgs)
|
||||
argTypes.push_back(arg.type);
|
||||
Type type = funcTypeBuilder(builder, argTypes, resultTypes,
|
||||
VariadicFlag(isVariadic), errorMessage);
|
||||
if (!type) {
|
||||
|
@ -246,17 +219,16 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
|
|||
result.attributes.append(parsedAttributes);
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
assert(argAttrs.size() == argTypes.size());
|
||||
assert(resultAttrs.size() == resultTypes.size());
|
||||
addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
|
||||
addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
|
||||
|
||||
// Parse the optional function body. The printer will not print the body if
|
||||
// its empty, so disallow parsing of empty body in the parser.
|
||||
auto *body = result.addRegion();
|
||||
SMLoc loc = parser.getCurrentLocation();
|
||||
OptionalParseResult parseResult = parser.parseOptionalRegion(
|
||||
*body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
|
||||
/*enableNameShadowing=*/false);
|
||||
OptionalParseResult parseResult =
|
||||
parser.parseOptionalRegion(*body, entryArgs,
|
||||
/*enableNameShadowing=*/false);
|
||||
if (parseResult.hasValue()) {
|
||||
if (failed(*parseResult))
|
||||
return failure();
|
||||
|
|
|
@ -301,11 +301,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
|
|||
return success();
|
||||
};
|
||||
|
||||
if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
|
||||
" in attribute dictionary"))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
return parseCommaSeparatedList(Delimiter::Braces, parseElt,
|
||||
" in attribute dictionary");
|
||||
}
|
||||
|
||||
/// Parse a float attribute.
|
||||
|
|
|
@ -249,6 +249,7 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
using UnresolvedOperand = OpAsmParser::UnresolvedOperand;
|
||||
using Argument = OpAsmParser::Argument;
|
||||
|
||||
struct DeferredLocInfo {
|
||||
SMLoc loc;
|
||||
|
@ -364,16 +365,13 @@ public:
|
|||
/// Parse a region into 'region' with the provided entry block arguments.
|
||||
/// 'isIsolatedNameScope' indicates if the naming scope of this region is
|
||||
/// isolated from those above.
|
||||
ParseResult
|
||||
parseRegion(Region ®ion,
|
||||
ArrayRef<std::pair<UnresolvedOperand, Type>> entryArguments,
|
||||
bool isIsolatedNameScope = false);
|
||||
ParseResult parseRegion(Region ®ion, ArrayRef<Argument> entryArguments,
|
||||
bool isIsolatedNameScope = false);
|
||||
|
||||
/// Parse a region body into 'region'.
|
||||
ParseResult
|
||||
parseRegionBody(Region ®ion, SMLoc startLoc,
|
||||
ArrayRef<std::pair<UnresolvedOperand, Type>> entryArguments,
|
||||
bool isIsolatedNameScope);
|
||||
ParseResult parseRegionBody(Region ®ion, SMLoc startLoc,
|
||||
ArrayRef<Argument> entryArguments,
|
||||
bool isIsolatedNameScope);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Block Parsing
|
||||
|
@ -947,7 +945,7 @@ ParseResult OperationParser::parseOperation() {
|
|||
unsigned opResI = 0;
|
||||
for (ResultRecord &resIt : resultIDs) {
|
||||
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
|
||||
if (addDefinition({std::get<2>(resIt), std::get<0>(resIt), subRes, {}},
|
||||
if (addDefinition({std::get<2>(resIt), std::get<0>(resIt), subRes},
|
||||
op->getResult(opResI++)))
|
||||
return failure();
|
||||
}
|
||||
|
@ -1279,10 +1277,8 @@ public:
|
|||
if (parser.parseSSAUse(useInfo, allowResultNumber))
|
||||
return failure();
|
||||
|
||||
result = {useInfo.location, useInfo.name, useInfo.number, {}};
|
||||
|
||||
// Parse a source locator on the operand if present.
|
||||
return parseOptionalLocationSpecifier(result.sourceLoc);
|
||||
result = {useInfo.location, useInfo.name, useInfo.number};
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a single operand if present.
|
||||
|
@ -1321,11 +1317,7 @@ public:
|
|||
}
|
||||
|
||||
auto parseOneOperand = [&]() -> ParseResult {
|
||||
UnresolvedOperand operandOrArg;
|
||||
if (parseOperand(operandOrArg, allowResultNumber))
|
||||
return failure();
|
||||
result.push_back(operandOrArg);
|
||||
return success();
|
||||
return parseOperand(result.emplace_back(), allowResultNumber);
|
||||
};
|
||||
|
||||
if (parseCommaSeparatedList(delimiter, parseOneOperand, " in operand list"))
|
||||
|
@ -1402,52 +1394,88 @@ public:
|
|||
return parser.parseAffineExprOfSSAIds(expr, parseElement);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Argument Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Parse a single argument with the following syntax:
|
||||
///
|
||||
/// `%ssaname : !type { optionalAttrDict} loc(optionalSourceLoc)`
|
||||
///
|
||||
/// If `allowType` is false or `allowAttrs` are false then the respective
|
||||
/// parts of the grammar are not parsed.
|
||||
ParseResult parseArgument(Argument &result, bool allowType = false,
|
||||
bool allowAttrs = false) override {
|
||||
NamedAttrList attrs;
|
||||
if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
|
||||
(allowType && parseColonType(result.type)) ||
|
||||
(allowAttrs && parseOptionalAttrDict(attrs)) ||
|
||||
parseOptionalLocationSpecifier(result.sourceLoc))
|
||||
return failure();
|
||||
result.attrs = attrs.getDictionary(getContext());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a single argument if present.
|
||||
OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
|
||||
bool allowAttrs) override {
|
||||
if (parser.getToken().is(Token::percent_identifier))
|
||||
return parseArgument(result, allowType, allowAttrs);
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
|
||||
Delimiter delimiter, bool allowType,
|
||||
bool allowAttrs) override {
|
||||
// The no-delimiter case has some special handling for the empty case.
|
||||
if (delimiter == Delimiter::None &&
|
||||
parser.getToken().isNot(Token::percent_identifier))
|
||||
return success();
|
||||
|
||||
auto parseOneArgument = [&]() -> ParseResult {
|
||||
return parseArgument(result.emplace_back(), allowType, allowAttrs);
|
||||
};
|
||||
return parseCommaSeparatedList(delimiter, parseOneArgument,
|
||||
" in argument list");
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Region Parsing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Parse a region that takes `arguments` of `argTypes` types. This
|
||||
/// effectively defines the SSA values of `arguments` and assigns their type.
|
||||
ParseResult parseRegion(Region ®ion, ArrayRef<UnresolvedOperand> arguments,
|
||||
ArrayRef<Type> argTypes,
|
||||
ParseResult parseRegion(Region ®ion, ArrayRef<Argument> arguments,
|
||||
bool enableNameShadowing) override {
|
||||
assert(arguments.size() == argTypes.size() &&
|
||||
"mismatching number of arguments and types");
|
||||
|
||||
SmallVector<std::pair<OperationParser::UnresolvedOperand, Type>, 2>
|
||||
regionArguments;
|
||||
for (auto pair : llvm::zip(arguments, argTypes))
|
||||
regionArguments.emplace_back(std::get<0>(pair), std::get<1>(pair));
|
||||
|
||||
// Try to parse the region.
|
||||
(void)isIsolatedFromAbove;
|
||||
assert((!enableNameShadowing || isIsolatedFromAbove) &&
|
||||
"name shadowing is only allowed on isolated regions");
|
||||
if (parser.parseRegion(region, regionArguments, enableNameShadowing))
|
||||
if (parser.parseRegion(region, arguments, enableNameShadowing))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parses a region if present.
|
||||
OptionalParseResult parseOptionalRegion(Region ®ion,
|
||||
ArrayRef<UnresolvedOperand> arguments,
|
||||
ArrayRef<Type> argTypes,
|
||||
ArrayRef<Argument> arguments,
|
||||
bool enableNameShadowing) override {
|
||||
if (parser.getToken().isNot(Token::l_brace))
|
||||
return llvm::None;
|
||||
return parseRegion(region, arguments, argTypes, enableNameShadowing);
|
||||
return parseRegion(region, arguments, enableNameShadowing);
|
||||
}
|
||||
|
||||
/// Parses a region if present. If the region is present, a new region is
|
||||
/// allocated and placed in `region`. If no region is present, `region`
|
||||
/// remains untouched.
|
||||
OptionalParseResult parseOptionalRegion(
|
||||
std::unique_ptr<Region> ®ion, ArrayRef<UnresolvedOperand> arguments,
|
||||
ArrayRef<Type> argTypes, bool enableNameShadowing = false) override {
|
||||
OptionalParseResult
|
||||
parseOptionalRegion(std::unique_ptr<Region> ®ion,
|
||||
ArrayRef<Argument> arguments,
|
||||
bool enableNameShadowing = false) override {
|
||||
if (parser.getToken().isNot(Token::l_brace))
|
||||
return llvm::None;
|
||||
std::unique_ptr<Region> newRegion = std::make_unique<Region>();
|
||||
if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
|
||||
if (parseRegion(*newRegion, arguments, enableNameShadowing))
|
||||
return failure();
|
||||
|
||||
region = std::move(newRegion);
|
||||
|
@ -1492,42 +1520,15 @@ public:
|
|||
/// Parse a list of assignments of the form
|
||||
/// (%x1 = %y1, %x2 = %y2, ...).
|
||||
OptionalParseResult parseOptionalAssignmentList(
|
||||
SmallVectorImpl<UnresolvedOperand> &lhs,
|
||||
SmallVectorImpl<Argument> &lhs,
|
||||
SmallVectorImpl<UnresolvedOperand> &rhs) override {
|
||||
if (failed(parseOptionalLParen()))
|
||||
return llvm::None;
|
||||
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
UnresolvedOperand regionArg, operand;
|
||||
if (parseOperand(regionArg, /*allowResultNumber=*/false) ||
|
||||
parseEqual() || parseOperand(operand))
|
||||
if (parseArgument(lhs.emplace_back()) || parseEqual() ||
|
||||
parseOperand(rhs.emplace_back()))
|
||||
return failure();
|
||||
lhs.push_back(regionArg);
|
||||
rhs.push_back(operand);
|
||||
return success();
|
||||
};
|
||||
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
|
||||
}
|
||||
|
||||
/// Parse a list of assignments of the form
|
||||
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
|
||||
OptionalParseResult
|
||||
parseOptionalAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
|
||||
SmallVectorImpl<UnresolvedOperand> &rhs,
|
||||
SmallVectorImpl<Type> &types) override {
|
||||
if (failed(parseOptionalLParen()))
|
||||
return llvm::None;
|
||||
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
UnresolvedOperand regionArg, operand;
|
||||
Type type;
|
||||
if (parseOperand(regionArg, /*allowResultNumber=*/false) ||
|
||||
parseEqual() || parseOperand(operand) || parseColon() ||
|
||||
parseType(type))
|
||||
return failure();
|
||||
lhs.push_back(regionArg);
|
||||
rhs.push_back(operand);
|
||||
types.push_back(type);
|
||||
return success();
|
||||
};
|
||||
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
|
||||
|
@ -1749,11 +1750,9 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
|
|||
// Region Parsing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult OperationParser::parseRegion(
|
||||
Region ®ion,
|
||||
ArrayRef<std::pair<OperationParser::UnresolvedOperand, Type>>
|
||||
entryArguments,
|
||||
bool isIsolatedNameScope) {
|
||||
ParseResult OperationParser::parseRegion(Region ®ion,
|
||||
ArrayRef<Argument> entryArguments,
|
||||
bool isIsolatedNameScope) {
|
||||
// Parse the '{'.
|
||||
Token lBraceTok = getToken();
|
||||
if (parseToken(Token::l_brace, "expected '{' to begin a region"))
|
||||
|
@ -1778,11 +1777,9 @@ ParseResult OperationParser::parseRegion(
|
|||
return success();
|
||||
}
|
||||
|
||||
ParseResult OperationParser::parseRegionBody(
|
||||
Region ®ion, SMLoc startLoc,
|
||||
ArrayRef<std::pair<OperationParser::UnresolvedOperand, Type>>
|
||||
entryArguments,
|
||||
bool isIsolatedNameScope) {
|
||||
ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc,
|
||||
ArrayRef<Argument> entryArguments,
|
||||
bool isIsolatedNameScope) {
|
||||
auto currentPt = opBuilder.saveInsertionPoint();
|
||||
|
||||
// Push a new named value scope.
|
||||
|
@ -1798,14 +1795,14 @@ ParseResult OperationParser::parseRegionBody(
|
|||
if (state.asmState && getToken().isNot(Token::caret_identifier))
|
||||
state.asmState->addDefinition(block, startLoc);
|
||||
|
||||
// Add arguments to the entry block.
|
||||
if (!entryArguments.empty()) {
|
||||
// Add arguments to the entry block if we had the form with explicit names.
|
||||
if (!entryArguments.empty() && !entryArguments[0].ssaName.name.empty()) {
|
||||
// If we had named arguments, then don't allow a block name.
|
||||
if (getToken().is(Token::caret_identifier))
|
||||
return emitError("invalid block name in region with named arguments");
|
||||
|
||||
for (auto &placeholderArgPair : entryArguments) {
|
||||
auto &argInfo = placeholderArgPair.first;
|
||||
for (auto &entryArg : entryArguments) {
|
||||
auto &argInfo = entryArg.ssaName;
|
||||
|
||||
// Ensure that the argument was not already defined.
|
||||
if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) {
|
||||
|
@ -1815,10 +1812,10 @@ ParseResult OperationParser::parseRegionBody(
|
|||
.attachNote(getEncodedSourceLocation(*defLoc))
|
||||
<< "previously referenced here";
|
||||
}
|
||||
Location loc = argInfo.sourceLoc.hasValue()
|
||||
? argInfo.sourceLoc.getValue()
|
||||
Location loc = entryArg.sourceLoc.hasValue()
|
||||
? entryArg.sourceLoc.getValue()
|
||||
: getEncodedSourceLocation(argInfo.location);
|
||||
BlockArgument arg = block->addArgument(placeholderArgPair.second, loc);
|
||||
BlockArgument arg = block->addArgument(entryArg.type, loc);
|
||||
|
||||
// Add a definition of this arg to the assembly state if provided.
|
||||
if (state.asmState)
|
||||
|
|
|
@ -202,7 +202,7 @@ module attributes {gpu.container_module} {
|
|||
|
||||
module attributes {gpu.container_module} {
|
||||
func.func @launch_func_kernel_operand_attr(%sz : index) {
|
||||
// expected-error@+1 {{expected arguments without attributes}}
|
||||
// expected-error@+1 {{expected ')' in argument list}}
|
||||
gpu.launch_func @foo::@bar blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%sz : index {foo})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -13,8 +13,9 @@ func.func @inline_notation() -> i32 {
|
|||
// CHECK: arith.constant 4 : index loc(callsite("foo" at "mysource.cc":10:8))
|
||||
%2 = arith.constant 4 : index loc(callsite("foo" at "mysource.cc":10:8))
|
||||
|
||||
// CHECK: affine.for %arg0 loc("IVlocation") = 0 to 8 {
|
||||
// CHECK: } loc(fused["foo", "mysource.cc":10:8])
|
||||
affine.for %i0 = 0 to 8 {
|
||||
affine.for %i0 loc("IVlocation") = 0 to 8 {
|
||||
} loc(fused["foo", "mysource.cc":10:8])
|
||||
|
||||
// CHECK: } loc(fused<"myPass">["foo", "foo2"])
|
||||
|
|
|
@ -691,18 +691,16 @@ static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
|
|||
|
||||
ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::UnresolvedOperand argInfo;
|
||||
Type argType = parser.getBuilder().getIndexType();
|
||||
|
||||
// Parse the input operand.
|
||||
if (parser.parseOperand(argInfo) ||
|
||||
parser.resolveOperand(argInfo, argType, result.operands))
|
||||
OpAsmParser::Argument argInfo;
|
||||
argInfo.type = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperand(argInfo.ssaName) ||
|
||||
parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
|
||||
return failure();
|
||||
|
||||
// Parse the body region, and reuse the operand info as the argument info.
|
||||
Region *body = result.addRegion();
|
||||
return parser.parseRegion(*body, argInfo, argType,
|
||||
/*enableNameShadowing=*/true);
|
||||
return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
|
||||
}
|
||||
|
||||
void IsolatedRegionOp::print(OpAsmPrinter &p) {
|
||||
|
@ -930,17 +928,16 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
|
||||
SmallVector<OpAsmParser::Argument, 4> ivsInfo;
|
||||
// Parse list of region arguments without a delimiter.
|
||||
if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None,
|
||||
/*allowResultNumber=*/false))
|
||||
if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
|
||||
return failure();
|
||||
|
||||
// Parse the body region.
|
||||
Region *body = result.addRegion();
|
||||
auto &builder = parser.getBuilder();
|
||||
SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
|
||||
return parser.parseRegion(*body, ivsInfo, argTypes);
|
||||
for (auto &iv : ivsInfo)
|
||||
iv.type = parser.getBuilder().getIndexType();
|
||||
return parser.parseRegion(*body, ivsInfo);
|
||||
}
|
||||
|
||||
void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
|
||||
|
|
Loading…
Reference in New Issue