Add Cifar op to pull mode. Improve warning message and set seed for debug mode.

This commit is contained in:
ivanshan_8170 2022-12-01 10:15:15 -05:00
parent 946a20478a
commit b0113b1431
16 changed files with 874 additions and 18 deletions

View File

@ -311,6 +311,10 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
Status AddNewWorkers(int32_t num_new_workers) override;
Status RemoveWorkers(int32_t num_workers) override;
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
};
} // namespace dataset
} // namespace mindspore

View File

@ -80,6 +80,11 @@ class ConcatOp : public PipelineOp {
/// \return bool
bool IgnoreSample();
protected:
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
private:
Status Verify(int32_t id, const TensorRow &new_row);

View File

@ -335,6 +335,23 @@ Status DatasetOp::PrepareOperatorPullBased() {
// Generate the column name map for the current op.
RETURN_IF_NOT_OK(this->ComputeColMap());
// check if operators are implemented in pull mode
std::string message = "";
ImplementedPullMode isImplemented = PullModeImplementationStatus();
if (isImplemented == ImplementedPullMode::NotImplemented) {
message = Name() + " is not implemented yet in pull mode.";
if (IsLeaf()) {
message = "Leaf node " + message;
if (GlobalContext::config_manager()->get_debug_mode()) {
RETURN_STATUS_UNEXPECTED(message);
}
}
} else if (isImplemented == ImplementedPullMode::DisabledDebugMode) {
message = "In debug mode, " + Name() + " is disabled for debugging purposes.";
}
if (message.size() > 0) {
MS_LOG(WARNING) << message;
}
return Status::OK();
}

View File

@ -386,6 +386,11 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// Launch the Op
virtual Status Launch() { return Status::OK(); }
enum ImplementedPullMode { NotImplemented = 0, Implemented, DisabledDebugMode };
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
virtual ImplementedPullMode PullModeImplementationStatus() const { return ImplementedPullMode::NotImplemented; }
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler

View File

@ -217,6 +217,10 @@ class MapOp : public ParallelOp<std::unique_ptr<MapWorkerJob>, TensorRow> {
Status AddNewWorkers(int32_t num_new_workers) override;
Status RemoveWorkers(int32_t num_workers) override;
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
private:
Status RebuildMapErrorMsg(const TensorRow &input_row, const std::string &op_name, Status *rc);
};

View File

@ -267,11 +267,6 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
Status ShuffleOp::GetNextRowPullMode(TensorRow *const row) {
RETURN_UNEXPECTED_IF_NULL(row);
RETURN_UNEXPECTED_IF_NULL(child_[0]);
if (GlobalContext::config_manager()->get_debug_mode()) {
MS_LOG(WARNING) << "In debug mode, shuffle operation is disabled for debugging purposes.";
} else {
MS_LOG(WARNING) << "Shuffle operation has not been implemented yet in pull mode.";
}
return child_[0]->GetNextRowPullMode(row);
}
} // namespace dataset

View File

@ -99,6 +99,11 @@ class ShuffleOp : public PipelineOp {
/// \return Status The status code returned
Status GetNextRowPullMode(TensorRow *const row) override;
protected:
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::DisabledDebugMode; }
private:
// Private function to add a new row to the shuffle buffer.
// @return Status The status code returned

View File

@ -58,6 +58,10 @@ class SkipOp : public PipelineOp {
/// \return Status The status code returned
Status GetNextRowPullMode(TensorRow *const row) override;
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
private:
int32_t max_skips_; // The number of skips that the user requested
int32_t skip_count_; // A counter for the current number of executed skips

View File

@ -385,5 +385,11 @@ Status CifarOp::ComputeColMap() {
return Status::OK();
}
Status CifarOp::InitPullMode() {
RETURN_IF_NOT_OK(cifar_raw_data_block_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
"Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this), nullptr, id()));
return PrepareData();
}
} // namespace dataset
} // namespace mindspore

View File

