Add Is training flag for BN

This commit is contained in:
Haim Moushkatel 2022-03-20 11:13:41 +02:00
parent 048d089f9a
commit e3bbc18e9f
12 changed files with 57 additions and 21 deletions

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_BATCHNORM_PARAMETER_H_
#define MINDSPORE_NNACL_BATCHNORM_PARAMETER_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_BATCHNORM_PARAMETER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_BATCHNORM_PARAMETER_H_
#include "nnacl/op_base.h"
@ -27,6 +27,7 @@ typedef struct BatchNormParameter {
int units_;
int channel_;
bool fused_;
bool is_training_;
} BatchNormParameter;
#endif // MINDSPORE_NNACL_BATCHNORM_PARAMETER_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_BATCHNORM_PARAMETER_H_

View File

@ -25,11 +25,11 @@ void var2Invar(float *save_var, int size, float eps) {
#ifdef _MSC_VER
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dbias, float *dscale, float *dx) {
int ch, float *dbias, float *dscale, float *dx, bool is_train) {
#else
void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean,
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias,
float *restrict dscale, float *restrict dx) {
float *restrict dscale, float *restrict dx, bool is_train) {
#endif
NNACL_CHECK_ZERO_RETURN(size);
float N = (float)size;
@ -47,7 +47,9 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
// dx_2
int ix = i * ch + c;
dx[ix] = yt[ix];
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
if (is_train) {
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
}
dx[ix] *= scale[c] * invar[c];
}
}
@ -74,11 +76,11 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float
#ifdef _MSC_VER
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale,
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx) {
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train) {
#else
void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean,
const float *restrict invar, const float *restrict dscale, const float *restrict dbias,
const float *restrict scale, int size, int total_size, int ch, float *restrict dx) {
const float *restrict scale, int size, int total_size, int ch, float *restrict dx, bool is_train) {
#endif
NNACL_CHECK_ZERO_RETURN(total_size);
const float N = (float)total_size;
@ -87,7 +89,9 @@ void backwardP2(const float *restrict in, const float *restrict yt, const float
// dx_2
int ix = i * ch + c;
dx[ix] = yt[ix];
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
if (is_train) {
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
}
dx[ix] *= scale[c] * invar[c];
}
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_H_
#define MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_
#include "nnacl/fp32_grad/batch_norm_parameter.h"
@ -25,13 +25,13 @@ extern "C" {
void var2Invar(float *save_var, int size, float eps);
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dbias, float *dscale, float *dx);
int ch, float *dbias, float *dscale, float *dx, bool is_train);
void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dbias, float *dscale);
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale,
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx);
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_

View File

@ -14,14 +14,15 @@
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_PARAMATER_H_
#define MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_PARAMATER_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct BNGradParameter {
OpParameter op_parameter_;
float epsilon_;
bool is_training_;
} BNGradParameter;
#endif // MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_PARAMATER_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_

View File

