Migrate repeat_pass.cc to IR optimizer and remove ExecTree optimizer

This commit is contained in:
Nat Sutyanyong 2021-01-28 17:12:08 -05:00
parent c16b45ab23
commit 5a7dc0accc
84 changed files with 706 additions and 1342 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -258,10 +258,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return The number of required repeats for the operator
int32_t op_total_repeats() { return op_total_repeats_; }
/// \brief Getter function
/// \return The number of required epochs for the operator
int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; }
/// \brief Getter function
/// \return The number of repeats per epoch for the operator
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,10 +17,8 @@
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/log_adapter.h"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -304,6 +304,10 @@ class CsvOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *const modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CsvOp"; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,20 +16,9 @@
#include "minddata/dataset/engine/execution_tree.h"
#include <iostream>
#include <string>
#include <utility>
#include <limits>
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#endif
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
#if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)
@ -255,97 +244,13 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
return Status::OK();
}
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status The status code returned
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
num_epochs_ = num_epochs;
partially_prepare_ = partial;
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PreAction());
// Post optimization compulsory transformation
RETURN_IF_NOT_OK(this->PostAction());
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
// Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK(this->PrepareDeprecated());
return Status::OK();
}
Status ExecutionTree::PreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
if (!partially_prepare_) {
#ifndef ENABLE_ANDROID
pre_actions.push_back(std::make_unique<CacheErrorPass>());
#endif
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
}
MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops.";
// Apply pre action passes
for (auto &pass : pre_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Pre passes complete.";
return Status::OK();
}
Status ExecutionTree::PostAction() {
bool modified = false;
OptPass post_actions;
// Construct pre actions
MS_LOG(INFO) << "Running post pass loops.";
#ifndef ENABLE_ANDROID
// Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind.
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
// This is because Python API binding to TensorOperation is still in progress.
post_actions.push_back(std::make_unique<CacheErrorPass>());
post_actions.push_back(std::make_unique<RepeatPass>());
#endif
// Apply post action passes
for (auto &pass : post_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Post passes complete.";
return Status::OK();
}
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
//
// This driver is deprecated.
Status ExecutionTree::PrepareDeprecated() {
// Tree must be in pending prepare state before we can assign root to it
if (tree_state_ != kDeTStatePrepare) {
std::string err_msg =
"Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg);
}
Status ExecutionTree::Prepare() {
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
if (root_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -169,24 +169,6 @@ class ExecutionTree {
// @return the prepare flags
uint32_t PrepareFlags() const { return prepare_flags_; }
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status The status code returned
Status Prepare(int num_epochs = -1, bool partial = false);
// Compulsory transformation/action pre optimization.
// @return Status The status code returned
Status PreAction();
@ -200,7 +182,7 @@ class ExecutionTree {
// it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status The status code returned
Status PrepareDeprecated();
Status Prepare();
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
@ -239,10 +221,6 @@ class ExecutionTree {
// Getter for profiling manager, no ownership
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
private:
// A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print
@ -257,9 +235,7 @@ class ExecutionTree {
int32_t id_count_; // Counter for generating operator id's
uint32_t prepare_flags_; // Flags used during tree prepare
TreeState tree_state_; // Tracking the current tree state
int32_t num_epochs_; // Total number of epochs to run for this tree
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
#if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)
// This rank_id is for numa and device_queue, one process work with only one rank_id,
// for standalone scenario, this rank_id may come from env 'CUDA_VISIBLE_DEVICES',

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -102,9 +102,11 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
node_ops->push_back(project_op);
}
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
pad_map_));
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
#else
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
in_col_names_, pad_map_));

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -84,9 +84,12 @@ void BucketBatchByLengthNode::Print(std::ostream &out) const {
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
node_ops->push_back(std::make_shared<BucketBatchByLengthOp>(
column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
element_length_function_, pad_info_, pad_to_bucket_boundary_,
drop_remainder_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
if (bucket_boundaries_[0] == 0) {
bucket_boundaries_.erase(bucket_boundaries_.begin());
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -55,10 +55,11 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
// Function to build BuildSentenceVocabNode
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<BuildSentencePieceVocabOp> build_sentence_piece_vocab_op;
build_sentence_piece_vocab_op = std::make_shared<BuildSentencePieceVocabOp>(
vocab_, col_names_, vocab_size_, character_coverage_, model_type_, params_, connector_que_size_);
node_ops->push_back(build_sentence_piece_vocab_op);
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
model_type_, params_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -54,6 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
std::shared_ptr<BuildVocabOp> build_vocab_op;
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
special_first_, num_workers_, connector_que_size_);
build_vocab_op->set_total_repeats(GetTotalRepeats());
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(build_vocab_op);
return Status::OK();
}

View File

@ -51,10 +51,24 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
"Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
lookup_op_->set_total_repeats(GetTotalRepeats());
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(lookup_op_);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status CacheLookupNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<CacheLookupNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status CacheLookupNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<CacheLookupNode>(), modified);
}
std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
// CacheLookupNode should already been copied, so we just return it here
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);

View File

@ -64,6 +64,18 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
private:
std::shared_ptr<SamplerObj> sampler_;
std::shared_ptr<DatasetOp> lookup_op_;

View File

@ -48,9 +48,23 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
merge_op->set_total_repeats(GetTotalRepeats());
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(merge_op);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status CacheMergeNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<CacheMergeNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status CacheMergeNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<CacheMergeNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -53,6 +53,18 @@ class CacheMergeNode : public DatasetNode {
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore

View File

@ -53,9 +53,23 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
cache_op->SetSampler(sampler_->SamplerBuild());
cache_op->set_total_repeats(GetTotalRepeats());
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cache_op);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status CacheNode::Accept(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<CacheNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status CacheNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<CacheNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -55,6 +55,18 @@ class CacheNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
private:
std::shared_ptr<SamplerObj> sampler_;
};

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -119,12 +119,16 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
}
Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
std::shared_ptr<ConcatOp> op;
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_));
op = std::make_shared<ConcatOp>(connector_que_size_);
} else {
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(),
children_flag_and_nums_, children_start_end_index_));
op = std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_,
children_start_end_index_);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -219,7 +219,9 @@ DatasetNode::DatasetNode()
dataset_size_(-1),
mappable_(kNotADataSource),
nary_op_(false),
descendant_of_cache_(false) {
descendant_of_cache_(false),
total_repeats_(-1),
num_epochs_(1) {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/consumers/tree_consumer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
@ -292,6 +293,24 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
/// \return Status of the function
virtual Status to_json(nlohmann::json *out_json);
/// \brief Setter function, set the number of total repeats for the operator
void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; }
/// \brief Setter function, set the number of epochs for the operator
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
/// \brief Getter function
/// \return The number of required repeats for the operator
int32_t GetTotalRepeats() const { return total_repeats_; }
/// \brief Getter function
/// \return The number of epochs for the operator
int32_t GetNumEpochs() const { return num_epochs_; }
/// \brief Getter function
/// \return The number of repeats per epoch for the operator
int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; }
protected:
std::vector<std::shared_ptr<DatasetNode>> children_;
DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase
@ -301,6 +320,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
int32_t total_repeats_; // Number of times required to run this operator
int32_t num_epochs_; // Number of epochs
// Establish a parent-child relationship between this node and the input node.
// Used only in the constructor of the class and its derived classes.
void AddChild(std::shared_ptr<DatasetNode> child);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,6 +44,8 @@ void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" +
// Function to build the EpochCtrlOp
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
new_op_->set_total_repeats(GetTotalRepeats());
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op_);
op_ = new_op_;
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,7 +44,10 @@ void FilterNode::Print(std::ostream &out) const {
}
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,7 +38,8 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
output_columns_(output_columns),
project_columns_(project_columns),
DatasetNode(std::move(cache)),
callbacks_(callbacks) {
callbacks_(callbacks),
under_a_cache_(false) {
this->AddChild(child);
}
@ -64,6 +65,17 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
[](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
// This is temporary code.
// Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
// to TensorOp, we need to check the randomness here.
// When TensorOperation captures the randomness behaviour, remove this code and the member "under_a_cache_"
// and the temporary code in CacheValidation pre pass in IR optimizer.
if (under_a_cache_) {
auto itr = std::find_if(tensor_ops.begin(), tensor_ops.end(), [](const auto &it) { return !it->Deterministic(); });
if (itr != tensor_ops.end()) {
RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
}
}
// This parameter will be removed with next rebase
std::vector<std::string> col_orders;
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
@ -74,9 +86,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
if (!project_columns_.empty()) {
auto project_op = std::make_shared<ProjectOp>(project_columns_);
project_op->set_total_repeats(GetTotalRepeats());
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(project_op);
}
map_op->set_total_repeats(GetTotalRepeats());
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(map_op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -79,6 +79,9 @@ class MapNode : public DatasetNode {
/// \brief setter to set all tensor operations
void setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations);
/// \brief indicate this Map will be cached
void Cached() { under_a_cache_ = true; }
/// \brief Getter functions
/// \brief Getter of tensor operations
/// \return Vector of operations the Map node will process
@ -95,12 +98,11 @@ class MapNode : public DatasetNode {
private:
std::vector<std::shared_ptr<TensorOperation>> operations_;
private:
std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_;
std::vector<std::string> project_columns_;
std::vector<std::shared_ptr<DSCallback>> callbacks_;
bool under_a_cache_;
};
} // namespace dataset

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,7 +53,10 @@ Status ProjectNode::ValidateParams() {
}
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<ProjectOp>(columns_));
auto op = std::make_shared<ProjectOp>(columns_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,7 +58,10 @@ Status RenameNode::ValidateParams() {
}
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,6 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + st
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
new_op->set_total_repeats(GetTotalRepeats());
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(new_op);
op_ = new_op;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,8 +44,11 @@ void ShuffleNode::Print(std::ostream &out) const {
// Function to build the ShuffleOp
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_));
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
rows_per_buffer_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,7 +39,10 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" +
// Function to build the SkipOp
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
auto op = std::make_shared<SkipOp>(skip_count_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -72,9 +72,11 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Argument that is not exposed to user in the API.
std::set<std::string> extensions = {};
node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, extensions, std::move(schema),
std::move(sampler_->SamplerBuild())));
auto album_op = std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_,
extensions, std::move(schema), std::move(sampler_->SamplerBuild()));
album_op->set_total_repeats(GetTotalRepeats());
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(album_op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -67,9 +67,12 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
// label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema),
std::move(sampler_->SamplerBuild())));
auto celeba_op =
std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_,
extensions_, std::move(schema), std::move(sampler_->SamplerBuild()));
celeba_op->set_total_repeats(GetTotalRepeats());
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(celeba_op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -64,9 +64,12 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->SamplerBuild())));
auto cifar_op =
std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()));
cifar_op->set_total_repeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -62,9 +62,12 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_->SamplerBuild())));
auto cifar_op =
std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()));
cifar_op->set_total_repeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -193,9 +193,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
clue_op->set_total_repeats(GetTotalRepeats());
clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(clue_op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -122,7 +122,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -130,10 +130,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
csv_op->set_total_repeats(GetTotalRepeats());
csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(csv_op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -91,7 +91,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
if (reset_ancestor_ != nullptr) {
reset_ancestor_->op_->AddToEoeList(op);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -70,9 +70,12 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->SamplerBuild())));
auto op = std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->SamplerBuild()));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -94,7 +94,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
manifest_op->set_total_repeats(GetTotalRepeats());
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(manifest_op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -169,6 +169,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
}
RETURN_IF_NOT_OK(mindrecord_op->Init());
mindrecord_op->set_total_repeats(GetTotalRepeats());
mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(mindrecord_op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,9 +58,11 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema),
std::move(sampler_->SamplerBuild())));
auto op = std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
std::move(schema), std::move(sampler_->SamplerBuild()));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -109,7 +109,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema_));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -98,9 +98,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
text_file_op->set_total_repeats(GetTotalRepeats());
text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
// Add TextFileOp
node_ops->push_back(text_file_op);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -140,9 +140,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
rows_per_buffer_, &shuffle_op));
shuffle_op->set_total_repeats(GetTotalRepeats());
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(shuffle_op);
}
tf_reader_op->set_total_repeats(GetTotalRepeats());
tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
// Add TFReaderOp
node_ops->push_back(tf_reader_op);
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -113,7 +113,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
voc_op =
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
voc_op->set_total_repeats(GetTotalRepeats());
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(voc_op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -47,7 +47,10 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
int32_t rows_per_buffer = 1;
node_ops->push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_));
auto op = std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 20202-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,7 +40,10 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s
// Function to build the TakeOp
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
auto op = std::make_shared<TakeOp>(take_count_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -100,8 +100,11 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
device_id_ = 0;
RETURN_IF_NOT_OK(this->GetShardId(&device_id_));
node_ops->push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, create_data_info_queue_));
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
total_batch_, create_data_info_queue_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,7 +58,10 @@ Status ZipNode::ValidateParams() {
}
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
node_ops->push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
auto op = std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}

