mindspore/tests/ut/python/dataset/test_sync_wait.py

265 lines
8.5 KiB
Python
Raw Normal View History

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2020-05-18 16:42:35 +08:00
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
def gen():
for i in range(100):
2020-05-22 14:16:07 +08:00
yield (np.array(i),)
class Augment:
def __init__(self, loss):
self.loss = loss
2020-05-18 10:31:46 +08:00
2020-05-26 16:17:53 +08:00
def preprocess(self, input_):
return input_
2020-05-18 10:31:46 +08:00
def update(self, data):
self.loss = data["loss"]
def test_simple_sync_wait():
"""
2020-05-22 14:16:07 +08:00
Test simple sync wait: test sync in dataset pipeline
"""
logger.info("test_simple_sync_wait")
batch_size = 4
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator(num_epochs=1):
2020-05-22 14:16:07 +08:00
assert data["input"][0] == count
count += batch_size
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_simple_shuffle_sync():
"""
2020-05-22 14:16:07 +08:00
Test simple shuffle sync: test shuffle before sync
"""
logger.info("test_simple_shuffle_sync")
shuffle_size = 4
batch_size = 10
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.shuffle(shuffle_size)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator(num_epochs=1):
count += 1
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_two_sync():
"""
2020-05-22 14:16:07 +08:00
Test two sync: dataset pipeline with with two sync_operators
"""
logger.info("test_two_sync")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# notice that with our design, we need to have step_size = shuffle size
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
2020-05-18 10:31:46 +08:00
dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
2020-05-18 10:31:46 +08:00
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator(num_epochs=1):
count += 1
data = {"loss": count}
dataset.sync_update(condition_name="every batch", data=data)
if count % 2 == 0:
dataset.sync_update(condition_name="every 2 batches")
def test_sync_epoch():
"""
2020-05-22 14:16:07 +08:00
Test sync wait with epochs: test sync with epochs in dataset pipeline
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
2020-05-22 14:16:07 +08:00
for _ in range(3):
aug.update({"loss": 0})
count = 0
for data in dataset.create_dict_iterator(num_epochs=1):
2020-05-22 14:16:07 +08:00
assert data["input"][0] == count
count += batch_size
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
2020-05-18 10:31:46 +08:00
def test_multiple_iterators():
"""
2020-05-22 14:16:07 +08:00
Test sync wait with multiple iterators: will start multiple
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
2020-05-26 16:17:53 +08:00
# 2nd dataset
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1), dataset2.create_dict_iterator(num_epochs=1)):
2020-05-22 14:16:07 +08:00
assert item1["input"][0] == item2["input"][0]
data1 = {"loss": item1["input"][0]}
data2 = {"loss": item2["input"][0]}
dataset.sync_update(condition_name="policy", data=data1)
dataset2.sync_update(condition_name="policy", data=data2)
def test_sync_exception_01():
"""
2020-05-22 14:16:07 +08:00
Test sync: with shuffle in sync mode
"""
logger.info("test_sync_exception_01")
shuffle_size = 4
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
with pytest.raises(RuntimeError) as e:
dataset.shuffle(shuffle_size)
assert "No shuffle after sync operators" in str(e.value)
def test_sync_exception_02():
"""
2020-05-22 14:16:07 +08:00
Test sync: with duplicated condition name
"""
logger.info("test_sync_exception_02")
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
2020-05-18 10:31:46 +08:00
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
2020-05-18 10:31:46 +08:00
with pytest.raises(RuntimeError) as e:
dataset.sync_wait(num_batch=2, condition_name="every batch")
assert "Condition name is already in use" in str(e.value)
2020-05-18 10:31:46 +08:00
def test_sync_exception_03():
"""
Test sync: with wrong batch size
"""
logger.info("test_sync_exception_03")
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# try to create dataset with batch_size < 0
with pytest.raises(ValueError) as e:
dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
assert "num_batch need to be greater than 0." in str(e.value)
def test_sync_exception_04():
"""
Test sync: with negative batch size in update
"""
logger.info("test_sync_exception_04")
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# try to create dataset with batch_size < 0
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
count = 0
with pytest.raises(RuntimeError) as e:
for _ in dataset.create_dict_iterator(num_epochs=1):
count += 1
data = {"loss": count}
dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
assert "Sync_update batch size can only be positive" in str(e.value)
def test_sync_exception_05():
"""
Test sync: with wrong batch size in update
"""
logger.info("test_sync_exception_05")
dataset = ds.GeneratorDataset(gen, column_names=["input"])
count = 0
aug = Augment(0)
# try to create dataset with batch_size < 0
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
with pytest.raises(RuntimeError) as e:
for _ in dataset.create_dict_iterator(num_epochs=1):
dataset.disable_sync()
count += 1
data = {"loss": count}
dataset.disable_sync()
dataset.sync_update(condition_name="every", data=data)
assert "Condition name not found" in str(e.value)
if __name__ == "__main__":
test_simple_sync_wait()
test_simple_shuffle_sync()
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_exception_03()
test_sync_exception_04()
test_sync_exception_05()
test_sync_epoch()
test_multiple_iterators()