!47950 [MD] Debug Mode - rename op and zip op support

Merge pull request !47950 from cathwong/ckw_debug_mode1
This commit is contained in:
i-robot 2023-01-18 18:28:20 +00:00 committed by Gitee
commit e0c11be277
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 366 additions and 39 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -296,7 +296,8 @@ class ConfigManager {
bool fast_recovery() const { return fast_recovery_; }
// setter function
// @param debug_mode_flag - Indicate whether the debug mode is on
// @param debug_mode_flag - Set whether debug mode is on. When enabled, the dataset pipeline runs synchronously and
// sequentially.
void set_debug_mode(const bool debug_mode_flag) { debug_mode_flag_ = debug_mode_flag; }
// getter function

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,7 +31,7 @@ RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vect
// destructor
RenameOp::~RenameOp() {}
// Gets a row from the child operator and projects the row.
// Gets a row from the child operator
Status RenameOp::GetNextRow(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row);
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row));
@ -41,6 +41,16 @@ Status RenameOp::GetNextRow(TensorRow *row) {
return Status::OK();
}
// For Pullmode, gets a row from the child operator
Status RenameOp::GetNextRowPullMode(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row);
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row));
if (row->eoe()) {
UpdateRepeatAndEpochCounter();
}
return Status::OK();
}
Status RenameOp::operator()() { RETURN_STATUS_UNEXPECTED("[Internal ERROR] RenameOp is an inlined operator."); }
// Rename core functionality to compute the new column name id map.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -59,11 +59,21 @@ class RenameOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kRenameOp; }
// Gets a row from the child node and projects that row. The caller is typically our parent node.
// @param row - output pointer to the projected row.
// @param worker_id - The worker id
/// \brief Gets the next row
/// \param row[out] - Fetched TensorRow
/// \return Status The status code returned
Status GetNextRow(TensorRow *row) override;
/// \brief In pull mode, gets the next row
/// \param row[out] - Fetched TensorRow
/// \return Status The status code returned
Status GetNextRowPullMode(TensorRow *const row) override;
protected:
/// \brief Gets the implementation status for operator in pull mode
/// \return Implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
protected:
// Rename core functionality
// Computing the assignment of the new column name map.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,12 +28,16 @@ ZipOp::ZipOp() : PipelineOp(0) {}
ZipOp::~ZipOp() {}
// fetches next zipped (merged) row
Status ZipOp::getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child) const {
Status ZipOp::getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, bool is_pull_mode) const {
*new_zip_row = {};
// iterate over all iterators and generate a row
for (size_t i = 0; i < child_.size(); ++i) {
TensorRow new_row;
RETURN_IF_NOT_OK(child_[i]->GetNextRow(&new_row));
if (!is_pull_mode) {
RETURN_IF_NOT_OK(child_[i]->GetNextRow(&new_row));
} else {
RETURN_IF_NOT_OK(child_[i]->GetNextRowPullMode(&new_row));
}
if (new_row.eoe() || new_row.eof()) {
*new_zip_row = new_row;
*skip_child = static_cast<int32_t>(i);
@ -47,7 +51,7 @@ Status ZipOp::getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child
}
// drain end of epoch messages from iterator for this epoch
Status ZipOp::drainPipeline(int32_t skip_child) const {
Status ZipOp::drainPipeline(int32_t skip_child, bool is_pull_mode) const {
for (size_t con = 0; con < child_.size(); ++con) {
if (con == skip_child) {
continue;
@ -55,7 +59,11 @@ Status ZipOp::drainPipeline(int32_t skip_child) const {
MS_LOG(DEBUG) << "Zip operator draining child at " << con << ".";
TensorRow row;
while (!row.eoe()) {
RETURN_IF_NOT_OK(child_[con]->GetNextRow(&row));
if (!is_pull_mode) {
RETURN_IF_NOT_OK(child_[con]->GetNextRow(&row));
} else {
RETURN_IF_NOT_OK(child_[con]->GetNextRowPullMode(&row));
}
}
}
// at this point all connectors don't contain end of epoch messages. next iteration should be clean
@ -120,11 +128,23 @@ Status ZipOp::operator()() { RETURN_STATUS_UNEXPECTED("[Internal ERROR] ZipOp is
Status ZipOp::GetNextRow(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row);
int32_t skip_child = -1;
RETURN_IF_NOT_OK(getNextZippedRow(row, &skip_child));
RETURN_IF_NOT_OK(getNextZippedRow(row, &skip_child, false));
if (row->eoe()) {
UpdateRepeatAndEpochCounter();
MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
RETURN_IF_NOT_OK(drainPipeline(skip_child));
RETURN_IF_NOT_OK(drainPipeline(skip_child, false));
}
return Status::OK();
}
Status ZipOp::GetNextRowPullMode(TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row);
int32_t skip_child = -1;
RETURN_IF_NOT_OK(getNextZippedRow(row, &skip_child, true));
if (row->eoe()) {
UpdateRepeatAndEpochCounter();
MS_LOG(DEBUG) << "Zip operator in pull mode is now draining child inputs.";
RETURN_IF_NOT_OK(drainPipeline(skip_child, true));
}
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -64,18 +64,34 @@ class ZipOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return kZipOp; }
/// \brief Gets the next row
/// \param row[out] - Fetched TensorRow
/// \return Status The status code returned
Status GetNextRow(TensorRow *row) override;
/// \brief In pull mode, gets the next row
/// \param row[out] - Fetched TensorRow
/// \return Status The status code returned
Status GetNextRowPullMode(TensorRow *const row) override;
protected:
/// \brief Gets the implementation status for operator in pull mode
/// \return Implementation status
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
private:
// Special handle case where an empty row has been received from child iterator
// @note - we need to drain eoe signals from all children connectors.
// @details - when this function is called, then we encountered eoe at child iterator
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
Status drainPipeline(int32_t skip_child) const;
/// \brief Drain eoe signals from all children connectors.
/// \notes Handle special handle case where an empty row has been received from child iterator.
/// When this function is called and encounters eoe at child iterator,
/// we need to drain rows from other child iterators until we hit eoe from all other child iterators.
/// \param[in] skip_child - identifier for child to be skipped
/// \param[in] is_pull_mode - an indicator to identify if in pull mode or not
Status drainPipeline(int32_t skip_child, bool is_pull_mode) const;
// Merges 1 row from each childIterator together
// \param[in] new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
// \param[in] updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true
// \param[in] skip_child - input and output, identifier for child to be skipped
// \param[in] is_pull_mode - an indicator to identify if in pull mode or not
// @details merge rows from iterator together. This is the main functionality for ZipOp
// this function takes one row and fills it with tensors from rows fetched
// from childIterators.
@ -84,7 +100,7 @@ class ZipOp : public PipelineOp {
// 1 a T
// \ | /
// 1, a, T
Status getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child) const;
Status getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, bool is_pull_mode) const;
// Computing the assignment of the column name map.
// @return - Status

View File

@ -842,8 +842,10 @@ def get_fast_recovery():
def set_debug_mode(debug_mode_flag):
"""
Set the debug_mode flag of the dataset pipeline
Notes:
Set the debug_mode flag of the dataset pipeline. When enabled, the dataset pipeline is run synchronously and
sequentially with a single thread.
Note:
1. If both debug_mode and auto_offload are enabled, then during runtime, auto_offload is forcibly disabled.
2. If both debug_mode is enabled and a dataset pipeline has Map operation with offload set, then offload is
ignored.

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -79,11 +79,12 @@ def test_pipeline_debug_mode_dict():
center_crop = vision.CenterCrop(crop_size)
resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode
data = data.map(operations=[center_crop, resize_op], input_columns=["image"])
data = data.rename(input_columns=["image"], output_columns=["image_out"])
data = data.batch(2)
num_row = 0
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
assert len(item) == 2
assert item["image"].shape == (2, 24, 24, 3)
assert item["image_out"].shape == (2, 24, 24, 3)
assert item["attr"].shape == (2, 40)
num_row += 1
assert num_row == 2
@ -93,7 +94,7 @@ def test_pipeline_debug_mode_minddata():
"""
Feature: Pipeline debug mode.
Description: Test iterator with MindDataset in debug mode.
Expectation:Successful.
Expectation: Successful.
"""
logger.info("test_pipeline_debug_mode_minddata")
data = ds.MindDataset("../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0")
@ -108,13 +109,15 @@ def test_pipeline_debug_mode_not_support():
"""
Feature: Pipeline debug mode.
Description: Test creating tuple iterator with op not supported in pull mode.
Expectation: raise exception for debug mode.
Expectation: Exception raised for unsupported op in debug mode.
"""
logger.info("test_pipeline_debug_mode_not_support")
data = ds.NumpySlicesDataset(data=[[0, 1, 2]], column_names=["data"])
with pytest.raises(RuntimeError) as error_info:
data.create_tuple_iterator(num_epochs=1, output_numpy=True)
assert "dataset pipeline" in str(error_info.value)
assert "Leaf node GeneratorOp is not implemented yet in pull mode." in str(error_info.value)
def test_pipeline_debug_mode_map_pyfunc():
@ -154,6 +157,37 @@ def test_pipeline_debug_mode_batch_pyfunc():
assert num_rows == 5
def generator_md():
"""
Create a dataset with [0, 1, 2, 3, 4]
"""
for i in range(5):
yield (np.array([i]),)
# Note: Generator op is not yet supported in pull mode
@pytest.mark.skip(reason="Unsupported in pull mode")
def test_pipeline_debug_mode_skip_take():
"""
Feature: Pipeline debug mode.
Description: Test skip op followed by a take op
Expectation: Output is equal to the expected output
"""
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [2, 3, 4]
ds1 = ds1.skip(2)
# Here ds1 should be [2, 3]
ds1 = ds1.take(2)
buf = []
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
buf.append(data[0][0])
assert len(buf) == 2
assert buf == [2, 3]
def test_pipeline_debug_mode_concat():
"""
Feature: Pipeline debug mode.
@ -186,6 +220,63 @@ def test_pipeline_debug_mode_concat():
assert sample_count == num_repeat * num_epoch * sample_row
# Note: TFRecordDataset op is not yet supported in pull mode
@pytest.mark.skip(reason="Unsupported in pull mode")
def test_pipeline_debug_mode_tfrecord_rename_zip():
"""
Feature: Pipeline debug mode.
Description: Test rename op and zip op followed by repeat
Expectation: Output is the same as expected output
"""
tf_data_dir = ["../data/dataset/testTFBert5Rows2/5TFDatas.data"]
tf_schema_dir = "../data/dataset/testTFBert5Rows2/datasetSchema.json"
data1 = ds.TFRecordDataset(tf_data_dir, tf_schema_dir, shuffle=False)
data2 = ds.TFRecordDataset(tf_data_dir, tf_schema_dir, shuffle=False)
data2 = data2.rename(input_columns=["input_ids", "segment_ids"], output_columns=["masks", "seg_ids"])
data = ds.zip((data1, data2))
data = data.repeat(3)
num_iter = 0
for _, item in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
logger.info("item[mask] is {}".format(item["masks"]))
np.testing.assert_equal(item["masks"], item["input_ids"])
logger.info("item[seg_ids] is {}".format(item["seg_ids"]))
np.testing.assert_equal(item["segment_ids"], item["seg_ids"])
# need to consume the data in the buffer
num_iter += 1
logger.info("Number of data in data: {}".format(num_iter))
assert num_iter == 15
def test_pipeline_debug_mode_imagefolder_rename_zip():
"""
Feature: Pipeline debug mode.
Description: Test ImageFolderDataset with rename op and zip op
Expectation: Output is the same as expected output
"""
# Apply dataset operations
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", num_samples=6)
data2 = ds.ImageFolderDataset("../data/dataset/testPK/data", num_samples=10)
# Rename dataset2 for no conflict
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
data2 = data2.skip(4)
data3 = ds.zip((data1, data2))
num_iter = 0
for item in data3.create_dict_iterator(num_epochs=1): # each data is a dictionary
# in this example, each dictionary has keys "image" and "label"
logger.info("image is {}".format(item["image"]))
logger.info("label is {}".format(item["label"]))
num_iter += 1
logger.info("Number of data in data: {}".format(num_iter))
assert num_iter == 6
def test_pipeline_debug_mode_map_random():
"""
Feature: Pipeline debug mode.
@ -298,7 +389,10 @@ if __name__ == '__main__':
test_pipeline_debug_mode_not_support()
test_pipeline_debug_mode_map_pyfunc()
test_pipeline_debug_mode_batch_pyfunc()
test_pipeline_debug_mode_skip_take()
test_pipeline_debug_mode_concat()
test_pipeline_debug_mode_tfrecord_rename_zip()
test_pipeline_debug_mode_imagefolder_rename_zip()
test_pipeline_debug_mode_shuffle()
test_pipeline_debug_mode_map_random()
test_pipeline_debug_mode_imdb_shuffle()

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,10 +22,8 @@ import matplotlib.pyplot as plt
import mindspore.dataset as ds
from mindspore import log as logger
pytestmark = pytest.mark.forked
DATA_DIR_10 = "../data/dataset/testCifar10Data"
DATA_DIR_100 = "../data/dataset/testCifar100Data"
NO_BIN_DIR = "../data/dataset/testMnistData"
@ -137,7 +135,7 @@ def test_cifar10_basic():
num_iter2 += 1
assert num_iter2 == 15
# case 5: test batch with drop_remainder=True
# case 3: test batch with drop_remainder=True
data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
assert data3.get_dataset_size() == 100
assert data3.get_batch_size() == 1
@ -460,6 +458,7 @@ def test_cifar_exception_file_path():
Description: Test Cifar10Dataset and Cifar100Dataset with invalid file path in debug mode
Expectation: Error is raised as expected
"""
def exception_func(item):
raise Exception("Error occur!")
@ -578,7 +577,6 @@ def test_cifar100ops():
# take -5
data5 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter4 = 0
with pytest.raises(ValueError) as error_info:
data5 = data5.take(-5)
for _ in data4.create_dict_iterator(num_epochs=1):
@ -602,15 +600,63 @@ def test_cifar100ops():
assert num_iter7 == 0
# skip -5
data5 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
num_iter4 = 0
data8 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
with pytest.raises(ValueError) as error_info:
data5 = data5.skip(-5)
for _ in data4.create_dict_iterator(num_epochs=1):
data8 = data8.skip(-5)
for _ in data8.create_dict_iterator(num_epochs=1):
pass
assert "Input count is not within the required interval of" in str(error_info.value)
### Focused debug mode testcases with Cifar10Dataset ###
def test_pipeline_debug_mode_cifar10_rename_zip(plot=False):
"""
Feature: Pipeline debug mode.
Description: Test Cifar10Dataset with rename op and zip op
Expectation: Output is the same as expected output
"""
# Apply dataset operations
data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=6)
data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=6)
# Rename dataset2 for no conflict
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image2", "label2"])
data3 = ds.zip((data1, data2))
num_iter = 0
image_list, image_list2, label_list, label_list2 = [], [], [], []
for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
image = item["image"]
label = item["label"]
image_list.append(image)
label_list.append("label {}".format(label))
assert isinstance(image, np.ndarray)
assert image.shape == (32, 32, 3)
assert image.dtype == np.uint8
assert label.dtype == np.uint32
image2 = item["image2"]
label2 = item["label2"]
image_list2.append(image2)
label_list2.append("label {}".format(label2))
assert isinstance(image2, np.ndarray)
assert image2.shape == (32, 32, 3)
assert image2.dtype == np.uint8
assert label2.dtype == np.uint32
assert label == label2
np.testing.assert_equal(image, image2)
num_iter += 1
assert num_iter == 6
if plot:
visualize_dataset(image_list, label_list)
visualize_dataset(image_list2, label_list2)
def test_pipeline_debug_mode_multi_epoch_cifar10():
"""
Feature: Pipeline debug mode.
@ -642,7 +688,131 @@ def test_pipeline_debug_mode_multi_epoch_cifar10():
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total sample: ", sample_count)
logger.debug("total samples: ", sample_count)
# Note: Pull mode has issue this scenario with batch followed by repeat
@pytest.mark.skip(reason="Unsupported in pull mode")
def test_pipeline_debug_mode_multi_epoch_cifar10_batch_repeat():
"""
Feature: Pipeline debug mode.
Description: Test creating tuple iterator in cifar10 dataset with batch then repeat and with multi epochs.
Expectation: Output is equal to the expected output
"""
logger.info("test_pipeline_debug_mode_multi_epoch_cifar10")
data_dir_10 = "../data/dataset/testCifar10Data"
num_repeat = 2
batch_size = 20
limit_dataset = 100
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
# Add batch then repeat
data1 = data1.batch(batch_size, True)
data1 = data1.repeat(num_repeat)
num_epoch = 5
iter1 = data1.create_tuple_iterator(num_epochs=num_epoch)
epoch_count = 0
sample_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
# in this example, each row has columns "image" and "label"
row_count += 1
assert row_count == int(limit_dataset * num_repeat / batch_size)
logger.debug("row_count: ", row_count)
epoch_count += 1
sample_count += row_count
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total samples: ", sample_count)
def test_pipeline_debug_mode_multi_epoch_cifar10_zip():
"""
Feature: Pipeline debug mode.
Description: Test creating tuple iterator in cifar10 dataset with zip op and multi epochs.
Expectation: Output is equal to the expected output
"""
logger.info("test_pipeline_debug_mode_multi_epoch_cifar10_zip")
data_dir_10 = "../data/dataset/testCifar10Data"
num_repeat = 5
batch_size = 10
limit_dataset = 20
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
data2 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
# Rename dataset2 for no conflict
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image2", "label2"])
data3 = ds.zip((data1, data2))
# Add batch after repeat
data3 = data3.repeat(num_repeat)
data3 = data3.batch(batch_size, True)
num_epoch = 2
iter1 = data3.create_tuple_iterator(num_epochs=num_epoch)
epoch_count = 0
sample_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
# in this example, each row has columns "image" and "label"
row_count += 1
assert row_count == int(limit_dataset * num_repeat / batch_size)
logger.debug("row_count: ", row_count)
epoch_count += 1
sample_count += row_count
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total samples: ", sample_count)
# Note: Pull mode has issue this scenario with batch followed by repeat
@pytest.mark.skip(reason="Unsupported in pull mode")
def test_pipeline_debug_mode_multi_epoch_cifar10_zip_batch_repeat():
"""
Feature: Pipeline debug mode.
Description: Test creating tuple iterator in cifar10 dataset with zip op, then batch and repeat and multi epochs.
Expectation: Output is equal to the expected output
"""
logger.info("test_pipeline_debug_mode_multi_epoch_cifar10_zip")
data_dir_10 = "../data/dataset/testCifar10Data"
num_repeat = 5
batch_size = 10
limit_dataset = 20
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
data2 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
# Rename dataset2 for no conflict
data2 = data2.rename(input_columns=["image", "label"], output_columns=["image2", "label2"])
data3 = ds.zip((data1, data2))
# Add batch then repeat
data3 = data3.batch(batch_size, True)
data3 = data3.repeat(num_repeat)
num_epoch = 2
iter1 = data3.create_tuple_iterator(num_epochs=num_epoch)
epoch_count = 0
sample_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
# in this example, each row has columns "image" and "label"
row_count += 1
assert row_count == int(limit_dataset * num_repeat / batch_size)
logger.debug("row_count: ", row_count)
epoch_count += 1
sample_count += row_count
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total samples: ", sample_count)
if __name__ == '__main__':
@ -663,5 +833,9 @@ if __name__ == '__main__':
test_cifar10_with_chained_sampler_get_dataset_size()
test_cifar10_pk_sampler_get_dataset_size()
test_cifar100ops()
test_pipeline_debug_mode_cifar10_rename_zip(plot=False)
test_pipeline_debug_mode_multi_epoch_cifar10()
test_pipeline_debug_mode_multi_epoch_cifar10_batch_repeat()
test_pipeline_debug_mode_multi_epoch_cifar10_zip()
test_pipeline_debug_mode_multi_epoch_cifar10_zip_batch_repeat()
teardown_function()