forked from mindspore-Ecosystem/mindspore
Migrate repeat_pass.cc to IR optimizer and remove ExecTree optimizer
This commit is contained in:
parent
c16b45ab23
commit
5a7dc0accc
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -258,10 +258,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return The number of required repeats for the operator
|
||||
int32_t op_total_repeats() { return op_total_repeats_; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of required epochs for the operator
|
||||
int32_t op_total_epochs() { return op_total_repeats_ / op_num_repeats_per_epoch_; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of repeats per epoch for the operator
|
||||
int32_t op_num_repeats_per_epoch() const { return op_num_repeats_per_epoch_; }
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,10 +17,8 @@
|
|||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/log_adapter.h"
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -304,6 +304,10 @@ class CsvOp : public ParallelOp {
|
|||
// @return - Status of the node visit.
|
||||
Status Accept(NodePass *p, bool *const modified) override;
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return "CsvOp"; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,20 +16,9 @@
|
|||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <limits>
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
|
||||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
#include "minddata/dataset/engine/perf/monitor.h"
|
||||
#if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)
|
||||
|
@ -255,97 +244,13 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// The driver of the prepare phase of the execution tree.
|
||||
// Prepare phase consists of three sub phases
|
||||
//
|
||||
// 1. PreAction()
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// For example, CacheOp Insertion
|
||||
//
|
||||
// 2. Optimize()
|
||||
// Optimization transformation/action, optional
|
||||
// For example, MapOp Fusion
|
||||
//
|
||||
// 3. PostAction()
|
||||
// Compulsory transformation/action post optimization.
|
||||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status The status code returned
|
||||
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
|
||||
num_epochs_ = num_epochs;
|
||||
partially_prepare_ = partial;
|
||||
|
||||
// Pre optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PreAction());
|
||||
|
||||
// Post optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PostAction());
|
||||
|
||||
// The tree is ready to be prepared.
|
||||
tree_state_ = kDeTStatePrepare;
|
||||
|
||||
// Existing transformation implementation, will be removed later
|
||||
RETURN_IF_NOT_OK(this->PrepareDeprecated());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionTree::PreAction() {
|
||||
bool modified = false;
|
||||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||
// Construct pre actions
|
||||
if (!partially_prepare_) {
|
||||
#ifndef ENABLE_ANDROID
|
||||
pre_actions.push_back(std::make_unique<CacheErrorPass>());
|
||||
#endif
|
||||
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
|
||||
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Running " << pre_actions.size() << " pre pass loops.";
|
||||
|
||||
// Apply pre action passes
|
||||
for (auto &pass : pre_actions) {
|
||||
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
||||
}
|
||||
MS_LOG(INFO) << "Pre passes complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionTree::PostAction() {
|
||||
bool modified = false;
|
||||
OptPass post_actions;
|
||||
// Construct pre actions
|
||||
MS_LOG(INFO) << "Running post pass loops.";
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Calling CacheErrorPass again. This is a temporary fix until the TensorOperation is properly done in Pybind.
|
||||
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
|
||||
// This is because Python API binding to TensorOperation is still in progress.
|
||||
post_actions.push_back(std::make_unique<CacheErrorPass>());
|
||||
post_actions.push_back(std::make_unique<RepeatPass>());
|
||||
#endif
|
||||
|
||||
// Apply post action passes
|
||||
for (auto &pass : post_actions) {
|
||||
RETURN_IF_NOT_OK(pass->Run(this, &modified));
|
||||
}
|
||||
MS_LOG(INFO) << "Post passes complete.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||
// it ready for execution.
|
||||
//
|
||||
// This driver is deprecated.
|
||||
Status ExecutionTree::PrepareDeprecated() {
|
||||
// Tree must be in pending prepare state before we can assign root to it
|
||||
if (tree_state_ != kDeTStatePrepare) {
|
||||
std::string err_msg =
|
||||
"Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
|
||||
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
Status ExecutionTree::Prepare() {
|
||||
// The tree is ready to be prepared.
|
||||
tree_state_ = kDeTStatePrepare;
|
||||
|
||||
if (root_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -169,24 +169,6 @@ class ExecutionTree {
|
|||
// @return the prepare flags
|
||||
uint32_t PrepareFlags() const { return prepare_flags_; }
|
||||
|
||||
// The driver of the prepare phase of the execution tree.
|
||||
// Prepare phase consists of three sub phases
|
||||
//
|
||||
// 1. PreAction()
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// For example, CacheOp Insertion
|
||||
//
|
||||
// 2. Optimize()
|
||||
// Optimization transformation/action, optional
|
||||
// For example, MapOp Fusion
|
||||
//
|
||||
// 3. PostAction()
|
||||
// Compulsory transformation/action post optimization.
|
||||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status The status code returned
|
||||
Status Prepare(int num_epochs = -1, bool partial = false);
|
||||
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// @return Status The status code returned
|
||||
Status PreAction();
|
||||
|
@ -200,7 +182,7 @@ class ExecutionTree {
|
|||
// it ready for execution.
|
||||
// @param Total number of epochs that will be run on this tree
|
||||
// @return Status The status code returned
|
||||
Status PrepareDeprecated();
|
||||
Status Prepare();
|
||||
|
||||
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
||||
// node actions during a tree walk.
|
||||
|
@ -239,10 +221,6 @@ class ExecutionTree {
|
|||
// Getter for profiling manager, no ownership
|
||||
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
|
||||
|
||||
// Getter function to get the total number of epochs to be run on this tree.
|
||||
// @return total number of epochs
|
||||
int32_t num_epochs() { return num_epochs_; }
|
||||
|
||||
private:
|
||||
// A helper functions for doing the recursive printing
|
||||
// @param dataset_op - The dataset op to print
|
||||
|
@ -257,9 +235,7 @@ class ExecutionTree {
|
|||
int32_t id_count_; // Counter for generating operator id's
|
||||
uint32_t prepare_flags_; // Flags used during tree prepare
|
||||
TreeState tree_state_; // Tracking the current tree state
|
||||
int32_t num_epochs_; // Total number of epochs to run for this tree
|
||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||
bool partially_prepare_; // Temp: during migration to IR, if true, run remaining passes.
|
||||
#if defined(ENABLE_GPUQUE) || defined(ENABLE_TDTQUE)
|
||||
// This rank_id is for numa and device_queue, one process work with only one rank_id,
|
||||
// for standalone scenario, this rank_id may come from env 'CUDA_VISIBLE_DEVICES',
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -102,9 +102,11 @@ Status BatchNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
node_ops->push_back(project_op);
|
||||
}
|
||||
|
||||
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
|
||||
pad_map_));
|
||||
auto op = std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, pad_map_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
#else
|
||||
node_ops->push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
in_col_names_, pad_map_));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -84,9 +84,12 @@ void BucketBatchByLengthNode::Print(std::ostream &out) const {
|
|||
|
||||
Status BucketBatchByLengthNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
bucket_boundaries_.insert(bucket_boundaries_.begin(), 0);
|
||||
node_ops->push_back(std::make_shared<BucketBatchByLengthOp>(
|
||||
column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_,
|
||||
pad_to_bucket_boundary_, drop_remainder_, connector_que_size_));
|
||||
auto op = std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
element_length_function_, pad_info_, pad_to_bucket_boundary_,
|
||||
drop_remainder_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
if (bucket_boundaries_[0] == 0) {
|
||||
bucket_boundaries_.erase(bucket_boundaries_.begin());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -55,10 +55,11 @@ void BuildSentenceVocabNode::Print(std::ostream &out) const {
|
|||
|
||||
// Function to build BuildSentenceVocabNode
|
||||
Status BuildSentenceVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
std::shared_ptr<BuildSentencePieceVocabOp> build_sentence_piece_vocab_op;
|
||||
build_sentence_piece_vocab_op = std::make_shared<BuildSentencePieceVocabOp>(
|
||||
vocab_, col_names_, vocab_size_, character_coverage_, model_type_, params_, connector_que_size_);
|
||||
node_ops->push_back(build_sentence_piece_vocab_op);
|
||||
auto op = std::make_shared<BuildSentencePieceVocabOp>(vocab_, col_names_, vocab_size_, character_coverage_,
|
||||
model_type_, params_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -54,6 +54,8 @@ Status BuildVocabNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node
|
|||
std::shared_ptr<BuildVocabOp> build_vocab_op;
|
||||
build_vocab_op = std::make_shared<BuildVocabOp>(vocab_, columns_, freq_range_, top_k_, special_tokens_,
|
||||
special_first_, num_workers_, connector_que_size_);
|
||||
build_vocab_op->set_total_repeats(GetTotalRepeats());
|
||||
build_vocab_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(build_vocab_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -51,10 +51,24 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
|||
"Internal error. Attempt to create a cache lookup node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
|
||||
lookup_op_->set_total_repeats(GetTotalRepeats());
|
||||
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(lookup_op_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status CacheLookupNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<CacheLookupNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status CacheLookupNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<CacheLookupNode>(), modified);
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
|
||||
// CacheLookupNode should already been copied, so we just return it here
|
||||
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
|
||||
|
|
|
@ -64,6 +64,18 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::shared_ptr<DatasetOp> lookup_op_;
|
||||
|
|
|
@ -48,9 +48,23 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
|||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
|
||||
merge_op->set_total_repeats(GetTotalRepeats());
|
||||
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(merge_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status CacheMergeNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<CacheMergeNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status CacheMergeNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<CacheMergeNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,6 +53,18 @@ class CacheMergeNode : public DatasetNode {
|
|||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,9 +53,23 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
|||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
|
||||
cache_op->SetSampler(sampler_->SamplerBuild());
|
||||
cache_op->set_total_repeats(GetTotalRepeats());
|
||||
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cache_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status CacheNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<CacheNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status CacheNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<CacheNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,18 @@ class CacheNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -119,12 +119,16 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
}
|
||||
|
||||
Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
std::shared_ptr<ConcatOp> op;
|
||||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_));
|
||||
op = std::make_shared<ConcatOp>(connector_que_size_);
|
||||
} else {
|
||||
node_ops->push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(),
|
||||
children_flag_and_nums_, children_start_end_index_));
|
||||
op = std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_,
|
||||
children_start_end_index_);
|
||||
}
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -219,7 +219,9 @@ DatasetNode::DatasetNode()
|
|||
dataset_size_(-1),
|
||||
mappable_(kNotADataSource),
|
||||
nary_op_(false),
|
||||
descendant_of_cache_(false) {
|
||||
descendant_of_cache_(false),
|
||||
total_repeats_(-1),
|
||||
num_epochs_(1) {
|
||||
// Fetch some default value from config manager
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,7 @@
|
|||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/filter_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/project_op.h"
|
||||
|
@ -292,6 +293,24 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return Status of the function
|
||||
virtual Status to_json(nlohmann::json *out_json);
|
||||
|
||||
/// \brief Setter function, set the number of total repeats for the operator
|
||||
void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; }
|
||||
|
||||
/// \brief Setter function, set the number of epochs for the operator
|
||||
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of required repeats for the operator
|
||||
int32_t GetTotalRepeats() const { return total_repeats_; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of epochs for the operator
|
||||
int32_t GetNumEpochs() const { return num_epochs_; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of repeats per epoch for the operator
|
||||
int32_t GetNumRepeatsPerEpoch() const { return total_repeats_ / num_epochs_; }
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<DatasetNode>> children_;
|
||||
DatasetNode *parent_; // used to record the only one parent of an IR node after parsing phase
|
||||
|
@ -301,6 +320,8 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
int32_t rows_per_buffer_;
|
||||
int32_t connector_que_size_;
|
||||
int32_t worker_connector_size_;
|
||||
int32_t total_repeats_; // Number of times required to run this operator
|
||||
int32_t num_epochs_; // Number of epochs
|
||||
// Establish a parent-child relationship between this node and the input node.
|
||||
// Used only in the constructor of the class and its derived classes.
|
||||
void AddChild(std::shared_ptr<DatasetNode> child);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -44,6 +44,8 @@ void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" +
|
|||
// Function to build the EpochCtrlOp
|
||||
Status EpochCtrlNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto new_op_ = std::make_shared<EpochCtrlOp>(repeat_count_);
|
||||
new_op_->set_total_repeats(GetTotalRepeats());
|
||||
new_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(new_op_);
|
||||
op_ = new_op_;
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -44,7 +44,10 @@ void FilterNode::Print(std::ostream &out) const {
|
|||
}
|
||||
|
||||
Status FilterNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_));
|
||||
auto op = std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,7 +38,8 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
|
|||
output_columns_(output_columns),
|
||||
project_columns_(project_columns),
|
||||
DatasetNode(std::move(cache)),
|
||||
callbacks_(callbacks) {
|
||||
callbacks_(callbacks),
|
||||
under_a_cache_(false) {
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
|
@ -64,6 +65,17 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
|
||||
[](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
|
||||
|
||||
// This is temporary code.
|
||||
// Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
|
||||
// to TensorOp, we need to check the randomness here.
|
||||
// When TensorOperation captures the randomness behaviour, remove this code and the member "under_a_cache_"
|
||||
// and the temporary code in CacheValidation pre pass in IR optimizer.
|
||||
if (under_a_cache_) {
|
||||
auto itr = std::find_if(tensor_ops.begin(), tensor_ops.end(), [](const auto &it) { return !it->Deterministic(); });
|
||||
if (itr != tensor_ops.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
|
||||
}
|
||||
}
|
||||
// This parameter will be removed with next rebase
|
||||
std::vector<std::string> col_orders;
|
||||
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
|
||||
|
@ -74,9 +86,12 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
|
||||
if (!project_columns_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
project_op->set_total_repeats(GetTotalRepeats());
|
||||
project_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(project_op);
|
||||
}
|
||||
|
||||
map_op->set_total_repeats(GetTotalRepeats());
|
||||
map_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(map_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -79,6 +79,9 @@ class MapNode : public DatasetNode {
|
|||
/// \brief setter to set all tensor operations
|
||||
void setOperations(const std::vector<std::shared_ptr<TensorOperation>> &operations);
|
||||
|
||||
/// \brief indicate this Map will be cached
|
||||
void Cached() { under_a_cache_ = true; }
|
||||
|
||||
/// \brief Getter functions
|
||||
/// \brief Getter of tensor operations
|
||||
/// \return Vector of operations the Map node will process
|
||||
|
@ -95,12 +98,11 @@ class MapNode : public DatasetNode {
|
|||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
|
||||
private:
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_;
|
||||
bool under_a_cache_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -53,7 +53,10 @@ Status ProjectNode::ValidateParams() {
|
|||
}
|
||||
|
||||
Status ProjectNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<ProjectOp>(columns_));
|
||||
auto op = std::make_shared<ProjectOp>(columns_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,7 +58,10 @@ Status RenameNode::ValidateParams() {
|
|||
}
|
||||
|
||||
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_));
|
||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,6 +40,8 @@ void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + st
|
|||
|
||||
Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto new_op = std::make_shared<RepeatOp>(repeat_count_);
|
||||
new_op->set_total_repeats(GetTotalRepeats());
|
||||
new_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(new_op);
|
||||
op_ = new_op;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -44,8 +44,11 @@ void ShuffleNode::Print(std::ostream &out) const {
|
|||
|
||||
// Function to build the ShuffleOp
|
||||
Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
|
||||
rows_per_buffer_));
|
||||
auto op = std::make_shared<ShuffleOp>(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_,
|
||||
rows_per_buffer_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,7 +39,10 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" +
|
|||
|
||||
// Function to build the SkipOp
|
||||
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
|
||||
auto op = std::make_shared<SkipOp>(skip_count_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -72,9 +72,11 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
// Argument that is not exposed to user in the API.
|
||||
std::set<std::string> extensions = {};
|
||||
|
||||
node_ops->push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, extensions, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
auto album_op = std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_,
|
||||
extensions, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
album_op->set_total_repeats(GetTotalRepeats());
|
||||
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(album_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -67,9 +67,12 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
// label is like this:0 1 0 0 1......
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
|
||||
node_ops->push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
auto celeba_op =
|
||||
std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_,
|
||||
extensions_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
celeba_op->set_total_repeats(GetTotalRepeats());
|
||||
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(celeba_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -64,9 +64,12 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
auto cifar_op =
|
||||
std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
cifar_op->set_total_repeats(GetTotalRepeats());
|
||||
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cifar_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -62,9 +62,12 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops->push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
auto cifar_op =
|
||||
std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
cifar_op->set_total_repeats(GetTotalRepeats());
|
||||
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cifar_op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -193,9 +193,12 @@ Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
|
||||
clue_op->set_total_repeats(GetTotalRepeats());
|
||||
clue_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(clue_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -122,7 +122,8 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -130,10 +130,12 @@ Status CSVNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
|
||||
csv_op->set_total_repeats(GetTotalRepeats());
|
||||
csv_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(csv_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -91,7 +91,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
|
|||
if (reset_ancestor_ != nullptr) {
|
||||
reset_ancestor_->op_->AddToEoeList(op);
|
||||
}
|
||||
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -70,9 +70,12 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops->push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
auto op = std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild()));
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -94,7 +94,8 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
manifest_op =
|
||||
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
|
||||
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
|
||||
|
||||
manifest_op->set_total_repeats(GetTotalRepeats());
|
||||
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(manifest_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -169,6 +169,8 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
}
|
||||
|
||||
RETURN_IF_NOT_OK(mindrecord_op->Init());
|
||||
mindrecord_op->set_total_repeats(GetTotalRepeats());
|
||||
mindrecord_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(mindrecord_op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,9 +58,11 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
node_ops->push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_,
|
||||
connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->SamplerBuild())));
|
||||
auto op = std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -109,7 +109,8 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema_));
|
||||
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -98,9 +98,12 @@ Status TextFileNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
|
||||
text_file_op->set_total_repeats(GetTotalRepeats());
|
||||
text_file_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
// Add TextFileOp
|
||||
node_ops->push_back(text_file_op);
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -140,9 +140,12 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
shuffle_op->set_total_repeats(GetTotalRepeats());
|
||||
shuffle_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(shuffle_op);
|
||||
}
|
||||
|
||||
tf_reader_op->set_total_repeats(GetTotalRepeats());
|
||||
tf_reader_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
// Add TFReaderOp
|
||||
node_ops->push_back(tf_reader_op);
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -113,7 +113,8 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
voc_op =
|
||||
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
|
||||
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
|
||||
|
||||
voc_op->set_total_repeats(GetTotalRepeats());
|
||||
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(voc_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -47,7 +47,10 @@ Status SyncWaitNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
// The reason for this is because having it otherwise can lead to blocking issues
|
||||
// See barrier_op.h for more details
|
||||
int32_t rows_per_buffer = 1;
|
||||
node_ops->push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_));
|
||||
auto op = std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 20202-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,7 +40,10 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s
|
|||
|
||||
// Function to build the TakeOp
|
||||
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<TakeOp>(take_count_, connector_que_size_));
|
||||
auto op = std::make_shared<TakeOp>(take_count_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -100,8 +100,11 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
device_id_ = 0;
|
||||
RETURN_IF_NOT_OK(this->GetShardId(&device_id_));
|
||||
|
||||
node_ops->push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||
total_batch_, create_data_info_queue_));
|
||||
auto op = std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_,
|
||||
total_batch_, create_data_info_queue_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,7 +58,10 @@ Status ZipNode::ValidateParams() {
|
|||
}
|
||||
|
||||
Status ZipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
node_ops->push_back(std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_));
|
||||
auto op = std::make_shared<ZipOp>(rows_per_buffer_, connector_que_size_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -2,27 +2,17 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
set(DATASET_ENGINE_OPT_SRC_FILES
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
pass.cc
|
||||
post/auto_worker_pass.cc
|
||||
pre/cache_validation_pass.cc
|
||||
pre/deep_copy_pass.cc
|
||||
pre/getter_pass.cc
|
||||
pre/input_validation_pass.cc
|
||||
pre/epoch_ctrl_pass.cc
|
||||
pre/node_removal_pass.cc
|
||||
)
|
||||
|
||||
# This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer.
|
||||
# When the migration is complete, we will remove these files.
|
||||
set(DATASET_ENGINE_OPT_SRC_FILES
|
||||
${DATASET_ENGINE_OPT_SRC_FILES}
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
pre/cache_error_pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
util/printer_pass.cc
|
||||
pre/removal_pass.cc
|
||||
pre/cache_validation_pass.cc
|
||||
pre/deep_copy_pass.cc
|
||||
pre/epoch_ctrl_pass.cc
|
||||
pre/getter_pass.cc
|
||||
pre/input_validation_pass.cc
|
||||
pre/node_removal_pass.cc
|
||||
)
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,11 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
||||
|
@ -187,6 +192,26 @@ Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified)
|
|||
Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status IRNodePass::Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status IRNodePass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,11 @@ namespace dataset {
|
|||
class BatchNode;
|
||||
class BucketBatchByLengthNode;
|
||||
class BuildVocabNode;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CacheLookupNode;
|
||||
class CacheMergeNode;
|
||||
class CacheNode;
|
||||
#endif
|
||||
class ConcatNode;
|
||||
class EpochCtrlNode;
|
||||
class FilterNode;
|
||||
|
@ -199,6 +204,14 @@ class IRNodePass : public IRPass {
|
|||
virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified);
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<CacheNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified);
|
||||
#endif
|
||||
virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
|
||||
virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,15 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -31,10 +32,10 @@ RepeatPass::RepeatPass()
|
|||
: num_repeats_(1), num_epochs_(1), is_merge_(false), is_cached_(false), cache_lookup_(nullptr) {}
|
||||
|
||||
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
|
||||
Status RepeatPass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
|
||||
// If this is an infinite repeat under infinite repeat/epoch, adjust current num_repeats_.
|
||||
// Otherwise, after multiplication it would become positive and this repeat wouldn't run infinitely.
|
||||
if (node->num_repeats() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
|
||||
if (node->Count() == DatasetOp::kInfiniteRepeat && num_repeats_ < 0) {
|
||||
num_repeats_ = -num_repeats_;
|
||||
}
|
||||
// This RepeatOp and its descendent nodes should be repeated for another num_repeats() times.
|
||||
|
@ -49,14 +50,14 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modi
|
|||
// num_repeats_ is originally 4, after repeat2(2), num_repeats_ becomes 8 (2*4),
|
||||
// meaning repeat2 and map op should be set to read 8 times (2*4).
|
||||
// Then, after repeat1(3), num_repeats_ becomes 24 (3*2*4), meaning repeat1 and tfreader op should repeat 24 times.
|
||||
num_repeats_ *= node->num_repeats();
|
||||
num_repeats_ *= node->Count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) {
|
||||
Status RepeatPass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
|
||||
// Get the total number of epochs from the EpochCtrlOp parameter
|
||||
num_epochs_ = node->num_repeats();
|
||||
num_epochs_ = node->Count();
|
||||
// Every node below this EpochCtrlOp should be repeated for num_epochs_ times.
|
||||
// For example: tfreader --> epoch ctrl(3)
|
||||
// num_repeats_ is originally 1 (default initialization), after this epoch ctrl(3), num_repeats_ becomes 3 (1*3),
|
||||
|
@ -65,115 +66,108 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const m
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Identifies the subtree below this node as being in a cache merge path
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) {
|
||||
Status RepeatPass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
|
||||
// Turn on the flag that we're under a merge op
|
||||
is_merge_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identifies the subtree below this node as being cached
|
||||
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
Status RepeatPass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
|
||||
// Turn on the flag that we're under a merge op
|
||||
is_cached_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Hooks up any identified eoe nodes under this repeat.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
|
||||
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
|
||||
// and set its total repeats. It is important that the op is removed from the save area,
|
||||
// because the merge op above us may also take action on it later for a different case when
|
||||
// there is no repeat in the merge leg.
|
||||
if (is_merge_ && cache_lookup_) {
|
||||
cache_lookup_->set_total_repeats(num_repeats_);
|
||||
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
cache_lookup_->SetTotalRepeats(num_repeats_);
|
||||
cache_lookup_->SetNumEpochs(num_epochs_);
|
||||
cache_lookup_.reset();
|
||||
}
|
||||
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
AddToCachedNodeStack(node);
|
||||
}
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
node->SetTotalRepeats(num_repeats_);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
// We finish the walk of this RepeatOp's descendent nodes.
|
||||
// The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
|
||||
// But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
|
||||
// so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
|
||||
num_repeats_ /= node->num_repeats();
|
||||
num_repeats_ /= node->Count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Hooks up any identified eoe nodes under this repeat.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) {
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
|
||||
node->SetTotalRepeats(num_repeats_);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
// We finish the walk of this EpochCtrl's descendent nodes.
|
||||
num_repeats_ /= node->num_repeats();
|
||||
num_repeats_ /= node->Count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
// for use with a controlling repeat above it.
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) {
|
||||
// If we are under a cache op, then save ourselves to the cached op stack.
|
||||
if (is_cached_) {
|
||||
AddToCachedNodeStack(node);
|
||||
}
|
||||
// Set total repeats and total epochs for the node
|
||||
node->SetTotalRepeats(num_repeats_);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// CacheOp removes previous leaf ops and replaces them with itself
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
|
||||
is_cached_ = false;
|
||||
|
||||
// if we are a cache within a repeat path of the tree, then adjust the total repeats and total epochs for cached ops.
|
||||
// So that those cached nodes become 1-time use (up to eoe), never repeated. Instead
|
||||
// the repeating behaviours shall be invoked against the cache op.
|
||||
std::shared_ptr<DatasetOp> cached_op = PopFromCachedOpStack();
|
||||
while (cached_op != nullptr) {
|
||||
int32_t cached_op_total_repeats = cached_op->op_total_repeats() / num_repeats_;
|
||||
cached_op->set_total_repeats(cached_op_total_repeats);
|
||||
std::shared_ptr<DatasetNode> cached_node = PopFromCachedNodeStack();
|
||||
while (cached_node != nullptr) {
|
||||
int32_t cached_op_total_repeats = cached_node->GetTotalRepeats() / num_repeats_;
|
||||
cached_node->SetTotalRepeats(cached_op_total_repeats);
|
||||
// Cached ops will only be executed on the first epoch, therefore, num_epochs_ = 1
|
||||
cached_op->set_num_repeats_per_epoch(cached_op_total_repeats);
|
||||
cached_op = PopFromCachedOpStack();
|
||||
cached_node->SetNumEpochs(1);
|
||||
cached_node = PopFromCachedNodeStack();
|
||||
}
|
||||
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
|
||||
// If we are under a cache op, then save ourselves to the cached op stack.
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
}
|
||||
// Set total repeats and total epochs for the node
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
// for use with a controlling repeat above it.
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) {
|
||||
// If we are under a cache op, then save ourselves to the cached op stack.
|
||||
if (is_cached_) {
|
||||
AddToCachedOpStack(node);
|
||||
}
|
||||
// Set total repeats and total epochs for the node
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
node->SetTotalRepeats(num_repeats_);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Turns off the tracking for operations under merge op
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) {
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
|
||||
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
|
||||
// would not have been consumed yet. In that case, we need to set its total repeats for it.
|
||||
if (cache_lookup_) {
|
||||
cache_lookup_->set_total_repeats(num_repeats_);
|
||||
cache_lookup_->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
cache_lookup_->SetTotalRepeats(num_repeats_);
|
||||
cache_lookup_->SetNumEpochs(num_epochs_);
|
||||
}
|
||||
node->set_total_repeats(num_repeats_);
|
||||
node->set_num_repeats_per_epoch(num_repeats_ / num_epochs_);
|
||||
node->SetTotalRepeats(num_repeats_);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
|
||||
is_merge_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Saves the lookup up in case it needs to be referenced by a repeat
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const modified) {
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
|
||||
if (!node->IsLeaf()) {
|
||||
// By definition, the CacheLookup must be a leaf op. Make that clear here.
|
||||
RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
|
||||
|
@ -184,29 +178,30 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const mo
|
|||
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
|
||||
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
|
||||
// add the lookup to the eoe stack
|
||||
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
|
||||
cache_lookup_ = std::static_pointer_cast<DatasetNode>(node);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) {
|
||||
// Set total repeats and total epochs for the DeviceQueueOp
|
||||
node->set_total_repeats(num_epochs_);
|
||||
node->set_num_repeats_per_epoch(1);
|
||||
Status RepeatPass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
|
||||
// Set total repeats and total epochs for the TransferNode
|
||||
node->SetTotalRepeats(num_epochs_);
|
||||
node->SetNumEpochs(num_epochs_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the cached operator stack save area
|
||||
void RepeatPass::AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op) { cached_op_stacks_.push(dataset_op); }
|
||||
void RepeatPass::AddToCachedNodeStack(std::shared_ptr<DatasetNode> node) { cached_node_stacks_.push(node); }
|
||||
|
||||
// Pops an operator from the cached operator stack save area
|
||||
std::shared_ptr<DatasetOp> RepeatPass::PopFromCachedOpStack() {
|
||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||
if (!cached_op_stacks_.empty()) {
|
||||
top_op = cached_op_stacks_.top();
|
||||
cached_op_stacks_.pop();
|
||||
std::shared_ptr<DatasetNode> RepeatPass::PopFromCachedNodeStack() {
|
||||
std::shared_ptr<DatasetNode> top_node = nullptr;
|
||||
if (!cached_node_stacks_.empty()) {
|
||||
top_node = cached_node_stacks_.top();
|
||||
cached_node_stacks_.pop();
|
||||
}
|
||||
return top_op;
|
||||
return top_node;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,12 +25,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class RepeatPass repeat_pass.h
|
||||
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
|
||||
/// to the eoe-producing (typically leaf) nodes underneath it.
|
||||
class RepeatPass : public NodePass {
|
||||
/// \class RepeatPass
|
||||
/// \brief This is a post pass that calculate the number of repeats the pipeline needs to fetch the data.
|
||||
class RepeatPass : public IRNodePass {
|
||||
public:
|
||||
using op_stack = std::stack<std::shared_ptr<DatasetOp>>;
|
||||
using op_stack = std::stack<std::shared_ptr<DatasetNode>>;
|
||||
|
||||
/// \brief Constructor
|
||||
RepeatPass();
|
||||
|
@ -40,93 +39,91 @@ class RepeatPass : public NodePass {
|
|||
|
||||
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Identifies the subtree below this node as being in a cache merge path
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
Status Visit(std::shared_ptr<CacheNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief CacheOp removes previous leaf ops and replaces them with itself
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief CacheNode removes previous leaf ops and replaces them with itself
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Turns of the tracking for operations under merge op
|
||||
/// \brief Turns off the tracking for operations under merge op
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Set the epoch count for DeviceQueue
|
||||
/// \brief Sets the epoch count for TransferNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Special case for GeneratorOp
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
/// for use with a controlling repeat above it.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \param[in,out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override;
|
||||
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) override;
|
||||
|
||||
private:
|
||||
/// \brief Adds an operator to the cached operator stack save area
|
||||
/// \param op - The dataset op to work add to cached stack
|
||||
/// \brief Adds an operator to the cached stack save area
|
||||
/// \param node - The dataset node to add to cached stack
|
||||
/// \return Status The status code returned
|
||||
void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
void AddToCachedNodeStack(std::shared_ptr<DatasetNode> node);
|
||||
|
||||
/// \brief Pops an operator from the cached operator stack save area
|
||||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromCachedOpStack();
|
||||
/// \brief Pops an operator from the cached stack save area
|
||||
/// \return shared_ptr to the popped dataset node
|
||||
std::shared_ptr<DatasetNode> PopFromCachedNodeStack();
|
||||
|
||||
bool is_merge_; // T/F if we are processing under a cache merge op
|
||||
bool is_cached_; // T/F is we are processing under a cache op
|
||||
int32_t num_repeats_; // A multiplier to the total number of repeats
|
||||
int32_t num_epochs_; // To save the total number of epochs
|
||||
op_stack cached_op_stacks_; // A save area for ops under a cache op
|
||||
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
|
||||
bool is_merge_; // T/F if we are processing under a cache merge node
|
||||
bool is_cached_; // T/F is we are processing under a cache node
|
||||
int32_t num_repeats_; // A multiplier to the total number of repeats
|
||||
int32_t num_epochs_; // To save the total number of epochs
|
||||
op_stack cached_node_stacks_; // A save area for operators under a cache node
|
||||
std::shared_ptr<DatasetNode> cache_lookup_; // A save area for a cache lookup node
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,189 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_error_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
CacheErrorPass::CacheErrorPass() : is_cached_(false), is_mappable_(false) {}
|
||||
|
||||
// Identifies the subtree below this node as being cached
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
// Turn on the flag that we're under a merge op
|
||||
is_cached_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if ZipOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"ZipOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if MapOp with non-deterministic TensorOps exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<MapOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
auto tfuncs = node->TFuncs();
|
||||
for (size_t i = 0; i < tfuncs.size(); i++) {
|
||||
if (!tfuncs[i]->Deterministic()) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache.");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if ConcatOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"ConcatOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if TakeOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"TakeOp/SplitOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if SkipOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"SkipOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns an error if SkipOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"BatchOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Returns an error if FilterOp exists under a cache
|
||||
Status CacheErrorPass::PreRunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) {
|
||||
if (is_cached_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"FilterOp is currently not supported as a descendant operator under a cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
|
||||
// Turn on the flag that this is a tree with mappable leaf dataset
|
||||
is_mappable_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
// Turn off the flag that we're under a merge op
|
||||
is_cached_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Currently, returns an error if RepeatOp exists under a cache
|
||||
// Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem.
|
||||
Status CacheErrorPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) {
|
||||
if (is_cached_ && is_mappable_) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__,
|
||||
"Repeat is not supported as a descendant operator under a mappable cache.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,167 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_CACHE_ERROR_PASS_
|
||||
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class CacheErrorPass cache_error_pass.h
|
||||
/// \brief This is a NodePass who's job is to catch invalid tree configurations related to cache and generate failures.
|
||||
class CacheErrorPass : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CacheErrorPass();
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheErrorPass() = default;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Returns an error if ZipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Returns an error if ConcatOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Returns an error if TakeOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Returns an error if SkipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Returns an error if SkipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
/// \brief Returns an error if FilterOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies the subtree above this node as not being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Identifies and block repeat under cache scenarios
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *const modified) override;
|
||||
|
||||
private:
|
||||
bool is_cached_;
|
||||
bool is_mappable_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_POST_CACHE_ERROR_PASS_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,7 +25,6 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -114,11 +113,18 @@ Status CacheValidationPass::Visit(std::shared_ptr<MapNode> node, bool *const mod
|
|||
}
|
||||
// If Map is created to be cached, set the flag indicating we found an operation with a cache.
|
||||
is_cached_ = true;
|
||||
|
||||
// This is temporary code.
|
||||
// Because the randomness of its tensor operations is not known in TensorOperation form until we convert them
|
||||
// to TensorOp, we need to check the randomness in MapNode::Build().
|
||||
// By setting this MapNode is under a cache, we will check the randomness of its tensor operations without the need
|
||||
// to walk the IR tree again.
|
||||
node->Cached();
|
||||
|
||||
auto tfuncs = node->TensorOperations();
|
||||
for (size_t i = 0; i < tfuncs.size(); i++) {
|
||||
if (tfuncs[i]->IsRandomOp()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"MapNode with non-deterministic operations is not supported as a descendant of cache.");
|
||||
RETURN_STATUS_UNEXPECTED("MapNode containing random operation is not supported as a descendant of cache.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// constructor
|
||||
EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
|
||||
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *const modified) {
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
|
||||
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node,
|
||||
bool *const modified) {
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) {
|
||||
// Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here.
|
||||
injection_point_ = node->child(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
EpochInjectionPass::EpochInjectionPass() {}
|
||||
|
||||
// Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *const modified) {
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass started.";
|
||||
|
||||
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
|
||||
// The finder can make updates to the EpochInjectionPass object.
|
||||
EpochInjectionPass::InjectionFinder finder(tree->root());
|
||||
RETURN_IF_NOT_OK(finder.Run(tree, modified));
|
||||
|
||||
// The first injection logic is to check if we should inject the epoch control op as the root node.
|
||||
// Do not inject the op if the number of epochs is 1.
|
||||
int32_t num_epochs = tree->num_epochs();
|
||||
std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point();
|
||||
if (num_epochs != 1 && epoch_inject_node != nullptr) {
|
||||
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
|
||||
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
|
||||
RETURN_IF_NOT_OK(epoch_inject_node->InsertAsParent(epoch_ctrl_op));
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,88 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class EpochInjectionPass epoch_injection_pass.h
|
||||
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
|
||||
/// parsing.
|
||||
class EpochInjectionPass : public TreePass {
|
||||
/// \class InjectionFinder
|
||||
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
|
||||
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
|
||||
/// it may need to inject.
|
||||
class InjectionFinder : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit InjectionFinder(std::shared_ptr<DatasetOp> node);
|
||||
|
||||
/// \brief Destructor
|
||||
~InjectionFinder() = default;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Register the DeviceQueueOp for further action.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
std::shared_ptr<DatasetOp> injection_point() { return injection_point_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<DatasetOp> injection_point_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
EpochInjectionPass();
|
||||
|
||||
/// \brief Destructor
|
||||
~EpochInjectionPass() = default;
|
||||
|
||||
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(ExecutionTree *tree, bool *const modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Identifies the subtree below this node as a cached descendant tree.
|
||||
Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets the tracking of the cache within the tree
|
||||
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
|
||||
is_caching_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Perform ShuffleOp removal check.
|
||||
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||
if (is_caching_) {
|
||||
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
RemovalPass::RemovalPass() {}
|
||||
|
||||
// Walk the tree to collect the nodes to remove, then removes them.
|
||||
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *const modified) {
|
||||
MS_LOG(INFO) << "Pre pass: removal pass started.";
|
||||
// Create the removal node pass which can identify which nodes need to be removed.
|
||||
std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>();
|
||||
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
||||
|
||||
// Then, execute the removal of any nodes that were set up for removal
|
||||
for (auto node : removal_nodes->nodes_to_remove()) {
|
||||
RETURN_IF_NOT_OK(node->Remove());
|
||||
}
|
||||
MS_LOG(INFO) << "Pre pass: removal pass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,90 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class RemovalPass removal_pass.h
|
||||
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
|
||||
/// nodes should be removed, and then removes them.
|
||||
class RemovalPass : public TreePass {
|
||||
/// \class RemovalNodes
|
||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||
/// It works in conjunction with the removal_pass.
|
||||
class RemovalNodes : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||
RemovalNodes();
|
||||
|
||||
/// \brief Destructor
|
||||
~RemovalNodes() = default;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform ShuffleOp removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
/// \return All the nodes to be removed
|
||||
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; }
|
||||
|
||||
private:
|
||||
bool is_caching_;
|
||||
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
RemovalPass();
|
||||
|
||||
/// \brief Destructor
|
||||
~RemovalPass() = default;
|
||||
|
||||
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(ExecutionTree *tree, bool *const modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_
|
|
@ -1,121 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/util/printer_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting DatasetOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting BatchOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting MapOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ProjectOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting RenameOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting SkipOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ShuffleOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting MindRecordOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting TFReaderOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting FilterOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting GeneratorOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting TakeOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ZipOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting DeviceQueueOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ImageFolderOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrinterPass::RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) {
|
||||
*modified = false;
|
||||
std::cout << "Visiting ImageFolderOp" << '\n';
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,68 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class PrinterPass : public NodePass {
|
||||
public:
|
||||
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<BatchOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *const modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ZipOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *const modified) override;
|
||||
|
||||
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *const modified) override;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@
|
|||
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
|
||||
#endif
|
||||
|
@ -94,6 +95,7 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
|
|||
#ifdef ENABLE_PYTHON
|
||||
actions.emplace_back(std::make_unique<GeneratorNodePass>());
|
||||
#endif
|
||||
actions.emplace_back(std::make_unique<RepeatPass>());
|
||||
|
||||
// We will gradually move RepeatPass from ExecutionTree::PrepareTreePostAction to here.
|
||||
|
||||
|
@ -133,7 +135,7 @@ Status TreeAdapter::BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
|
||||
Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
|
||||
// This will evolve in the long run
|
||||
tree_ = std::make_unique<ExecutionTree>();
|
||||
// disable profiling if this is only a getter pass
|
||||
|
@ -146,7 +148,7 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epoc
|
|||
// Note: We will gradually move the pre pass, optimizer pass, and post pass
|
||||
// on ExecutionTree to perform on IR tree.
|
||||
// Prepare the tree
|
||||
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs, true));
|
||||
RETURN_IF_NOT_OK(tree_->Prepare());
|
||||
|
||||
// After the tree is prepared, the col_name_id_map can safely be obtained
|
||||
column_name_map_ = tree_->root()->column_name_id_map();
|
||||
|
@ -192,7 +194,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
|
|||
// Remember the root node
|
||||
root_ir_ = root_ir;
|
||||
|
||||
RETURN_IF_NOT_OK(Build(root_ir_, num_epochs));
|
||||
RETURN_IF_NOT_OK(Build(root_ir_));
|
||||
tree_state_ = kCompileStateReady;
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -83,7 +83,7 @@ class TreeAdapter {
|
|||
Status PostPass(std::shared_ptr<DatasetNode> ir);
|
||||
|
||||
// Build an Execution tree
|
||||
Status Build(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs);
|
||||
Status Build(std::shared_ptr<DatasetNode> root_ir);
|
||||
|
||||
// This RECURSIVE function walks the (optimized) IR tree in DFS to build its corresponding Execution tree.
|
||||
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,22 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/album_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -89,7 +79,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) {
|
|||
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
|
||||
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
|
||||
std::vector<std::string> column_names = {"image", "label", "id"};
|
||||
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false), Repeat(2)});
|
||||
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file, column_names, false);
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
ASSERT_OK(tree->Prepare());
|
||||
ASSERT_OK(tree->Launch());
|
||||
DatasetIterator di(tree);
|
||||
|
@ -111,7 +105,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchema) {
|
|||
TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaNoOrder) {
|
||||
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
|
||||
std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
|
||||
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
|
||||
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file);
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
ASSERT_OK(tree->Prepare());
|
||||
ASSERT_OK(tree->Launch());
|
||||
DatasetIterator di(tree);
|
||||
|
@ -134,7 +132,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithSchemaFloat) {
|
|||
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
|
||||
// add the priority column
|
||||
std::string schema_file = datasets_root_path_ + "/testAlbum/floatSchema.json";
|
||||
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
|
||||
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file);
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
ASSERT_OK(tree->Launch());
|
||||
DatasetIterator di(tree);
|
||||
|
@ -159,7 +161,11 @@ TEST_F(MindDataTestAlbum, TestSequentialAlbumWithFullSchema) {
|
|||
std::string folder_path = datasets_root_path_ + "/testAlbum/images";
|
||||
// add the priority column
|
||||
std::string schema_file = datasets_root_path_ + "/testAlbum/fullSchema.json";
|
||||
auto tree = Build({AlbumSchema(16, 2, 32, folder_path, schema_file), Repeat(2)});
|
||||
auto op1 = AlbumSchema(16, 2, 32, folder_path, schema_file);
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
ASSERT_OK(tree->Prepare());
|
||||
ASSERT_OK(tree->Launch());
|
||||
DatasetIterator di(tree);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,14 +13,11 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "securec.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
@ -112,7 +109,12 @@ TEST_F(MindDataTestBatchOp, TestSimpleBatch) {
|
|||
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, true, 99)});
|
||||
auto op1 = TFReader(schema_file);
|
||||
auto op2 = Repeat(2);
|
||||
auto op3 = Batch(7, true, 99);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -157,7 +159,12 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropTrue) {
|
|||
TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({TFReader(schema_file), Repeat(2), Batch(7, false, 99)});
|
||||
auto op1 = TFReader(schema_file);
|
||||
auto op2 = Repeat(2);
|
||||
auto op3 = Batch(7, false, 99);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -209,7 +216,14 @@ TEST_F(MindDataTestBatchOp, TestRepeatBatchDropFalse) {
|
|||
TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({TFReader(schema_file), Batch(7, false, 99), Repeat(2)});
|
||||
auto op1 = TFReader(schema_file);
|
||||
auto op2 = Batch(7, false, 99);
|
||||
auto op3 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
op2->set_total_repeats(2);
|
||||
op2->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
@ -255,7 +269,14 @@ TEST_F(MindDataTestBatchOp, TestBatchDropFalseRepeat) {
|
|||
TEST_F(MindDataTestBatchOp, TestBatchDropTrueRepeat) {
|
||||
std::string schema_file = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
bool success = false;
|
||||
auto tree = Build({TFReader(schema_file), Batch(5, true, 99), Repeat(2)});
|
||||
auto op1 = TFReader(schema_file);
|
||||
auto op2 = Batch(5, true, 99);
|
||||
auto op3 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
op2->set_total_repeats(2);
|
||||
op2->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -293,15 +293,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Assign tree relations and root
|
||||
myCacheOp->set_total_repeats(numRepeats);
|
||||
myCacheOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myRepeatOp->AddChild(myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
// Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats.
|
||||
myRandomDataOp->set_total_repeats(1);
|
||||
myRandomDataOp->set_num_repeats_per_epoch(1);
|
||||
rc = myCacheOp->AddChild(myRandomDataOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssignRoot(myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||
rc = myTree->Prepare(1);
|
||||
rc = myTree->Prepare();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// quick check to see what tree looks like
|
||||
|
@ -412,15 +417,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Assign tree relations and root
|
||||
myCacheOp->set_total_repeats(numRepeats);
|
||||
myCacheOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myRepeatOp->AddChild(myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
// Always set to 1 under a CacheOp because we read from it only once. The CacheOp is the one that repeats.
|
||||
myRandomDataOp->set_total_repeats(1);
|
||||
myRandomDataOp->set_num_repeats_per_epoch(1);
|
||||
rc = myCacheOp->AddChild(myRandomDataOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->AssignRoot(myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
MS_LOG(INFO) << "Launching tree and begin iteration";
|
||||
rc = myTree->Prepare(1);
|
||||
rc = myTree->Prepare();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
std::cout << *myClient << std::endl;
|
||||
|
@ -502,14 +512,20 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = myTree->AssignRoot(myRepeatOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
myMergeOp->set_total_repeats(numRepeats);
|
||||
myMergeOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myRepeatOp->AddChild(myMergeOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
myLookupOp->set_total_repeats(numRepeats);
|
||||
myLookupOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myMergeOp->AddChild(myLookupOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
so->set_total_repeats(numRepeats);
|
||||
so->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myMergeOp->AddChild(so);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
rc = myTree->Prepare(1);
|
||||
rc = myTree->Prepare();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = myTree->Launch();
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,14 +13,11 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
@ -98,7 +95,11 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
|||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
|
||||
uint32_t count = 0;
|
||||
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
|
||||
auto op1 = Celeba(16, 2, 32, dir);
|
||||
auto op2 = Repeat(2);
|
||||
auto tree = Build({op1, op2});
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,8 +39,6 @@ using mindspore::MsLogLevel::ERROR;
|
|||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
std::shared_ptr<RepeatOp> Repeat(int repeatCnt);
|
||||
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
std::shared_ptr<CifarOp> Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -45,8 +45,6 @@ using mindspore::LogStream;
|
|||
|
||||
std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2);
|
||||
|
||||
std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
|
||||
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
class MindDataTestCocoOp : public UT::DatasetOpTesting {
|
||||
|
@ -261,4 +259,4 @@ TEST_F(MindDataTestCocoOp, TestCocoPanoptic) {
|
|||
}
|
||||
|
||||
ASSERT_EQ(row_count, 2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +13,6 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -29,7 +28,6 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -82,7 +80,11 @@ class MindDataTestImageFolderSampler : public UT::DatasetOpTesting {
|
|||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeat) {
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2)});
|
||||
auto op1 = ImageFolder(16, 2, 32, folder_path, false);
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
int32_t res[] = {0, 1, 2, 3};
|
||||
Status rc = tree->Launch();
|
||||
|
@ -166,7 +168,12 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) {
|
|||
|
||||
TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch) {
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false), Repeat(2), Batch(11)});
|
||||
auto op1 = ImageFolder(16, 2, 32, folder_path, false);
|
||||
auto op2 = Repeat(2);
|
||||
auto op3 = Batch(11);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
int32_t res[4][11] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
|
@ -297,7 +304,11 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) {
|
|||
int64_t num_samples = 0;
|
||||
std::shared_ptr<SamplerRT> sampler = std::make_shared<DistributedSamplerRT>(num_samples, 11, 10, false);
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data";
|
||||
auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)});
|
||||
auto op1 = ImageFolder(16, 2, 32, folder_path, false, std::move(sampler));
|
||||
auto op2 = Repeat(4);
|
||||
op1->set_total_repeats(4);
|
||||
op1->set_num_repeats_per_epoch(4);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,6 +20,7 @@
|
|||
#include "common/common.h"
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
@ -166,6 +167,10 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
|
|||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||
// start build then launch tree
|
||||
leaf->set_total_repeats(2);
|
||||
leaf->set_num_repeats_per_epoch(2);
|
||||
map_op->set_total_repeats(2);
|
||||
map_op->set_num_repeats_per_epoch(2);
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -213,8 +218,15 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
|
|||
// config RepeatOp
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||
// config EpochCtrlOp
|
||||
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
|
||||
rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op);
|
||||
// start build then launch tree
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||
leaf->set_total_repeats(-2);
|
||||
leaf->set_num_repeats_per_epoch(2);
|
||||
map_op->set_total_repeats(-2);
|
||||
map_op->set_num_repeats_per_epoch(2);
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op});
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->Launch();
|
||||
|
@ -271,8 +283,15 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
|
|||
// config RepeatOp
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||
// config EpochCtrlOp
|
||||
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
|
||||
rc = EpochCtrlOp::Builder(-1).Build(&epoch_ctrl_op);
|
||||
// start build then launch tree
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||
leaf->set_total_repeats(-2);
|
||||
leaf->set_num_repeats_per_epoch(2);
|
||||
map_op->set_total_repeats(-2);
|
||||
map_op->set_num_repeats_per_epoch(2);
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op, epoch_ctrl_op});
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->Launch();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,7 +58,11 @@ class MindDataTestManifest : public UT::DatasetOpTesting {
|
|||
|
||||
TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) {
|
||||
std::string file = datasets_root_path_ + "/testManifestData/cpp.json";
|
||||
auto tree = Build({Manifest(16, 2, 32, file), Repeat(2)});
|
||||
auto op1 = Manifest(16, 2, 32, file);
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
uint32_t res[] = {0, 1, 0, 1};
|
||||
Status rc = tree->Launch();
|
||||
|
@ -148,7 +152,11 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) {
|
|||
int64_t num_samples = 1;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
|
||||
auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)});
|
||||
auto op1 = Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {});
|
||||
auto op2 = Repeat(4);
|
||||
op1->set_total_repeats(4);
|
||||
op1->set_num_repeats_per_epoch(4);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
Status rc = tree->Launch();
|
||||
if (rc.IsError()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,7 +17,6 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
|
@ -416,6 +415,8 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) {
|
|||
rc = my_map_op->AddChild(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tfreader_op->set_total_repeats(num_repeats);
|
||||
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -471,9 +472,13 @@ TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) {
|
|||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_map_op->set_total_repeats(num_repeats);
|
||||
my_map_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tfreader_op->set_total_repeats(num_repeats);
|
||||
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_map_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -548,9 +553,13 @@ TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) {
|
|||
rc = my_tree_->AssociateNode(my_map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tfreader_op->set_total_repeats(num_repeats);
|
||||
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_map_decode_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_map_decode_op->set_total_repeats(num_repeats);
|
||||
my_map_decode_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(my_map_decode_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -611,7 +620,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
|
|||
rc = map_resize_builder.Build(&map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op});
|
||||
auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false);
|
||||
image_folder_op->set_total_repeats(num_repeats);
|
||||
image_folder_op->set_num_repeats_per_epoch(num_repeats);
|
||||
map_decode_map->set_total_repeats(num_repeats);
|
||||
map_decode_map->set_num_repeats_per_epoch(num_repeats);
|
||||
my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op});
|
||||
rc = my_tree_->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
|
@ -656,7 +670,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
|
|||
rc = map_resize_builder.Build(&map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
auto my_tree_2 = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op});
|
||||
image_folder_op = ImageFolder(16, 2, 32, folder_path, false);
|
||||
image_folder_op->set_total_repeats(num_repeats);
|
||||
image_folder_op->set_num_repeats_per_epoch(num_repeats);
|
||||
map_decode_map->set_total_repeats(num_repeats);
|
||||
map_decode_map->set_num_repeats_per_epoch(num_repeats);
|
||||
auto my_tree_2 = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op});
|
||||
|
||||
rc = my_tree_2->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -714,7 +733,12 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) {
|
|||
rc = map_resize_builder.Build(&map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_tree_ = Build({ImageFolder(16, 2, 32, folder_path, false), map_decode_map, repeat_op, map_resize_op});
|
||||
auto image_folder_op = ImageFolder(16, 2, 32, folder_path, false);
|
||||
image_folder_op->set_total_repeats(num_repeats);
|
||||
image_folder_op->set_num_repeats_per_epoch(num_repeats);
|
||||
map_decode_map->set_total_repeats(num_repeats);
|
||||
map_decode_map->set_num_repeats_per_epoch(num_repeats);
|
||||
my_tree_ = Build({image_folder_op, map_decode_map, repeat_op, map_resize_op});
|
||||
rc = my_tree_->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->Launch();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -370,9 +370,12 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) {
|
|||
rc = my_tree->AssociateNode(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_mindrecord_op->set_total_repeats(num_repeats);
|
||||
my_mindrecord_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(my_mindrecord_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
||||
// Set children/root layout.
|
||||
rc = my_tree->AssignRoot(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -452,6 +455,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) {
|
|||
rc = my_tree->AssociateNode(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
my_mindrecord_op->set_total_repeats(num_repeats);
|
||||
my_mindrecord_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(my_mindrecord_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -78,7 +78,11 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) {
|
|||
int64_t num_samples = 10;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
|
||||
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)});
|
||||
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
|
||||
auto op2 = Repeat(2);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2});
|
||||
tree->Prepare();
|
||||
uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
Status rc = tree->Launch();
|
||||
|
@ -108,7 +112,12 @@ TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) {
|
|||
int64_t num_samples = 10;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
|
||||
auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)});
|
||||
auto op1 = CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler));
|
||||
auto op2 = Repeat(2);
|
||||
auto op3 = Batch(5);
|
||||
op1->set_total_repeats(2);
|
||||
op1->set_num_repeats_per_epoch(2);
|
||||
auto tree = Build({op1, op2, op3});
|
||||
tree->Prepare();
|
||||
uint32_t res[4][5] = { {0, 0, 0, 0, 0 },
|
||||
{0, 0, 0, 0, 0 },
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,7 +35,7 @@ class MindDataTestRandomDataOp : public UT::DatasetOpTesting {
|
|||
|
||||
// Test info:
|
||||
// - Simple test with a user-provided schema generated purely from DataSchema C API
|
||||
// - has an interation loop
|
||||
// - has an interaction loop
|
||||
//
|
||||
// Tree: single node tree with RandomDataOp
|
||||
//
|
||||
|
@ -213,7 +213,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic3) {
|
|||
|
||||
// Test info:
|
||||
// - json schema input it's a fairly simple one
|
||||
// - has an interation loop
|
||||
// - has an interaction loop
|
||||
//
|
||||
// Tree: RepeatOp over RandomDataOp
|
||||
//
|
||||
|
@ -253,6 +253,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) {
|
|||
rc = myTree->AssociateNode(myRepeatOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
myRandomDataOp->set_total_repeats(numRepeats);
|
||||
myRandomDataOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myRepeatOp->AddChild(myRandomDataOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -290,7 +292,7 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic4) {
|
|||
|
||||
// Test info:
|
||||
// - json schema input it's a fairly simple one
|
||||
// - has an interation loop
|
||||
// - has an interaction loop
|
||||
// - same as MindDataTestRandomDataOpBasic4 except that this one will have parallel workers
|
||||
//
|
||||
// Tree: RepeatOp over RandomDataOp
|
||||
|
@ -331,6 +333,8 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpBasic5) {
|
|||
rc = myTree->AssociateNode(myRepeatOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
myRandomDataOp->set_total_repeats(numRepeats);
|
||||
myRandomDataOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myRepeatOp->AddChild(myRandomDataOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -418,9 +422,13 @@ TEST_F(MindDataTestRandomDataOp, RandomDataOpTree1) {
|
|||
rc = myTree->AssociateNode(myRepeatOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
myShuffleOp->set_total_repeats(numRepeats);
|
||||
myShuffleOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myRepeatOp->AddChild(myShuffleOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
||||
myRandomDataOp->set_total_repeats(numRepeats);
|
||||
myRandomDataOp->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = myShuffleOp->AddChild(myRandomDataOp);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
|
@ -75,6 +75,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceFromDatasetFuntions) {
|
|||
rc = spv_op->AddChild(file_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
file_op->set_total_repeats(1);
|
||||
file_op->set_num_repeats_per_epoch(1);
|
||||
rc = tree->AssignRoot(spv_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = tree->Prepare();
|
||||
|
@ -147,6 +149,8 @@ TEST_F(MindDataTestSentencePieceVocabOp, TestSentencePieceTokenizerFuntions) {
|
|||
rc = spv_op->AddChild(file_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
file_op->set_total_repeats(1);
|
||||
file_op->set_num_repeats_per_epoch(1);
|
||||
rc = tree->AssignRoot(spv_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = tree->Prepare();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -300,8 +300,12 @@ TEST_F(MindDataTestShuffleOp, TestRepeatShuffle) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Set children/root layout.
|
||||
my_shuffle_op->set_total_repeats(numRepeats);
|
||||
my_shuffle_op->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = my_repeat_op->AddChild(my_shuffle_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
my_tfreader_op->set_total_repeats(numRepeats);
|
||||
my_tfreader_op->set_num_repeats_per_epoch(numRepeats);
|
||||
rc = my_shuffle_op->AddChild(my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_repeat_op);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,7 +20,6 @@
|
|||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "common/common.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -330,11 +329,14 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderRepeat) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// RepeatOp
|
||||
std::shared_ptr<RepeatOp> my_repeat_op = std::make_shared<RepeatOp>(3);
|
||||
uint32_t num_repeats = 3;
|
||||
std::shared_ptr<RepeatOp> my_repeat_op = std::make_shared<RepeatOp>(num_repeats);
|
||||
rc = my_tree->AssociateNode(my_repeat_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// Set children/root layout.
|
||||
my_tfreader_op->set_total_repeats(num_repeats);
|
||||
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_repeat_op);
|
||||
|
@ -705,7 +707,7 @@ TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) {
|
|||
std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data";
|
||||
std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
|
||||
std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt";
|
||||
std::string nonexistent_file = "this/file/doesnt/exist";
|
||||
std::string nonexistent_file = "this/file/not/exist";
|
||||
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -45,8 +45,6 @@ using mindspore::LogStream;
|
|||
|
||||
std::shared_ptr<BatchOp> Batch(int batch_size = 1, bool drop = false, int rows_per_buf = 2);
|
||||
|
||||
std::shared_ptr<RepeatOp> Repeat(int repeat_cnt);
|
||||
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
|
||||
class MindDataTestVOCOp : public UT::DatasetOpTesting {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -141,6 +141,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
|
|||
MS_LOG(INFO) << "UT test TestZipRepeat.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
uint32_t num_repeats = 3;
|
||||
std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data";
|
||||
std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data";
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
|
@ -169,17 +170,23 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) {
|
|||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(zip_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
my_tfreader_op->set_total_repeats(num_repeats);
|
||||
my_tfreader_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = zip_op->AddChild(std::move(my_tfreader_op));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
my_tfreader_op2->set_total_repeats(num_repeats);
|
||||
my_tfreader_op2->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = zip_op->AddChild(std::move(my_tfreader_op2));
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Builder(num_of_repeats)
|
||||
std::shared_ptr<RepeatOp> my_repeat_op;
|
||||
rc = RepeatOp::Builder(3).Build(&my_repeat_op);
|
||||
rc = RepeatOp::Builder(num_repeats).Build(&my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_repeat_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
zip_op->set_total_repeats(num_repeats);
|
||||
zip_op->set_num_repeats_per_epoch(num_repeats);
|
||||
rc = my_repeat_op->AddChild(zip_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssignRoot(my_repeat_op);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -138,7 +138,7 @@ def test_cache_map_basic3():
|
|||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_basic4():
|
||||
"""
|
||||
Test Map with non-deterministic TensorOps above cache
|
||||
Test Map containing random operation above cache
|
||||
|
||||
repeat
|
||||
|
|
||||
|
@ -374,7 +374,7 @@ def test_cache_map_failure4():
|
|||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_failure5():
|
||||
"""
|
||||
Test Map with non-deterministic TensorOps under cache (failure)
|
||||
Test Map containing random operation under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
|
@ -406,7 +406,7 @@ def test_cache_map_failure5():
|
|||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_failure5 Ended.\n')
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -2087,7 +2087,7 @@ def test_cache_nomap_failure4():
|
|||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_failure5():
|
||||
"""
|
||||
Test Map with non-deterministic TensorOps under cache (failure)
|
||||
Test Map containing random operation under cache (failure)
|
||||
|
||||
repeat
|
||||
|
|
||||
|
@ -2118,7 +2118,7 @@ def test_cache_nomap_failure5():
|
|||
num_iter = 0
|
||||
for _ in data.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert "MapNode with non-deterministic operations is not supported as a descendant of cache" in str(e.value)
|
||||
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
|
||||
|
||||
assert num_iter == 0
|
||||
logger.info('test_cache_nomap_failure5 Ended.\n')
|
||||
|
|
Loading…
Reference in New Issue