forked from mindspore-Ecosystem/mindspore
!15715 fix codex
From: @xutianchun Reviewed-by: @HilbertDavid,@zhanghaibo5 Signed-off-by: @HilbertDavid
This commit is contained in:
commit
784dc36a79
|
@ -30,7 +30,7 @@ class AccuracyMonitor : public session::TrainLoopCallBack {
|
||||||
public:
|
public:
|
||||||
explicit AccuracyMonitor(mindspore::dataset::Dataset *dataset, int check_every_n, int max_steps = -1)
|
explicit AccuracyMonitor(mindspore::dataset::Dataset *dataset, int check_every_n, int max_steps = -1)
|
||||||
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
|
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
|
||||||
|
~AccuracyMonitor() = default;
|
||||||
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
void Begin(const session::TrainLoopCallBackData &cb_data) override;
|
||||||
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
|
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
|
||||||
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
|
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
|
||||||
|
|
|
@ -32,6 +32,8 @@ class CkptSaver : public session::TrainLoopCallBack {
|
||||||
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
|
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
|
||||||
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
|
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
|
||||||
|
|
||||||
|
~CkptSaver() = default;
|
||||||
|
|
||||||
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
|
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
|
||||||
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
|
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
|
||||||
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
|
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
|
||||||
|
|
|
@ -56,7 +56,7 @@ class TrainLoop {
|
||||||
/// \brief Accessor to the TrainSession
|
/// \brief Accessor to the TrainSession
|
||||||
///
|
///
|
||||||
/// \return pointer of the train_session
|
/// \return pointer of the train_session
|
||||||
virtual session::TrainSession *train_session() = 0;
|
const virtual session::TrainSession *train_session() = 0;
|
||||||
|
|
||||||
/// \brief Initialize object with metrics
|
/// \brief Initialize object with metrics
|
||||||
///
|
///
|
||||||
|
|
|
@ -151,7 +151,7 @@ public class TrainSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean setLossName(String lossName) {
|
public boolean setLossName(String lossName) {
|
||||||
return this.setLossName(this.sessionPtr,lossName);
|
return this.setLossName(this.sessionPtr, lossName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,5 +191,5 @@ public class TrainSession {
|
||||||
|
|
||||||
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
|
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
|
||||||
|
|
||||||
private native boolean setLossName(long sessionPtr,String lossName);
|
private native boolean setLossName(long sessionPtr, String lossName);
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ class TrainLoop : virtual public session::TrainLoop {
|
||||||
public:
|
public:
|
||||||
explicit TrainLoop(session::TrainSession *session) : train_session_(session) {}
|
explicit TrainLoop(session::TrainSession *session) : train_session_(session) {}
|
||||||
|
|
||||||
session::TrainSession *train_session() override { return train_session_; }
|
const session::TrainSession *train_session() override { return train_session_; }
|
||||||
|
|
||||||
int Reset() override {
|
int Reset() override {
|
||||||
epoch_ = 0;
|
epoch_ = 0;
|
||||||
|
|
|
@ -40,6 +40,7 @@ OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_SmoothL1Loss();
|
auto value = primitive->value_as_SmoothL1Loss();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
p->op_parameter_.type_ = primitive->value_type();
|
p->op_parameter_.type_ = primitive->value_type();
|
||||||
p->beta_ = value->beta();
|
p->beta_ = value->beta();
|
||||||
return reinterpret_cast<OpParameter *>(p);
|
return reinterpret_cast<OpParameter *>(p);
|
||||||
|
@ -53,6 +54,7 @@ OpParameter *PopulateSmoothL1LossGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_SmoothL1LossGrad();
|
auto value = primitive->value_as_SmoothL1LossGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
p->op_parameter_.type_ = primitive->value_type();
|
p->op_parameter_.type_ = primitive->value_type();
|
||||||
p->beta_ = value->beta();
|
p->beta_ = value->beta();
|
||||||
return reinterpret_cast<OpParameter *>(p);
|
return reinterpret_cast<OpParameter *>(p);
|
||||||
|
@ -92,6 +94,7 @@ OpParameter *PopulateBCEGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_BinaryCrossEntropyGrad();
|
auto value = primitive->value_as_BinaryCrossEntropyGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
*reduction = value->reduction();
|
*reduction = value->reduction();
|
||||||
return reinterpret_cast<OpParameter *>(reduction);
|
return reinterpret_cast<OpParameter *>(reduction);
|
||||||
}
|
}
|
||||||
|
@ -104,6 +107,7 @@ OpParameter *PopulateAdamParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_Adam();
|
auto value = primitive->value_as_Adam();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
p->op_parameter_.type_ = primitive->value_type();
|
p->op_parameter_.type_ = primitive->value_type();
|
||||||
p->use_nesterov_ = value->use_nesterov();
|
p->use_nesterov_ = value->use_nesterov();
|
||||||
return reinterpret_cast<OpParameter *>(p);
|
return reinterpret_cast<OpParameter *>(p);
|
||||||
|
@ -117,6 +121,7 @@ OpParameter *PopulateSgdParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_SGD();
|
auto value = primitive->value_as_SGD();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
p->op_parameter_.type_ = primitive->value_type();
|
p->op_parameter_.type_ = primitive->value_type();
|
||||||
p->weight_decay_ = value->weight_decay();
|
p->weight_decay_ = value->weight_decay();
|
||||||
p->dampening_ = value->dampening();
|
p->dampening_ = value->dampening();
|
||||||
|
@ -134,6 +139,7 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *pr
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
|
auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
sce_param->op_parameter_.type_ = primitive->value_type();
|
sce_param->op_parameter_.type_ = primitive->value_type();
|
||||||
sce_param->is_grad_ = value->is_grad();
|
sce_param->is_grad_ = value->is_grad();
|
||||||
return reinterpret_cast<OpParameter *>(sce_param);
|
return reinterpret_cast<OpParameter *>(sce_param);
|
||||||
|
@ -160,6 +166,7 @@ OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_MaxPoolGrad();
|
auto value = primitive->value_as_MaxPoolGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
pooling_param->op_parameter_.type_ = primitive->value_type();
|
pooling_param->op_parameter_.type_ = primitive->value_type();
|
||||||
|
|
||||||
pooling_param->global_ = false;
|
pooling_param->global_ = false;
|
||||||
|
@ -197,6 +204,7 @@ OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_AvgPoolGrad();
|
auto value = primitive->value_as_AvgPoolGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
pooling_param->op_parameter_.type_ = primitive->value_type();
|
pooling_param->op_parameter_.type_ = primitive->value_type();
|
||||||
|
|
||||||
pooling_param->global_ = false;
|
pooling_param->global_ = false;
|
||||||
|
@ -245,6 +253,7 @@ OpParameter *PopulateActivationGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_ActivationGrad();
|
auto value = primitive->value_as_ActivationGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
act_param->op_parameter_.type_ = primitive->value_type();
|
act_param->op_parameter_.type_ = primitive->value_type();
|
||||||
act_param->type_ = static_cast<int>(value->activation_type());
|
act_param->type_ = static_cast<int>(value->activation_type());
|
||||||
act_param->alpha_ = value->alpha();
|
act_param->alpha_ = value->alpha();
|
||||||
|
@ -260,6 +269,7 @@ OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
|
||||||
memset(param, 0, sizeof(ConvParameter));
|
memset(param, 0, sizeof(ConvParameter));
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_Conv2DBackpropFilterFusion();
|
auto value = primitive->value_as_Conv2DBackpropFilterFusion();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
param->op_parameter_.type_ = primitive->value_type();
|
param->op_parameter_.type_ = primitive->value_type();
|
||||||
|
|
||||||
param->kernel_h_ = value->kernel_size()->Get(0);
|
param->kernel_h_ = value->kernel_size()->Get(0);
|
||||||
|
@ -296,6 +306,7 @@ OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_Conv2DBackpropInputFusion();
|
auto value = primitive->value_as_Conv2DBackpropInputFusion();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
param->op_parameter_.type_ = primitive->value_type();
|
param->op_parameter_.type_ = primitive->value_type();
|
||||||
|
|
||||||
param->kernel_h_ = value->kernel_size()->Get(0);
|
param->kernel_h_ = value->kernel_size()->Get(0);
|
||||||
|
@ -332,6 +343,7 @@ OpParameter *PopulatePowerGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_PowerGrad();
|
auto value = primitive->value_as_PowerGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
power_param->op_parameter_.type_ = primitive->value_type();
|
power_param->op_parameter_.type_ = primitive->value_type();
|
||||||
power_param->power_ = value->power();
|
power_param->power_ = value->power();
|
||||||
power_param->scale_ = value->scale();
|
power_param->scale_ = value->scale();
|
||||||
|
@ -358,6 +370,7 @@ OpParameter *PopulateBNGradParameter(const void *prim) {
|
||||||
}
|
}
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_BatchNormGrad();
|
auto value = primitive->value_as_BatchNormGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
bnGrad_param->op_parameter_.type_ = primitive->value_type();
|
bnGrad_param->op_parameter_.type_ = primitive->value_type();
|
||||||
bnGrad_param->epsilon_ = value->epsilon();
|
bnGrad_param->epsilon_ = value->epsilon();
|
||||||
return reinterpret_cast<OpParameter *>(bnGrad_param);
|
return reinterpret_cast<OpParameter *>(bnGrad_param);
|
||||||
|
@ -372,6 +385,7 @@ OpParameter *PopulateDropoutParameter(const void *prim) {
|
||||||
memset(dropout_parameter, 0, sizeof(DropoutParameter));
|
memset(dropout_parameter, 0, sizeof(DropoutParameter));
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_Dropout();
|
auto value = primitive->value_as_Dropout();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
dropout_parameter->op_parameter_.type_ = primitive->value_type();
|
dropout_parameter->op_parameter_.type_ = primitive->value_type();
|
||||||
dropout_parameter->ratio_ = value->keep_prob();
|
dropout_parameter->ratio_ = value->keep_prob();
|
||||||
if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
|
if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
|
||||||
|
@ -391,6 +405,7 @@ OpParameter *PopulateDropoutGradParameter(const void *prim) {
|
||||||
memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
|
memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_DropoutGrad();
|
auto value = primitive->value_as_DropoutGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
|
dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
|
||||||
dropoutgrad_parameter->ratio_ = value->keep_prob();
|
dropoutgrad_parameter->ratio_ = value->keep_prob();
|
||||||
if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
|
if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
|
||||||
|
@ -423,7 +438,7 @@ OpParameter *PopulateResizeGradParameter(const void *prim) {
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
resize_grad_param->op_parameter_.type_ = primitive->value_type();
|
resize_grad_param->op_parameter_.type_ = primitive->value_type();
|
||||||
auto param = primitive->value_as_ResizeGrad();
|
auto param = primitive->value_as_ResizeGrad();
|
||||||
|
MS_ASSERT(param != nullptr);
|
||||||
resize_grad_param->method = static_cast<int>(param->method());
|
resize_grad_param->method = static_cast<int>(param->method());
|
||||||
resize_grad_param->align_corners_ = param->align_corners();
|
resize_grad_param->align_corners_ = param->align_corners();
|
||||||
|
|
||||||
|
@ -441,6 +456,7 @@ OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
|
||||||
|
|
||||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||||
auto value = primitive->value_as_StridedSliceGrad();
|
auto value = primitive->value_as_StridedSliceGrad();
|
||||||
|
MS_ASSERT(value != nullptr);
|
||||||
strided_slice_param->op_parameter_.type_ = primitive->value_type();
|
strided_slice_param->op_parameter_.type_ = primitive->value_type();
|
||||||
|
|
||||||
strided_slice_param->begins_mask_ = value->begin_mask();
|
strided_slice_param->begins_mask_ = value->begin_mask();
|
||||||
|
|
|
@ -104,7 +104,7 @@ OpParameter *PopulateSmoothL1LossParameter(const void *primitive) {
|
||||||
p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1Loss;
|
p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1Loss;
|
||||||
|
|
||||||
auto smoothL1Loss_prim = prim->value_as_SmoothL1Loss();
|
auto smoothL1Loss_prim = prim->value_as_SmoothL1Loss();
|
||||||
|
MS_ASSERT(smoothL1Loss_prim != nullptr);
|
||||||
p->beta_ = smoothL1Loss_prim->beta();
|
p->beta_ = smoothL1Loss_prim->beta();
|
||||||
return reinterpret_cast<OpParameter *>(p);
|
return reinterpret_cast<OpParameter *>(p);
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ OpParameter *PopulateSmoothL1LossGradParameter(const void *primitive) {
|
||||||
p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1LossGrad;
|
p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1LossGrad;
|
||||||
|
|
||||||
auto smoothL1LossGrad_prim = prim->value_as_SmoothL1LossGrad();
|
auto smoothL1LossGrad_prim = prim->value_as_SmoothL1LossGrad();
|
||||||
|
MS_ASSERT(smoothL1LossGrad_prim != nullptr);
|
||||||
p->beta_ = smoothL1LossGrad_prim->beta();
|
p->beta_ = smoothL1LossGrad_prim->beta();
|
||||||
return reinterpret_cast<OpParameter *>(p);
|
return reinterpret_cast<OpParameter *>(p);
|
||||||
}
|
}
|
||||||
|
@ -142,7 +142,7 @@ OpParameter *PopulateApplyMomentumParameter(const void *primitive) {
|
||||||
p->op_parameter_.type_ = schema::PrimitiveType_ApplyMomentum;
|
p->op_parameter_.type_ = schema::PrimitiveType_ApplyMomentum;
|
||||||
|
|
||||||
auto applyMomentum_prim = prim->value_as_ApplyMomentum();
|
auto applyMomentum_prim = prim->value_as_ApplyMomentum();
|
||||||
|
MS_ASSERT(applyMomentum_prim != nullptr);
|
||||||
p->grad_scale_ = applyMomentum_prim->gradientScale();
|
p->grad_scale_ = applyMomentum_prim->gradientScale();
|
||||||
p->use_nesterov_ = applyMomentum_prim->useNesterov();
|
p->use_nesterov_ = applyMomentum_prim->useNesterov();
|
||||||
|
|
||||||
|
@ -157,6 +157,7 @@ OpParameter *PopulateBCEParameter(const void *primitive) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto bCE_prim = prim->value_as_BinaryCrossEntropy();
|
auto bCE_prim = prim->value_as_BinaryCrossEntropy();
|
||||||
|
MS_ASSERT(bCE_prim != nullptr);
|
||||||
*reduction = bCE_prim->reduction();
|
*reduction = bCE_prim->reduction();
|
||||||
return reinterpret_cast<OpParameter *>(reduction);
|
return reinterpret_cast<OpParameter *>(reduction);
|
||||||
}
|
}
|
||||||
|
@ -169,7 +170,7 @@ OpParameter *PopulateBCEGradParameter(const void *primitive) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto bCEGrad_prim = prim->value_as_BinaryCrossEntropyGrad();
|
auto bCEGrad_prim = prim->value_as_BinaryCrossEntropyGrad();
|
||||||
|
MS_ASSERT(bCEGrad_prim != nullptr);
|
||||||
*reduction = bCEGrad_prim->reduction();
|
*reduction = bCEGrad_prim->reduction();
|
||||||
return reinterpret_cast<OpParameter *>(reduction);
|
return reinterpret_cast<OpParameter *>(reduction);
|
||||||
}
|
}
|
||||||
|
@ -188,7 +189,7 @@ OpParameter *PopulateAdamParameter(const void *primitive) {
|
||||||
p->op_parameter_.type_ = schema::PrimitiveType_Adam;
|
p->op_parameter_.type_ = schema::PrimitiveType_Adam;
|
||||||
|
|
||||||
auto adam_prim = prim->value_as_Adam();
|
auto adam_prim = prim->value_as_Adam();
|
||||||
|
MS_ASSERT(adam_prim != nullptr);
|
||||||
p->use_nesterov_ = adam_prim->useNesterov();
|
p->use_nesterov_ = adam_prim->useNesterov();
|
||||||
return reinterpret_cast<OpParameter *>(p);
|
return reinterpret_cast<OpParameter *>(p);
|
||||||
}
|
}
|
||||||
|
@ -207,7 +208,7 @@ OpParameter *PopulateSgdParameter(const void *primitive) {
|
||||||
p->op_parameter_.type_ = schema::PrimitiveType_SGD;
|
p->op_parameter_.type_ = schema::PrimitiveType_SGD;
|
||||||
|
|
||||||
auto sgd_prim = prim->value_as_Sgd();
|
auto sgd_prim = prim->value_as_Sgd();
|
||||||
|
MS_ASSERT(sgd_prim != nullptr);
|
||||||
p->weight_decay_ = sgd_prim->weightDecay();
|
p->weight_decay_ = sgd_prim->weightDecay();
|
||||||
p->dampening_ = sgd_prim->dampening();
|
p->dampening_ = sgd_prim->dampening();
|
||||||
p->use_nesterov_ = sgd_prim->useNesterov();
|
p->use_nesterov_ = sgd_prim->useNesterov();
|
||||||
|
@ -228,7 +229,7 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *primitive) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto sparseSoftmaxCrossEntropy_prim = prim->value_as_SparseSoftmaxCrossEntropy();
|
auto sparseSoftmaxCrossEntropy_prim = prim->value_as_SparseSoftmaxCrossEntropy();
|
||||||
|
MS_ASSERT(sparseSoftmaxCrossEntropy_prim != nullptr);
|
||||||
sce_param->is_grad_ = sparseSoftmaxCrossEntropy_prim->isGrad();
|
sce_param->is_grad_ = sparseSoftmaxCrossEntropy_prim->isGrad();
|
||||||
|
|
||||||
sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
|
sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
|
||||||
|
@ -264,7 +265,7 @@ OpParameter *PopulatePoolingGradParameter(const void *primitive) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto poolingGrad_prim = prim->value_as_PoolingGrad();
|
auto poolingGrad_prim = prim->value_as_PoolingGrad();
|
||||||
|
MS_ASSERT(poolingGrad_prim != nullptr);
|
||||||
pooling_param->global_ = poolingGrad_prim->global();
|
pooling_param->global_ = poolingGrad_prim->global();
|
||||||
pooling_param->window_w_ = poolingGrad_prim->windowW();
|
pooling_param->window_w_ = poolingGrad_prim->windowW();
|
||||||
pooling_param->window_h_ = poolingGrad_prim->windowH();
|
pooling_param->window_h_ = poolingGrad_prim->windowH();
|
||||||
|
@ -321,7 +322,7 @@ OpParameter *PopulateActivationGradParameter(const void *primitive) {
|
||||||
}
|
}
|
||||||
act_param->op_parameter_.type_ = schema::PrimitiveType_ActivationGrad;
|
act_param->op_parameter_.type_ = schema::PrimitiveType_ActivationGrad;
|
||||||
auto activationGrad_prim = prim->value_as_ActivationGrad();
|
auto activationGrad_prim = prim->value_as_ActivationGrad();
|
||||||
|
MS_ASSERT(activationGrad_prim != nullptr);
|
||||||
act_param->type_ = static_cast<int>(activationGrad_prim->type());
|
act_param->type_ = static_cast<int>(activationGrad_prim->type());
|
||||||
act_param->alpha_ = activationGrad_prim->alpha();
|
act_param->alpha_ = activationGrad_prim->alpha();
|
||||||
return reinterpret_cast<OpParameter *>(act_param);
|
return reinterpret_cast<OpParameter *>(act_param);
|
||||||
|
@ -342,7 +343,9 @@ OpParameter *PopulateConvolutionGradFilterParameter(const void *primitive) {
|
||||||
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropFilterFusion;
|
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropFilterFusion;
|
||||||
|
|
||||||
auto convolutionGradFilter_prim = prim->value_as_Conv2DGradFilter();
|
auto convolutionGradFilter_prim = prim->value_as_Conv2DGradFilter();
|
||||||
|
MS_ASSERT(convolutionGradFilter_prim != nullptr);
|
||||||
auto fb_vector = convolutionGradFilter_prim->filter_shape();
|
auto fb_vector = convolutionGradFilter_prim->filter_shape();
|
||||||
|
MS_ASSERT(fb_vector != nullptr);
|
||||||
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
|
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
if (filter_shape.size() > MAX_SHAPE_SIZE) {
|
if (filter_shape.size() > MAX_SHAPE_SIZE) {
|
||||||
free(param);
|
free(param);
|
||||||
|
@ -390,7 +393,9 @@ OpParameter *PopulateConvolutionGradInputParameter(const void *primitive) {
|
||||||
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion;
|
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion;
|
||||||
|
|
||||||
auto convolutionGradInput_prim = prim->value_as_Conv2DGradInput();
|
auto convolutionGradInput_prim = prim->value_as_Conv2DGradInput();
|
||||||
|
MS_ASSERT(convolutionGradInput_prim != nullptr);
|
||||||
auto fb_vector = convolutionGradInput_prim->input_shape();
|
auto fb_vector = convolutionGradInput_prim->input_shape();
|
||||||
|
MS_ASSERT(fb_vector != nullptr);
|
||||||
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
|
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
if (filter_shape.size() > MAX_SHAPE_SIZE) {
|
if (filter_shape.size() > MAX_SHAPE_SIZE) {
|
||||||
free(param);
|
free(param);
|
||||||
|
@ -438,7 +443,9 @@ OpParameter *PopulateGroupConvolutionGradInputParameter(const void *primitive) {
|
||||||
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion;
|
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion;
|
||||||
|
|
||||||
auto groupConvolutionGradInput_prim = prim->value_as_GroupConv2DGradInput();
|
auto groupConvolutionGradInput_prim = prim->value_as_GroupConv2DGradInput();
|
||||||
|
MS_ASSERT(groupConvolutionGradInput_prim != nullptr);
|
||||||
auto fb_vector = groupConvolutionGradInput_prim->input_shape();
|
auto fb_vector = groupConvolutionGradInput_prim->input_shape();
|
||||||
|
MS_ASSERT(fb_vector != nullptr);
|
||||||
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
|
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||||
if (filter_shape.size() > MAX_SHAPE_SIZE) {
|
if (filter_shape.size() > MAX_SHAPE_SIZE) {
|
||||||
free(param);
|
free(param);
|
||||||
|
@ -485,7 +492,7 @@ OpParameter *PopulatePowerGradParameter(const void *primitive) {
|
||||||
}
|
}
|
||||||
power_param->op_parameter_.type_ = schema::PrimitiveType_PowerGrad;
|
power_param->op_parameter_.type_ = schema::PrimitiveType_PowerGrad;
|
||||||
auto powerGrad_prim = prim->value_as_PowerGrad();
|
auto powerGrad_prim = prim->value_as_PowerGrad();
|
||||||
|
MS_ASSERT(powerGrad_prim != nullptr);
|
||||||
power_param->power_ = powerGrad_prim->power();
|
power_param->power_ = powerGrad_prim->power();
|
||||||
power_param->scale_ = powerGrad_prim->scale();
|
power_param->scale_ = powerGrad_prim->scale();
|
||||||
power_param->shift_ = powerGrad_prim->shift();
|
power_param->shift_ = powerGrad_prim->shift();
|
||||||
|
@ -521,7 +528,7 @@ OpParameter *PopulateBNGradParameter(const void *primitive) {
|
||||||
}
|
}
|
||||||
bnGrad_param->op_parameter_.type_ = schema::PrimitiveType_BatchNormGrad;
|
bnGrad_param->op_parameter_.type_ = schema::PrimitiveType_BatchNormGrad;
|
||||||
auto bNGrad_prim = prim->value_as_BNGrad();
|
auto bNGrad_prim = prim->value_as_BNGrad();
|
||||||
|
MS_ASSERT(bNGrad_prim != nullptr);
|
||||||
bnGrad_param->epsilon_ = bNGrad_prim->eps();
|
bnGrad_param->epsilon_ = bNGrad_prim->eps();
|
||||||
return reinterpret_cast<OpParameter *>(bnGrad_param);
|
return reinterpret_cast<OpParameter *>(bnGrad_param);
|
||||||
}
|
}
|
||||||
|
@ -536,7 +543,7 @@ OpParameter *PopulateDropoutParameter(const void *primitive) {
|
||||||
memset(dropout_parameter, 0, sizeof(DropoutParameter));
|
memset(dropout_parameter, 0, sizeof(DropoutParameter));
|
||||||
dropout_parameter->op_parameter_.type_ = schema::PrimitiveType_Dropout;
|
dropout_parameter->op_parameter_.type_ = schema::PrimitiveType_Dropout;
|
||||||
auto dropout_prim = prim->value_as_Dropout();
|
auto dropout_prim = prim->value_as_Dropout();
|
||||||
|
MS_ASSERT(dropout_prim != nullptr);
|
||||||
dropout_parameter->ratio_ = dropout_prim->ratio();
|
dropout_parameter->ratio_ = dropout_prim->ratio();
|
||||||
if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
|
if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
|
||||||
MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
|
MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
|
||||||
|
@ -556,7 +563,7 @@ OpParameter *PopulateDropoutGradParameter(const void *primitive) {
|
||||||
memset(dropoutGrad_parameter, 0, sizeof(DropoutParameter));
|
memset(dropoutGrad_parameter, 0, sizeof(DropoutParameter));
|
||||||
dropoutGrad_parameter->op_parameter_.type_ = schema::PrimitiveType_DropoutGrad;
|
dropoutGrad_parameter->op_parameter_.type_ = schema::PrimitiveType_DropoutGrad;
|
||||||
auto dropoutGrad_prim = prim->value_as_DropoutGrad();
|
auto dropoutGrad_prim = prim->value_as_DropoutGrad();
|
||||||
|
MS_ASSERT(dropoutGrad_prim != nullptr);
|
||||||
dropoutGrad_parameter->ratio_ = dropoutGrad_prim->ratio();
|
dropoutGrad_parameter->ratio_ = dropoutGrad_prim->ratio();
|
||||||
if (dropoutGrad_parameter->ratio_ < 0.f || dropoutGrad_parameter->ratio_ > 1.f) {
|
if (dropoutGrad_parameter->ratio_ < 0.f || dropoutGrad_parameter->ratio_ > 1.f) {
|
||||||
MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutGrad_parameter->ratio_;
|
MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutGrad_parameter->ratio_;
|
||||||
|
|
|
@ -73,6 +73,7 @@ int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Ten
|
||||||
if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
|
if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
|
||||||
return RET_NO_CHANGE;
|
return RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
|
MS_ASSERT(src_tensor.data() != nullptr);
|
||||||
auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
|
auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
|
||||||
MS_ASSERT(data != nullptr);
|
MS_ASSERT(data != nullptr);
|
||||||
std::string encode_str(data, src_tensor.data()->size());
|
std::string encode_str(data, src_tensor.data()->size());
|
||||||
|
|
|
@ -143,7 +143,7 @@ class WeightDecoder {
|
||||||
if (is_last && remainder > 0) {
|
if (is_last && remainder > 0) {
|
||||||
for (size_t i = 0; i < remainder; i++) {
|
for (size_t i = 0; i < remainder; i++) {
|
||||||
bool bit = unpack_bit_data->front();
|
bool bit = unpack_bit_data->front();
|
||||||
uint_result = (static_cast<int>(bit) << i) + uint_result;
|
uint_result = (static_cast<unsigned int>(bit) << i) + uint_result;
|
||||||
unpack_bit_data->pop();
|
unpack_bit_data->pop();
|
||||||
}
|
}
|
||||||
result = static_cast<T1>(uint_result - static_cast<T2>(pow(2, origin_bit - 1)));
|
result = static_cast<T1>(uint_result - static_cast<T2>(pow(2, origin_bit - 1)));
|
||||||
|
|
Loading…
Reference in New Issue