forked from OSchip/llvm-project
[fir][NFC] Update fir.iterate_while op
Add getFinalValueAttrName() and remove specified number of inlined elements for SmallVector. This patch is mainly motivated to help the upstreaming effort. This patch is part of the upstreaming effort from fir-dev branch. Co-authored-by: Jean Perier <jperier@nvidia.com> Co-authored-by: Valentin Clement <clementval@gmail.com> Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D110710
This commit is contained in:
parent
9a640a1cb8
commit
6e2afdb7f5
|
@ -2577,6 +2577,9 @@ def fir_IterWhileOp : region_Op<"iterate_while",
|
||||||
];
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
static constexpr llvm::StringRef getFinalValueAttrName() {
|
||||||
|
return "finalValue";
|
||||||
|
}
|
||||||
mlir::Block *getBody() { return ®ion().front(); }
|
mlir::Block *getBody() { return ®ion().front(); }
|
||||||
mlir::Value getIterateVar() { return getBody()->getArgument(1); }
|
mlir::Value getIterateVar() { return getBody()->getArgument(1); }
|
||||||
mlir::Value getInductionVar() { return getBody()->getArgument(0); }
|
mlir::Value getInductionVar() { return getBody()->getArgument(0); }
|
||||||
|
|
|
@ -799,7 +799,7 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
|
||||||
result.addOperands({lb, ub, step, iterate});
|
result.addOperands({lb, ub, step, iterate});
|
||||||
if (finalCountValue) {
|
if (finalCountValue) {
|
||||||
result.addTypes(builder.getIndexType());
|
result.addTypes(builder.getIndexType());
|
||||||
result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr());
|
result.addAttribute(getFinalValueAttrName(), builder.getUnitAttr());
|
||||||
}
|
}
|
||||||
result.addTypes(iterate.getType());
|
result.addTypes(iterate.getType());
|
||||||
result.addOperands(iterArgs);
|
result.addOperands(iterArgs);
|
||||||
|
@ -841,7 +841,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
// Parse the initial iteration arguments.
|
// Parse the initial iteration arguments.
|
||||||
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs;
|
llvm::SmallVector<mlir::OpAsmParser::OperandType> regionArgs;
|
||||||
auto prependCount = false;
|
auto prependCount = false;
|
||||||
|
|
||||||
// Induction variable.
|
// Induction variable.
|
||||||
|
@ -849,8 +849,8 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
|
||||||
regionArgs.push_back(iterateVar);
|
regionArgs.push_back(iterateVar);
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
|
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
|
||||||
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
|
llvm::SmallVector<mlir::OpAsmParser::OperandType> operands;
|
||||||
llvm::SmallVector<mlir::Type, 4> regionTypes;
|
llvm::SmallVector<mlir::Type> regionTypes;
|
||||||
// Parse assignment list and results type list.
|
// Parse assignment list and results type list.
|
||||||
if (parser.parseAssignmentList(regionArgs, operands) ||
|
if (parser.parseAssignmentList(regionArgs, operands) ||
|
||||||
parser.parseArrowTypeList(regionTypes))
|
parser.parseArrowTypeList(regionTypes))
|
||||||
|
@ -860,9 +860,9 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
|
||||||
llvm::ArrayRef<mlir::Type> resTypes = regionTypes;
|
llvm::ArrayRef<mlir::Type> resTypes = regionTypes;
|
||||||
resTypes = prependCount ? resTypes.drop_front(2) : resTypes;
|
resTypes = prependCount ? resTypes.drop_front(2) : resTypes;
|
||||||
// Resolve input operands.
|
// Resolve input operands.
|
||||||
for (auto operand_type : llvm::zip(operands, resTypes))
|
for (auto operandType : llvm::zip(operands, resTypes))
|
||||||
if (parser.resolveOperand(std::get<0>(operand_type),
|
if (parser.resolveOperand(std::get<0>(operandType),
|
||||||
std::get<1>(operand_type), result.operands))
|
std::get<1>(operandType), result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
if (prependCount) {
|
if (prependCount) {
|
||||||
result.addTypes(regionTypes);
|
result.addTypes(regionTypes);
|
||||||
|
@ -871,7 +871,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
|
||||||
result.addTypes(resTypes);
|
result.addTypes(resTypes);
|
||||||
}
|
}
|
||||||
} else if (succeeded(parser.parseOptionalArrow())) {
|
} else if (succeeded(parser.parseOptionalArrow())) {
|
||||||
llvm::SmallVector<mlir::Type, 4> typeList;
|
llvm::SmallVector<mlir::Type> typeList;
|
||||||
if (parser.parseLParen() || parser.parseTypeList(typeList) ||
|
if (parser.parseLParen() || parser.parseTypeList(typeList) ||
|
||||||
parser.parseRParen())
|
parser.parseRParen())
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -888,10 +888,10 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
|
||||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> argTypes;
|
llvm::SmallVector<mlir::Type> argTypes;
|
||||||
// Induction variable (hidden)
|
// Induction variable (hidden)
|
||||||
if (prependCount)
|
if (prependCount)
|
||||||
result.addAttribute(IterWhileOp::finalValueAttrName(result.name),
|
result.addAttribute(IterWhileOp::getFinalValueAttrName(),
|
||||||
builder.getUnitAttr());
|
builder.getUnitAttr());
|
||||||
else
|
else
|
||||||
argTypes.push_back(indexType);
|
argTypes.push_back(indexType);
|
||||||
|
@ -984,7 +984,8 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) {
|
||||||
} else if (op.finalValue()) {
|
} else if (op.finalValue()) {
|
||||||
p << " -> (" << op.getResultTypes() << ')';
|
p << " -> (" << op.getResultTypes() << ')';
|
||||||
}
|
}
|
||||||
p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"finalValue"});
|
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
|
||||||
|
{IterWhileOp::getFinalValueAttrName()});
|
||||||
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/true);
|
/*printBlockTerminators=*/true);
|
||||||
}
|
}
|
||||||
|
@ -997,7 +998,7 @@ bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) {
|
||||||
|
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult
|
||||||
fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
||||||
for (auto op : ops)
|
for (auto *op : ops)
|
||||||
op->moveBefore(*this);
|
op->moveBefore(*this);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue