!9648 Fix the bug of toolbox

From: @zhangxinfeng3
Reviewed-by: @sunnybeike,@zichun_ye
Signed-off-by: @sunnybeike
This commit is contained in:
mindspore-ci-bot 2020-12-09 09:29:38 +08:00 committed by Gitee
commit 56ce0f4a27
2 changed files with 15 additions and 7 deletions

View File

@ -61,6 +61,7 @@ class UncertaintyEvaluation:
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
>>> load_param_into_net(network, param_dict)
>>> ds_train = create_dataset('workspace/mnist/train')
>>> ds_eval = create_dataset('workspace/mnist/test')
>>> evaluation = UncertaintyEvaluation(model=network,
... train_dataset=ds_train,
... task_type='classification',
@ -69,8 +70,10 @@ class UncertaintyEvaluation:
... epi_uncer_model_path=None,
... ale_uncer_model_path=None,
... save_model=False)
>>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
>>> for eval_data in ds_eval.create_dict_iterator(output_numpy=True, num_epochs=1):
... eval_data = Tensor(eval_data['image'], mstype.float32)
... epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
... aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
>>> output = epistemic_uncertainty.shape
>>> print(output)
(32, 10)
@ -128,9 +131,11 @@ class UncertaintyEvaluation:
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
directory=self.epi_uncer_model_path,
config=config_ck)
model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False,
callbacks=[ckpoint_cb, LossMonitor()])
elif self.epi_uncer_model_path is None:
model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False,
callbacks=[LossMonitor()])
else:
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
@ -173,9 +178,11 @@ class UncertaintyEvaluation:
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
directory=self.ale_uncer_model_path,
config=config_ck)
model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
model.train(self.epochs, self.ale_train_dataset, dataset_sink_mode=False,
callbacks=[ckpoint_cb, LossMonitor()])
elif self.ale_uncer_model_path is None:
model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
model.train(self.epochs, self.ale_train_dataset, dataset_sink_mode=False,
callbacks=[LossMonitor()])
else:
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
load_param_into_net(self.ale_uncer_model, uncer_param_dict)

View File

@ -64,7 +64,8 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ocean_model/README.md)
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)
- [Community](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)