!68928 [auto parallel]Fix select check error
Merge pull request !68928 from huangxinjing/fix_select
This commit is contained in:
commit
807ebe1512
|
@ -29,6 +29,22 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
bool SelectInfo::CheckIsBroadcast(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2) {
|
||||
int size1 = shape1.size();
|
||||
int size2 = shape2.size();
|
||||
|
||||
if (size1 > size2 || size1 < size2) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < size1 && i < size2; ++i) {
|
||||
if (shape1[i] != shape2[i] && (shape1[i] != 1 && shape2[i] != 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Status SelectInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
MS_EXCEPTION_IF_NULL(strategy);
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
|
@ -40,8 +56,17 @@ Status SelectInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
MS_LOG(ERROR) << name_ << ": The size of inputs shape must be 3";
|
||||
return FAILED;
|
||||
}
|
||||
if ((inputs_shape_[0] != inputs_shape_[1]) && !CheckIsBroadcast(inputs_shape_[0], inputs_shape_[1])) {
|
||||
MS_LOG(ERROR) << name_ << ": Now we only support the case that all three input shapes are equal";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if ((inputs_shape_[0] != inputs_shape_[1]) || (inputs_shape_[1] != inputs_shape_[2])) {
|
||||
if ((inputs_shape_[0] != inputs_shape_[2]) && !CheckIsBroadcast(inputs_shape_[0], inputs_shape_[2])) {
|
||||
MS_LOG(ERROR) << name_ << ": Now we only support the case that all three input shapes are equal";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (inputs_shape_[1] != inputs_shape_[2]) {
|
||||
MS_LOG(ERROR) << name_ << ": Now we only support the case that all three input shapes are equal";
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ class SelectInfo : public OperatorInfo {
|
|||
Status InferForwardCommunication() override { return SUCCESS; }
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
bool CheckIsBroadcast(const std::vector<int64_t> &shape1, const std::vector<int64_t> &shape2);
|
||||
};
|
||||
|
||||
class BetaincInfo : public SelectInfo {
|
||||
|
|
Loading…
Reference in New Issue