modify batchnorm info

This commit is contained in:
yangzhenzhang 2021-06-23 14:11:04 +08:00
parent 0e42c14d66
commit 7980b807dc
1 changed files with 12 additions and 0 deletions

View File

@ -21,6 +21,7 @@
#include <utility>
#include <vector>
#include "utils/ms_context.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
@ -196,6 +197,8 @@ Status BatchNormInfo::InferForwardCommunication() {
}
Status BatchNormInfo::InferReplaceOps() {
replace_op_.clear();
if (!is_training_) {
MS_LOG(INFO) << name_ << ": It is not training, no need to replace op";
return SUCCESS;
@ -206,6 +209,15 @@ Status BatchNormInfo::InferReplaceOps() {
return SUCCESS;
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend != kAscendDevice && backend != kDavinciDevice) {
MS_LOG(INFO) << name_ << ": The backend is " << backend << ", it does not support SyncBatchNorm operator";
return SUCCESS;
}
ValuePtr epsilon = MakeValue(epsilon_);
ValuePtr momentum = MakeValue(momentum_);
ValuePtr group = MakeValue(forward_allreduce_group_[0].name());