fix getdatasetsize error II
This commit is contained in:
parent
669a8969c7
commit
57eab288cd
|
@ -16,6 +16,7 @@
|
|||
from collections.abc import Iterable
|
||||
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
|
@ -402,7 +403,7 @@ class Model:
|
|||
if sink_size == -1:
|
||||
epoch_num = epoch
|
||||
else:
|
||||
epoch_num = epoch * sink_size // train_dataset.get_dataset_size()
|
||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
|
|
Loading…
Reference in New Issue