View File

@ -2,27 +2,17 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
set(DATASET_ENGINE_OPT_SRC_FILES
optional/tensor_op_fusion_pass.cc
pass.cc
post/auto_worker_pass.cc
pre/cache_validation_pass.cc
pre/deep_copy_pass.cc
pre/getter_pass.cc
pre/input_validation_pass.cc
pre/epoch_ctrl_pass.cc
pre/node_removal_pass.cc
)
# This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer.
# When the migration is complete, we will remove these files.
set(DATASET_ENGINE_OPT_SRC_FILES
${DATASET_ENGINE_OPT_SRC_FILES}
optional/tensor_op_fusion_pass.cc
pre/cache_error_pass.cc
post/repeat_pass.cc
pre/cache_transform_pass.cc
pre/epoch_injection_pass.cc
util/printer_pass.cc
pre/removal_pass.cc
pre/cache_validation_pass.cc
pre/deep_copy_pass.cc
pre/epoch_ctrl_pass.cc
pre/getter_pass.cc
pre/input_validation_pass.cc
pre/node_removal_pass.cc
)
if(ENABLE_PYTHON)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,11 @@
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
#endif
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
@ -187,6 +192,26 @@ Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified)
Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#ifndef ENABLE_ANDROID
Status IRNodePass::Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}
Status IRNodePass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
}
#endif
Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,11 @@ namespace dataset {
class BatchNode;
class BucketBatchByLengthNode;
class BuildVocabNode;
#ifndef ENABLE_ANDROID
class CacheLookupNode;
class CacheMergeNode;
class CacheNode;
#endif
class ConcatNode;
class EpochCtrlNode;
class FilterNode;
@ -199,6 +204,14 @@ class IRNodePass : public IRPass {
virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified);
#ifndef ENABLE_ANDROID
virtual Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<CacheNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified);
#endif
virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,15 +14,16 @@
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include <memory>
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
namespace mindspore {
namespace dataset {
@ -31,10 +32,10 @@ RepeatPass::RepeatPass()
: num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
Status RepeatPass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
if (node->Count() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
num_repeats_ = -num_repeats_;
}
// This RepeatOp and its descendent nodes should be repeated for another num_repeats() times.
@ -49,14 +50,14 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modi
// num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4),
// meaning repeat2 and map op should be set to read 8 times (2*4).
// Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times.
num_repeats_ *= node->num_repeats();
num_repeats_ *= node->Count();
return Status::OK();
}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) {
Status RepeatPass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
// Get the total number of epochs from the EpochCtrlOp parameter
num_epochs_ = node->num_repeats();
num_epochs_ = node->Count();
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
// For example: tfreader --> epoch ctrl(3)
// num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3),
@ -65,115 +66,108 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const m
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Identifies the subtree below this node as being in a cache merge path
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) {
Status RepeatPass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
// Turn on the flag that we're under a merge op
is_merge_ = true;
return Status::OK();
}
// Identifies the subtree below this node as being cached
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
Status RepeatPass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
// Turn on the flag that we're under a merge op
is_cached_ = true;
return Status::OK();
}
#endif
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
Status RepeatPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and set its total repeats. It is important that the op is removed from the save area,
// because the merge op above us may also take action on it later for a different case when
// there is no repeat in the merge leg.
if (is_merge_ && cache_lookup_) {
cache_lookup_->set_total_repeats(num_repeats_);
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
cache_lookup_->SetTotalRepeats(num_repeats_);
cache_lookup_->SetNumEpochs(num_epochs_);
cache_lookup_.reset();
}
if (is_cached_) {
AddToCachedOpStack(node);
AddToCachedNodeStack(node);
}
node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
node->SetTotalRepeats(num_repeats_);
node->SetNumEpochs(num_epochs_);
// We finish the walk of this RepeatOp's descendent nodes.
// The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
// But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
// so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
num_repeats_ /= node->num_repeats();
num_repeats_ /= node->Count();
return Status::OK();
}
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) {
node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
Status RepeatPass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
node->SetTotalRepeats(num_repeats_);
node->SetNumEpochs(num_epochs_);
// We finish the walk of this EpochCtrl's descendent nodes.
num_repeats_ /= node->num_repeats();
num_repeats_ /= node->Count();
return Status::OK();
}
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
// for use with a controlling repeat above it.
Status RepeatPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
// If we are under a cache op, then save ourselves to the cached op stack.
if (is_cached_) {
AddToCachedNodeStack(node);
}
// Set total repeats and total epochs for the node
node->SetTotalRepeats(num_repeats_);
node->SetNumEpochs(num_epochs_);
return Status::OK();
}
#ifndef ENABLE_ANDROID
// CacheOp removes previous leaf ops and replaces them with itself
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
Status RepeatPass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
is_cached_ = false;
// if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops.
// So that those cached nodes become 1-time use (up to eoe), never repeated. Instead
// the repeating behaviours shall be invoked against the cache op.
std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack();
while (cached_op != nullptr) {
int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_;
cached_op->set_total_repeats(cached_op_total_repeats);
std::shared_ptr<DatasetNode> cached_node = PopFromCachedNodeStack();
while (cached_node != nullptr) {
int32_t cached_op_total_repeats = cached_node->GetTotalRepeats() / num_repeats_;
cached_node->SetTotalRepeats(cached_op_total_repeats);
// Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1
cached_op->set_num_repeats_per_epoch(cached_op_total_repeats);
cached_op = PopFromCachedOpStack();
cached_node->SetNumEpochs(1);
cached_node = PopFromCachedNodeStack();
}
node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
return Status::OK();
}
Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
// If we are under a cache op, then save ourselves to the cached op stack.
if (is_cached_) {
AddToCachedOpStack(node);
}
// Set total repeats and total epochs for the node
node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
return Status::OK();
}
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
// for use with a controlling repeat above it.
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) {
// If we are under a cache op, then save ourselves to the cached op stack.
if (is_cached_) {
AddToCachedOpStack(node);
}
// Set total repeats and total epochs for the node
node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
node->SetTotalRepeats(num_repeats_);
node->SetNumEpochs(num_epochs_);
return Status::OK();
}
// Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) {
Status RepeatPass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to set its total repeats for it.
if (cache_lookup_) {
cache_lookup_->set_total_repeats(num_repeats_);
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
cache_lookup_->SetTotalRepeats(num_repeats_);
cache_lookup_->SetNumEpochs(num_epochs_);
}
node->set_total_repeats(num_repeats_);
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
node->SetTotalRepeats(num_repeats_);
node->SetNumEpochs(num_epochs_);
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
is_merge_ = false;
return Status::OK();
}
// Saves the lookup up in case it needs to be referenced by a repeat
Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const modified) {
Status RepeatPass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
if (!node->IsLeaf()) {
// By definition, the CacheLookup must be a leaf op. Make that clear here.
RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
@ -184,29 +178,30 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const mo
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
// add the lookup to the eoe stack
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
cache_lookup_ = std::static_pointer_cast<DatasetNode>(node);
return Status::OK();
}
#endif
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) {
// Set total repeats and total epochs for the DeviceQueueOp
node->set_total_repeats(num_epochs_);
node->set_num_repeats_per_epoch(1);
Status RepeatPass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
// Set total repeats and total epochs for the TransferNode
node->SetTotalRepeats(num_epochs_);
node->SetNumEpochs(num_epochs_);
return Status::OK();
}
// Adds an operator to the cached operator stack save area
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }
void RepeatPass::AddToCachedNodeStack(std::shared_ptr<DatasetNode> node) { cached_node_stacks_.push(node); }
// Pops an operator from the cached operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromCachedOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!cached_op_stacks_.empty()) {
top_op = cached_op_stacks_.top();
cached_op_stacks_.pop();
std::shared_ptr<DatasetNode> RepeatPass::PopFromCachedNodeStack() {
std::shared_ptr<DatasetNode> top_node = nullptr;
if (!cached_node_stacks_.empty()) {
top_node = cached_node_stacks_.top();
cached_node_stacks_.pop();
}
return top_op;
return top_node;
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,12 +25,11 @@
namespace mindspore {
namespace dataset {
/// \class RepeatPass repeat_pass.h
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
/// to the eoe-producing (typically leaf) nodes underneath it.
class RepeatPass : public NodePass {
/// \class RepeatPass
/// \brief This is a post pass that calculate the number of repeats the pipeline needs to fetch the data.
class RepeatPass : public IRNodePass {
public:
using op_stack = std::stack<std::shared_ptr<DatasetOp>>;
using op_stack = std::stack<std::shared_ptr<DatasetNode>>;
/// \brief Constructor
RepeatPass();
@ -40,93 +39,91 @@ class RepeatPass : public NodePass {
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override;
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
#ifndef ENABLE_ANDROID
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) override;
/// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
Status Visit(std::shared_ptr<CacheNode> node, bool *const modified) override;
#endif
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
#ifndef ENABLE_ANDROID
/// \brief CacheNode removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) override;
/// \brief Turns of the tracking for operations under merge op
/// \brief Turns off the tracking for operations under merge op
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) override;
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) override;
#endif
/// \brief Set the epoch count for DeviceQueue
/// \brief Sets the epoch count for TransferNode
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override;
/// \brief Special case for GeneratorOp
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \param[in,out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override;
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override;
private:
/// \brief Adds an operator to the cached operator stack save area
/// \param op - The dataset op to work add to cached stack
/// \brief Adds an operator to the cached stack save area
/// \param node - The dataset node to add to cached stack
/// \return Status The status code returned
void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op);
void AddToCachedNodeStack(std::shared_ptr<DatasetNode> node);
/// \brief Pops an operator from the cached operator stack save area
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromCachedOpStack();
/// \brief Pops an operator from the cached stack save area
/// \return shared_ptr to the popped dataset node
std::shared_ptr<DatasetNode> PopFromCachedNodeStack();
bool is_merge_; // T/F if we are processing under a cache merge op
bool is_cached_; // T/F is we are processing under a cache op
int32_t num_repeats_; // A multiplier to the total number of repeats
int32_t num_epochs_; // To save the total number of epochs
op_stack cached_op_stacks_; // A save area for ops under a cache op
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
bool is_merge_; // T/F if we are processing under a cache merge node
bool is_cached_; // T/F is we are processing under a cache node
int32_t num_repeats_; // A multiplier to the total number of repeats
int32_t num_epochs_; // To save the total number of epochs
op_stack cached_node_stacks_; // A save area for operators under a cache node
std::shared_ptr<DatasetNode> cache_lookup_; // A save area for a cache lookup node
};
} // namespace dataset
} // namespace mindspore

