remove select check

This commit is contained in:
huangxinjing 2024-04-29 21:00:38 +08:00
parent a87635b6f6
commit 94dbe55add
2 changed files with 27 additions and 1 deletions

View File

@ -29,6 +29,22 @@
namespace mindspore {
namespace parallel {
bool SelectInfo::CheckIsBroadcast(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2) {
int size1 = shape1.size();
int size2 = shape2.size();
if (size1 > size2 || size1 < size2) {
return false;
}
for (int i = 0; i < size1 && i < size2; ++i) {
if (shape1[i] != shape2[i] && (shape1[i] != 1 && shape2[i] != 1)) {
return false;
}
}
return true;
}
Status SelectInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
@ -40,8 +56,17 @@ Status SelectInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_LOG(ERROR) << name_ << ": The size of inputs shape must be 3";
return FAILED;
}
if ((inputs_shape_[0] != inputs_shape_[1]) && !CheckIsBroadcast(inputs_shape_[0], inputs_shape_[1])) {
MS_LOG(ERROR) << name_ << ": Now we only support the case that all three input shapes are equal";
return FAILED;
}
if ((inputs_shape_[0] != inputs_shape_[1]) || (inputs_shape_[1] != inputs_shape_[2])) {
if ((inputs_shape_[0] != inputs_shape_[2]) && !CheckIsBroadcast(inputs_shape_[0], inputs_shape_[2])) {
MS_LOG(ERROR) << name_ << ": Now we only support the case that all three input shapes are equal";
return FAILED;
}
if (inputs_shape_[1] != inputs_shape_[2]) {
MS_LOG(ERROR) << name_ << ": Now we only support the case that all three input shapes are equal";
return FAILED;
}

View File

@ -46,6 +46,7 @@ class SelectInfo : public OperatorInfo {
Status InferForwardCommunication() override { return SUCCESS; }
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
bool CheckIsBroadcast(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2);
};
class BetaincInfo : public SelectInfo {