forked from mindspore-Ecosystem/mindspore
fix comments
This commit is contained in:
parent
1c5c98b184
commit
64cebf13b2
|
@ -49,7 +49,7 @@ class PythonPullBasedIteratorConsumer : public PullBasedIteratorConsumer {
|
|||
public:
|
||||
/// Constructor which will call the base class default constructor.
|
||||
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
|
||||
explicit PythonPullBasedIteratorConsumer(int32_t num_epochs = -1) : PullBasedIteratorConsumer() {}
|
||||
explicit PythonPullBasedIteratorConsumer(int32_t num_epochs = -1) : PullBasedIteratorConsumer(num_epochs) {}
|
||||
|
||||
~PythonPullBasedIteratorConsumer() = default;
|
||||
/// Returns the next row in a vector format
|
||||
|
|
|
@ -678,6 +678,12 @@ int64_t BatchOp::GetTreeBatchSize() {
|
|||
|
||||
Status BatchOp::GetNextRowPullMode(TensorRow *const row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
if (eoe_received_) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
*row = TensorRow(TensorRow::kFlagEOE);
|
||||
eoe_received_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
int32_t cur_batch_size = 0;
|
||||
|
@ -685,6 +691,16 @@ Status BatchOp::GetNextRowPullMode(TensorRow *const row) {
|
|||
for (int i = 0; i < cur_batch_size; i++) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(&new_row));
|
||||
if (new_row.eoe()) {
|
||||
if (!drop_) {
|
||||
eoe_received_ = true;
|
||||
} else {
|
||||
*row = new_row;
|
||||
UpdateRepeatAndEpochCounter();
|
||||
return Status::OK();
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (!new_row.empty()) {
|
||||
table->emplace_back(new_row);
|
||||
if (table->size() == static_cast<size_t>(cur_batch_size)) {
|
||||
|
|
|
@ -315,6 +315,9 @@ class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatc
|
|||
/// \brief Gets the implementation status for operator in pull mode
|
||||
/// \return implementation status
|
||||
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
|
||||
|
||||
private:
|
||||
bool eoe_received_ = false;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,5 +79,24 @@ Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
|
|||
}
|
||||
|
||||
int64_t EpochCtrlOp::GetTreeRepeatCount() { return child_[0]->GetTreeRepeatCount(); }
|
||||
|
||||
Status EpochCtrlOp::GetNextRowPullMode(TensorRow *row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
if (child_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] EpochCtrlOp can't be the leaf node(first operator) of pipeline.");
|
||||
}
|
||||
|
||||
// `retry_if_eoe` is false because EpochCtrlOp does not eat EOE.
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row));
|
||||
|
||||
// Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op.
|
||||
// Other TensorRows containing data or EOF will simply be forwarded.
|
||||
// EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up.
|
||||
if (row->eoe()) {
|
||||
RETURN_IF_NOT_OK(EoeReceived(0));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,11 @@ class EpochCtrlOp : public RepeatOp {
|
|||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
int64_t GetTreeRepeatCount() override;
|
||||
|
||||
/// \brief In pull mode, gets the next row
|
||||
/// \param row[out] - Fetched TensorRow
|
||||
/// \return Status The status code returned
|
||||
Status GetNextRowPullMode(TensorRow *const row) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -484,7 +484,11 @@ std::vector<int32_t> MapOp::GetMPWorkerPIDs() const {
|
|||
Status MapOp::GetNextRowPullMode(TensorRow *const row) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(&new_row));
|
||||
if (new_row.eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
if (new_row.empty()) {
|
||||
(*row) = std::move(new_row);
|
||||
return Status::OK();
|
||||
}
|
||||
auto column_name_id_map = child_[0]->column_name_id_map();
|
||||
|
|
|
@ -130,5 +130,28 @@ Status RepeatOp::Reset() {
|
|||
}
|
||||
|
||||
int64_t RepeatOp::GetTreeRepeatCount() { return num_repeats_; }
|
||||
|
||||
Status RepeatOp::GetNextRowPullMode(TensorRow *const row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
if (child_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR] Pipeline init failed, RepeatOp can't be the leaf node(first operator) of pipeline.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row));
|
||||
// Loop until non EOE is received
|
||||
while (row->eoe()) {
|
||||
MS_LOG(INFO) << "RepeatOp::GetNextRowPullMode eoe received.";
|
||||
RETURN_IF_NOT_OK(EoeReceived(0));
|
||||
if (state_ == OpState::kDeOpIdle) {
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row));
|
||||
}
|
||||
// Check if the last buf is next eof
|
||||
if (row->eof()) {
|
||||
RETURN_IF_NOT_OK(EofReceived(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -97,6 +97,11 @@ class RepeatOp : public PipelineOp {
|
|||
|
||||
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
|
||||
|
||||
/// \brief In pull mode, gets the next row
|
||||
/// \param row[out] - Fetched TensorRow
|
||||
/// \return Status The status code returned
|
||||
Status GetNextRowPullMode(TensorRow *const row) override;
|
||||
|
||||
protected:
|
||||
// The number of repeats that the user requested.
|
||||
// Note that num_repeats_ is different with op_total_repeats_ or op_num_repeats_per_epoch_ in base DatasetOp class.
|
||||
|
@ -107,6 +112,10 @@ class RepeatOp : public PipelineOp {
|
|||
// Note that repeat_count_ is different with op_current_repeats_ in the base DatasetOp class
|
||||
// because it counts the repeats in the current epoch, whereas op_current_repeats_ counts the global total repeats.
|
||||
int32_t repeat_count_;
|
||||
|
||||
/// \brief Gets the implementation status for operator in pull mode
|
||||
/// \return implementation status
|
||||
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,6 +78,9 @@ Status SkipOp::GetNextRowPullMode(TensorRow *const row) {
|
|||
bool eoe_received = false;
|
||||
while (skip_count_ < max_skips_) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRowPullMode(row));
|
||||
if (row->eof()) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (row->eoe() && !once_only_) {
|
||||
eoe_received = true;
|
||||
break;
|
||||
|
|
|
@ -22,7 +22,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
MappableLeafOp::MappableLeafOp(int32_t num_wkrs, int32_t queue_size, std::shared_ptr<SamplerRT> sampler)
|
||||
: ParallelOp(num_wkrs, queue_size, std::move(sampler)), sample_ids_(nullptr), curr_row_(0), prepared_data_{false} {}
|
||||
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
|
||||
sample_ids_(nullptr),
|
||||
curr_row_(0),
|
||||
prepared_data_{false},
|
||||
eof_handled_{false} {}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status MappableLeafOp::ImageDecrypt(const std::string &path, std::shared_ptr<Tensor> *tensor,
|
||||
|
@ -110,6 +114,7 @@ Status MappableLeafOp::operator()() {
|
|||
Status MappableLeafOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
RETURN_IF_NOT_OK(sampler_->ResetSampler());
|
||||
curr_row_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -166,15 +171,21 @@ Status MappableLeafOp::GetNextRowPullMode(TensorRow *const row) {
|
|||
RETURN_IF_NOT_OK(InitPullMode());
|
||||
prepared_data_ = true;
|
||||
}
|
||||
if (eof_handled_) {
|
||||
*row = TensorRow(TensorRow::kFlagEOF);
|
||||
return Status::OK();
|
||||
}
|
||||
TensorRow sample_row;
|
||||
if (sample_ids_ == nullptr) {
|
||||
RETURN_IF_NOT_OK(this->InitSampler());
|
||||
TensorRow sample_row;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(sample_row.size() > 0, "GetNextRowPullMode: Expect at least one sample in sampler.");
|
||||
sample_ids_ = sample_row[0];
|
||||
}
|
||||
if (curr_row_ + 1 > sample_ids_->Size()) {
|
||||
*row = TensorRow(TensorRow::kFlagEOE);
|
||||
RETURN_IF_NOT_OK(ResetAndUpdateRepeat());
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t key;
|
||||
|
@ -183,5 +194,15 @@ Status MappableLeafOp::GetNextRowPullMode(TensorRow *const row) {
|
|||
curr_row_++;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MappableLeafOp::ResetAndUpdateRepeat() {
|
||||
if (!IsLastIteration()) {
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
UpdateRepeatAndEpochCounter();
|
||||
} else {
|
||||
eof_handled_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -85,6 +85,7 @@ class MappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>, p
|
|||
TensorPtr sample_ids_; // sample id pointer for pull mode
|
||||
uint32_t curr_row_; // current row number count for pull mode
|
||||
bool prepared_data_; // flag to indicate whether the data is prepared before LoadTensorRow for pull mode
|
||||
bool eof_handled_; // T/F if this op got an eof
|
||||
|
||||
/// Initialize Sampler, calls sampler->Init() within
|
||||
/// @return Status The status code returned
|
||||
|
@ -126,6 +127,11 @@ class MappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>, p
|
|||
/// \return Status The status code returned
|
||||
virtual Status LoadTensorRowPullMode(row_id_type row_id, TensorRow *row) { return LoadTensorRow(row_id, row); }
|
||||
|
||||
/// reset the op and update repeat and epoch number if the condition is met.
|
||||
/// \param row[out] - Fetched EOF if it is the last iteration for epoch
|
||||
/// \return Status The status code returned
|
||||
virtual Status ResetAndUpdateRepeat();
|
||||
|
||||
/// \brief Gets the implementation status for operator in pull mode
|
||||
/// \return implementation status
|
||||
ImplementedPullMode PullModeImplementationStatus() const override { return ImplementedPullMode::Implemented; }
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
#include "minddata/dataset/engine/tree_adapter_lite.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/opt/pre/debug_mode_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||
|
@ -73,6 +76,10 @@ Status TreeAdapterLite::BuildTree(std::shared_ptr<DatasetNode> root_ir) {
|
|||
Status TreeAdapterLite::GetNextRow(TensorRow *const row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(root_);
|
||||
RETURN_IF_NOT_OK(root_->GetNextRowPullMode(row));
|
||||
if (row->eof()) {
|
||||
std::string err = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs.";
|
||||
RETURN_STATUS_UNEXPECTED(err);
|
||||
}
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -97,6 +104,20 @@ Status TreeAdapterLite::PrePass(std::shared_ptr<DatasetNode> ir) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapterLite::PostPass(std::shared_ptr<DatasetNode> ir) const {
|
||||
RETURN_UNEXPECTED_IF_NULL(ir);
|
||||
// Vector of actions in post-pass phase
|
||||
std::vector<std::unique_ptr<IRPass>> actions;
|
||||
#ifndef ENABLE_ANDROID
|
||||
MS_LOG(INFO) << "Running repeat pass.";
|
||||
(void)actions.emplace_back(std::make_unique<RepeatPass>());
|
||||
bool modified = false;
|
||||
RETURN_IF_NOT_OK(actions[0]->Run(ir, &modified));
|
||||
MS_LOG(INFO) << "Repeat pass completed.";
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapterLite::Compile(const std::shared_ptr<DatasetNode> &input_ir, int32_t num_epochs) {
|
||||
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
||||
input_ir_ = input_ir;
|
||||
|
@ -119,6 +140,9 @@ Status TreeAdapterLite::Compile(const std::shared_ptr<DatasetNode> &input_ir, in
|
|||
// Pre-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PrePass(root_ir));
|
||||
MS_LOG(INFO) << "Plan after PrePass:" << '\n' << *root_ir << '\n';
|
||||
|
||||
RETURN_IF_NOT_OK(PostPass(root_ir));
|
||||
MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n';
|
||||
root_ir_ = root_ir;
|
||||
|
||||
RETURN_IF_NOT_OK(BuildTree(root_ir));
|
||||
|
|
|
@ -51,6 +51,9 @@ class TreeAdapterLite {
|
|||
// Run the mandatory pass checking the syntax and semantics of the IR tree
|
||||
Status PrePass(std::shared_ptr<DatasetNode> ir) const;
|
||||
|
||||
// Run the mandatory pass augmenting the IR tree
|
||||
Status PostPass(std::shared_ptr<DatasetNode> ir) const;
|
||||
|
||||
std::shared_ptr<DatasetNode> input_ir_;
|
||||
std::shared_ptr<DatasetNode> root_ir_;
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ from mindspore import log as logger
|
|||
# the global configuration setting of debug_mode may impact other tests running in parallel.
|
||||
pytestmark = pytest.mark.forked
|
||||
|
||||
DATA_DIR_10 = "../data/dataset/testCifar10Data"
|
||||
DEBUG_MODE = False
|
||||
SEED_VAL = 0 # seed will be set internally in debug mode, save original seed value to restore.
|
||||
|
||||
|
@ -163,17 +162,28 @@ def test_pipeline_debug_mode_concat():
|
|||
"""
|
||||
logger.info("test_pipeline_debug_mode_concat")
|
||||
data_dir = "../data/dataset/testCelebAData/"
|
||||
num_repeat = 3
|
||||
data1 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0)
|
||||
data2 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0)
|
||||
data3 = ds.CelebADataset(data_dir, decode=True, num_shards=1, shard_id=0)
|
||||
data4 = data1.concat(data2)
|
||||
data5 = data3 + data4
|
||||
num_rows = 0
|
||||
for item1 in data5.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item1) == 2
|
||||
assert item1[0].shape == (2268, 4032, 3)
|
||||
num_rows += 1
|
||||
assert num_rows == 12
|
||||
data5 = data5.repeat(num_repeat)
|
||||
num_epoch = 2
|
||||
epoch_count = 0
|
||||
sample_row = 12
|
||||
sample_count = 0
|
||||
for _ in range(num_epoch):
|
||||
num_rows = 0
|
||||
for item1 in data5.create_tuple_iterator(num_epochs=1):
|
||||
assert len(item1) == 2
|
||||
assert item1[0].shape == (2268, 4032, 3)
|
||||
num_rows += 1
|
||||
epoch_count += 1
|
||||
sample_count += num_rows
|
||||
assert num_rows == sample_row * num_repeat
|
||||
assert epoch_count == num_epoch
|
||||
assert sample_count == num_repeat * num_epoch * sample_row
|
||||
|
||||
|
||||
def test_pipeline_debug_mode_map_random():
|
||||
|
@ -223,28 +233,63 @@ def test_pipeline_debug_mode_imdb_shuffle():
|
|||
Expectation: The data is processed successfully in the same order.
|
||||
"""
|
||||
logger.info("test_pipeline_debug_mode_imdb_shuffle")
|
||||
buffer_size = 5
|
||||
|
||||
# apply dataset operations
|
||||
data1 = ds.IMDBDataset("../data/dataset/testIMDBDataset", shuffle=True)
|
||||
data1 = data1.shuffle(buffer_size=buffer_size)
|
||||
|
||||
# Verify dataset size
|
||||
data1_size = data1.get_dataset_size()
|
||||
logger.info("dataset size is: {}".format(data1_size))
|
||||
assert data1_size == 8
|
||||
|
||||
expect_output = [["train_pos_1.txt", 1], ["train_pos_0.txt", 1], ["train_neg_0.txt", 0], ["test_pos_1.txt", 1], [
|
||||
"test_neg_1.txt", 0], ["test_pos_0.txt", 1], ["test_neg_0.txt", 0], ["train_neg_1.txt", 0]]
|
||||
num_iter = 0
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
|
||||
# in this example, each dictionary has keys "text" and "label"
|
||||
logger.info("text is {}".format(item["text"]))
|
||||
assert item["text"] == expect_output[num_iter][0]
|
||||
logger.info("label is {}".format(item["label"]))
|
||||
assert item["label"] == expect_output[num_iter][1]
|
||||
num_iter += 1
|
||||
|
||||
logger.info("Number of data in data1: {}".format(num_iter))
|
||||
assert num_iter == 8
|
||||
|
||||
|
||||
def test_pipeline_debug_mode_multi_epoch_map_pyfunc():
|
||||
"""
|
||||
Feature: Pipeline debug mode.
|
||||
Description: Test creating dict iterator with map(PyFunc) with num_epochs > 1.
|
||||
Expectation: Successful.
|
||||
"""
|
||||
logger.info("test_pipeline_debug_mode_multi_epoch_map_pyfunc")
|
||||
data = ds.CelebADataset("../data/dataset/testCelebAData/", sampler=ds.SequentialSampler(),
|
||||
decode=True)
|
||||
num_repeat = 5
|
||||
sample_row = 4
|
||||
data = data.repeat(num_repeat)
|
||||
data = data.map(operations=[(lambda x: x - 1), (lambda x: x * 2)], input_columns=["image"])
|
||||
num_epoch = 7
|
||||
epoch_count = 0
|
||||
sample_count = 0
|
||||
iter1 = data.create_dict_iterator(num_epochs=num_epoch)
|
||||
for _ in range(num_epoch):
|
||||
num_rows = 0
|
||||
for item in iter1:
|
||||
assert len(item) == 2
|
||||
assert item["image"].shape == (2268, 4032, 3)
|
||||
num_rows += 1
|
||||
assert num_rows == sample_row * num_repeat
|
||||
sample_count += num_rows
|
||||
epoch_count += 1
|
||||
assert epoch_count == num_epoch
|
||||
assert sample_count == num_repeat * num_epoch * sample_row
|
||||
|
||||
err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
|
||||
with pytest.raises(RuntimeError, match=err_msg):
|
||||
iter1.__next__()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_function()
|
||||
test_pipeline_debug_mode_tuple()
|
||||
|
@ -257,4 +302,5 @@ if __name__ == '__main__':
|
|||
test_pipeline_debug_mode_shuffle()
|
||||
test_pipeline_debug_mode_map_random()
|
||||
test_pipeline_debug_mode_imdb_shuffle()
|
||||
test_pipeline_debug_mode_multi_epoch_map_pyfunc()
|
||||
teardown_function()
|
||||
|
|
|
@ -611,6 +611,40 @@ def test_cifar100ops():
|
|||
assert "Input count is not within the required interval of" in str(error_info.value)
|
||||
|
||||
|
||||
def test_pipeline_debug_mode_multi_epoch_cifar10():
|
||||
"""
|
||||
Feature: Pipeline debug mode.
|
||||
Description: Test creating tuple iterator in cifar10 dataset with multi epochs.
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
logger.info("test_pipeline_debug_mode_multi_epoch_cifar10")
|
||||
data_dir_10 = "../data/dataset/testCifar10Data"
|
||||
num_repeat = 2
|
||||
batch_size = 32
|
||||
limit_dataset = 100
|
||||
# apply dataset operations
|
||||
data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
|
||||
data1 = data1.repeat(num_repeat)
|
||||
data1 = data1.batch(batch_size, True)
|
||||
num_epoch = 5
|
||||
iter1 = data1.create_tuple_iterator(num_epochs=num_epoch)
|
||||
epoch_count = 0
|
||||
sample_count = 0
|
||||
for _ in range(num_epoch):
|
||||
row_count = 0
|
||||
for _ in iter1:
|
||||
# in this example, each row has columns "image" and "label"
|
||||
row_count += 1
|
||||
assert row_count == int(limit_dataset * num_repeat / batch_size)
|
||||
logger.debug("row_count: ", row_count)
|
||||
epoch_count += 1
|
||||
sample_count += row_count
|
||||
assert epoch_count == num_epoch
|
||||
logger.debug("total epochs: ", epoch_count)
|
||||
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
|
||||
logger.debug("total sample: ", sample_count)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_function()
|
||||
test_cifar10_content_check()
|
||||
|
@ -629,4 +663,5 @@ if __name__ == '__main__':
|
|||
test_cifar10_with_chained_sampler_get_dataset_size()
|
||||
test_cifar10_pk_sampler_get_dataset_size()
|
||||
test_cifar100ops()
|
||||
test_pipeline_debug_mode_multi_epoch_cifar10()
|
||||
teardown_function()
|
||||
|
|
Loading…
Reference in New Issue