[fir][NFC] Update fir.iterate_while op

Add getFinalValueAttrName() and remove specified number of
inlined elements for SmallVector. This patch is mainly motivated
to help the upstreaming effort.

This patch is part of the upstreaming effort from fir-dev branch.

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Valentin Clement <clementval@gmail.com>

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D110710
This commit is contained in:
Eric Schweitz 2021-09-29 18:12:17 +02:00 committed by Valentin Clement
parent 9a640a1cb8
commit 6e2afdb7f5
No known key found for this signature in database
GPG Key ID: 086D54783C928776
2 changed files with 16 additions and 12 deletions

View File

@ -2577,6 +2577,9 @@ def fir_IterWhileOp : region_Op<"iterate_while",
];
let extraClassDeclaration = [{
static constexpr llvm::StringRef getFinalValueAttrName() {
return "finalValue";
}
mlir::Block *getBody() { return &region().front(); }
mlir::Value getIterateVar() { return getBody()->getArgument(1); }
mlir::Value getInductionVar() { return getBody()->getArgument(0); }

View File

@ -799,7 +799,7 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
result.addOperands({lb, ub, step, iterate});
if (finalCountValue) {
result.addTypes(builder.getIndexType());
result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr());
result.addAttribute(getFinalValueAttrName(), builder.getUnitAttr());
}
result.addTypes(iterate.getType());
result.addOperands(iterArgs);
@ -841,7 +841,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
return mlir::failure();
// Parse the initial iteration arguments.
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs;
llvm::SmallVector<mlir::OpAsmParser::OperandType> regionArgs;
auto prependCount = false;
// Induction variable.
@ -849,8 +849,8 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
regionArgs.push_back(iterateVar);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
llvm::SmallVector<mlir::Type, 4> regionTypes;
llvm::SmallVector<mlir::OpAsmParser::OperandType> operands;
llvm::SmallVector<mlir::Type> regionTypes;
// Parse assignment list and results type list.
if (parser.parseAssignmentList(regionArgs, operands) ||
parser.parseArrowTypeList(regionTypes))
@ -860,9 +860,9 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
llvm::ArrayRef<mlir::Type> resTypes = regionTypes;
resTypes = prependCount ? resTypes.drop_front(2) : resTypes;
// Resolve input operands.
for (auto operand_type : llvm::zip(operands, resTypes))
if (parser.resolveOperand(std::get<0>(operand_type),
std::get<1>(operand_type), result.operands))
for (auto operandType : llvm::zip(operands, resTypes))
if (parser.resolveOperand(std::get<0>(operandType),
std::get<1>(operandType), result.operands))
return failure();
if (prependCount) {
result.addTypes(regionTypes);
@ -871,7 +871,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
result.addTypes(resTypes);
}
} else if (succeeded(parser.parseOptionalArrow())) {
llvm::SmallVector<mlir::Type, 4> typeList;
llvm::SmallVector<mlir::Type> typeList;
if (parser.parseLParen() || parser.parseTypeList(typeList) ||
parser.parseRParen())
return failure();
@ -888,10 +888,10 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return mlir::failure();
llvm::SmallVector<mlir::Type, 4> argTypes;
llvm::SmallVector<mlir::Type> argTypes;
// Induction variable (hidden)
if (prependCount)
result.addAttribute(IterWhileOp::finalValueAttrName(result.name),
result.addAttribute(IterWhileOp::getFinalValueAttrName(),
builder.getUnitAttr());
else
argTypes.push_back(indexType);
@ -984,7 +984,8 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) {
} else if (op.finalValue()) {
p << " -> (" << op.getResultTypes() << ')';
}
p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"finalValue"});
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
{IterWhileOp::getFinalValueAttrName()});
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
@ -997,7 +998,7 @@ bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) {
mlir::LogicalResult
fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
for (auto op : ops)
for (auto *op : ops)
op->moveBefore(*this);
return success();
}