View File

@ -1,189 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {}
// Identifies the subtree below this node as being cached
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
// Turn on the flag that we're under a merge op
is_cached_ = true;
return Status::OK();
}
// Returns an error if ZipOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) {
if (is_cached_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"ZipOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
// Returns an error if MapOp with non-deterministic TensorOps exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *const modified) {
if (is_cached_) {
auto tfuncs = node->TFuncs();
for (size_t i = 0; i < tfuncs.size(); i++) {
if (!tfuncs[i]->Deterministic()) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache.");
}
}
}
return Status::OK();
}
// Returns an error if ConcatOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *const modified) {
if (is_cached_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"ConcatOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
// Returns an error if TakeOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) {
if (is_cached_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"TakeOp/SplitOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
// Returns an error if SkipOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) {
if (is_cached_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"SkipOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
// Returns an error if SkipOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) {
if (is_cached_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"BatchOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
#ifdef ENABLE_PYTHON
// Returns an error if FilterOp exists under a cache
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) {
if (is_cached_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"FilterOp is currently not supported as a descendant operator under a cache.");
}
return Status::OK();
}
#endif
Status CacheErrorPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
// Turn on the flag that this is a tree with mappable leaf dataset
is_mappable_ = true;
return Status::OK();
}
Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
// Turn off the flag that we're under a merge op
is_cached_ = false;
return Status::OK();
}
// Currently, returns an error if RepeatOp exists under a cache
// Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem.
Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
if (is_cached_ && is_mappable_) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
"Repeat is not supported as a descendant operator under a mappable cache.");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,167 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
#include <memory>
#include <stack>
#include <utility>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
/// \class CacheErrorPass cache_error_pass.h
/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures.
class CacheErrorPass : public NodePass {
public:
/// \brief Constructor
CacheErrorPass();
/// \brief Destructor
~CacheErrorPass() = default;
/// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
/// \brief Returns an error if ZipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) override;
/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *const modified) override;
/// \brief Returns an error if ConcatOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *const modified) override;
/// \brief Returns an error if TakeOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) override;
/// \brief Returns an error if SkipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) override;
/// \brief Returns an error if SkipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) override;
#ifdef ENABLE_PYTHON
/// \brief Returns an error if FilterOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) override;
#endif
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
/// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
/// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override;
private:
bool is_cached_;
bool is_mappable_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,7 +25,6 @@
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#include "minddata/dataset/include/transforms.h"
namespace mindspore {
namespace dataset {
@ -114,11 +113,18 @@ Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *const mod
}
// If Map is created to be cached, set the flag indicating we found an operation with a cache.
is_cached_ = true;
// This is temporary code.
// Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
// to TensorOp, we need to check the randomness in MapNode::Build().
// By setting this MapNode is under a cache, we will check the randomness of its tensor operations without the need
// to walk the IR tree again.
node->Cached();
auto tfuncs = node->TensorOperations();
for (size_t i = 0; i < tfuncs.size(); i++) {
if (tfuncs[i]->IsRandomOp()) {
RETURN_STATUS_UNEXPECTED(
"MapNode with non-deterministic operations is not supported as a descendant of cache.");
RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
}
}
}

