From: @xutianchun
Reviewed-by: @HilbertDavid,@zhanghaibo5
Signed-off-by: @HilbertDavid
This commit is contained in:
mindspore-ci-bot 2021-04-27 09:33:10 +08:00 committed by Gitee
commit 784dc36a79
9 changed files with 46 additions and 20 deletions

View File

@ -30,7 +30,7 @@ class AccuracyMonitor : public session::TrainLoopCallBack {
public: public:
explicit AccuracyMonitor(mindspore::dataset::Dataset *dataset, int check_every_n, int max_steps = -1) explicit AccuracyMonitor(mindspore::dataset::Dataset *dataset, int check_every_n, int max_steps = -1)
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {} : ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
~AccuracyMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override; void Begin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override; int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; } const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }

View File

@ -32,6 +32,8 @@ class CkptSaver : public session::TrainLoopCallBack {
CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model) CkptSaver(int save_every_n, const std::string &filename_prefix, mindspore::lite::Model *model)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {} : save_every_n_(save_every_n), filename_prefix_(filename_prefix), model_(model) {}
~CkptSaver() = default;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";

View File

@ -56,7 +56,7 @@ class TrainLoop {
/// \brief Accessor to the TrainSession /// \brief Accessor to the TrainSession
/// ///
/// \return pointer of the train_session /// \return pointer of the train_session
virtual session::TrainSession *train_session() = 0; const virtual session::TrainSession *train_session() = 0;
/// \brief Initialize object with metrics /// \brief Initialize object with metrics
/// ///

View File

@ -151,7 +151,7 @@ public class TrainSession {
} }
public boolean setLossName(String lossName) { public boolean setLossName(String lossName) {
return this.setLossName(this.sessionPtr,lossName); return this.setLossName(this.sessionPtr, lossName);
} }
@ -191,5 +191,5 @@ public class TrainSession {
private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum); private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum);
private native boolean setLossName(long sessionPtr,String lossName); private native boolean setLossName(long sessionPtr, String lossName);
} }

View File

