!18758 don't use SyncBatchNorm in Gpu backend

Merge pull request !18758 from yangzhenzhang/modify-batch-norm-info
This commit is contained in:
i-robot 2021-06-24 01:18:37 +00:00 committed by Gitee
commit 4270142bd1
1 changed files with 12 additions and 0 deletions

View File

@ -21,6 +21,7 @@
#include <utility>
#include <vector>
#include "utils/ms_context.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
@ -196,6 +197,8 @@ Status BatchNormInfo::InferForwardCommunication() {
}
Status BatchNormInfo::InferReplaceOps() {
replace_op_.clear();
if (!is_training_) {
MS_LOG(INFO) << name_ << ": It is not training, no need to replace op";
return SUCCESS;
@ -206,6 +209,15 @@ Status BatchNormInfo::InferReplaceOps() {
return SUCCESS;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend != kAscendDevice && backend != kDavinciDevice) {
MS_LOG(INFO) << name_ << ": The backend is " << backend << ", it does not support SyncBatchNorm operator";
return SUCCESS;
}
ValuePtr epsilon = MakeValue(epsilon_);
ValuePtr momentum = MakeValue(momentum_);
ValuePtr group = MakeValue(forward_allreduce_group_[0].name());