forked from mindspore-Ecosystem/mindspore
modify batchnorm info
This commit is contained in:
parent
0e42c14d66
commit
7980b807dc
|
@ -21,6 +21,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/ms_context.h"
|
||||
#include "frontend/parallel/device_matrix.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
|
@ -196,6 +197,8 @@ Status BatchNormInfo::InferForwardCommunication() {
|
|||
}
|
||||
|
||||
Status BatchNormInfo::InferReplaceOps() {
|
||||
replace_op_.clear();
|
||||
|
||||
if (!is_training_) {
|
||||
MS_LOG(INFO) << name_ << ": It is not training, no need to replace op";
|
||||
return SUCCESS;
|
||||
|
@ -206,6 +209,15 @@ Status BatchNormInfo::InferReplaceOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
|
||||
if (backend != kAscendDevice && backend != kDavinciDevice) {
|
||||
MS_LOG(INFO) << name_ << ": The backend is " << backend << ", it does not support SyncBatchNorm operator";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
ValuePtr epsilon = MakeValue(epsilon_);
|
||||
ValuePtr momentum = MakeValue(momentum_);
|
||||
ValuePtr group = MakeValue(forward_allreduce_group_[0].name());
|
||||
|
|
Loading…
Reference in New Issue