265 lines
8.5 KiB
Python
265 lines
8.5 KiB
Python
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import mindspore.dataset as ds
|
|
from mindspore import log as logger
|
|
|
|
|
|
def gen():
|
|
for i in range(100):
|
|
yield (np.array(i),)
|
|
|
|
|
|
class Augment:
|
|
def __init__(self, loss):
|
|
self.loss = loss
|
|
|
|
def preprocess(self, input_):
|
|
return input_
|
|
|
|
def update(self, data):
|
|
self.loss = data["loss"]
|
|
|
|
|
|
def test_simple_sync_wait():
|
|
"""
|
|
Test simple sync wait: test sync in dataset pipeline
|
|
"""
|
|
logger.info("test_simple_sync_wait")
|
|
batch_size = 4
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size)
|
|
count = 0
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
assert data["input"][0] == count
|
|
count += batch_size
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
def test_simple_shuffle_sync():
|
|
"""
|
|
Test simple shuffle sync: test shuffle before sync
|
|
"""
|
|
logger.info("test_simple_shuffle_sync")
|
|
shuffle_size = 4
|
|
batch_size = 10
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.shuffle(shuffle_size)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
count = 0
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
count += 1
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
def test_two_sync():
|
|
"""
|
|
Test two sync: dataset pipeline with with two sync_operators
|
|
"""
|
|
logger.info("test_two_sync")
|
|
batch_size = 6
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
# notice that with our design, we need to have step_size = shuffle size
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
|
|
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
count = 0
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
count += 1
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="every batch", data=data)
|
|
if count % 2 == 0:
|
|
dataset.sync_update(condition_name="every 2 batches")
|
|
|
|
|
|
def test_sync_epoch():
|
|
"""
|
|
Test sync wait with epochs: test sync with epochs in dataset pipeline
|
|
"""
|
|
logger.info("test_sync_epoch")
|
|
batch_size = 30
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
|
|
for _ in range(3):
|
|
aug.update({"loss": 0})
|
|
count = 0
|
|
for data in dataset.create_dict_iterator(num_epochs=1):
|
|
assert data["input"][0] == count
|
|
count += batch_size
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
def test_multiple_iterators():
|
|
"""
|
|
Test sync wait with multiple iterators: will start multiple
|
|
"""
|
|
logger.info("test_sync_epoch")
|
|
batch_size = 30
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
# 2nd dataset
|
|
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
|
|
|
|
for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1), dataset2.create_dict_iterator(num_epochs=1)):
|
|
assert item1["input"][0] == item2["input"][0]
|
|
data1 = {"loss": item1["input"][0]}
|
|
data2 = {"loss": item2["input"][0]}
|
|
dataset.sync_update(condition_name="policy", data=data1)
|
|
dataset2.sync_update(condition_name="policy", data=data2)
|
|
|
|
|
|
def test_sync_exception_01():
|
|
"""
|
|
Test sync: with shuffle in sync mode
|
|
"""
|
|
logger.info("test_sync_exception_01")
|
|
shuffle_size = 4
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
with pytest.raises(RuntimeError) as e:
|
|
dataset.shuffle(shuffle_size)
|
|
assert "No shuffle after sync operators" in str(e.value)
|
|
|
|
|
|
def test_sync_exception_02():
|
|
"""
|
|
Test sync: with duplicated condition name
|
|
"""
|
|
logger.info("test_sync_exception_02")
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
with pytest.raises(RuntimeError) as e:
|
|
dataset.sync_wait(num_batch=2, condition_name="every batch")
|
|
assert "Condition name is already in use" in str(e.value)
|
|
|
|
|
|
def test_sync_exception_03():
|
|
"""
|
|
Test sync: with wrong batch size
|
|
"""
|
|
logger.info("test_sync_exception_03")
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
# try to create dataset with batch_size < 0
|
|
with pytest.raises(ValueError) as e:
|
|
dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
|
|
assert "num_batch need to be greater than 0." in str(e.value)
|
|
|
|
|
|
def test_sync_exception_04():
|
|
"""
|
|
Test sync: with negative batch size in update
|
|
"""
|
|
logger.info("test_sync_exception_04")
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
# try to create dataset with batch_size < 0
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
count = 0
|
|
with pytest.raises(RuntimeError) as e:
|
|
for _ in dataset.create_dict_iterator(num_epochs=1):
|
|
count += 1
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
|
|
assert "Sync_update batch size can only be positive" in str(e.value)
|
|
|
|
|
|
def test_sync_exception_05():
|
|
"""
|
|
Test sync: with wrong batch size in update
|
|
"""
|
|
logger.info("test_sync_exception_05")
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
count = 0
|
|
aug = Augment(0)
|
|
# try to create dataset with batch_size < 0
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
with pytest.raises(RuntimeError) as e:
|
|
for _ in dataset.create_dict_iterator(num_epochs=1):
|
|
dataset.disable_sync()
|
|
count += 1
|
|
data = {"loss": count}
|
|
dataset.disable_sync()
|
|
dataset.sync_update(condition_name="every", data=data)
|
|
assert "Condition name not found" in str(e.value)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_simple_sync_wait()
|
|
test_simple_shuffle_sync()
|
|
test_two_sync()
|
|
test_sync_exception_01()
|
|
test_sync_exception_02()
|
|
test_sync_exception_03()
|
|
test_sync_exception_04()
|
|
test_sync_exception_05()
|
|
test_sync_epoch()
|
|
test_multiple_iterators()
|