@ -34,7 +34,7 @@ class TrainLoop : virtual public session::TrainLoop {
public: public:
explicit TrainLoop(session::TrainSession *session) : train_session_(session) {} explicit TrainLoop(session::TrainSession *session) : train_session_(session) {}
session::TrainSession *train_session() override { return train_session_; } const session::TrainSession *train_session() override { return train_session_; }
int Reset() override { int Reset() override {
epoch_ = 0; epoch_ = 0;

View File

@ -40,6 +40,7 @@ OpParameter *PopulateSmoothL1LossParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SmoothL1Loss(); auto value = primitive->value_as_SmoothL1Loss();
MS_ASSERT(value != nullptr);
p->op_parameter_.type_ = primitive->value_type(); p->op_parameter_.type_ = primitive->value_type();
p->beta_ = value->beta(); p->beta_ = value->beta();
return reinterpret_cast<OpParameter *>(p); return reinterpret_cast<OpParameter *>(p);
@ -53,6 +54,7 @@ OpParameter *PopulateSmoothL1LossGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SmoothL1LossGrad(); auto value = primitive->value_as_SmoothL1LossGrad();
MS_ASSERT(value != nullptr);
p->op_parameter_.type_ = primitive->value_type(); p->op_parameter_.type_ = primitive->value_type();
p->beta_ = value->beta(); p->beta_ = value->beta();
return reinterpret_cast<OpParameter *>(p); return reinterpret_cast<OpParameter *>(p);
@ -92,6 +94,7 @@ OpParameter *PopulateBCEGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_BinaryCrossEntropyGrad(); auto value = primitive->value_as_BinaryCrossEntropyGrad();
MS_ASSERT(value != nullptr);
*reduction = value->reduction(); *reduction = value->reduction();
return reinterpret_cast<OpParameter *>(reduction); return reinterpret_cast<OpParameter *>(reduction);
} }
@ -104,6 +107,7 @@ OpParameter *PopulateAdamParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_Adam(); auto value = primitive->value_as_Adam();
MS_ASSERT(value != nullptr);
p->op_parameter_.type_ = primitive->value_type(); p->op_parameter_.type_ = primitive->value_type();
p->use_nesterov_ = value->use_nesterov(); p->use_nesterov_ = value->use_nesterov();
return reinterpret_cast<OpParameter *>(p); return reinterpret_cast<OpParameter *>(p);
@ -117,6 +121,7 @@ OpParameter *PopulateSgdParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SGD(); auto value = primitive->value_as_SGD();
MS_ASSERT(value != nullptr);
p->op_parameter_.type_ = primitive->value_type(); p->op_parameter_.type_ = primitive->value_type();
p->weight_decay_ = value->weight_decay(); p->weight_decay_ = value->weight_decay();
p->dampening_ = value->dampening(); p->dampening_ = value->dampening();
@ -134,6 +139,7 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyWithLogitsParameter(const void *pr
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits(); auto value = primitive->value_as_SparseSoftmaxCrossEntropyWithLogits();
MS_ASSERT(value != nullptr);
sce_param->op_parameter_.type_ = primitive->value_type(); sce_param->op_parameter_.type_ = primitive->value_type();
sce_param->is_grad_ = value->is_grad(); sce_param->is_grad_ = value->is_grad();
return reinterpret_cast<OpParameter *>(sce_param); return reinterpret_cast<OpParameter *>(sce_param);
@ -160,6 +166,7 @@ OpParameter *PopulateMaxPoolGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_MaxPoolGrad(); auto value = primitive->value_as_MaxPoolGrad();
MS_ASSERT(value != nullptr);
pooling_param->op_parameter_.type_ = primitive->value_type(); pooling_param->op_parameter_.type_ = primitive->value_type();
pooling_param->global_ = false; pooling_param->global_ = false;
@ -197,6 +204,7 @@ OpParameter *PopulateAvgPoolGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_AvgPoolGrad(); auto value = primitive->value_as_AvgPoolGrad();
MS_ASSERT(value != nullptr);
pooling_param->op_parameter_.type_ = primitive->value_type(); pooling_param->op_parameter_.type_ = primitive->value_type();
pooling_param->global_ = false; pooling_param->global_ = false;
@ -245,6 +253,7 @@ OpParameter *PopulateActivationGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_ActivationGrad(); auto value = primitive->value_as_ActivationGrad();
MS_ASSERT(value != nullptr);
act_param->op_parameter_.type_ = primitive->value_type(); act_param->op_parameter_.type_ = primitive->value_type();
act_param->type_ = static_cast<int>(value->activation_type()); act_param->type_ = static_cast<int>(value->activation_type());
act_param->alpha_ = value->alpha(); act_param->alpha_ = value->alpha();
@ -260,6 +269,7 @@ OpParameter *PopulateConvolutionGradFilterParameter(const void *prim) {
memset(param, 0, sizeof(ConvParameter)); memset(param, 0, sizeof(ConvParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_Conv2DBackpropFilterFusion(); auto value = primitive->value_as_Conv2DBackpropFilterFusion();
MS_ASSERT(value != nullptr);
param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->kernel_h_ = value->kernel_size()->Get(0); param->kernel_h_ = value->kernel_size()->Get(0);
@ -296,6 +306,7 @@ OpParameter *PopulateConvolutionGradInputParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_Conv2DBackpropInputFusion(); auto value = primitive->value_as_Conv2DBackpropInputFusion();
MS_ASSERT(value != nullptr);
param->op_parameter_.type_ = primitive->value_type(); param->op_parameter_.type_ = primitive->value_type();
param->kernel_h_ = value->kernel_size()->Get(0); param->kernel_h_ = value->kernel_size()->Get(0);
@ -332,6 +343,7 @@ OpParameter *PopulatePowerGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_PowerGrad(); auto value = primitive->value_as_PowerGrad();
MS_ASSERT(value != nullptr);
power_param->op_parameter_.type_ = primitive->value_type(); power_param->op_parameter_.type_ = primitive->value_type();
power_param->power_ = value->power(); power_param->power_ = value->power();
power_param->scale_ = value->scale(); power_param->scale_ = value->scale();
@ -358,6 +370,7 @@ OpParameter *PopulateBNGradParameter(const void *prim) {
} }
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_BatchNormGrad(); auto value = primitive->value_as_BatchNormGrad();
MS_ASSERT(value != nullptr);
bnGrad_param->op_parameter_.type_ = primitive->value_type(); bnGrad_param->op_parameter_.type_ = primitive->value_type();
bnGrad_param->epsilon_ = value->epsilon(); bnGrad_param->epsilon_ = value->epsilon();
return reinterpret_cast<OpParameter *>(bnGrad_param); return reinterpret_cast<OpParameter *>(bnGrad_param);
@ -372,6 +385,7 @@ OpParameter *PopulateDropoutParameter(const void *prim) {
memset(dropout_parameter, 0, sizeof(DropoutParameter)); memset(dropout_parameter, 0, sizeof(DropoutParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_Dropout(); auto value = primitive->value_as_Dropout();
MS_ASSERT(value != nullptr);
dropout_parameter->op_parameter_.type_ = primitive->value_type(); dropout_parameter->op_parameter_.type_ = primitive->value_type();
dropout_parameter->ratio_ = value->keep_prob(); dropout_parameter->ratio_ = value->keep_prob();
if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) { if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
@ -391,6 +405,7 @@ OpParameter *PopulateDropoutGradParameter(const void *prim) {
memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter)); memset(dropoutgrad_parameter, 0, sizeof(DropoutParameter));
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_DropoutGrad(); auto value = primitive->value_as_DropoutGrad();
MS_ASSERT(value != nullptr);
dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type(); dropoutgrad_parameter->op_parameter_.type_ = primitive->value_type();
dropoutgrad_parameter->ratio_ = value->keep_prob(); dropoutgrad_parameter->ratio_ = value->keep_prob();
if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) { if (dropoutgrad_parameter->ratio_ < 0.f || dropoutgrad_parameter->ratio_ > 1.f) {
@ -423,7 +438,7 @@ OpParameter *PopulateResizeGradParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
resize_grad_param->op_parameter_.type_ = primitive->value_type(); resize_grad_param->op_parameter_.type_ = primitive->value_type();
auto param = primitive->value_as_ResizeGrad(); auto param = primitive->value_as_ResizeGrad();
MS_ASSERT(param != nullptr);
resize_grad_param->method = static_cast<int>(param->method()); resize_grad_param->method = static_cast<int>(param->method());
resize_grad_param->align_corners_ = param->align_corners(); resize_grad_param->align_corners_ = param->align_corners();
@ -441,6 +456,7 @@ OpParameter *PopulateStridedSliceGradParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim); auto primitive = static_cast<const schema::Primitive *>(prim);
auto value = primitive->value_as_StridedSliceGrad(); auto value = primitive->value_as_StridedSliceGrad();
MS_ASSERT(value != nullptr);
strided_slice_param->op_parameter_.type_ = primitive->value_type(); strided_slice_param->op_parameter_.type_ = primitive->value_type();
strided_slice_param->begins_mask_ = value->begin_mask(); strided_slice_param->begins_mask_ = value->begin_mask();

View File

@ -104,7 +104,7 @@ OpParameter *PopulateSmoothL1LossParameter(const void *primitive) {
p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1Loss; p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1Loss;
auto smoothL1Loss_prim = prim->value_as_SmoothL1Loss(); auto smoothL1Loss_prim = prim->value_as_SmoothL1Loss();
MS_ASSERT(smoothL1Loss_prim != nullptr);
p->beta_ = smoothL1Loss_prim->beta(); p->beta_ = smoothL1Loss_prim->beta();
return reinterpret_cast<OpParameter *>(p); return reinterpret_cast<OpParameter *>(p);
} }
@ -123,7 +123,7 @@ OpParameter *PopulateSmoothL1LossGradParameter(const void *primitive) {
p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1LossGrad; p->op_parameter_.type_ = schema::PrimitiveType_SmoothL1LossGrad;
auto smoothL1LossGrad_prim = prim->value_as_SmoothL1LossGrad(); auto smoothL1LossGrad_prim = prim->value_as_SmoothL1LossGrad();
MS_ASSERT(smoothL1LossGrad_prim != nullptr);
p->beta_ = smoothL1LossGrad_prim->beta(); p->beta_ = smoothL1LossGrad_prim->beta();
return reinterpret_cast<OpParameter *>(p); return reinterpret_cast<OpParameter *>(p);
} }
@ -142,7 +142,7 @@ OpParameter *PopulateApplyMomentumParameter(const void *primitive) {
p->op_parameter_.type_ = schema::PrimitiveType_ApplyMomentum; p->op_parameter_.type_ = schema::PrimitiveType_ApplyMomentum;
auto applyMomentum_prim = prim->value_as_ApplyMomentum(); auto applyMomentum_prim = prim->value_as_ApplyMomentum();
MS_ASSERT(applyMomentum_prim != nullptr);
p->grad_scale_ = applyMomentum_prim->gradientScale(); p->grad_scale_ = applyMomentum_prim->gradientScale();
p->use_nesterov_ = applyMomentum_prim->useNesterov(); p->use_nesterov_ = applyMomentum_prim->useNesterov();
@ -157,6 +157,7 @@ OpParameter *PopulateBCEParameter(const void *primitive) {
return nullptr; return nullptr;
} }
auto bCE_prim = prim->value_as_BinaryCrossEntropy(); auto bCE_prim = prim->value_as_BinaryCrossEntropy();
MS_ASSERT(bCE_prim != nullptr);
*reduction = bCE_prim->reduction(); *reduction = bCE_prim->reduction();
return reinterpret_cast<OpParameter *>(reduction); return reinterpret_cast<OpParameter *>(reduction);
} }
@ -169,7 +170,7 @@ OpParameter *PopulateBCEGradParameter(const void *primitive) {
return nullptr; return nullptr;
} }
auto bCEGrad_prim = prim->value_as_BinaryCrossEntropyGrad(); auto bCEGrad_prim = prim->value_as_BinaryCrossEntropyGrad();
MS_ASSERT(bCEGrad_prim != nullptr);
*reduction = bCEGrad_prim->reduction(); *reduction = bCEGrad_prim->reduction();
return reinterpret_cast<OpParameter *>(reduction); return reinterpret_cast<OpParameter *>(reduction);
} }
@ -188,7 +189,7 @@ OpParameter *PopulateAdamParameter(const void *primitive) {
p->op_parameter_.type_ = schema::PrimitiveType_Adam; p->op_parameter_.type_ = schema::PrimitiveType_Adam;
auto adam_prim = prim->value_as_Adam(); auto adam_prim = prim->value_as_Adam();
MS_ASSERT(adam_prim != nullptr);
p->use_nesterov_ = adam_prim->useNesterov(); p->use_nesterov_ = adam_prim->useNesterov();
return reinterpret_cast<OpParameter *>(p); return reinterpret_cast<OpParameter *>(p);
} }
@ -207,7 +208,7 @@ OpParameter *PopulateSgdParameter(const void *primitive) {
p->op_parameter_.type_ = schema::PrimitiveType_SGD; p->op_parameter_.type_ = schema::PrimitiveType_SGD;
auto sgd_prim = prim->value_as_Sgd(); auto sgd_prim = prim->value_as_Sgd();
MS_ASSERT(sgd_prim != nullptr);
p->weight_decay_ = sgd_prim->weightDecay(); p->weight_decay_ = sgd_prim->weightDecay();
p->dampening_ = sgd_prim->dampening(); p->dampening_ = sgd_prim->dampening();
p->use_nesterov_ = sgd_prim->useNesterov(); p->use_nesterov_ = sgd_prim->useNesterov();
@ -228,7 +229,7 @@ OpParameter *PopulateSparseSoftmaxCrossEntropyParameter(const void *primitive) {
return nullptr; return nullptr;
} }
auto sparseSoftmaxCrossEntropy_prim = prim->value_as_SparseSoftmaxCrossEntropy(); auto sparseSoftmaxCrossEntropy_prim = prim->value_as_SparseSoftmaxCrossEntropy();
MS_ASSERT(sparseSoftmaxCrossEntropy_prim != nullptr);
sce_param->is_grad_ = sparseSoftmaxCrossEntropy_prim->isGrad(); sce_param->is_grad_ = sparseSoftmaxCrossEntropy_prim->isGrad();
sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits; sce_param->op_parameter_.type_ = schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
@ -264,7 +265,7 @@ OpParameter *PopulatePoolingGradParameter(const void *primitive) {
} }
auto poolingGrad_prim = prim->value_as_PoolingGrad(); auto poolingGrad_prim = prim->value_as_PoolingGrad();
MS_ASSERT(poolingGrad_prim != nullptr);
pooling_param->global_ = poolingGrad_prim->global(); pooling_param->global_ = poolingGrad_prim->global();
pooling_param->window_w_ = poolingGrad_prim->windowW(); pooling_param->window_w_ = poolingGrad_prim->windowW();
pooling_param->window_h_ = poolingGrad_prim->windowH(); pooling_param->window_h_ = poolingGrad_prim->windowH();
@ -321,7 +322,7 @@ OpParameter *PopulateActivationGradParameter(const void *primitive) {
} }
act_param->op_parameter_.type_ = schema::PrimitiveType_ActivationGrad; act_param->op_parameter_.type_ = schema::PrimitiveType_ActivationGrad;
auto activationGrad_prim = prim->value_as_ActivationGrad(); auto activationGrad_prim = prim->value_as_ActivationGrad();
MS_ASSERT(activationGrad_prim != nullptr);
act_param->type_ = static_cast<int>(activationGrad_prim->type()); act_param->type_ = static_cast<int>(activationGrad_prim->type());
act_param->alpha_ = activationGrad_prim->alpha(); act_param->alpha_ = activationGrad_prim->alpha();
return reinterpret_cast<OpParameter *>(act_param); return reinterpret_cast<OpParameter *>(act_param);
@ -342,7 +343,9 @@ OpParameter *PopulateConvolutionGradFilterParameter(const void *primitive) {
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropFilterFusion; param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropFilterFusion;
auto convolutionGradFilter_prim = prim->value_as_Conv2DGradFilter(); auto convolutionGradFilter_prim = prim->value_as_Conv2DGradFilter();
MS_ASSERT(convolutionGradFilter_prim != nullptr);
auto fb_vector = convolutionGradFilter_prim->filter_shape(); auto fb_vector = convolutionGradFilter_prim->filter_shape();
MS_ASSERT(fb_vector != nullptr);
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end()); auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
if (filter_shape.size() > MAX_SHAPE_SIZE) { if (filter_shape.size() > MAX_SHAPE_SIZE) {
free(param); free(param);
@ -390,7 +393,9 @@ OpParameter *PopulateConvolutionGradInputParameter(const void *primitive) {
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion; param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion;
auto convolutionGradInput_prim = prim->value_as_Conv2DGradInput(); auto convolutionGradInput_prim = prim->value_as_Conv2DGradInput();
MS_ASSERT(convolutionGradInput_prim != nullptr);
auto fb_vector = convolutionGradInput_prim->input_shape(); auto fb_vector = convolutionGradInput_prim->input_shape();
MS_ASSERT(fb_vector != nullptr);
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end()); auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
if (filter_shape.size() > MAX_SHAPE_SIZE) { if (filter_shape.size() > MAX_SHAPE_SIZE) {
free(param); free(param);
@ -438,7 +443,9 @@ OpParameter *PopulateGroupConvolutionGradInputParameter(const void *primitive) {
param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion; param->op_parameter_.type_ = schema::PrimitiveType_Conv2DBackpropInputFusion;
auto groupConvolutionGradInput_prim = prim->value_as_GroupConv2DGradInput(); auto groupConvolutionGradInput_prim = prim->value_as_GroupConv2DGradInput();
MS_ASSERT(groupConvolutionGradInput_prim != nullptr);
auto fb_vector = groupConvolutionGradInput_prim->input_shape(); auto fb_vector = groupConvolutionGradInput_prim->input_shape();
MS_ASSERT(fb_vector != nullptr);
auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end()); auto filter_shape = std::vector<int>(fb_vector->begin(), fb_vector->end());
if (filter_shape.size() > MAX_SHAPE_SIZE) { if (filter_shape.size() > MAX_SHAPE_SIZE) {
free(param); free(param);
@ -485,7 +492,7 @@ OpParameter *PopulatePowerGradParameter(const void *primitive) {
} }
power_param->op_parameter_.type_ = schema::PrimitiveType_PowerGrad; power_param->op_parameter_.type_ = schema::PrimitiveType_PowerGrad;
auto powerGrad_prim = prim->value_as_PowerGrad(); auto powerGrad_prim = prim->value_as_PowerGrad();
MS_ASSERT(powerGrad_prim != nullptr);
power_param->power_ = powerGrad_prim->power(); power_param->power_ = powerGrad_prim->power();
power_param->scale_ = powerGrad_prim->scale(); power_param->scale_ = powerGrad_prim->scale();
power_param->shift_ = powerGrad_prim->shift(); power_param->shift_ = powerGrad_prim->shift();
@ -521,7 +528,7 @@ OpParameter *PopulateBNGradParameter(const void *primitive) {
} }
bnGrad_param->op_parameter_.type_ = schema::PrimitiveType_BatchNormGrad; bnGrad_param->op_parameter_.type_ = schema::PrimitiveType_BatchNormGrad;
auto bNGrad_prim = prim->value_as_BNGrad(); auto bNGrad_prim = prim->value_as_BNGrad();
MS_ASSERT(bNGrad_prim != nullptr);
bnGrad_param->epsilon_ = bNGrad_prim->eps(); bnGrad_param->epsilon_ = bNGrad_prim->eps();
return reinterpret_cast<OpParameter *>(bnGrad_param); return reinterpret_cast<OpParameter *>(bnGrad_param);
} }
@ -536,7 +543,7 @@ OpParameter *PopulateDropoutParameter(const void *primitive) {
memset(dropout_parameter, 0, sizeof(DropoutParameter)); memset(dropout_parameter, 0, sizeof(DropoutParameter));
dropout_parameter->op_parameter_.type_ = schema::PrimitiveType_Dropout; dropout_parameter->op_parameter_.type_ = schema::PrimitiveType_Dropout;
auto dropout_prim = prim->value_as_Dropout(); auto dropout_prim = prim->value_as_Dropout();
MS_ASSERT(dropout_prim != nullptr);
dropout_parameter->ratio_ = dropout_prim->ratio(); dropout_parameter->ratio_ = dropout_prim->ratio();
if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) { if (dropout_parameter->ratio_ < 0.f || dropout_parameter->ratio_ > 1.f) {
MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_; MS_LOG(ERROR) << "Dropout ratio must be between 0 to 1, got " << dropout_parameter->ratio_;
@ -556,7 +563,7 @@ OpParameter *PopulateDropoutGradParameter(const void *primitive) {
memset(dropoutGrad_parameter, 0, sizeof(DropoutParameter)); memset(dropoutGrad_parameter, 0, sizeof(DropoutParameter));
dropoutGrad_parameter->op_parameter_.type_ = schema::PrimitiveType_DropoutGrad; dropoutGrad_parameter->op_parameter_.type_ = schema::PrimitiveType_DropoutGrad;
auto dropoutGrad_prim = prim->value_as_DropoutGrad(); auto dropoutGrad_prim = prim->value_as_DropoutGrad();
MS_ASSERT(dropoutGrad_prim != nullptr);
dropoutGrad_parameter->ratio_ = dropoutGrad_prim->ratio(); dropoutGrad_parameter->ratio_ = dropoutGrad_prim->ratio();
if (dropoutGrad_parameter->ratio_ < 0.f || dropoutGrad_parameter->ratio_ > 1.f) { if (dropoutGrad_parameter->ratio_ < 0.f || dropoutGrad_parameter->ratio_ > 1.f) {
MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutGrad_parameter->ratio_; MS_LOG(ERROR) << "Dropout Grad ratio must be between 0 to 1, got " << dropoutGrad_parameter->ratio_;

View File

@ -73,6 +73,7 @@ int WeightDecoder::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Ten
if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) { if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
return RET_NO_CHANGE; return RET_NO_CHANGE;
} }
MS_ASSERT(src_tensor.data() != nullptr);
auto data = reinterpret_cast<const char *>(src_tensor.data()->data()); auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
MS_ASSERT(data != nullptr); MS_ASSERT(data != nullptr);
std::string encode_str(data, src_tensor.data()->size()); std::string encode_str(data, src_tensor.data()->size());

View File

@ -143,7 +143,7 @@ class WeightDecoder {
if (is_last && remainder > 0) { if (is_last && remainder > 0) {
for (size_t i = 0; i < remainder; i++) { for (size_t i = 0; i < remainder; i++) {
bool bit = unpack_bit_data->front(); bool bit = unpack_bit_data->front();
uint_result = (static_cast<int>(bit) << i) + uint_result; uint_result = (static_cast<unsigned int>(bit) << i) + uint_result;
unpack_bit_data->pop(); unpack_bit_data->pop();
} }
result = static_cast<T1>(uint_result - static_cast<T2>(pow(2, origin_bit - 1))); result = static_cast<T1>(uint_result - static_cast<T2>(pow(2, origin_bit - 1)));