From eb6d87eb8d20cbb7957e591a4ef2993a3b98bf55 Mon Sep 17 00:00:00 2001 From: zhangxinfeng3 Date: Tue, 8 Dec 2020 16:56:02 +0800 Subject: [PATCH] fix the bug of toolbox --- .../toolbox/uncertainty_evaluation.py | 19 +++++++++++++------ model_zoo/README.md | 3 ++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index b624e079c8b..dcbca97761e 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -61,6 +61,7 @@ class UncertaintyEvaluation: >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') >>> load_param_into_net(network, param_dict) >>> ds_train = create_dataset('workspace/mnist/train') + >>> ds_eval = create_dataset('workspace/mnist/test') >>> evaluation = UncertaintyEvaluation(model=network, ... train_dataset=ds_train, ... task_type='classification', @@ -69,8 +70,10 @@ class UncertaintyEvaluation: ... epi_uncer_model_path=None, ... ale_uncer_model_path=None, ... save_model=False) - >>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) - >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) + >>> for eval_data in ds_eval.create_dict_iterator(output_numpy=True, num_epochs=1): + ... eval_data = Tensor(eval_data['image'], mstype.float32) + ... epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) + ... aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) >>> output = epistemic_uncertainty.shape >>> print(output) (32, 10) @@ -128,9 +131,11 @@ class UncertaintyEvaluation: ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model', directory=self.epi_uncer_model_path, config=config_ck) - model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()]) + model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False, + callbacks=[ckpoint_cb, LossMonitor()]) elif self.epi_uncer_model_path is None: - model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()]) + model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False, + callbacks=[LossMonitor()]) else: uncer_param_dict = load_checkpoint(self.epi_uncer_model_path) load_param_into_net(self.epi_uncer_model, uncer_param_dict) @@ -173,9 +178,11 @@ class UncertaintyEvaluation: ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model', directory=self.ale_uncer_model_path, config=config_ck) - model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()]) + model.train(self.epochs, self.ale_train_dataset, dataset_sink_mode=False, + callbacks=[ckpoint_cb, LossMonitor()]) elif self.ale_uncer_model_path is None: - model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()]) + model.train(self.epochs, self.ale_train_dataset, dataset_sink_mode=False, + callbacks=[LossMonitor()]) else: uncer_param_dict = load_checkpoint(self.ale_uncer_model_path) load_param_into_net(self.ale_uncer_model, uncer_param_dict) diff --git a/model_zoo/README.md b/model_zoo/README.md index 695a93ad15f..0f6ebeac943 100644 --- a/model_zoo/README.md +++ b/model_zoo/README.md @@ -64,7 +64,8 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework, - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) - [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md) - [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc) - - [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ocean_model/README.md) + - [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md) + - [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md) - [Community](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)