diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc index 489232fab3c..e8247fead95 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc @@ -21,6 +21,7 @@ #include #include +#include "utils/ms_context.h" #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" @@ -196,6 +197,8 @@ Status BatchNormInfo::InferForwardCommunication() { } Status BatchNormInfo::InferReplaceOps() { + replace_op_.clear(); + if (!is_training_) { MS_LOG(INFO) << name_ << ": It is not training, no need to replace op"; return SUCCESS; @@ -206,6 +209,15 @@ Status BatchNormInfo::InferReplaceOps() { return SUCCESS; } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::string backend = ms_context->get_param(MS_CTX_DEVICE_TARGET); + + if (backend != kAscendDevice && backend != kDavinciDevice) { + MS_LOG(INFO) << name_ << ": The backend is " << backend << ", it does not support SyncBatchNorm operator"; + return SUCCESS; + } + ValuePtr epsilon = MakeValue(epsilon_); ValuePtr momentum = MakeValue(momentum_); ValuePtr group = MakeValue(forward_allreduce_group_[0].name());