View File

@ -1,78 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
namespace mindspore {
namespace dataset {
// constructor
EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {}
#ifndef ENABLE_ANDROID
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *const modified) {
injection_point_ = nullptr;
return Status::OK();
}
// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node,
bool *const modified) {
injection_point_ = nullptr;
return Status::OK();
}
#endif
Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) {
// Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here.
injection_point_ = node->child(0);
return Status::OK();
}
// constructor
EpochInjectionPass::EpochInjectionPass() {}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *const modified) {
MS_LOG(INFO) << "Pre pass: Injection pass started.";
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the EpochInjectionPass object.
EpochInjectionPass::InjectionFinder finder(tree->root());
RETURN_IF_NOT_OK(finder.Run(tree, modified));
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t num_epochs = tree->num_epochs();
std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point();
if (num_epochs != 1 && epoch_inject_node != nullptr) {
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
RETURN_IF_NOT_OK(epoch_inject_node->InsertAsParent(epoch_ctrl_op));
}
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,88 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class EpochInjectionPass epoch_injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class EpochInjectionPass : public TreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
/// it may need to inject.
class InjectionFinder : public NodePass {
public:
/// \brief Constructor
explicit InjectionFinder(std::shared_ptr<DatasetOp> node);
/// \brief Destructor
~InjectionFinder() = default;
#ifndef ENABLE_ANDROID
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *const modified) override;
/// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *const modified) override;
#endif
/// \brief Register the DeviceQueueOp for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override;
/// \brief Getter
std::shared_ptr<DatasetOp> injection_point() { return injection_point_; }
private:
std::shared_ptr<DatasetOp> injection_point_;
};
public:
/// \brief Constructor
EpochInjectionPass();
/// \brief Destructor
~EpochInjectionPass() = default;
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,7 +15,6 @@
*/
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"

