[MLIR][Linalg] Make detensoring cost-model more flexible.

So far, the CF cost-model for detensoring was limited to discovering
pure CF structures. This means, if while discovering the CF component,
the cost-model found any op that is not detensorable, it gives up on
detensoring altogether. This patch makes it a bit more flexible by
cleaning-up the detensorable component from non-detensorable ops without
giving up entirely.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D109965
This commit is contained in:
KareemErgawy-TomTom 2021-09-20 09:35:42 +02:00
parent 7f6a4826ac
commit bdcf4b9b96
2 changed files with 61 additions and 32 deletions

View File

@ -272,25 +272,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
/// Detensorize linalg ops involved in control-flow within a function.
///
/// This model starts from CondBranchOps within a function. For each cond_br,
/// the model then walks the use-def chain for the branch's condition
/// backwards in order to understand where the condition's value comes from.
/// If the condition value is (indirectly) computed by a linalg op that can be
/// detensored, the model then continues walking the use-def chain in order to
/// understand where the linalg op's operands come from. This leads to
/// discovering a "detensoring component". A detensoring component is the set
/// of operations + block arguments that are involved in control-flow AND can
/// be detensored.
///
/// For examples where this model succeeds to discover a detensoring
/// component, see:
/// - test/Dialect/Linalg/detensorize_while.mlir
/// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
///
/// For an example where this model marks control-flow as "non-detensorable",
/// see:
/// - test/Dialect/Linalg/detensorize_while_failure.mlir
class PureControlFlowDetectionModel : public CostModel {
/// This model starts from BranchOps and CondBranchOps within a function. For
/// each such branch, the model then walks the use-def chain for the branch's
/// condition backwards in order to understand where the condition's value
/// comes from. If the condition value is (indirectly) computed by a linalg op
/// that can be detensored, the model then continues walking the use-def chain
/// in order to understand where the linalg op's operands come from. This
/// leads to discovering a "detensoring component". A detensoring component is
/// the set of operations + block arguments that are involved in control-flow
/// AND can be detensored.
class ControlFlowDetectionModel : public CostModel {
public:
void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
DenseSet<Operation *> &opsToDetensor,
@ -376,19 +367,19 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
for (PredecessorIterator pred = ownerBlock->pred_begin();
pred != ownerBlock->pred_end(); ++pred) {
BranchOpInterface terminator =
BranchOpInterface predTerminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
// TODO: For now, we give up if any of the control-flow components
// in a function is not detensorable. Fix that.
if (!terminator) {
if (!predTerminator) {
opsToDetensor.clear();
blockArgsToDetensor.clear();
return;
}
auto ownerBlockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!ownerBlockOperands || ownerBlockOperands->empty())
continue;
@ -418,12 +409,10 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
if (opsToDetensor.count(genericOp))
continue;
// TODO: For now, we give up if any of the control-flow components
// in a function is not detensorable. Fix that.
// The op should not be detensored, give up on it but continue with
// discovering the rest of the control-flow component.
if (!shouldBeDetensored(genericOp, typeConverter)) {
opsToDetensor.clear();
blockArgsToDetensor.clear();
return;
continue;
}
opsToDetensor.insert(genericOp);
@ -452,6 +441,47 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
workList.push_back(scalarOpOperand);
}
// Since the cost model gives up on some ops (see the details of step 2.2
// above), block arguments that correspond to the values produced by those
// ops should not be detensored as well.
DenseSet<BlockArgument> blockArgsToRemove;
for (auto &blockArg : blockArgsToDetensor) {
Block *block = blockArg.getParentBlock();
// For the potentially detensorable block argument, find the
// correpsonding operands in predecessor blocks.
for (PredecessorIterator pred = block->pred_begin();
pred != block->pred_end(); ++pred) {
BranchOpInterface terminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!blockOperands || blockOperands->empty())
continue;
Operation *definingOp =
terminator
->getOperand(blockOperands->getBeginOperandIndex() +
blockArg.getArgNumber())
.getDefiningOp();
// If the operand is defined by a GenericOp that will not be
// detensored, then do not detensor the corresponding block argument.
if (dyn_cast_or_null<GenericOp>(definingOp) &&
opsToDetensor.count(definingOp) == 0) {
blockArgsToRemove.insert(blockArg);
break;
}
}
}
for (auto &blockArg : blockArgsToRemove) {
blockArgsToDetensor.erase(blockArg);
}
}
};
@ -487,7 +517,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
blockArgsToDetensor);
} else {
PureControlFlowDetectionModel costModel;
ControlFlowDetectionModel costModel;
costModel.compute(getFunction(), typeConverter, opsToDetensor,
blockArgsToDetensor);
}

View File

@ -93,15 +93,14 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
// DET-ALL: return %{{.*}} : tensor<i32>
// DET-ALL: }
// Try to detensor pure control-flow. However, that fails since the potential
// detensorable component contains some ops that cannot be detensored.
//
// DET-CF-LABEL: func @main
// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {