forked from mindspore-Ecosystem/mindspore
!498 fix bug in model eval and predict
Merge pull request !498 from wangnan39/fix_bug_in_model_eval_and_model_predict
This commit is contained in:
commit
f129d31bd9
|
@ -108,6 +108,7 @@ class Model:
|
|||
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
|
@ -153,6 +154,12 @@ class Model:
|
|||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn)
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
def _build_predict_network(self):
|
||||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
|
||||
def _clear_metrics(self):
|
||||
"""Clear metrics local values."""
|
||||
for metric in self._metric_fns.values():
|
||||
|
@ -470,6 +477,7 @@ class Model:
|
|||
|
||||
dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False)
|
||||
for next_element in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._eval_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
|
@ -549,12 +557,9 @@ class Model:
|
|||
>>> model = Model(Net())
|
||||
>>> model.predict(input_data)
|
||||
"""
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._network = _VirtualDatasetCell(self._network)
|
||||
|
||||
self._network.set_train(False)
|
||||
self._predict_network.set_train(False)
|
||||
check_input_data(*predict_data, data_class=Tensor)
|
||||
result = self._network(*predict_data)
|
||||
result = self._predict_network(*predict_data)
|
||||
|
||||
check_output_data(result)
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue