forked from mindspore-Ecosystem/mindspore
!9648 Fix the bug of toolbox
From: @zhangxinfeng3 Reviewed-by: @sunnybeike,@zichun_ye Signed-off-by: @sunnybeike
This commit is contained in:
commit
56ce0f4a27
|
@ -61,6 +61,7 @@ class UncertaintyEvaluation:
|
|||
>>> param_dict = load_checkpoint('checkpoint_lenet.ckpt')
|
||||
>>> load_param_into_net(network, param_dict)
|
||||
>>> ds_train = create_dataset('workspace/mnist/train')
|
||||
>>> ds_eval = create_dataset('workspace/mnist/test')
|
||||
>>> evaluation = UncertaintyEvaluation(model=network,
|
||||
... train_dataset=ds_train,
|
||||
... task_type='classification',
|
||||
|
@ -69,8 +70,10 @@ class UncertaintyEvaluation:
|
|||
... epi_uncer_model_path=None,
|
||||
... ale_uncer_model_path=None,
|
||||
... save_model=False)
|
||||
>>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
|
||||
>>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
||||
>>> for eval_data in ds_eval.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
... eval_data = Tensor(eval_data['image'], mstype.float32)
|
||||
... epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)
|
||||
... aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)
|
||||
>>> output = epistemic_uncertainty.shape
|
||||
>>> print(output)
|
||||
(32, 10)
|
||||
|
@ -128,9 +131,11 @@ class UncertaintyEvaluation:
|
|||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_epi_uncer_model',
|
||||
directory=self.epi_uncer_model_path,
|
||||
config=config_ck)
|
||||
model.train(self.epochs, self.epi_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
||||
model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False,
|
||||
callbacks=[ckpoint_cb, LossMonitor()])
|
||||
elif self.epi_uncer_model_path is None:
|
||||
model.train(self.epochs, self.epi_train_dataset, callbacks=[LossMonitor()])
|
||||
model.train(self.epochs, self.epi_train_dataset, dataset_sink_mode=False,
|
||||
callbacks=[LossMonitor()])
|
||||
else:
|
||||
uncer_param_dict = load_checkpoint(self.epi_uncer_model_path)
|
||||
load_param_into_net(self.epi_uncer_model, uncer_param_dict)
|
||||
|
@ -173,9 +178,11 @@ class UncertaintyEvaluation:
|
|||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_ale_uncer_model',
|
||||
directory=self.ale_uncer_model_path,
|
||||
config=config_ck)
|
||||
model.train(self.epochs, self.ale_train_dataset, callbacks=[ckpoint_cb, LossMonitor()])
|
||||
model.train(self.epochs, self.ale_train_dataset, dataset_sink_mode=False,
|
||||
callbacks=[ckpoint_cb, LossMonitor()])
|
||||
elif self.ale_uncer_model_path is None:
|
||||
model.train(self.epochs, self.ale_train_dataset, callbacks=[LossMonitor()])
|
||||
model.train(self.epochs, self.ale_train_dataset, dataset_sink_mode=False,
|
||||
callbacks=[LossMonitor()])
|
||||
else:
|
||||
uncer_param_dict = load_checkpoint(self.ale_uncer_model_path)
|
||||
load_param_into_net(self.ale_uncer_model, uncer_param_dict)
|
||||
|
|
|
@ -64,7 +64,8 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ocean_model/README.md)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model/README.md)
|
||||
- [Molecular_Dynamics](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/molecular_dynamics/README.md)
|
||||
|
||||
- [Community](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)
|
||||
|
||||
|
|
Loading…
Reference in New Issue