forked from mindspore-Ecosystem/mindspore
check strategy for conv2d transpose
This commit is contained in:
parent
a27c7b2436
commit
9569cc665f
|
@ -190,15 +190,16 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Create group failed, group name is " << group_name;
|
||||
return Status::FAILED;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
|
||||
group_info_.push_back(group_info);
|
||||
|
||||
bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
|
||||
if (!ret) {
|
||||
MS_LOG(WARNING) << "Create group failed, group name is " << group_name;
|
||||
return Status::FAILED;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Create group success, group name is " << group_name;
|
||||
return Status::SUCCESS;
|
||||
}
|
||||
|
|
|
@ -641,7 +641,31 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { return SUCCESS; }
|
||||
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
||||
if (pad_mode_ != 1) { // only support same mode
|
||||
MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (h_strategy > 1) {
|
||||
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when in_shape * stride != out_shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[0] > stride_[2]) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when kernel size larger than stride";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
|
||||
MS_LOG(ERROR) << name_ << ": Do not support split w dimension when in_shape * stride != out_shape";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
|
||||
// the strategy is ((n, o, h, w), (o, i, 1, 1))
|
||||
|
|
Loading…
Reference in New Issue