!32294 [MS][LITE][TRAIN] fix bn grad bug

Merge pull request !32294 from jianghui58/train_dev
This commit is contained in:
i-robot 2022-03-31 12:46:10 +00:00 committed by Gitee
commit ccd3b6c028
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 37 additions and 8 deletions

View File

@ -53,11 +53,14 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
for (int c = 0; c < ch; c++) {
int ix = i * ch + c;
dbias[c] += yt[ix];
// dscale
float x_hat = (in[ix] - mean[c]) * invar[c];
// in fact, x_hat should also mul invar[c]. now put this step to the end.
float x_hat = in[ix] - mean[c];
dscale[c] += (yt[ix] * x_hat);
}
}
for (int c = 0; c < ch; c++) {
dscale[c] *= invar[c];
}
backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train);
}
@ -73,11 +76,14 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float
for (int c = 0; c < ch; c++) {
int ix = i * ch + c;
dbias[c] += yt[ix];
// dscale
float x_hat = (in[ix] - mean[c]) * invar[c];
// in fact, x_hat should also mul invar[c]. now put this step to the end.
float x_hat = in[ix] - mean[c];
dscale[c] += (yt[ix] * x_hat);
}
}
for (int c = 0; c < ch; c++) {
dscale[c] *= invar[c];
}
}
#ifdef _MSC_VER

View File

@ -70,6 +70,7 @@ int BNGradCPUKernel::DoExecute(int task_id) {
input_var = in_tensors_.at(5);
}
auto bn_param = reinterpret_cast<BNGradParameter *>(op_parameter_);
CHECK_NULL_RETURN(bn_param);
int stage = stage_;
int thread_num = thread_num_;
float *save_mean = reinterpret_cast<float *>(input_mean->MutableData());
@ -126,7 +127,8 @@ int BNGradCPUKernel::DoExecute(int task_id) {
}
}
if (thread_num == 1) {
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx, (IsTrain()));
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx,
(IsTrain() && bn_param->is_training_));
}
break;
}
@ -136,7 +138,8 @@ int BNGradCPUKernel::DoExecute(int task_id) {
}
case 2: {
backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, dscale, dbias,
scale, count, total, channels, dx + task_id * stride * channels, (IsTrain()));
scale, count, total, channels, dx + task_id * stride * channels,
(IsTrain() && bn_param->is_training_));
break;
}
default:

View File

@ -583,6 +583,26 @@ int MoveAttrBatchNorm(const CNodePtr &cnode) {
value_node->set_value(dst_prim_c);
return lite::RET_OK;
}
int MoveAttrMapBatchNormGrad(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value_node is nullptr");
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
MS_CHECK_TRUE_MSG(src_prim != nullptr, RET_NULL_PTR, "value_node is nullptr");
auto dst_prim = std::make_shared<ops::BatchNormGrad>();
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_NULL_PTR, "dst_prim is nullptr.");
auto dst_prim_c = dst_prim->GetPrim();
MS_CHECK_TRUE_MSG(dst_prim_c != nullptr, RET_NULL_PTR, "dst_prim_c is nullptr.");
dst_prim_c->SetAttrs(src_prim->attrs());
auto is_training_attr = src_prim->GetAttr(ops::kIsTraining);
if (is_training_attr == nullptr) {
MS_LOG(INFO) << "no \"is_training\" attr found in BatchNormGrad, will set it to true by default.";
dst_prim->set_is_training(true);
}
value_node->set_value(dst_prim_c);
return lite::RET_OK;
}
} // namespace
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
@ -676,8 +696,8 @@ REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation)
REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad)
REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon<ops::ExpFusion>)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon<ops::FusedBatchNorm>)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon<ops::BatchNormGrad>)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapCommon<ops::BatchNormGrad>)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapBatchNormGrad)
REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradCPU, MoveAttrMapBatchNormGrad)
REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation)
REGIST_PRIMITIVE_ADJUST(kNameGeLUGrad, MoveAttrMapActivationGrad)
REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation)