forked from mindspore-Ecosystem/mindspore
213 lines
6.6 KiB
Python
213 lines
6.6 KiB
Python
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import mindspore.dataset as ds
|
|
from mindspore import log as logger
|
|
import time
|
|
import numpy as np
|
|
|
|
|
|
def gen():
|
|
for i in range(100):
|
|
yield np.array(i),
|
|
|
|
|
|
class Augment:
|
|
def __init__(self, loss):
|
|
self.loss = loss
|
|
|
|
def preprocess(self, input):
|
|
return input
|
|
|
|
def update(self, data):
|
|
self.loss = data["loss"]
|
|
|
|
|
|
def test_simple_sync_wait():
|
|
"""
|
|
Test simple sync wait: test sync in dataset pipeline
|
|
"""
|
|
logger.info("test_simple_sync_wait")
|
|
batch_size = 4
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
count = 0
|
|
for data in dataset.create_dict_iterator():
|
|
assert (data["input"][0] == count)
|
|
count += batch_size
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
def test_simple_shuffle_sync():
|
|
"""
|
|
Test simple shuffle sync: test shuffle before sync
|
|
"""
|
|
logger.info("test_simple_shuffle_sync")
|
|
shuffle_size = 4
|
|
batch_size = 10
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.shuffle(shuffle_size)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
count = 0
|
|
for data in dataset.create_dict_iterator():
|
|
count += 1
|
|
# time.sleep(0.5)
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
def test_two_sync():
|
|
"""
|
|
Test two sync: dataset pipeline with with two sync_operators
|
|
"""
|
|
logger.info("test_two_sync")
|
|
batch_size = 6
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
# notice that with our design, we need to have step_size = shuffle size
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
|
|
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
count = 0
|
|
for data in dataset.create_dict_iterator():
|
|
count += 1
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="every batch", data=data)
|
|
if count % 2 == 0:
|
|
dataset.sync_update(condition_name="every 2 batches")
|
|
|
|
|
|
def test_sync_epoch():
|
|
"""
|
|
Test sync wait with epochs: test sync with epochs in dataset pipeline
|
|
"""
|
|
logger.info("test_sync_epoch")
|
|
batch_size = 30
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
|
|
for epochs in range(3):
|
|
aug.update({"loss": 0})
|
|
count = 0
|
|
for data in dataset.create_dict_iterator():
|
|
assert (data["input"][0] == count)
|
|
count += batch_size
|
|
data = {"loss": count}
|
|
dataset.sync_update(condition_name="policy", data=data)
|
|
|
|
|
|
def test_multiple_iterators():
|
|
"""
|
|
Test sync wait with multiple iterators: will start multiple
|
|
"""
|
|
logger.info("test_sync_epoch")
|
|
batch_size = 30
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset = dataset.batch(batch_size, drop_remainder=True)
|
|
# 2nd dataset
|
|
dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
|
|
dataset2 = dataset2.batch(batch_size, drop_remainder=True)
|
|
|
|
for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
|
|
assert (item1["input"][0] == item2["input"][0])
|
|
data1 = {"loss": item1["input"][0]}
|
|
data2 = {"loss": item2["input"][0]}
|
|
dataset.sync_update(condition_name="policy", data=data1)
|
|
dataset2.sync_update(condition_name="policy", data=data2)
|
|
|
|
|
|
def test_sync_exception_01():
|
|
"""
|
|
Test sync: with shuffle in sync mode
|
|
"""
|
|
logger.info("test_sync_exception_01")
|
|
shuffle_size = 4
|
|
batch_size = 10
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
try:
|
|
dataset = dataset.shuffle(shuffle_size)
|
|
except BaseException as e:
|
|
assert "shuffle" in str(e)
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
|
|
def test_sync_exception_02():
|
|
"""
|
|
Test sync: with duplicated condition name
|
|
"""
|
|
logger.info("test_sync_exception_02")
|
|
batch_size = 6
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
aug = Augment(0)
|
|
# notice that with our design, we need to have step_size = shuffle size
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
try:
|
|
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
|
|
except BaseException as e:
|
|
assert "name" in str(e)
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_simple_sync_wait()
|
|
test_simple_shuffle_sync()
|
|
test_two_sync()
|
|
test_sync_exception_01()
|
|
test_sync_exception_02()
|
|
test_sync_epoch()
|
|
test_multiple_iterators()
|