[mlir] NFC - Add help functions to scf.ForOp

This revision adds 2 helperr functions that help tie OpOperands and
BlockArguments in scf.ForOp without having to use the internal implementation
details.
This commit is contained in:
Nicolas Vasilache 2021-04-09 10:37:44 +00:00
parent 3d816537df
commit ca0e250ec6
1 changed files with 22 additions and 2 deletions

View File

@ -166,7 +166,7 @@ def ForOp : SCF_Op<"for",
Value getInductionVar() { return getBody()->getArgument(0); }
Block::BlockArgListType getRegionIterArgs() {
return getBody()->getArguments().drop_front();
return getBody()->getArguments().drop_front(getNumInductionVars());
}
Operation::operand_range getIterOperands() {
return getOperands().drop_front(getNumControlOperands());
@ -176,9 +176,11 @@ def ForOp : SCF_Op<"for",
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
/// Number of induction variables, always 1 for scf::ForOp.
unsigned getNumInductionVars() { return 1; }
/// Number of region arguments for loop-carried values
unsigned getNumRegionIterArgs() {
return getBody()->getNumArguments() - 1;
return getBody()->getNumArguments() - getNumInductionVars();
}
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
@ -190,6 +192,24 @@ def ForOp : SCF_Op<"for",
unsigned getNumIterOperands() {
return getOperation()->getNumOperands() - getNumControlOperands();
}
/// Get the region iter arg that corresponds to an OpOperand.
BlockArgument getRegionIterArgForOpOperand(OpOperand &opOperand) {
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
"expected an iter args operand");
assert(opOperand.getOwner() == getOperation() &&
"opOperand does not belong to this scf::ForOp operation");
return getRegionIterArgs()[
opOperand.getOperandNumber() - getNumControlOperands()];
}
/// Get the OpOperand& that corresponds to a region iter arg.
OpOperand &getOpOperandForRegionIterArg(BlockArgument bbArg) {
assert(bbArg.getArgNumber() >= getNumInductionVars() &&
"expected a bbArg that is not an induction variable");
assert(bbArg.getOwner()->getParentOp() == getOperation() &&
"bbArg does not belong to the scf::ForOp body");
return getOperation()->getOpOperand(
getNumControlOperands() + bbArg.getArgNumber() - getNumInductionVars());
}
/// Return operands used when entering the region at 'index'. These operands
/// correspond to the loop iterator operands, i.e., those exclusing the