modify endofsequence for multi-machine

This commit is contained in:
wuweikang 2020-09-29 15:16:52 +08:00
parent 32351635ef
commit a32811e160
3 changed files with 4 additions and 2 deletions

View File

@ -548,6 +548,8 @@ class Model:
>>> model.train(2, dataset) >>> model.train(2, dataset)
""" """
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
if sink_size == -1:
sink_size = train_dataset.get_dataset_size()
check_int(sink_size) check_int(sink_size)
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

View File

@ -21,7 +21,7 @@ from mindspore import Tensor
class MindData: class MindData:
""" Stub for MindData """ """ Stub for MindData """
def __init__(self, size=None, batch_size=None, repeat_count=1, def __init__(self, size=1, batch_size=None, repeat_count=1,
np_types=None, output_shapes=None, input_indexs=()): np_types=None, output_shapes=None, input_indexs=()):
self._size = size self._size = size
self._batch_size = batch_size self._batch_size = batch_size

View File

@ -113,7 +113,7 @@ class TrainOneStepWithLossScaleCell(nn.Cell):
class DatasetLenet(MindData): class DatasetLenet(MindData):
def __init__(self, predict, label, length=3): def __init__(self, predict, label, length=3):
super(DatasetLenet, self).__init__() super(DatasetLenet, self).__init__(size=length)
self.predict = predict self.predict = predict
self.label = label self.label = label
self.index = 0 self.index = 0