View File

@ -1,75 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}
#ifndef ENABLE_ANDROID
// Identifies the subtree below this node as a cached descendant tree.
Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
is_caching_ = false;
return Status::OK();
}
#endif
// Perform ShuffleOp removal check.
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node));
}
return Status::OK();
}
// constructor
RemovalPass::RemovalPass() {}
// Walk the tree to collect the nodes to remove, then removes them.
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *const modified) {
MS_LOG(INFO) << "Pre pass: removal pass started.";
// Create the removal node pass which can identify which nodes need to be removed.
std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>();
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
// Then, execute the removal of any nodes that were set up for removal
for (auto node : removal_nodes->nodes_to_remove()) {
RETURN_IF_NOT_OK(node->Remove());
}
MS_LOG(INFO) << "Pre pass: removal pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,90 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class RemovalPass removal_pass.h
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
/// nodes should be removed, and then removes them.
class RemovalPass : public TreePass {
/// \class RemovalNodes
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public NodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
RemovalNodes();
/// \brief Destructor
~RemovalNodes() = default;
#ifndef ENABLE_ANDROID
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
#endif
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) override;
/// \brief Getter
/// \return All the nodes to be removed
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; }
private:
bool is_caching_;
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_;
};
public:
/// \brief Constructor
RemovalPass();
/// \brief Destructor
~RemovalPass() = default;
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_

View File

