forked from mindspore-Ecosystem/mindspore
!32294 [MS][LITE][TRAIN] fix bn grad bug
Merge pull request !32294 from jianghui58/train_dev
This commit is contained in:
commit
ccd3b6c028
|
@ -53,11 +53,14 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
|
|||
for (int c = 0; c < ch; c++) {
|
||||
int ix = i * ch + c;
|
||||
dbias[c] += yt[ix];
|
||||
// dscale
|
||||
float x_hat = (in[ix] - mean[c]) * invar[c];
|
||||
// in fact, x_hat should also mul invar[c]. now put this step to the end.
|
||||
float x_hat = in[ix] - mean[c];
|
||||
dscale[c] += (yt[ix] * x_hat);
|
||||
}
|
||||
}
|
||||
for (int c = 0; c < ch; c++) {
|
||||
dscale[c] *= invar[c];
|
||||
}
|
||||
backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train);
|
||||
}
|
||||
|
||||
|
@ -73,11 +76,14 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float
|
|||
for (int c = 0; c < ch; c++) {
|
||||
int ix = i * ch + c;
|
||||
dbias[c] += yt[ix];
|
||||
// dscale
|
||||
float x_hat = (in[ix] - mean[c]) * invar[c];
|
||||
// in fact, x_hat should also mul invar[c]. now put this step to the end.
|
||||
float x_hat = in[ix] - mean[c];
|
||||
dscale[c] += (yt[ix] * x_hat);
|
||||
}
|
||||
}
|
||||
for (int c = 0; c < ch; c++) {
|
||||
dscale[c] *= invar[c];
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
|
|
@ -70,6 +70,7 @@ int BNGradCPUKernel::DoExecute(int task_id) {
|
|||
input_var = in_tensors_.at(5);
|
||||
}
|
||||
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
|
||||
CHECK_NULL_RETURN(bn_param);
|
||||
int stage = stage_;
|
||||
int thread_num = thread_num_;
|
||||
float *save_mean = reinterpret_cast<float *>(input_mean->MutableData());
|
||||
|
@ -126,7 +127,8 @@ int BNGradCPUKernel::DoExecute(int task_id) {
|
|||
}
|
||||
}
|
||||
if (thread_num == 1) {
|
||||
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx, (IsTrain()));
|
||||
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx,
|
||||
(IsTrain() && bn_param->is_training_));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -136,7 +138,8 @@ int BNGradCPUKernel::DoExecute(int task_id) {
|
|||
}
|
||||
case 2: {
|
||||
backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, dscale, dbias,
|
||||
scale, count, total, channels, dx + task_id * stride * channels, (IsTrain()));
|
||||
scale, count, total, channels, dx + task_id * stride * channels,
|
||||
(IsTrain() && bn_param->is_training_));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -583,6 +583,26 @@ int MoveAttrBatchNorm(const CNodePtr &cnode) {
|
|||
value_node->set_value(dst_prim_c);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int MoveAttrMapBatchNormGrad(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value_node is nullptr");
|
||||
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_CHECK_TRUE_MSG(src_prim != nullptr, RET_NULL_PTR, "value_node is nullptr");
|
||||
auto dst_prim = std::make_shared<ops::BatchNormGrad>();
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_NULL_PTR, "dst_prim is nullptr.");
|
||||
auto dst_prim_c = dst_prim->GetPrim();
|
||||
MS_CHECK_TRUE_MSG(dst_prim_c != nullptr, RET_NULL_PTR, "dst_prim_c is nullptr.");
|
||||
dst_prim_c->SetAttrs(src_prim->attrs());
|
||||
auto is_training_attr = src_prim->GetAttr(ops::kIsTraining);
|
||||
if (is_training_attr == nullptr) {
|
||||
MS_LOG(INFO) << "no \"is_training\" attr found in BatchNormGrad, will set it to true by default.";
|
||||
dst_prim->set_is_training(true);
|
||||
}
|
||||
value_node->set_value(dst_prim_c);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
|
||||
|
@ -676,8 +696,8 @@ REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
|
|||
REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapCommon<ops::BatchNormGrad>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapBatchNormGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapBatchNormGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation)
|
||||
|
|
Loading…
Reference in New Issue