modify endofsequence for multi-machine
This commit is contained in:
parent
32351635ef
commit
a32811e160
|
@ -548,6 +548,8 @@ class Model:
|
|||
>>> model.train(2, dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
if sink_size == -1:
|
||||
sink_size = train_dataset.get_dataset_size()
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore import Tensor
|
|||
class MindData:
|
||||
""" Stub for MindData """
|
||||
|
||||
def __init__(self, size=None, batch_size=None, repeat_count=1,
|
||||
def __init__(self, size=1, batch_size=None, repeat_count=1,
|
||||
np_types=None, output_shapes=None, input_indexs=()):
|
||||
self._size = size
|
||||
self._batch_size = batch_size
|
||||
|
|
|
@ -113,7 +113,7 @@ class TrainOneStepWithLossScaleCell(nn.Cell):
|
|||
|
||||
class DatasetLenet(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(DatasetLenet, self).__init__()
|
||||
super(DatasetLenet, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
|
|
Loading…
Reference in New Issue