!1703 Adding sync_wait input check

Merge pull request !1703 from EricZ/master
This commit is contained in:
mindspore-ci-bot 2020-05-30 05:29:38 +08:00 committed by Gitee
commit 0e3dd8149a
2 changed files with 107 additions and 8 deletions

View File

@ -1055,6 +1055,11 @@ class Dataset:
return self.input[0].get_sync_notifiers()
return {}
def disable_sync(self):
if self.input:
return self.input[0].disable_sync()
return {}
def is_sync(self):
if self.input:
return self.input[0].is_sync()
@ -1062,16 +1067,23 @@ class Dataset:
def sync_update(self, condition_name, num_batch=None, data=None):
"""
Release a blocking condition and triger callback with given data.
Release a blocking condition and trigger callback with given data.
Args:
condition_name (str): The condition name that is used to toggle sending next row.
num_batch (int or None): The number of batches(rows) that are released.
When num_batch is None, it will default to the number specified by the sync_wait operator.
data (dict or None): The data passed to the callback.
When num_batch is None, it will default to the number specified by the
sync_wait operator (default=None).
data (dict or None): The data passed to the callback (default=None).
"""
if isinstance(num_batch, int) and num_batch <= 0:
# throwing exception, disable all sync_wait in pipeline
self.disable_sync()
raise RuntimeError("Sync_update batch size can only be positive, got : {}".format(num_batch))
notifiers_dict = self.get_sync_notifiers()
if condition_name not in notifiers_dict:
# throwing exception, disable all sync_wait in pipeline
self.disable_sync()
raise RuntimeError("Condition name not found")
if num_batch is not None:
num_batch *= self.get_batch_size()
@ -1439,7 +1451,6 @@ class BatchDataset(DatasetOp):
for input_dataset in dataset.input:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
class BatchInfo(CBatchInfo):
"""
The information object associates with the current batch of tensors.
@ -1472,10 +1483,13 @@ class BlockReleasePair:
callback (function): The callback funciton that will be called when release is called.
"""
def __init__(self, init_release_rows, callback=None):
if isinstance(init_release_rows, int) and init_release_rows <= 0:
raise ValueError("release_rows need to be greater than 0.")
self.row_count = -init_release_rows
self.cv = threading.Condition()
self.callback = callback
self.default_rows = init_release_rows
self.disable = False
def __deepcopy__(self, memodict):
if id(self) in memodict:
@ -1491,13 +1505,18 @@ class BlockReleasePair:
self.cv.notify_all()
def update_batched_size(self, batch_size):
# sanity check
if isinstance(batch_size, int) and batch_size <= 0:
raise ValueError("batch_size need to be greater than 0.")
# should only use before the pipeline creates
self.row_count *= batch_size
self.default_rows *= batch_size
def block_func(self):
with self.cv:
self.cv.wait_for(lambda: self.row_count < 0)
# if disable is true, the always evaluate to true
self.cv.wait_for(lambda: (self.row_count < 0 or self.disable))
self.row_count += 1
return True
@ -1510,6 +1529,12 @@ class BlockReleasePair:
self.callback(data)
self.cv.notify_all()
def disable_lock(self):
with self.cv:
self.disable = True
self.cv.notify_all()
class SyncWaitDataset(DatasetOp):
"""
The result of adding a blocking condition to the input Dataset.
@ -1530,6 +1555,9 @@ class SyncWaitDataset(DatasetOp):
input_dataset.output.append(self)
# set to the default value, waiting for the batch to update it
self._condition_name = condition_name
if isinstance(num_batch, int) and num_batch <= 0:
raise ValueError("num_batch need to be greater than 0.")
self._pair = BlockReleasePair(num_batch, callback)
if self._condition_name in self.input[0].get_sync_notifiers():
raise RuntimeError("Condition name is already in use")
@ -1549,8 +1577,14 @@ class SyncWaitDataset(DatasetOp):
return args
def update_sync_batch_size(self, batch_size):
if isinstance(batch_size, int) and batch_size <= 0:
raise ValueError("num_batch need to be greater than 0.")
self._pair.update_batched_size(batch_size)
def disable_sync(self):
logger.info("Disabling Sync")
self._pair.disable_lock()
@staticmethod
def _is_ancestor_of_batch(dataset):
"""

View File

@ -47,7 +47,6 @@ def test_simple_sync_wait():
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
assert data["input"][0] == count
@ -75,7 +74,6 @@ def test_simple_shuffle_sync():
count = 0
for data in dataset.create_dict_iterator():
count += 1
# time.sleep(0.5)
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
@ -190,7 +188,6 @@ def test_sync_exception_02():
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# notice that with our design, we need to have step_size = shuffle size
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
@ -202,11 +199,79 @@ def test_sync_exception_02():
dataset = dataset.batch(batch_size)
def test_sync_exception_03():
"""
Test sync: with wrong batch size
"""
logger.info("test_sync_exception_03")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# try to create dataset with batch_size < 0
try:
dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
except Exception as e:
assert "num_batch" in str(e)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
def test_sync_exception_04():
"""
Test sync: with negative batch size in update
"""
logger.info("test_sync_exception_04")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# try to create dataset with batch_size < 0
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
count = 0
try:
for item in dataset.create_dict_iterator():
count += 1
data = {"loss": count}
# dataset.disable_sync()
dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
except Exception as e:
assert "batch" in str(e)
def test_sync_exception_05():
"""
Test sync: with wrong batch size in update
"""
logger.info("test_sync_exception_05")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
count = 0
aug = Augment(0)
# try to create dataset with batch_size < 0
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
try:
for item in dataset.create_dict_iterator():
dataset.disable_sync()
count += 1
data = {"loss": count}
dataset.disable_sync()
dataset.sync_update(condition_name="every", data=data)
except Exception as e:
assert "name" in str(e)
if __name__ == "__main__":
test_simple_sync_wait()
test_simple_shuffle_sync()
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_exception_03()
test_sync_exception_04()
test_sync_exception_05()
test_sync_epoch()
test_multiple_iterators()