@ -109,6 +109,10 @@ class CifarOp : public MappableLeafOp {
/// @return - Status
Status ComputeColMap() override;
/// Initialize pull mode, calls PrepareData() within
/// @return Status The status code returned
Status InitPullMode() override;
CifarType cifar_type_;
std::string folder_path_;
std::unique_ptr<DataSchema> data_schema_;

View File

@ -111,7 +111,7 @@ class MappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>, p
Status SendWaitFlagToWorker(int32_t worker_id) override;
Status SendQuitFlagToWorker(int32_t worker_id) override;
/// \brief In pull mode, gets the next row
/// \brief In pull mode, gets the next row
/// \param row[out] - Fetched TensorRow
/// \return Status The status code returned
Status GetNextRowPullMode(TensorRow *const row) override;
@ -119,6 +119,10 @@ class MappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>, p
/// Initialize pull mode, calls PrepareData() within
/// @return Status The status code returned
virtual Status InitPullMode() { return PrepareData(); }
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
};
} // namespace dataset
} // namespace mindspore

View File

@ -66,6 +66,10 @@ class TakeOp : public PipelineOp {
/// \return Status The status code returned
Status GetNextRowPullMode(TensorRow *const row) override;
/// \brief Gets the implementation status for operator in pull mode
/// \return implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
private:
int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes

View File

@ -44,6 +44,18 @@ bool DebugModePass::RemoveCacheAndOffload(std::shared_ptr<DatasetNode> node) {
return ret;
}
Status SetSeed() {
// Debug mode requires the deterministic result. Set seed if users have not done so.
uint32_t seed = GlobalContext::config_manager()->seed();
if (seed == std::mt19937::default_seed) {
int8_t kSeedValue = 1;
MS_LOG(WARNING) << "Debug mode is enabled. Set seed to ensure deterministic results. Seed value: "
<< std::to_string(kSeedValue);
GlobalContext::config_manager()->set_seed(kSeedValue);
}
return Status::OK();
}
Status DebugModePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
*modified = RemoveCacheAndOffload(node);
if (node->GetOffload() == ManualOffloadMode::kEnabled) {
@ -52,11 +64,13 @@ Status DebugModePass::Visit(std::shared_ptr<MapNode> node, bool *const modified)
node->SetOffload(ManualOffloadMode::kDisabled);
*modified = true;
}
RETURN_IF_NOT_OK(SetSeed());
return Status::OK();
}
Status DebugModePass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
*modified = RemoveCacheAndOffload(node);
RETURN_IF_NOT_OK(SetSeed());
return Status::OK();
}
} // namespace dataset

View File

@ -95,6 +95,84 @@ TEST_F(MindDataTestPipeline, TestGetNextPullBasedMappableCelebA) {
iter->Stop();
}
/// Feature: PullBasedIterator GetNextRowPullMode
/// Description: Test PullBasedIterator on Cifar10
/// Expectation: Output is the same as the normal iterator
TEST_F(MindDataTestPipeline, TestGetNextPullBasedMappableCifar10) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetNextPullBasedMappableCifar10.";
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, "all", std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<PullIterator> iter = ds->CreatePullBasedIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::vector<mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_EQ(row.size(), 2);
// order: image, label
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row[0];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
auto label = row[1];
MS_LOG(INFO) << "Tensor label shape: " << label.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: PullBasedIterator GetNextRowPullMode
/// Description: Test PullBasedIterator on Cifar100
/// Expectation: Output is the same as the normal iterator
TEST_F(MindDataTestPipeline, TestGetNextPullBasedMappableCifar100) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetNextPullBasedMappableCifar100.";
// Create a Cifar100 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar100Data/";
std::shared_ptr<Dataset> ds = Cifar100(folder_path, "all", std::make_shared<RandomSampler>(false, 10));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<PullIterator> iter = ds->CreatePullBasedIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::vector<mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
EXPECT_EQ(row.size(), 3);
// order: image, coarse_label, fine_label
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row[0];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
auto coarse_label = row[1];
MS_LOG(INFO) << "Tensor coarse_label shape: " << coarse_label.Shape();
auto fine_label = row[2];
MS_LOG(INFO) << "Tensor fine_label shape: " << fine_label.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: PullBasedIterator GetNextRowPullMode
/// Description: Test PullBasedIterator on Cityscapes
/// Expectation: Output is the same as the normal iterator