@ -336,6 +336,7 @@ table BatchNorm {
table BatchNormGrad {
epsilon: float;
is_training: bool;
}
table BatchToSpace {

View File

@ -336,6 +336,7 @@ OP_SCHEMA_DEF_END(BatchNorm)
OP_SCHEMA_DEF(BatchNormGrad)
OP_ATTR(epsilon, float)
OP_ATTR(is_training, bool)
OP_SCHEMA_DEF_END(BatchNormGrad)
OP_SCHEMA_DEF(BatchToSpace)

View File

@ -38,6 +38,7 @@ OpParameter *PopulateBatchNorm(const void *prim) {
param->op_parameter_.type_ = primitive->value_type();
param->epsilon_ = value->epsilon();
param->fused_ = false;
param->is_training_ = value->is_training();
return reinterpret_cast<OpParameter *>(param);
}

View File

@ -39,6 +39,7 @@ OpParameter *PopulateFusedBatchNorm(const void *prim) {
param->epsilon_ = value->epsilon();
param->momentum_ = value->momentum();
param->fused_ = true;
param->is_training_ = static_cast<bool>(value->mode());
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm, SCHEMA_CUR)

View File

@ -153,7 +153,7 @@ int FusedBatchnormCPUKernel::InitConstTensor() {
int FusedBatchnormCPUKernel::Run() {
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_);
MS_ASSERT(param != nullptr);
if (IsTrain() && IsTrainable() && in_tensors_.size() >= DIMENSION_5D) {
if (IsTrain() && param->is_training_ && in_tensors_.size() >= DIMENSION_5D) {
float *in = static_cast<float *>(in_tensors_.at(FIRST_INPUT)->data());
float *scale = static_cast<float *>(in_tensors_.at(SECOND_INPUT)->data());
float *offset = static_cast<float *>(in_tensors_.at(THIRD_INPUT)->data());
@ -185,6 +185,13 @@ int FusedBatchnormCPUKernel::Run() {
(void)memcpy(offset_, offset, in_tensors_.at(THIRD_INPUT)->Size());
trained_ = true; // trained at least once
} else {
if (out_tensors_.size() >= DIMENSION_5D) {
(void)memcpy(out_tensors_.at(SECOND_INPUT)->data(), scale_, out_tensors_.at(SECOND_INPUT)->Size());
(void)memcpy(out_tensors_.at(THIRD_INPUT)->data(), offset_, out_tensors_.at(THIRD_INPUT)->Size());
(void)memcpy(out_tensors_.at(FOURTH_INPUT)->data(), mean_, out_tensors_.at(FOURTH_INPUT)->Size());
(void)memcpy(out_tensors_.at(FIFTH_INPUT)->data(), variance_, out_tensors_.at(FIFTH_INPUT)->Size());
}
}
auto ret = ParallelLaunch(this->ms_context_, BatchNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {

View File

@ -126,7 +126,7 @@ int BNGradCPUKernel::DoExecute(int task_id) {
}
}
if (thread_num == 1) {
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx);
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx, (IsTrain()));
}
break;
}
@ -136,7 +136,7 @@ int BNGradCPUKernel::DoExecute(int task_id) {
}
case 2: {
backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, dscale, dbias,
scale, count, total, channels, dx + task_id * stride * channels);
scale, count, total, channels, dx + task_id * stride * channels, (IsTrain()));
break;
}
default:

View File

@ -409,6 +409,7 @@ OpParameter *PopulateBNGradParameter(const void *prim) {
MS_ASSERT(value != nullptr);
bnGrad_param->op_parameter_.type_ = primitive->value_type();
bnGrad_param->epsilon_ = value->epsilon();
bnGrad_param->is_training_ = value->is_training();
return reinterpret_cast<OpParameter *>(bnGrad_param);
}

View File

@ -548,6 +548,24 @@ int MoveAttrMapResizeGrad(const CNodePtr &cnode) {
value_node->set_value(dst_prim);
return lite::RET_OK;
}
int MoveAttrBatchNorm(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
if (src_prim == nullptr) {
MS_LOG(ERROR) << "value node is invalid.";
return lite::RET_ERROR;
}
auto dst_prim = std::make_shared<ops::FusedBatchNorm>();
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_NULL_PTR, "dst_prim is nullptr.");
dst_prim->SetAttrs(src_prim->attrs());
bool is_training = GetValue<bool>(src_prim->GetAttr(ops::kIsTraining));
dst_prim->set_mode(static_cast<int64_t>(is_training));
value_node->set_value(dst_prim);
return lite::RET_OK;
}
} // namespace
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
@ -620,7 +638,7 @@ REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradGpu, MoveAttrPoolGrad)
REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradCpu, MoveAttrPoolGrad)
REGIST_PRIMITIVE_ADJUST(kNameBatchMatMul, MoveAttrMapCommon<ops::MatMulFusion>)
REGIST_PRIMITIVE_ADJUST(kNameMatMul, MoveAttrMapCommon<ops::MatMulFusion>)
REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>)
REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrBatchNorm)
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)