check strategy for conv2d transpose

This commit is contained in:
yangzhenzhang 2021-07-08 16:40:55 +08:00
parent a27c7b2436
commit 9569cc665f
2 changed files with 31 additions and 6 deletions

View File

@ -190,15 +190,16 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
if (!ret) {
MS_LOG(ERROR) << "Create group failed, group name is " << group_name;
return Status::FAILED;
}
std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
group_info_.push_back(group_info);
bool ret = CreateGroupByExecutor(device_name, group_name, ranks, device_id);
if (!ret) {
MS_LOG(WARNING) << "Create group failed, group name is " << group_name;
return Status::FAILED;
}
MS_LOG(INFO) << "Create group success, group name is " << group_name;
return Status::SUCCESS;
}

View File

@ -641,7 +641,31 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { return SUCCESS; }
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
if (pad_mode_ != 1) { // only support same mode
MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
return FAILED;
}
if (h_strategy > 1) {
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when in_shape * stride != out_shape";
return FAILED;
}
if (kernel_size_[0] > stride_[2]) {
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when kernel size larger than stride";
return FAILED;
}
}
if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
MS_LOG(ERROR) << name_ << ": Do not support split w dimension when in_shape * stride != out_shape";
return FAILED;
}
return SUCCESS;
}
Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
// the strategy is ((n, o, h, w), (o, i, 1, 1))