@ -1,121 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/util/printer_pass.h"
namespace mindspore {
namespace dataset {
Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting DatasetOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting BatchOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting MapOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting ProjectOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting RenameOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting SkipOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting ShuffleOp" << '\n';
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting MindRecordOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting TFReaderOp" << '\n';
return Status::OK();
}
#endif
#ifdef ENABLE_PYTHON
Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting FilterOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting GeneratorOp" << '\n';
return Status::OK();
}
#endif
Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting TakeOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting ZipOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting DeviceQueueOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting ImageFolderOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) {
*modified = false;
std::cout << "Visiting ImageFolderOp" << '\n';
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,68 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#include <memory>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class PrinterPass : public NodePass {
public:
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<MapOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) override;
#ifndef ENABLE_ANDROID
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) override;
#endif
#ifdef ENABLE_PYTHON
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
#endif
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override;
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
#endif
@ -94,6 +95,7 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
#ifdef ENABLE_PYTHON
actions.emplace_back(std::make_unique<GeneratorNodePass>());
#endif
actions.emplace_back(std::make_unique<RepeatPass>());
// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.
@ -133,7 +135,7 @@ Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std
return Status::OK();
}
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
// This will evolve in the long run
tree_ = std::make_unique<ExecutionTree>();
// disable profiling if this is only a getter pass
@ -146,7 +148,7 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epoc
// Note: We will gradually move the pre pass, optimizer pass, and post pass
// on ExecutionTree to perform on IR tree.
// Prepare the tree
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true));
RETURN_IF_NOT_OK(tree_->Prepare());
// After the tree is prepared, the col_name_id_map can safely be obtained
column_name_map_ = tree_->root()->column_name_id_map();
@ -192,7 +194,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
// Remember the root node
root_ir_ = root_ir;
RETURN_IF_NOT_OK(Build(root_ir_, num_epochs));
RETURN_IF_NOT_OK(Build(root_ir_));
tree_state_ = kCompileStateReady;
return Status::OK();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -83,7 +83,7 @@ class TreeAdapter {
Status PostPass(std::shared_ptr<DatasetNode> ir);
// Build an Execution tree
Status Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs);
Status Build(std::shared_ptr<DatasetNode> root_ir);
// This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree.
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,22 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "common/common.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/datasetops/source/album_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
@ -89,7 +79,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false), Repeat(2)});
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false);
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
ASSERT_OK(tree->Prepare());
ASSERT_OK(tree->Launch());
DatasetIterator di(tree);
@ -111,7 +105,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) {
TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaNoOrder) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file);
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
ASSERT_OK(tree->Prepare());
ASSERT_OK(tree->Launch());
DatasetIterator di(tree);
@ -134,7 +132,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaFloat) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
// add the priority column
std::string schema_file = datasets_root_path_ + "/testAlbum/floatSchema.json";
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file);
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
tree->Prepare();
ASSERT_OK(tree->Launch());
DatasetIterator di(tree);
@ -159,7 +161,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithFullSchema) {
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
// add the priority column
std::string schema_file = datasets_root_path_ + "/testAlbum/fullSchema.json";
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file);
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
ASSERT_OK(tree->Prepare());
ASSERT_OK(tree->Launch());
DatasetIterator di(tree);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,14 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <string>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "utils/ms_utils.h"
#include "gtest/gtest.h"
#include "minddata/dataset/core/global_context.h"
#include "utils/log_adapter.h"
#include "securec.h"
#include "minddata/dataset/util/status.h"
@ -112,7 +109,12 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)});
auto op1 = TFReader(schema_file);
auto op2 = Repeat(2);
auto op3 = Batch(7, true, 99);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -157,7 +159,12 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)});
auto op1 = TFReader(schema_file);
auto op2 = Repeat(2);
auto op3 = Batch(7, false, 99);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -209,7 +216,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)});
auto op1 = TFReader(schema_file);
auto op2 = Batch(7, false, 99);
auto op3 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
op2->set_total_repeats(2);
op2->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {
@ -255,7 +269,14 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
bool success = false;
auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)});
auto op1 = TFReader(schema_file);
auto op2 = Batch(5, true, 99);
auto op3 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
op2->set_total_repeats(2);
op2->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -293,15 +293,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
ASSERT_TRUE(rc.IsOk());
// Assign tree relations and root
myCacheOp->set_total_repeats(numRepeats);
myCacheOp->set_num_repeats_per_epoch(numRepeats);
rc = myRepeatOp->AddChild(myCacheOp);
ASSERT_TRUE(rc.IsOk());
// Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats.
myRandomDataOp->set_total_repeats(1);
myRandomDataOp->set_num_repeats_per_epoch(1);
rc = myCacheOp->AddChild(myRandomDataOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare(1);
rc = myTree->Prepare();
ASSERT_TRUE(rc.IsOk());
// quick check to see what tree looks like
@ -412,15 +417,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
ASSERT_TRUE(rc.IsOk());
// Assign tree relations and root
myCacheOp->set_total_repeats(numRepeats);
myCacheOp->set_num_repeats_per_epoch(numRepeats);
rc = myRepeatOp->AddChild(myCacheOp);
ASSERT_TRUE(rc.IsOk());
// Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats.
myRandomDataOp->set_total_repeats(1);
myRandomDataOp->set_num_repeats_per_epoch(1);
rc = myCacheOp->AddChild(myRandomDataOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
ASSERT_TRUE(rc.IsOk());
MS_LOG(INFO) << "Launching tree and begin iteration";
rc = myTree->Prepare(1);
rc = myTree->Prepare();
ASSERT_TRUE(rc.IsOk());
std::cout << *myClient << std::endl;
@ -502,14 +512,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = myTree->AssignRoot(myRepeatOp);
ASSERT_TRUE(rc.IsOk());
myMergeOp->set_total_repeats(numRepeats);
myMergeOp->set_num_repeats_per_epoch(numRepeats);
rc = myRepeatOp->AddChild(myMergeOp);
ASSERT_TRUE(rc.IsOk());
myLookupOp->set_total_repeats(numRepeats);
myLookupOp->set_num_repeats_per_epoch(numRepeats);
rc = myMergeOp->AddChild(myLookupOp);
ASSERT_TRUE(rc.IsOk());
so->set_total_repeats(numRepeats);
so->set_num_repeats_per_epoch(numRepeats);
rc = myMergeOp->AddChild(so);
ASSERT_TRUE(rc.IsOk());
rc = myTree->Prepare(1);
rc = myTree->Prepare();
ASSERT_TRUE(rc.IsOk());
rc = myTree->Launch();
ASSERT_TRUE(rc.IsOk());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,14 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "common/common.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/util/status.h"
@ -98,7 +95,11 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
auto op1 = Celeba(16, 2, 32, dir);
auto op2 = Repeat(2);
auto tree = Build({op1, op2});
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,8 +39,6 @@ using mindspore::MsLogLevel::ERROR;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
std::shared_ptr<CifarOp> Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,8 +45,6 @@ using mindspore::LogStream;
std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2);
std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
class MindDataTestCocoOp : public UT::DatasetOpTesting {
@ -261,4 +259,4 @@ TEST_F(MindDataTestCocoOp, TestCocoPanoptic) {
}
ASSERT_EQ(row_count, 2);
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
@ -29,7 +28,6 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/status.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
@ -82,7 +80,11 @@ class MindDataTestImageFolderSampler : public UT::DatasetOpTesting {
TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeat) {
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2)});
auto op1 = ImageFolder(16, 2, 32, folder_path, false);
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
tree->Prepare();
int32_t res[] = {0, 1, 2, 3};
Status rc = tree->Launch();
@ -166,7 +168,12 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch) {
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2), Batch(11)});
auto op1 = ImageFolder(16, 2, 32, folder_path, false);
auto op2 = Repeat(2);
auto op3 = Batch(11);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
int32_t res[4][11] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
@ -297,7 +304,11 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
int64_t num_samples = 0;
std::shared_ptr<SamplerRT> sampler = std::make_shared<DistributedSamplerRT>(num_samples, 11, 10, false);
std::string folder_path = datasets_root_path_ + "/testPK/data";
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)});
auto op1 = ImageFolder(16, 2, 32, folder_path, false, std::move(sampler));
auto op2 = Repeat(4);
op1->set_total_repeats(4);
op1->set_num_repeats_per_epoch(4);
auto tree = Build({op1, op2});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,6 +20,7 @@
#include "common/common.h"
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/tree_adapter.h"
#include "minddata/dataset/include/datasets.h"
@ -166,6 +167,10 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(2).Build(&repeat_op);
// start build then launch tree
leaf->set_total_repeats(2);
leaf->set_num_repeats_per_epoch(2);
map_op->set_total_repeats(2);
map_op->set_num_repeats_per_epoch(2);
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
@ -213,8 +218,15 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(2).Build(&repeat_op);
// config EpochCtrlOp
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op);
// start build then launch tree
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
leaf->set_total_repeats(-2);
leaf->set_num_repeats_per_epoch(2);
map_op->set_total_repeats(-2);
map_op->set_num_repeats_per_epoch(2);
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();
@ -271,8 +283,15 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
// config RepeatOp
std::shared_ptr<RepeatOp> repeat_op;
rc = RepeatOp::Builder(2).Build(&repeat_op);
// config EpochCtrlOp
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op);
// start build then launch tree
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
leaf->set_total_repeats(-2);
leaf->set_num_repeats_per_epoch(2);
map_op->set_total_repeats(-2);
map_op->set_num_repeats_per_epoch(2);
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op});
rc = tree->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = tree->Launch();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,7 +58,11 @@ class MindDataTestManifest : public UT::DatasetOpTesting {
TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) {
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
auto tree = Build({Manifest(16, 2, 32, file), Repeat(2)});
auto op1 = Manifest(16, 2, 32, file);
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
tree->Prepare();
uint32_t res[] = {0, 1, 0, 1};
Status rc = tree->Launch();
@ -148,7 +152,11 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
int64_t num_samples = 1;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)});
auto op1 = Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {});
auto op2 = Repeat(4);
op1->set_total_repeats(4);
op1->set_num_repeats_per_epoch(4);
auto tree = Build({op1, op2});
tree->Prepare();
Status rc = tree->Launch();
if (rc.IsError()) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,7 +17,6 @@
#include <memory>
#include <vector>
#include "common/common.h"
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/core/tensor.h"
@ -416,6 +415,8 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) {
rc = my_map_op->AddChild(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
my_tfreader_op->set_total_repeats(num_repeats);
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
@ -471,9 +472,13 @@ TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) {
rc = my_tree_->AssociateNode(my_map_op);
EXPECT_TRUE(rc.IsOk());
my_map_op->set_total_repeats(num_repeats);
my_map_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(my_map_op);
EXPECT_TRUE(rc.IsOk());
my_tfreader_op->set_total_repeats(num_repeats);
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
rc = my_map_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
@ -548,9 +553,13 @@ TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) {
rc = my_tree_->AssociateNode(my_map_resize_op);
EXPECT_TRUE(rc.IsOk());
my_tfreader_op->set_total_repeats(num_repeats);
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
rc = my_map_decode_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
my_map_decode_op->set_total_repeats(num_repeats);
my_map_decode_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(my_map_decode_op);
EXPECT_TRUE(rc.IsOk());
@ -611,7 +620,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
rc = map_resize_builder.Build(&map_resize_op);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op});
auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false);
image_folder_op->set_total_repeats(num_repeats);
image_folder_op->set_num_repeats_per_epoch(num_repeats);
map_decode_map->set_total_repeats(num_repeats);
map_decode_map->set_num_repeats_per_epoch(num_repeats);
my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op});
rc = my_tree_->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();
@ -656,7 +670,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
rc = map_resize_builder.Build(&map_resize_op);
EXPECT_TRUE(rc.IsOk());
auto my_tree_2 = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op});
image_folder_op = ImageFolder(16, 2, 32, folder_path, false);
image_folder_op->set_total_repeats(num_repeats);
image_folder_op->set_num_repeats_per_epoch(num_repeats);
map_decode_map->set_total_repeats(num_repeats);
map_decode_map->set_num_repeats_per_epoch(num_repeats);
auto my_tree_2 = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op});
rc = my_tree_2->Prepare();
EXPECT_TRUE(rc.IsOk());
@ -714,7 +733,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) {
rc = map_resize_builder.Build(&map_resize_op);
EXPECT_TRUE(rc.IsOk());
my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op});
auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false);
image_folder_op->set_total_repeats(num_repeats);
image_folder_op->set_num_repeats_per_epoch(num_repeats);
map_decode_map->set_total_repeats(num_repeats);
map_decode_map->set_num_repeats_per_epoch(num_repeats);
my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op});
rc = my_tree_->Prepare();
EXPECT_TRUE(rc.IsOk());
rc = my_tree_->Launch();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -370,9 +370,12 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
rc = my_tree->AssociateNode(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
my_mindrecord_op->set_total_repeats(num_repeats);
my_mindrecord_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(my_mindrecord_op);
EXPECT_TRUE(rc.IsOk());
// Set children/root layout.
rc = my_tree->AssignRoot(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
@ -452,6 +455,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
rc = my_tree->AssociateNode(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
my_mindrecord_op->set_total_repeats(num_repeats);
my_mindrecord_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(my_mindrecord_op);
EXPECT_TRUE(rc.IsOk());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -78,7 +78,11 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
int64_t num_samples = 10;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)});
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
auto op2 = Repeat(2);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2});
tree->Prepare();
uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
Status rc = tree->Launch();
@ -108,7 +112,12 @@ TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) {
int64_t num_samples = 10;
int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)});
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
auto op2 = Repeat(2);
auto op3 = Batch(5);
op1->set_total_repeats(2);
op1->set_num_repeats_per_epoch(2);
auto tree = Build({op1, op2, op3});
tree->Prepare();
uint32_t res[4][5] = { {0, 0, 0, 0, 0 },
{0, 0, 0, 0, 0 },

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -35,7 +35,7 @@ class MindDataTestRandomDataOp : public UT::DatasetOpTesting {
// Test info:
// - Simple test with a user-provided schema generated purely from DataSchema C API
// - has an interation loop
// - has an interaction loop
//
// Tree: single node tree with RandomDataOp
//
@ -213,7 +213,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic3) {
// Test info:
// - json schema input it's a fairly simple one
// - has an interation loop
// - has an interaction loop
//
// Tree: RepeatOp over RandomDataOp
//
@ -253,6 +253,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) {
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
myRandomDataOp->set_total_repeats(numRepeats);
myRandomDataOp->set_num_repeats_per_epoch(numRepeats);
rc = myRepeatOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
@ -290,7 +292,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) {
// Test info:
// - json schema input it's a fairly simple one
// - has an interation loop
// - has an interaction loop
// - same as MindDataTestRandomDataOpBasic4 except that this one will have parallel workers
//
// Tree: RepeatOp over RandomDataOp
@ -331,6 +333,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic5) {
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
myRandomDataOp->set_total_repeats(numRepeats);
myRandomDataOp->set_num_repeats_per_epoch(numRepeats);
rc = myRepeatOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());
@ -418,9 +422,13 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpTree1) {
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
myShuffleOp->set_total_repeats(numRepeats);
myShuffleOp->set_num_repeats_per_epoch(numRepeats);
rc = myRepeatOp->AddChild(myShuffleOp);
EXPECT_TRUE(rc.IsOk());
myRandomDataOp->set_total_repeats(numRepeats);
myRandomDataOp->set_num_repeats_per_epoch(numRepeats);
rc = myShuffleOp->AddChild(myRandomDataOp);
EXPECT_TRUE(rc.IsOk());

View File

@ -75,6 +75,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) {
rc = spv_op->AddChild(file_op);
ASSERT_TRUE(rc.IsOk());
file_op->set_total_repeats(1);
file_op->set_num_repeats_per_epoch(1);
rc = tree->AssignRoot(spv_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->Prepare();
@ -147,6 +149,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) {
rc = spv_op->AddChild(file_op);
ASSERT_TRUE(rc.IsOk());
file_op->set_total_repeats(1);
file_op->set_num_repeats_per_epoch(1);
rc = tree->AssignRoot(spv_op);
ASSERT_TRUE(rc.IsOk());
rc = tree->Prepare();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -300,8 +300,12 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
EXPECT_TRUE(rc.IsOk());
// Set children/root layout.
my_shuffle_op->set_total_repeats(numRepeats);
my_shuffle_op->set_num_repeats_per_epoch(numRepeats);
rc = my_repeat_op->AddChild(my_shuffle_op);
EXPECT_TRUE(rc.IsOk());
my_tfreader_op->set_total_repeats(numRepeats);
my_tfreader_op->set_num_repeats_per_epoch(numRepeats);
rc = my_shuffle_op->AddChild(my_tfreader_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_repeat_op);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,7 +20,6 @@
#include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/data_schema.h"
#include "common/common.h"
#include "utils/ms_utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
@ -330,11 +329,14 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderRepeat) {
ASSERT_TRUE(rc.IsOk());
// RepeatOp
std::shared_ptr<RepeatOp> my_repeat_op = std::make_shared<RepeatOp>(3);
uint32_t num_repeats = 3;
std::shared_ptr<RepeatOp> my_repeat_op = std::make_shared<RepeatOp>(num_repeats);
rc = my_tree->AssociateNode(my_repeat_op);
ASSERT_TRUE(rc.IsOk());
// Set children/root layout.
my_tfreader_op->set_total_repeats(num_repeats);
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_repeat_op);
@ -705,7 +707,7 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) {
std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt";
std::string nonexistent_file = "this/file/doesnt/exist";
std::string nonexistent_file = "this/file/not/exist";
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,8 +45,6 @@ using mindspore::LogStream;
std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2);
std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
class MindDataTestVOCOp : public UT::DatasetOpTesting {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -141,6 +141,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
MS_LOG(INFO) << "UT test TestZipRepeat.";
auto my_tree = std::make_shared<ExecutionTree>();
uint32_t num_repeats = 3;
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
std::shared_ptr<TFReaderOp> my_tfreader_op;
@ -169,17 +170,23 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(zip_op);
EXPECT_TRUE(rc.IsOk());
my_tfreader_op->set_total_repeats(num_repeats);
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
rc = zip_op->AddChild(std::move(my_tfreader_op));
EXPECT_TRUE(rc.IsOk());
my_tfreader_op2->set_total_repeats(num_repeats);
my_tfreader_op2->set_num_repeats_per_epoch(num_repeats);
rc = zip_op->AddChild(std::move(my_tfreader_op2));
EXPECT_TRUE(rc.IsOk());
// Builder(num_of_repeats)
std::shared_ptr<RepeatOp> my_repeat_op;
rc = RepeatOp::Builder(3).Build(&my_repeat_op);
rc = RepeatOp::Builder(num_repeats).Build(&my_repeat_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_repeat_op);
EXPECT_TRUE(rc.IsOk());
zip_op->set_total_repeats(num_repeats);
zip_op->set_num_repeats_per_epoch(num_repeats);
rc = my_repeat_op->AddChild(zip_op);
EXPECT_TRUE(rc.IsOk());
rc = my_tree->AssignRoot(my_repeat_op);

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -138,7 +138,7 @@ def test_cache_map_basic3():
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_basic4():
"""
Test Map with non-deterministic TensorOps above cache
Test Map containing random operation above cache
repeat
|
@ -374,7 +374,7 @@ def test_cache_map_failure4():
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_failure5():
"""
Test Map with non-deterministic TensorOps under cache (failure)
Test Map containing random operation under cache (failure)
repeat
|
@ -406,7 +406,7 @@ def test_cache_map_failure5():
num_iter = 0
for _ in data.create_dict_iterator():
num_iter += 1
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_failure5 Ended.\n')

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -2087,7 +2087,7 @@ def test_cache_nomap_failure4():
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_failure5():
"""
Test Map with non-deterministic TensorOps under cache (failure)
Test Map containing random operation under cache (failure)
repeat
|
@ -2118,7 +2118,7 @@ def test_cache_nomap_failure5():
num_iter = 0
for _ in data.create_dict_iterator():
num_iter += 1
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
assert num_iter == 0
logger.info('test_cache_nomap_failure5 Ended.\n')