forked from mindspore-Ecosystem/mindspore
Add Is training flag for BN
This commit is contained in:
parent
048d089f9a
commit
e3bbc18e9f
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_BATCHNORM_PARAMETER_H_
|
||||
#define MINDSPORE_NNACL_BATCHNORM_PARAMETER_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_BATCHNORM_PARAMETER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_BATCHNORM_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
|
@ -27,6 +27,7 @@ typedef struct BatchNormParameter {
|
|||
int units_;
|
||||
int channel_;
|
||||
bool fused_;
|
||||
bool is_training_;
|
||||
} BatchNormParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_BATCHNORM_PARAMETER_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_BATCHNORM_PARAMETER_H_
|
||||
|
|
|
@ -25,11 +25,11 @@ void var2Invar(float *save_var, int size, float eps) {
|
|||
|
||||
#ifdef _MSC_VER
|
||||
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||
int ch, float *dbias, float *dscale, float *dx) {
|
||||
int ch, float *dbias, float *dscale, float *dx, bool is_train) {
|
||||
#else
|
||||
void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias,
|
||||
float *restrict dscale, float *restrict dx) {
|
||||
float *restrict dscale, float *restrict dx, bool is_train) {
|
||||
#endif
|
||||
NNACL_CHECK_ZERO_RETURN(size);
|
||||
float N = (float)size;
|
||||
|
@ -47,7 +47,9 @@ void backwardAll(const float *restrict in, const float *restrict yt, const float
|
|||
// dx_2
|
||||
int ix = i * ch + c;
|
||||
dx[ix] = yt[ix];
|
||||
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
|
||||
if (is_train) {
|
||||
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
|
||||
}
|
||||
dx[ix] *= scale[c] * invar[c];
|
||||
}
|
||||
}
|
||||
|
@ -74,11 +76,11 @@ void backwardP1(const float *restrict in, const float *restrict yt, const float
|
|||
|
||||
#ifdef _MSC_VER
|
||||
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale,
|
||||
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx) {
|
||||
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train) {
|
||||
#else
|
||||
void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean,
|
||||
const float *restrict invar, const float *restrict dscale, const float *restrict dbias,
|
||||
const float *restrict scale, int size, int total_size, int ch, float *restrict dx) {
|
||||
const float *restrict scale, int size, int total_size, int ch, float *restrict dx, bool is_train) {
|
||||
#endif
|
||||
NNACL_CHECK_ZERO_RETURN(total_size);
|
||||
const float N = (float)total_size;
|
||||
|
@ -87,7 +89,9 @@ void backwardP2(const float *restrict in, const float *restrict yt, const float
|
|||
// dx_2
|
||||
int ix = i * ch + c;
|
||||
dx[ix] = yt[ix];
|
||||
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
|
||||
if (is_train) {
|
||||
dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N;
|
||||
}
|
||||
dx[ix] *= scale[c] * invar[c];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_H_
|
||||
#define MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_
|
||||
|
||||
#include "nnacl/fp32_grad/batch_norm_parameter.h"
|
||||
|
||||
|
@ -25,13 +25,13 @@ extern "C" {
|
|||
|
||||
void var2Invar(float *save_var, int size, float eps);
|
||||
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||
int ch, float *dbias, float *dscale, float *dx);
|
||||
int ch, float *dbias, float *dscale, float *dx, bool is_train);
|
||||
void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
|
||||
int ch, float *dbias, float *dscale);
|
||||
void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale,
|
||||
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx);
|
||||
const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_
|
||||
|
|
|
@ -14,14 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_PARAMATER_H_
|
||||
#define MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_PARAMATER_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct BNGradParameter {
|
||||
OpParameter op_parameter_;
|
||||
float epsilon_;
|
||||
bool is_training_;
|
||||
} BNGradParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_GRAD_BATCH_NORM_PARAMATER_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_
|
||||
|
|
|
@ -336,6 +336,7 @@ table BatchNorm {
|
|||
|
||||
table BatchNormGrad {
|
||||
epsilon: float;
|
||||
is_training: bool;
|
||||
}
|
||||
|
||||
table BatchToSpace {
|
||||
|
|
|
@ -336,6 +336,7 @@ OP_SCHEMA_DEF_END(BatchNorm)
|
|||
|
||||
OP_SCHEMA_DEF(BatchNormGrad)
|
||||
OP_ATTR(epsilon, float)
|
||||
OP_ATTR(is_training, bool)
|
||||
OP_SCHEMA_DEF_END(BatchNormGrad)
|
||||
|
||||
OP_SCHEMA_DEF(BatchToSpace)
|
||||
|
|
|
@ -38,6 +38,7 @@ OpParameter *PopulateBatchNorm(const void *prim) {
|
|||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->epsilon_ = value->epsilon();
|
||||
param->fused_ = false;
|
||||
param->is_training_ = value->is_training();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ OpParameter *PopulateFusedBatchNorm(const void *prim) {
|
|||
param->epsilon_ = value->epsilon();
|
||||
param->momentum_ = value->momentum();
|
||||
param->fused_ = true;
|
||||
param->is_training_ = static_cast<bool>(value->mode());
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm, SCHEMA_CUR)
|
||||
|
|
|
@ -153,7 +153,7 @@ int FusedBatchnormCPUKernel::InitConstTensor() {
|
|||
int FusedBatchnormCPUKernel::Run() {
|
||||
auto param = reinterpret_cast<BatchNormParameter *>(op_parameter_);
|
||||
MS_ASSERT(param != nullptr);
|
||||
if (IsTrain() && IsTrainable() && in_tensors_.size() >= DIMENSION_5D) {
|
||||
if (IsTrain() && param->is_training_ && in_tensors_.size() >= DIMENSION_5D) {
|
||||
float *in = static_cast<float *>(in_tensors_.at(FIRST_INPUT)->data());
|
||||
float *scale = static_cast<float *>(in_tensors_.at(SECOND_INPUT)->data());
|
||||
float *offset = static_cast<float *>(in_tensors_.at(THIRD_INPUT)->data());
|
||||
|
@ -185,6 +185,13 @@ int FusedBatchnormCPUKernel::Run() {
|
|||
(void)memcpy(offset_, offset, in_tensors_.at(THIRD_INPUT)->Size());
|
||||
|
||||
trained_ = true; // trained at least once
|
||||
} else {
|
||||
if (out_tensors_.size() >= DIMENSION_5D) {
|
||||
(void)memcpy(out_tensors_.at(SECOND_INPUT)->data(), scale_, out_tensors_.at(SECOND_INPUT)->Size());
|
||||
(void)memcpy(out_tensors_.at(THIRD_INPUT)->data(), offset_, out_tensors_.at(THIRD_INPUT)->Size());
|
||||
(void)memcpy(out_tensors_.at(FOURTH_INPUT)->data(), mean_, out_tensors_.at(FOURTH_INPUT)->Size());
|
||||
(void)memcpy(out_tensors_.at(FIFTH_INPUT)->data(), variance_, out_tensors_.at(FIFTH_INPUT)->Size());
|
||||
}
|
||||
}
|
||||
auto ret = ParallelLaunch(this->ms_context_, BatchNormRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -126,7 +126,7 @@ int BNGradCPUKernel::DoExecute(int task_id) {
|
|||
}
|
||||
}
|
||||
if (thread_num == 1) {
|
||||
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx);
|
||||
backwardAll(x, yt, save_mean, save_var, scale, total, channels, dbias, dscale, dx, (IsTrain()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ int BNGradCPUKernel::DoExecute(int task_id) {
|
|||
}
|
||||
case 2: {
|
||||
backwardP2(x + task_id * stride * channels, yt + task_id * stride * channels, save_mean, save_var, dscale, dbias,
|
||||
scale, count, total, channels, dx + task_id * stride * channels);
|
||||
scale, count, total, channels, dx + task_id * stride * channels, (IsTrain()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -409,6 +409,7 @@ OpParameter *PopulateBNGradParameter(const void *prim) {
|
|||
MS_ASSERT(value != nullptr);
|
||||
bnGrad_param->op_parameter_.type_ = primitive->value_type();
|
||||
bnGrad_param->epsilon_ = value->epsilon();
|
||||
bnGrad_param->is_training_ = value->is_training();
|
||||
return reinterpret_cast<OpParameter *>(bnGrad_param);
|
||||
}
|
||||
|
||||
|
|
|
@ -548,6 +548,24 @@ int MoveAttrMapResizeGrad(const CNodePtr &cnode) {
|
|||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int MoveAttrBatchNorm(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
auto src_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
if (src_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "value node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto dst_prim = std::make_shared<ops::FusedBatchNorm>();
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, RET_NULL_PTR, "dst_prim is nullptr.");
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
bool is_training = GetValue<bool>(src_prim->GetAttr(ops::kIsTraining));
|
||||
dst_prim->set_mode(static_cast<int64_t>(is_training));
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PrimitiveAdjust::Run(const FuncGraphPtr &func_graphs) {
|
||||
|
@ -620,7 +638,7 @@ REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradGpu, MoveAttrPoolGrad)
|
|||
REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradCpu, MoveAttrPoolGrad)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameBatchMatMul, MoveAttrMapCommon<ops::MatMulFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameMatMul, MoveAttrMapCommon<ops::MatMulFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon<ops::FusedBatchNorm>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrBatchNorm)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon<ops::Conv2DBackpropFilterFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropInput, MoveAttrMapCommon<ops::Conv2DBackpropInputFusion>)
|
||||
REGIST_PRIMITIVE_ADJUST(kNameConv2D, MoveAttrMapConv2D)
|
||||
|
|
Loading…
Reference in New Issue