diff --git a/mindspore/train/model.py b/mindspore/train/model.py index c5a3a1147b..844480d20d 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -16,6 +16,7 @@ from collections.abc import Iterable import os +import math import numpy as np from mindspore import log as logger @@ -402,7 +403,7 @@ class Model: if sink_size == -1: epoch_num = epoch else: - epoch_num = epoch * sink_size // train_dataset.get_dataset_size() + epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) dataset_helper, train_network = self._exec_preprocess(self._train_network, is_train=True,