View File

@ -15,27 +15,33 @@
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
from mindspore.dataset.vision import Inter
from mindspore import log as logger
# Need to run all these tests in separate processes since
# the global configuration setting of debug_mode may impact other tests running in parallel.
pytestmark = pytest.mark.forked
DATA_DIR_10 = "../data/dataset/testCifar10Data"
DEBUG_MODE = False
SEED_VAL = 0 # seed will be set internally in debug mode, save original seed value to restore.
def setup_function():
global DEBUG_MODE
global SEED_VAL
DEBUG_MODE = ds.config.get_debug_mode()
SEED_VAL = ds.config.get_seed()
ds.config.set_debug_mode(True)
def teardown_function():
ds.config.set_debug_mode(DEBUG_MODE)
ds.config.set_seed(SEED_VAL)
@pytest.mark.forked
def test_pipeline_debug_mode_tuple():
"""
Feature: Pipeline debug mode.
@ -60,7 +66,6 @@ def test_pipeline_debug_mode_tuple():
assert num_row == 2
@pytest.mark.forked
def test_pipeline_debug_mode_dict():
"""
Feature: Pipeline debug mode.
@ -85,7 +90,6 @@ def test_pipeline_debug_mode_dict():
assert num_row == 2
@pytest.mark.forked
def test_pipeline_debug_mode_minddata():
"""
Feature: Pipeline debug mode.
@ -104,18 +108,16 @@ def test_pipeline_debug_mode_minddata():
def test_pipeline_debug_mode_not_support():
"""
Feature: Pipeline debug mode.
Description: Test creating tuple iterator with op have not supported in pull mode.
Expectation: Successful with no data generated.
Description: Test creating tuple iterator with op not supported in pull mode.
Expectation: raise exception for debug mode.
"""
logger.info("test_pipeline_debug_mode_not_support")
data = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
num_rows = 0
for _ in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
assert num_rows == 0
with pytest.raises(RuntimeError) as error_info:
data.create_tuple_iterator(num_epochs=1, output_numpy=True)
assert "dataset pipeline" in str(error_info.value)
@pytest.mark.forked
def test_pipeline_debug_mode_map_pyfunc():
"""
Feature: Pipeline debug mode.
@ -133,7 +135,6 @@ def test_pipeline_debug_mode_map_pyfunc():
assert num_rows == 4
@pytest.mark.forked
def test_pipeline_debug_mode_batch_pyfunc():
"""
Feature: Pipeline debug mode.
@ -141,6 +142,7 @@ def test_pipeline_debug_mode_batch_pyfunc():
Expectation: Successful.
"""
logger.info("test_pipeline_debug_mode_batch_pyfunc")
def add_one(batch_info):
return batch_info.get_batch_num() + 1
@ -153,7 +155,6 @@ def test_pipeline_debug_mode_batch_pyfunc():
assert num_rows == 5
@pytest.mark.forked
def test_pipeline_debug_mode_concat():
"""
Feature: Pipeline debug mode.
@ -175,7 +176,77 @@ def test_pipeline_debug_mode_concat():
assert num_rows == 12
def test_pipeline_debug_mode_map_random():
"""
Feature: Pipeline debug mode.
Description: Test creating dict iterator with map with random augmentation operations.
Expectation: Successful.
"""
logger.info("test_pipeline_debug_mode_map_random")
# the explicit intent of this test to not set the seed and allow debug mode support to set it
# (if the default seed is used)
data = ds.CelebADataset("../data/dataset/testCelebAData/", decode=True, num_shards=1, shard_id=0)
transforms_list = [vision.CenterCrop(64), vision.RandomRotation(30)]
random_apply = transforms.RandomApply(transforms_list, prob=0.6)
data = data.map(operations=[random_apply], input_columns=["image"])
expected_shape = [(2268, 4032, 3), (2268, 4032, 3), (64, 64, 3), (2268, 4032, 3)]
index = 0
for item in data.create_dict_iterator(num_epochs=1):
assert len(item) == 2
assert item["image"].shape == expected_shape[index]
index += 1
assert index == 4
def test_pipeline_debug_mode_shuffle():
"""
Feature: Pipeline debug mode.
Description: Test creating dict iterator with Shuffle.
Expectation: Shuffle is disabled, but has the same number of rows as not in debug mode.
"""
logger.info("test_pipeline_debug_mode_shuffle")
buffer_size = 5
data = ds.MnistDataset("../data/dataset/testMnistData", num_samples=20)
data = data.shuffle(buffer_size=buffer_size)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
assert num_rows == 20
def test_pipeline_debug_mode_imdb_shuffle():
"""
Feature: Pipeline debug mode.
Description: Justify shuffle is disabled with IMDBDataset
Expectation: The data is processed successfully in the same order.
"""
logger.info("test_pipeline_debug_mode_imdb_shuffle")
buffer_size = 5
# apply dataset operations
data1 = ds.IMDBDataset("../data/dataset/testIMDBDataset", shuffle=True)
data1 = data1.shuffle(buffer_size=buffer_size)
# Verify dataset size
data1_size = data1.get_dataset_size()
logger.info("dataset size is: {}".format(data1_size))
assert data1_size == 8
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# in this example, each dictionary has keys "text" and "label"
logger.info("text is {}".format(item["text"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 8
if __name__ == '__main__':
setup_function()
test_pipeline_debug_mode_tuple()
test_pipeline_debug_mode_dict()
test_pipeline_debug_mode_minddata()
@ -183,3 +254,7 @@ if __name__ == '__main__':
test_pipeline_debug_mode_map_pyfunc()
test_pipeline_debug_mode_batch_pyfunc()
test_pipeline_debug_mode_concat()
test_pipeline_debug_mode_shuffle()
test_pipeline_debug_mode_map_random()
test_pipeline_debug_mode_imdb_shuffle()
teardown_function()

View File

@ -0,0 +1,632 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Cifar10 and Cifar100 dataset operations in debug mode
"""
import os
import pytest
import numpy as np
import matplotlib.pyplot as plt
import mindspore.dataset as ds
from mindspore import log as logger
pytestmark = pytest.mark.forked
DATA_DIR_10 = "../data/dataset/testCifar10Data"
DATA_DIR_100 = "../data/dataset/testCifar100Data"
NO_BIN_DIR = "../data/dataset/testMnistData"
DEBUG_MODE = False
SEED_VAL = 0 # seed will be set internally in debug mode, save original seed value to restore.
def setup_function():
global DEBUG_MODE
global SEED_VAL
DEBUG_MODE = ds.config.get_debug_mode()
SEED_VAL = ds.config.get_seed()
ds.config.set_debug_mode(True)
def teardown_function():
ds.config.set_debug_mode(DEBUG_MODE)
ds.config.set_seed(SEED_VAL)
def load_cifar(path, kind="cifar10"):
"""
load Cifar10/100 data
"""
raw = np.empty(0, dtype=np.uint8)
for file_name in os.listdir(path):
if file_name.endswith(".bin"):
with open(os.path.join(path, file_name), mode='rb') as file:
raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0)
if kind == "cifar10":
raw = raw.reshape(-1, 3073)
labels = raw[:, 0]
images = raw[:, 1:]
elif kind == "cifar100":
raw = raw.reshape(-1, 3074)
labels = raw[:, :2]
images = raw[:, 2:]
else:
raise ValueError("Invalid parameter value")
images = images.reshape(-1, 3, 32, 32)
images = images.transpose(0, 2, 3, 1)
return images, labels
def visualize_dataset(images, labels):
"""
Helper function to visualize the dataset samples
"""
num_samples = len(images)
for i in range(num_samples):
plt.subplot(1, num_samples, i + 1)
plt.imshow(images[i])
plt.title(labels[i])
plt.show()
### Testcases for Cifar10Dataset Op ###
def test_cifar10_content_check():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar10Dataset with content check on image readings in pull mode
Expectation: The dataset is processed as expected
"""
logger.info("Test debug mode Cifar10Dataset Op with content check")
data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100, shuffle=False)
images, labels = load_cifar(DATA_DIR_10)
num_iter = 0
# in this example, each dictionary has keys "image" and "label"
for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(d["image"], images[i])
np.testing.assert_array_equal(d["label"], labels[i])
num_iter += 1
assert num_iter == 100
def test_cifar10_basic():
"""
Feature: Cifar10Dataset
Description: Test Cifar10Dataset with some basic arguments and methods
Expectation: The dataset is processed as expected
"""
logger.info("Test Cifar10Dataset Op")
# case 0: test loading the whole dataset
data0 = ds.Cifar10Dataset(DATA_DIR_10)
num_iter0 = 0
for _ in data0.create_dict_iterator(num_epochs=1):
num_iter0 += 1
assert num_iter0 == 10000
# case 1: test num_samples
data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter1 += 1
assert num_iter1 == 100
# case 2: test batch with drop_remainder=False
data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
assert data2.get_dataset_size() == 100
assert data2.get_batch_size() == 1
data2 = data2.batch(batch_size=7) # drop_remainder is default to be False
assert data2.get_dataset_size() == 15
assert data2.get_batch_size() == 7
num_iter2 = 0
for _ in data2.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 15
# case 5: test batch with drop_remainder=True
data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
assert data3.get_dataset_size() == 100
assert data3.get_batch_size() == 1
data3 = data3.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
assert data3.get_dataset_size() == 14
assert data3.get_batch_size() == 7
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 14
def test_cifar10_pk_sampler():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar10Dataset with PKSampler in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test debug mode Cifar10Dataset Op with PKSampler")
golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
sampler = ds.PKSampler(3)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
label_list = []
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
label_list.append(item["label"])
num_iter += 1
np.testing.assert_array_equal(golden, label_list)
assert num_iter == 30
def test_cifar10_sequential_sampler():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar10Dataset with SequentialSampler in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test debug mode Cifar10Dataset Op with SequentialSampler")
num_samples = 30
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
data2 = ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_samples=num_samples)
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_equal(item1["label"], item2["label"])
num_iter += 1
assert num_iter == num_samples
def test_cifar10_exception():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test error cases Cifar10Dataset in debug mode
Expectation: Throw correct error as expected
"""
logger.info("Test error cases for Cifar10Dataset in debug mode")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.Cifar10Dataset(DATA_DIR_10, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.Cifar10Dataset(DATA_DIR_10, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.Cifar10Dataset(DATA_DIR_10, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=256)
error_msg_7 = r"cifar\(.bin\) files are missing"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar10_visualize(plot=False):
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar10Dataset visualization results in debug mode
Expectation: Results are presented as expected in debug mode
"""
logger.info("Test debug mode Cifar10Dataset visualization")
data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=10, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (32, 32, 3)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
num_iter += 1
assert num_iter == 10
if plot:
visualize_dataset(image_list, label_list)
### Testcases for Cifar100Dataset Op ###
def test_cifar100_content_check():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar100Dataset image readings with content check in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test debug mode Cifar100Dataset with content check")
data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, shuffle=False)
images, labels = load_cifar(DATA_DIR_100, kind="cifar100")
num_iter = 0
# in this example, each dictionary has keys "image", "coarse_label" and "fine_image"
for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(d["image"], images[i])
np.testing.assert_array_equal(d["coarse_label"], labels[i][0])
np.testing.assert_array_equal(d["fine_label"], labels[i][1])
num_iter += 1
assert num_iter == 100
def test_cifar100_basic():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar100Dataset basic arguments and features in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test Cifar100Dataset basic in debug mode")
# case 1: test num_samples
data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter1 += 1
assert num_iter1 == 100
# case 2: test batch with drop_remainder=True
data2 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
data2 = data2.batch(batch_size=3, drop_remainder=True)
assert data2.get_dataset_size() == 33
assert data2.get_batch_size() == 3
num_iter2 = 0
for _ in data2.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 33
# case 3: test batch with drop_remainder=False
data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
assert data3.get_dataset_size() == 100
assert data3.get_batch_size() == 1
data3 = data3.batch(batch_size=3)
assert data3.get_dataset_size() == 34
assert data3.get_batch_size() == 3
num_iter3 = 0
for _ in data3.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 34
def test_cifar100_pk_sampler():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar100Dataset with PKSampler in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test Cifar100Dataset with PKSampler in deubg mode")
golden = [i for i in range(20)]
sampler = ds.PKSampler(1)
data = ds.Cifar100Dataset(DATA_DIR_100, sampler=sampler)
num_iter = 0
label_list = []
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
label_list.append(item["coarse_label"])
num_iter += 1
np.testing.assert_array_equal(golden, label_list)
assert num_iter == 20
def test_cifar100_exception():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test error cases for Cifar100Dataset in debug mode
Expectation: Throw correct error as expected
"""
logger.info("Test error cases for Cifar100Dataset in debug mode")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.Cifar100Dataset(DATA_DIR_100, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.Cifar100Dataset(DATA_DIR_100, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.Cifar100Dataset(DATA_DIR_100, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.Cifar100Dataset(DATA_DIR_100, num_shards=2, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.Cifar10Dataset(DATA_DIR_100, num_shards=2, shard_id=5)
error_msg_7 = r"cifar\(.bin\) files are missing"
with pytest.raises(RuntimeError, match=error_msg_7):
ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
for _ in ds1.__iter__():
pass
def test_cifar100_visualize(plot=False):
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar100Dataset visualization results in debug mode
Expectation: Results are presented as expected
"""
logger.info("Test Cifar100Dataset visualization in debug mode")
data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=10, shuffle=False)
num_iter = 0
image_list, label_list = [], []
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
coarse_label = item["coarse_label"]
fine_label = item["fine_label"]
image_list.append(image)
label_list.append("coarse_label {}\nfine_label {}".format(coarse_label, fine_label))
assert isinstance(image, np.ndarray)
assert image.shape == (32, 32, 3)
assert image.dtype == np.uint8
assert coarse_label.dtype == np.uint32
assert fine_label.dtype == np.uint32
num_iter += 1
assert num_iter == 10
if plot:
visualize_dataset(image_list, label_list)
def test_cifar_usage():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar100Dataset usage flag in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test Cifar100Dataset usage flag in defbug mode")
# flag, if True, test cifar10 else test cifar100
def test_config(usage, flag=True, cifar_path=None):
if cifar_path is None:
cifar_path = DATA_DIR_10 if flag else DATA_DIR_100
try:
data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
# test the usage of CIFAR100
assert test_config("train") == 10000
assert test_config("all") == 10000
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
assert "Cifar10Dataset API can't read the data file (interface mismatch or no data found)" in test_config("test")
# test the usage of CIFAR10
assert test_config("test", False) == 10000
assert test_config("all", False) == 10000
assert "Cifar100Dataset API can't read the data file" in test_config("train", False)
assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False)
# change this directory to the folder that contains all cifar10 files
all_cifar10 = None
if all_cifar10 is not None:
assert test_config("train", True, all_cifar10) == 50000
assert test_config("test", True, all_cifar10) == 10000
assert test_config("all", True, all_cifar10) == 60000
assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000
assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000
assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000
# change this directory to the folder that contains all cifar100 files
all_cifar100 = None
if all_cifar100 is not None:
assert test_config("train", False, all_cifar100) == 50000
assert test_config("test", False, all_cifar100) == 10000
assert test_config("all", False, all_cifar100) == 60000
assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000
assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000
assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000
def test_cifar_exception_file_path():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar10Dataset and Cifar100Dataset with invalid file path in debug mode
Expectation: Error is raised as expected
"""
def exception_func(item):
raise Exception("Error occur!")
with pytest.raises(RuntimeError) as error_info:
data = ds.Cifar10Dataset(DATA_DIR_10)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1):
num_rows += 1
assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
data = ds.Cifar10Dataset(DATA_DIR_10)
data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1):
num_rows += 1
assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
data = ds.Cifar100Dataset(DATA_DIR_100)
data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1):
num_rows += 1
assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
data = ds.Cifar100Dataset(DATA_DIR_100)
data = data.map(operations=exception_func, input_columns=["coarse_label"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1):
num_rows += 1
assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value)
with pytest.raises(RuntimeError) as error_info:
data = ds.Cifar100Dataset(DATA_DIR_100)
data = data.map(operations=exception_func, input_columns=["fine_label"], num_parallel_workers=1)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1):
num_rows += 1
assert False
assert "map operation: [PyFunc] failed. The corresponding data file is" in str(error_info.value)
def test_cifar10_pk_sampler_get_dataset_size():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar10Dataset get_dataset_size in debug mode
Expectation: The dataset is processed as expected
"""
sampler = ds.PKSampler(3)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
ds_sz = data.get_dataset_size()
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert ds_sz == num_iter == 30
def test_cifar10_with_chained_sampler_get_dataset_size():
"""
Feature: Cifar10Dataset
Description: Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
Expectation: The dataset is processed as expected
"""
sampler = ds.SequentialSampler(start_index=0, num_samples=5)
child_sampler = ds.PKSampler(4)
sampler.add_child(child_sampler)
data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
num_iter = 0
ds_sz = data.get_dataset_size()
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert ds_sz == num_iter == 5
def test_cifar100ops():
"""
Feature: Pipeline debug mode with Cifar10Dataset
Description: Test Cifar100Dataset with take and skip operations in debug mode
Expectation: The dataset is processed as expected
"""
logger.info("Test Cifar100Dataset operations in debug mode")
# case 1: test num_samples
data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter1 = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter1 += 1
assert num_iter1 == 100
# take 30
num_iter2 = 0
data2 = data1.take(30)
for _ in data2.create_dict_iterator(num_epochs=1):
num_iter2 += 1
assert num_iter2 == 30
# take default 0
data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter3 = 0
data3 = data3.take()
for _ in data3.create_dict_iterator(num_epochs=1):
num_iter3 += 1
assert num_iter3 == 100
# take more than dataset size
data4 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter4 = 0
data4 = data4.take(1000)
for _ in data4.create_dict_iterator(num_epochs=1):
num_iter4 += 1
assert num_iter4 == 100
# take -5
data5 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter4 = 0
with pytest.raises(ValueError) as error_info:
data5 = data5.take(-5)
for _ in data4.create_dict_iterator(num_epochs=1):
pass
assert "count should be either -1 or within the required interval" in str(error_info.value)
# skip 0
data6 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter6 = 0
data6 = data6.skip(0)
for _ in data6.create_dict_iterator(num_epochs=1):
num_iter6 += 1
assert num_iter6 == 100
# skip more than dataset size
data7 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter7 = 0
data7 = data7.skip(1000)
for _ in data7.create_dict_iterator(num_epochs=1):
num_iter7 += 1
assert num_iter7 == 0
# skip -5
data5 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter4 = 0
with pytest.raises(ValueError) as error_info:
data5 = data5.skip(-5)
for _ in data4.create_dict_iterator(num_epochs=1):
pass
assert "Input count is not within the required interval of" in str(error_info.value)
if __name__ == '__main__':
setup_function()
test_cifar10_content_check()
test_cifar10_basic()
test_cifar10_pk_sampler()
test_cifar10_sequential_sampler()
test_cifar10_exception()
test_cifar10_visualize(plot=False)
test_cifar100_content_check()
test_cifar100_basic()
test_cifar100_pk_sampler()
test_cifar100_exception()
test_cifar100_visualize(plot=False)
test_cifar_usage()
test_cifar_exception_file_path()
test_cifar10_with_chained_sampler_get_dataset_size()
test_cifar10_pk_sampler_get_dataset_size()
test_cifar100ops()
teardown_function()