forked from OSchip/llvm-project
[MLIR][Linalg] Make detensoring cost-model more flexible.
So far, the CF cost-model for detensoring was limited to discovering pure CF structures. This means, if while discovering the CF component, the cost-model found any op that is not detensorable, it gives up on detensoring altogether. This patch makes it a bit more flexible by cleaning-up the detensorable component from non-detensorable ops without giving up entirely. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D109965
This commit is contained in:
parent
7f6a4826ac
commit
bdcf4b9b96
|
@ -272,25 +272,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|||
|
||||
/// Detensorize linalg ops involved in control-flow within a function.
|
||||
///
|
||||
/// This model starts from CondBranchOps within a function. For each cond_br,
|
||||
/// the model then walks the use-def chain for the branch's condition
|
||||
/// backwards in order to understand where the condition's value comes from.
|
||||
/// If the condition value is (indirectly) computed by a linalg op that can be
|
||||
/// detensored, the model then continues walking the use-def chain in order to
|
||||
/// understand where the linalg op's operands come from. This leads to
|
||||
/// discovering a "detensoring component". A detensoring component is the set
|
||||
/// of operations + block arguments that are involved in control-flow AND can
|
||||
/// be detensored.
|
||||
///
|
||||
/// For examples where this model succeeds to discover a detensoring
|
||||
/// component, see:
|
||||
/// - test/Dialect/Linalg/detensorize_while.mlir
|
||||
/// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
|
||||
///
|
||||
/// For an example where this model marks control-flow as "non-detensorable",
|
||||
/// see:
|
||||
/// - test/Dialect/Linalg/detensorize_while_failure.mlir
|
||||
class PureControlFlowDetectionModel : public CostModel {
|
||||
/// This model starts from BranchOps and CondBranchOps within a function. For
|
||||
/// each such branch, the model then walks the use-def chain for the branch's
|
||||
/// condition backwards in order to understand where the condition's value
|
||||
/// comes from. If the condition value is (indirectly) computed by a linalg op
|
||||
/// that can be detensored, the model then continues walking the use-def chain
|
||||
/// in order to understand where the linalg op's operands come from. This
|
||||
/// leads to discovering a "detensoring component". A detensoring component is
|
||||
/// the set of operations + block arguments that are involved in control-flow
|
||||
/// AND can be detensored.
|
||||
class ControlFlowDetectionModel : public CostModel {
|
||||
public:
|
||||
void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
|
||||
DenseSet<Operation *> &opsToDetensor,
|
||||
|
@ -376,19 +367,19 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|||
|
||||
for (PredecessorIterator pred = ownerBlock->pred_begin();
|
||||
pred != ownerBlock->pred_end(); ++pred) {
|
||||
BranchOpInterface terminator =
|
||||
BranchOpInterface predTerminator =
|
||||
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
|
||||
|
||||
// TODO: For now, we give up if any of the control-flow components
|
||||
// in a function is not detensorable. Fix that.
|
||||
if (!terminator) {
|
||||
if (!predTerminator) {
|
||||
opsToDetensor.clear();
|
||||
blockArgsToDetensor.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
auto ownerBlockOperands =
|
||||
terminator.getSuccessorOperands(pred.getSuccessorIndex());
|
||||
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
|
||||
|
||||
if (!ownerBlockOperands || ownerBlockOperands->empty())
|
||||
continue;
|
||||
|
@ -418,12 +409,10 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|||
if (opsToDetensor.count(genericOp))
|
||||
continue;
|
||||
|
||||
// TODO: For now, we give up if any of the control-flow components
|
||||
// in a function is not detensorable. Fix that.
|
||||
// The op should not be detensored, give up on it but continue with
|
||||
// discovering the rest of the control-flow component.
|
||||
if (!shouldBeDetensored(genericOp, typeConverter)) {
|
||||
opsToDetensor.clear();
|
||||
blockArgsToDetensor.clear();
|
||||
return;
|
||||
continue;
|
||||
}
|
||||
|
||||
opsToDetensor.insert(genericOp);
|
||||
|
@ -452,6 +441,47 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|||
for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
|
||||
workList.push_back(scalarOpOperand);
|
||||
}
|
||||
|
||||
// Since the cost model gives up on some ops (see the details of step 2.2
|
||||
// above), block arguments that correspond to the values produced by those
|
||||
// ops should not be detensored as well.
|
||||
|
||||
DenseSet<BlockArgument> blockArgsToRemove;
|
||||
|
||||
for (auto &blockArg : blockArgsToDetensor) {
|
||||
Block *block = blockArg.getParentBlock();
|
||||
|
||||
// For the potentially detensorable block argument, find the
|
||||
// correpsonding operands in predecessor blocks.
|
||||
for (PredecessorIterator pred = block->pred_begin();
|
||||
pred != block->pred_end(); ++pred) {
|
||||
BranchOpInterface terminator =
|
||||
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
|
||||
auto blockOperands =
|
||||
terminator.getSuccessorOperands(pred.getSuccessorIndex());
|
||||
|
||||
if (!blockOperands || blockOperands->empty())
|
||||
continue;
|
||||
|
||||
Operation *definingOp =
|
||||
terminator
|
||||
->getOperand(blockOperands->getBeginOperandIndex() +
|
||||
blockArg.getArgNumber())
|
||||
.getDefiningOp();
|
||||
|
||||
// If the operand is defined by a GenericOp that will not be
|
||||
// detensored, then do not detensor the corresponding block argument.
|
||||
if (dyn_cast_or_null<GenericOp>(definingOp) &&
|
||||
opsToDetensor.count(definingOp) == 0) {
|
||||
blockArgsToRemove.insert(blockArg);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &blockArg : blockArgsToRemove) {
|
||||
blockArgsToDetensor.erase(blockArg);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -487,7 +517,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|||
blockArgsToDetensor);
|
||||
|
||||
} else {
|
||||
PureControlFlowDetectionModel costModel;
|
||||
ControlFlowDetectionModel costModel;
|
||||
costModel.compute(getFunction(), typeConverter, opsToDetensor,
|
||||
blockArgsToDetensor);
|
||||
}
|
||||
|
|
|
@ -93,15 +93,14 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
|
|||
// DET-ALL: return %{{.*}} : tensor<i32>
|
||||
// DET-ALL: }
|
||||
|
||||
// Try to detensor pure control-flow. However, that fails since the potential
|
||||
// detensorable component contains some ops that cannot be detensored.
|
||||
//
|
||||
// DET-CF-LABEL: func @main
|
||||
// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
|
||||
// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
|
||||
// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>)
|
||||
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
|
||||
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
|
||||
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
|
||||
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
|
||||
// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
|
||||
// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
|
||||
// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
|
||||
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
|
Loading…
Reference in New Issue