From a32811e16060277b259611a638aabdfd41c2366b Mon Sep 17 00:00:00 2001 From: wuweikang Date: Tue, 29 Sep 2020 15:16:52 +0800 Subject: [PATCH] modify endofsequence for multi-machine --- mindspore/train/model.py | 2 ++ tests/dataset_mock.py | 2 +- tests/ut/python/parallel/test_loss_scale.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 32da5a9e012..dd2fdfab670 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -548,6 +548,8 @@ class Model: >>> model.train(2, dataset) """ check_bool(dataset_sink_mode) + if sink_size == -1: + sink_size = train_dataset.get_dataset_size() check_int(sink_size) if sink_size < -1 or sink_size == 0: raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 0ee0a90813c..ee9af3c40f3 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -21,7 +21,7 @@ from mindspore import Tensor class MindData: """ Stub for MindData """ - def __init__(self, size=None, batch_size=None, repeat_count=1, + def __init__(self, size=1, batch_size=None, repeat_count=1, np_types=None, output_shapes=None, input_indexs=()): self._size = size self._batch_size = batch_size diff --git a/tests/ut/python/parallel/test_loss_scale.py b/tests/ut/python/parallel/test_loss_scale.py index 88997160f8e..9649679474d 100644 --- a/tests/ut/python/parallel/test_loss_scale.py +++ b/tests/ut/python/parallel/test_loss_scale.py @@ -113,7 +113,7 @@ class TrainOneStepWithLossScaleCell(nn.Cell): class DatasetLenet(MindData): def __init__(self, predict, label, length=3): - super(DatasetLenet, self).__init__() + super(DatasetLenet, self).__init__(size=length) self.predict = predict self.label = label self.index = 0